use super::super::{WgpuClient, WgpuRuntime};
use crate::dtype::DType;
use crate::tensor::{Layout, Storage, Tensor};
macro_rules! get_buffer_or_err {
($ptr:expr, $name:expr) => {
get_buffer($ptr).ok_or_else(|| {
Error::Internal(format!(
"Failed to get {} buffer from GPU allocation",
$name
))
})?
};
}
pub(crate) use get_buffer_or_err;
impl WgpuClient {
pub(crate) unsafe fn tensor_from_raw(
ptr: u64,
shape: &[usize],
dtype: DType,
device: &super::super::WgpuDevice,
) -> Tensor<WgpuRuntime> {
let len = if shape.is_empty() {
1
} else {
shape.iter().product()
};
let storage = unsafe { Storage::<WgpuRuntime>::from_ptr(ptr, len, dtype, device) };
let layout = Layout::contiguous(shape);
Tensor::from_parts(storage, layout)
}
}