Skip to main content

runmat_runtime/builtins/common/
gpu_helpers.rs

1use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, GpuTensorStorage, HostTensorView};
2use runmat_builtins::{ComplexTensor, Tensor, Value};
3
4use crate::build_runtime_error;
5
6/// Download a GPU tensor handle to host memory, returning a dense `Tensor`.
7///
8/// This helper routes through the dispatcher so residency hooks and provider
9/// semantics stay consistent with the rest of the runtime.
10pub async fn gather_tensor_async(
11    handle: &runmat_accelerate_api::GpuTensorHandle,
12) -> crate::BuiltinResult<Tensor> {
13    // Ensure the correct provider is active for WGPU-backed handles when tests run in parallel.
14    // This mirrors the guard used in test_support::gather.
15    #[cfg(all(test, feature = "wgpu"))]
16    {
17        if handle.device_id != 0 {
18            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
19                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
20            );
21        }
22    }
23    let value = Value::GpuTensor(handle.clone());
24    let gathered = crate::dispatcher::gather_if_needed_async(&value).await?;
25    match gathered {
26        Value::Tensor(t) => Ok(t),
27        Value::Num(n) => Tensor::new(vec![n], vec![1, 1])
28            .map_err(|e| build_runtime_error(format!("gather: {e}")).build()),
29        Value::LogicalArray(la) => {
30            let data: Vec<f64> = la
31                .data
32                .iter()
33                .map(|&b| if b != 0 { 1.0 } else { 0.0 })
34                .collect();
35            Tensor::new(data, la.shape.clone())
36                .map_err(|e| build_runtime_error(format!("gather: {e}")).build())
37        }
38        other => {
39            Err(build_runtime_error(format!("gather: unexpected value kind {other:?}")).build())
40        }
41    }
42}
43
44/// Gather an arbitrary value, returning a host-side `Value`.
45pub async fn gather_value_async(value: &Value) -> crate::BuiltinResult<Value> {
46    crate::dispatcher::gather_if_needed_async(value).await
47}
48
49/// Upload a host complex tensor as an interleaved GPU buffer and record complex
50/// storage metadata on the returned handle.
51pub fn upload_complex_tensor(
52    provider: &dyn AccelProvider,
53    tensor: &ComplexTensor,
54) -> crate::BuiltinResult<GpuTensorHandle> {
55    let mut interleaved = Vec::with_capacity(tensor.data.len() * 2);
56    for &(re, im) in &tensor.data {
57        interleaved.push(re);
58        interleaved.push(im);
59    }
60    let view = HostTensorView {
61        data: &interleaved,
62        shape: &tensor.shape,
63    };
64    let handle = provider
65        .upload(&view)
66        .map_err(|e| build_runtime_error(format!("gpu upload: {e}")).build())?;
67    runmat_accelerate_api::set_handle_logical(&handle, false);
68    runmat_accelerate_api::set_handle_storage(&handle, GpuTensorStorage::ComplexInterleaved);
69    runmat_accelerate_api::set_handle_precision(&handle, provider.precision());
70    Ok(handle)
71}
72
73/// Wrap a GPU tensor handle, marking it as resident for downstream fusion-aware
74/// consumers and tests.
75pub fn resident_gpu_value(handle: GpuTensorHandle) -> Value {
76    runmat_accelerate_api::mark_residency(&handle);
77    Value::GpuTensor(handle)
78}
79
80/// Wrap a GPU tensor handle as a logical gpuArray value, recording metadata so that
81/// predicates like `islogical` can inspect the handle without downloading it.
82pub fn logical_gpu_value(handle: GpuTensorHandle) -> Value {
83    runmat_accelerate_api::set_handle_logical(&handle, true);
84    resident_gpu_value(handle)
85}
86
87/// Wrap a GPU tensor handle as a complex gpuArray value.
88pub fn complex_gpu_value(handle: GpuTensorHandle) -> Value {
89    runmat_accelerate_api::set_handle_logical(&handle, false);
90    runmat_accelerate_api::set_handle_storage(&handle, GpuTensorStorage::ComplexInterleaved);
91    resident_gpu_value(handle)
92}