runmat_runtime/builtins/common/
gpu_helpers.rs1use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, GpuTensorStorage, HostTensorView};
2use runmat_builtins::{ComplexTensor, Tensor, Value};
3
4use crate::build_runtime_error;
5
6pub async fn gather_tensor_async(
11 handle: &runmat_accelerate_api::GpuTensorHandle,
12) -> crate::BuiltinResult<Tensor> {
13 #[cfg(all(test, feature = "wgpu"))]
16 {
17 if handle.device_id != 0 {
18 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
19 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
20 );
21 }
22 }
23 let value = Value::GpuTensor(handle.clone());
24 let gathered = crate::dispatcher::gather_if_needed_async(&value).await?;
25 match gathered {
26 Value::Tensor(t) => Ok(t),
27 Value::Num(n) => Tensor::new(vec![n], vec![1, 1])
28 .map_err(|e| build_runtime_error(format!("gather: {e}")).build()),
29 Value::LogicalArray(la) => {
30 let data: Vec<f64> = la
31 .data
32 .iter()
33 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
34 .collect();
35 Tensor::new(data, la.shape.clone())
36 .map_err(|e| build_runtime_error(format!("gather: {e}")).build())
37 }
38 other => {
39 Err(build_runtime_error(format!("gather: unexpected value kind {other:?}")).build())
40 }
41 }
42}
43
44pub async fn gather_value_async(value: &Value) -> crate::BuiltinResult<Value> {
46 crate::dispatcher::gather_if_needed_async(value).await
47}
48
49pub fn upload_complex_tensor(
52 provider: &dyn AccelProvider,
53 tensor: &ComplexTensor,
54) -> crate::BuiltinResult<GpuTensorHandle> {
55 let mut interleaved = Vec::with_capacity(tensor.data.len() * 2);
56 for &(re, im) in &tensor.data {
57 interleaved.push(re);
58 interleaved.push(im);
59 }
60 let view = HostTensorView {
61 data: &interleaved,
62 shape: &tensor.shape,
63 };
64 let handle = provider
65 .upload(&view)
66 .map_err(|e| build_runtime_error(format!("gpu upload: {e}")).build())?;
67 runmat_accelerate_api::set_handle_logical(&handle, false);
68 runmat_accelerate_api::set_handle_storage(&handle, GpuTensorStorage::ComplexInterleaved);
69 runmat_accelerate_api::set_handle_precision(&handle, provider.precision());
70 Ok(handle)
71}
72
73pub fn resident_gpu_value(handle: GpuTensorHandle) -> Value {
76 runmat_accelerate_api::mark_residency(&handle);
77 Value::GpuTensor(handle)
78}
79
80pub fn logical_gpu_value(handle: GpuTensorHandle) -> Value {
83 runmat_accelerate_api::set_handle_logical(&handle, true);
84 resident_gpu_value(handle)
85}
86
87pub fn complex_gpu_value(handle: GpuTensorHandle) -> Value {
89 runmat_accelerate_api::set_handle_logical(&handle, false);
90 runmat_accelerate_api::set_handle_storage(&handle, GpuTensorStorage::ComplexInterleaved);
91 resident_gpu_value(handle)
92}