use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, GpuTensorStorage, HostTensorView};
use runmat_builtins::{ComplexTensor, Tensor, Value};
use crate::build_runtime_error;
pub async fn gather_tensor_async(
handle: &runmat_accelerate_api::GpuTensorHandle,
) -> crate::BuiltinResult<Tensor> {
#[cfg(all(test, feature = "wgpu"))]
{
if handle.device_id != 0 {
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
}
}
let value = Value::GpuTensor(handle.clone());
let gathered = crate::dispatcher::gather_if_needed_async(&value).await?;
match gathered {
Value::Tensor(t) => Ok(t),
Value::Num(n) => Tensor::new(vec![n], vec![1, 1])
.map_err(|e| build_runtime_error(format!("gather: {e}")).build()),
Value::LogicalArray(la) => {
let data: Vec<f64> = la
.data
.iter()
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
.collect();
Tensor::new(data, la.shape.clone())
.map_err(|e| build_runtime_error(format!("gather: {e}")).build())
}
other => {
Err(build_runtime_error(format!("gather: unexpected value kind {other:?}")).build())
}
}
}
pub async fn gather_value_async(value: &Value) -> crate::BuiltinResult<Value> {
crate::dispatcher::gather_if_needed_async(value).await
}
pub fn upload_complex_tensor(
provider: &dyn AccelProvider,
tensor: &ComplexTensor,
) -> crate::BuiltinResult<GpuTensorHandle> {
let mut interleaved = Vec::with_capacity(tensor.data.len() * 2);
for &(re, im) in &tensor.data {
interleaved.push(re);
interleaved.push(im);
}
let view = HostTensorView {
data: &interleaved,
shape: &tensor.shape,
};
let handle = provider
.upload(&view)
.map_err(|e| build_runtime_error(format!("gpu upload: {e}")).build())?;
runmat_accelerate_api::set_handle_logical(&handle, false);
runmat_accelerate_api::set_handle_storage(&handle, GpuTensorStorage::ComplexInterleaved);
runmat_accelerate_api::set_handle_precision(&handle, provider.precision());
Ok(handle)
}
pub fn resident_gpu_value(handle: GpuTensorHandle) -> Value {
runmat_accelerate_api::mark_residency(&handle);
Value::GpuTensor(handle)
}
pub fn logical_gpu_value(handle: GpuTensorHandle) -> Value {
runmat_accelerate_api::set_handle_logical(&handle, true);
resident_gpu_value(handle)
}
pub fn complex_gpu_value(handle: GpuTensorHandle) -> Value {
runmat_accelerate_api::set_handle_logical(&handle, false);
runmat_accelerate_api::set_handle_storage(&handle, GpuTensorStorage::ComplexInterleaved);
resident_gpu_value(handle)
}