runmat_runtime/builtins/common/
gpu_helpers.rs1use runmat_accelerate_api::GpuTensorHandle;
2use runmat_builtins::{Tensor, Value};
3
4pub fn gather_tensor(handle: &runmat_accelerate_api::GpuTensorHandle) -> Result<Tensor, String> {
9 #[cfg(all(test, feature = "wgpu"))]
12 {
13 if handle.device_id != 0 {
14 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
15 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
16 );
17 }
18 }
19 let value = Value::GpuTensor(handle.clone());
20 match crate::dispatcher::gather_if_needed(&value)? {
21 Value::Tensor(t) => Ok(t),
22 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("gather: {e}")),
23 Value::LogicalArray(la) => {
24 let data: Vec<f64> = la
25 .data
26 .iter()
27 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
28 .collect();
29 Tensor::new(data, la.shape.clone()).map_err(|e| format!("gather: {e}"))
30 }
31 other => Err(format!("gather: unexpected value kind {other:?}")),
32 }
33}
34
35pub fn gather_value(value: &Value) -> Result<Value, String> {
37 crate::dispatcher::gather_if_needed(value)
38}
39
40pub fn logical_gpu_value(handle: GpuTensorHandle) -> Value {
43 runmat_accelerate_api::set_handle_logical(&handle, true);
44 Value::GpuTensor(handle)
45}