use runmat_accelerate_api::GpuTensorHandle;
use runmat_builtins::{Tensor, Value};
use crate::build_runtime_error;
pub async fn gather_tensor_async(
handle: &runmat_accelerate_api::GpuTensorHandle,
) -> crate::BuiltinResult<Tensor> {
#[cfg(all(test, feature = "wgpu"))]
{
if handle.device_id != 0 {
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
}
}
let value = Value::GpuTensor(handle.clone());
let gathered = crate::dispatcher::gather_if_needed_async(&value).await?;
match gathered {
Value::Tensor(t) => Ok(t),
Value::Num(n) => Tensor::new(vec![n], vec![1, 1])
.map_err(|e| build_runtime_error(format!("gather: {e}")).build()),
Value::LogicalArray(la) => {
let data: Vec<f64> = la
.data
.iter()
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
.collect();
Tensor::new(data, la.shape.clone())
.map_err(|e| build_runtime_error(format!("gather: {e}")).build())
}
other => {
Err(build_runtime_error(format!("gather: unexpected value kind {other:?}")).build())
}
}
}
pub async fn gather_value_async(value: &Value) -> crate::BuiltinResult<Value> {
crate::dispatcher::gather_if_needed_async(value).await
}
pub fn resident_gpu_value(handle: GpuTensorHandle) -> Value {
runmat_accelerate_api::mark_residency(&handle);
Value::GpuTensor(handle)
}
pub fn logical_gpu_value(handle: GpuTensorHandle) -> Value {
runmat_accelerate_api::set_handle_logical(&handle, true);
resident_gpu_value(handle)
}