use alloc::{boxed::Box, vec::Vec};
use core::{ffi::c_void, slice};
use js_sys::Uint8Array;
use ort::{AsPointer, value::ValueTypeMarker};
use wasm_bindgen::{JsCast, JsValue};
use crate::{
Error,
binding::{self, DataType},
memory::MemoryInfo,
util::num_elements
};
pub const TENSOR_SENTINEL: [u8; 4] = [0xFC, 0x86, 0xA5, 0x39];
pub enum TensorData {
RustView { ptr: *mut c_void, byte_len: usize },
External { buffer: Option<Box<[u8]>> }
}
#[repr(C)]
pub struct Tensor {
sentinel: [u8; 4],
pub js: binding::Tensor,
pub data: TensorData,
pub memory_info: MemoryInfo
}
impl Tensor {
pub unsafe fn from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize, dims: &[i32]) -> Result<Self, JsValue> {
let tensor = binding::Tensor::new_from_buffer(dtype, unsafe { buffer_from_ptr(dtype, ptr, byte_len) }, dims)?;
Ok(Self {
sentinel: TENSOR_SENTINEL,
memory_info: MemoryInfo { location: tensor.location() },
js: tensor,
data: TensorData::RustView { ptr, byte_len }
})
}
pub fn from_tensor(tensor: binding::Tensor) -> Self {
Self {
sentinel: TENSOR_SENTINEL,
memory_info: MemoryInfo { location: tensor.location() },
js: tensor,
data: TensorData::External { buffer: None }
}
}
pub async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> {
match direction {
SyncDirection::Rust => {
let data = self.js.get_data().await?;
let generic_typed_array = Uint8Array::unchecked_from_js(data);
let bytes = Uint8Array::new_with_byte_offset_and_length(
&generic_typed_array.buffer(),
generic_typed_array.byte_offset(),
generic_typed_array.byte_length()
);
match &mut self.data {
TensorData::RustView { ptr, byte_len } => {
bytes.copy_to(unsafe { core::slice::from_raw_parts_mut(ptr.cast(), *byte_len) });
}
TensorData::External { buffer } => {
let buffer = match buffer {
Some(buffer) => buffer,
None => {
*buffer = Some(vec![0; generic_typed_array.byte_length() as usize].into_boxed_slice());
unsafe { buffer.as_mut().unwrap_unchecked() }
}
};
bytes.copy_to(buffer);
}
}
}
SyncDirection::Runtime => {
let Ok(generic_typed_array) = self.js.data().map(Uint8Array::unchecked_from_js) else {
return Err(Error::new(
"Cannot synchronize Rust data to a runtime tensor that is not on the CPU; modify the WebGPU/WebGL buffer directly."
));
};
let bytes = Uint8Array::new_with_byte_offset_and_length(
&generic_typed_array.buffer(),
generic_typed_array.byte_offset(),
generic_typed_array.byte_length()
);
bytes.copy_from(match &self.data {
TensorData::RustView { ptr, byte_len } => unsafe { core::slice::from_raw_parts(ptr.cast(), *byte_len) },
TensorData::External { buffer } => {
let Some(buffer) = buffer else {
return Ok(());
};
&*buffer
}
});
}
}
Ok(())
}
}
pub fn create_buffer(dtype: binding::DataType, shape: &[i32]) -> JsValue {
let numel = num_elements(shape) as u32;
match dtype {
binding::DataType::Bool | binding::DataType::Uint8 => js_sys::Uint8Array::new_with_length(numel).into(),
binding::DataType::Int8 => js_sys::Int8Array::new_with_length(numel).into(),
binding::DataType::Uint16 => js_sys::Uint16Array::new_with_length(numel).into(),
binding::DataType::Int16 => js_sys::Int16Array::new_with_length(numel).into(),
binding::DataType::Uint32 => js_sys::Uint32Array::new_with_length(numel).into(),
binding::DataType::Int32 => js_sys::Int32Array::new_with_length(numel).into(),
binding::DataType::Uint64 => js_sys::BigUint64Array::new_with_length(numel).into(),
binding::DataType::Int64 => js_sys::BigInt64Array::new_with_length(numel).into(),
binding::DataType::Float32 => js_sys::Float32Array::new_with_length(numel).into(),
binding::DataType::Float64 => js_sys::Float64Array::new_with_length(numel).into(),
binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(),
binding::DataType::__Invalid => unreachable!()
}
}
pub unsafe fn buffer_from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize) -> JsValue {
match dtype {
binding::DataType::Bool | binding::DataType::Uint8 => unsafe { js_sys::Uint8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(),
binding::DataType::Int8 => unsafe { js_sys::Int8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(),
binding::DataType::Uint16 => unsafe { js_sys::Uint16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(),
binding::DataType::Int16 => unsafe { js_sys::Int16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(),
binding::DataType::Uint32 => unsafe { js_sys::Uint32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
binding::DataType::Int32 => unsafe { js_sys::Int32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
binding::DataType::Uint64 => unsafe { js_sys::BigUint64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
binding::DataType::Int64 => unsafe { js_sys::BigInt64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
binding::DataType::Float32 => unsafe { js_sys::Float32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
binding::DataType::Float64 => unsafe { js_sys::Float64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(),
binding::DataType::__Invalid => unreachable!()
}
}
pub fn dtype_to_onnx(dtype: binding::DataType) -> ort_sys::ONNXTensorElementDataType {
match dtype {
binding::DataType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
binding::DataType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
binding::DataType::Uint8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
binding::DataType::Int8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
binding::DataType::Uint16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
binding::DataType::Int16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
binding::DataType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
binding::DataType::Int32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
binding::DataType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
binding::DataType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
binding::DataType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
binding::DataType::Float32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
binding::DataType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
binding::DataType::Int4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4,
binding::DataType::Uint4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4,
binding::DataType::__Invalid => unreachable!()
}
}
pub fn onnx_to_dtype(dtype: ort_sys::ONNXTensorElementDataType) -> Option<binding::DataType> {
match dtype {
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => Some(binding::DataType::String),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => Some(binding::DataType::Bool),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Some(binding::DataType::Uint8),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => Some(binding::DataType::Int8),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => Some(binding::DataType::Uint16),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => Some(binding::DataType::Int16),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Some(binding::DataType::Uint32),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => Some(binding::DataType::Int32),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => Some(binding::DataType::Uint64),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Some(binding::DataType::Int64),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => Some(binding::DataType::Float16),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Some(binding::DataType::Float32),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Some(binding::DataType::Float64),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 => Some(binding::DataType::Int4),
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4 => Some(binding::DataType::Uint4),
_ => None
}
}
pub struct TypeInfo {
pub dtype: ort_sys::ONNXTensorElementDataType,
pub shape: Vec<i32>
}
impl TypeInfo {
pub fn new_sys_from_tensor(tensor: &Tensor) -> *mut ort_sys::OrtTypeInfo {
Self::new_sys(tensor.js.dtype(), tensor.js.dims())
}
pub fn new_sys_from_value_metadata(metadata: &binding::ValueMetadata) -> *mut ort_sys::OrtTypeInfo {
Self::new_sys(
metadata.r#type.unwrap(),
metadata
.shape
.as_ref()
.unwrap()
.iter()
.map(|el| match el {
binding::ShapeElement::Value(v) => *v as i32,
binding::ShapeElement::Named(_) => -1
})
.collect()
)
}
pub fn new_sys(dtype: DataType, shape: Vec<i32>) -> *mut ort_sys::OrtTypeInfo {
(Box::leak(Box::new(Self { dtype: dtype_to_onnx(dtype), shape })) as *mut TypeInfo).cast()
}
pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtTypeInfo) -> Box<TypeInfo> {
unsafe { Box::from_raw(ptr.cast::<TypeInfo>()) }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SyncDirection {
Rust,
Runtime
}
pub trait ValueExt {
crate::private_trait!();
#[allow(async_fn_in_trait)]
async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()>;
}
impl<T: ValueTypeMarker> ValueExt for ort::value::Value<T> {
crate::private_impl!();
async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> {
let ptr = self.ptr_mut();
let sentinel: [u8; 4] = unsafe { core::ptr::read(ptr.cast()) };
if sentinel != TENSOR_SENTINEL {
return Err(Error::new("Cannot synchronize Value that was not created by ort-web"));
}
let tensor: &mut Tensor = unsafe { &mut *ptr.cast() };
tensor.sync(direction).await
}
}