runmat_runtime/builtins/common/
gpu_helpers.rs

1use runmat_accelerate_api::GpuTensorHandle;
2use runmat_builtins::{Tensor, Value};
3
4/// Download a GPU tensor handle to host memory, returning a dense `Tensor`.
5///
6/// This helper routes through the dispatcher so residency hooks and provider
7/// semantics stay consistent with the rest of the runtime.
8pub fn gather_tensor(handle: &runmat_accelerate_api::GpuTensorHandle) -> Result<Tensor, String> {
9    // Ensure the correct provider is active for WGPU-backed handles when tests run in parallel.
10    // This mirrors the guard used in test_support::gather.
11    #[cfg(all(test, feature = "wgpu"))]
12    {
13        if handle.device_id != 0 {
14            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
15                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
16            );
17        }
18    }
19    let value = Value::GpuTensor(handle.clone());
20    match crate::dispatcher::gather_if_needed(&value)? {
21        Value::Tensor(t) => Ok(t),
22        Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("gather: {e}")),
23        Value::LogicalArray(la) => {
24            let data: Vec<f64> = la
25                .data
26                .iter()
27                .map(|&b| if b != 0 { 1.0 } else { 0.0 })
28                .collect();
29            Tensor::new(data, la.shape.clone()).map_err(|e| format!("gather: {e}"))
30        }
31        other => Err(format!("gather: unexpected value kind {other:?}")),
32    }
33}
34
35/// Gather an arbitrary value, returning a host-side `Value`.
36pub fn gather_value(value: &Value) -> Result<Value, String> {
37    crate::dispatcher::gather_if_needed(value)
38}
39
40/// Wrap a GPU tensor handle as a logical gpuArray value, recording metadata so that
41/// predicates like `islogical` can inspect the handle without downloading it.
42pub fn logical_gpu_value(handle: GpuTensorHandle) -> Value {
43    runmat_accelerate_api::set_handle_logical(&handle, true);
44    Value::GpuTensor(handle)
45}