1use alloc::{boxed::Box, vec::Vec};
2use core::{ffi::c_void, slice};
3
4use js_sys::Uint8Array;
5use ort::{AsPointer, value::ValueTypeMarker};
6use wasm_bindgen::{JsCast, JsValue};
7
8use crate::{
9 Error,
10 binding::{self, DataType},
11 memory::MemoryInfo,
12 util::num_elements
13};
14
15pub const TENSOR_SENTINEL: [u8; 4] = [0xFC, 0x86, 0xA5, 0x39];
16
17pub enum TensorData {
18 RustView { ptr: *mut c_void, byte_len: usize },
20 External { buffer: Option<Box<[u8]>> }
23}
24
25#[repr(C)]
26pub struct Tensor {
27 sentinel: [u8; 4],
28 pub js: binding::Tensor,
29 pub data: TensorData,
30 pub memory_info: MemoryInfo
31}
32
33impl Tensor {
34 pub unsafe fn from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize, dims: &[i32]) -> Result<Self, JsValue> {
35 let tensor = binding::Tensor::new_from_buffer(dtype, unsafe { buffer_from_ptr(dtype, ptr, byte_len) }, dims)?;
36 Ok(Self {
37 sentinel: TENSOR_SENTINEL,
38 memory_info: MemoryInfo { location: tensor.location() },
39 js: tensor,
40 data: TensorData::RustView { ptr, byte_len }
41 })
42 }
43
44 pub fn from_tensor(tensor: binding::Tensor) -> Self {
45 Self {
46 sentinel: TENSOR_SENTINEL,
47 memory_info: MemoryInfo { location: tensor.location() },
48 js: tensor,
49 data: TensorData::External { buffer: None }
50 }
51 }
52
53 pub async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> {
54 match direction {
55 SyncDirection::Rust => {
56 let data = self.js.get_data().await?;
57
58 let generic_typed_array = Uint8Array::unchecked_from_js(data);
60 let bytes = Uint8Array::new_with_byte_offset_and_length(
61 &generic_typed_array.buffer(),
62 generic_typed_array.byte_offset(),
63 generic_typed_array.byte_length()
64 );
65 match &mut self.data {
66 TensorData::RustView { ptr, byte_len } => {
67 bytes.copy_to(unsafe { core::slice::from_raw_parts_mut(ptr.cast(), *byte_len) });
68 }
69 TensorData::External { buffer } => {
70 let buffer = match buffer {
71 Some(buffer) => buffer,
72 None => {
73 *buffer = Some(vec![0; generic_typed_array.byte_length() as usize].into_boxed_slice());
74 unsafe { buffer.as_mut().unwrap_unchecked() }
75 }
76 };
77 bytes.copy_to(buffer);
78 }
79 }
80 }
81 SyncDirection::Runtime => {
82 let Ok(generic_typed_array) = self.js.data().map(Uint8Array::unchecked_from_js) else {
83 return Err(Error::new(
85 "Cannot synchronize Rust data to a runtime tensor that is not on the CPU; modify the WebGPU/WebGL buffer directly."
86 ));
87 };
88 let bytes = Uint8Array::new_with_byte_offset_and_length(
89 &generic_typed_array.buffer(),
90 generic_typed_array.byte_offset(),
91 generic_typed_array.byte_length()
92 );
93 bytes.copy_from(match &self.data {
94 TensorData::RustView { ptr, byte_len } => unsafe { core::slice::from_raw_parts(ptr.cast(), *byte_len) },
95 TensorData::External { buffer } => {
96 let Some(buffer) = buffer else {
97 return Ok(());
98 };
99 &*buffer
100 }
101 });
102 }
103 }
104 Ok(())
105 }
106}
107
108pub fn create_buffer(dtype: binding::DataType, shape: &[i32]) -> JsValue {
109 let numel = num_elements(shape) as u32;
110 match dtype {
111 binding::DataType::Bool | binding::DataType::Uint8 => js_sys::Uint8Array::new_with_length(numel).into(),
112 binding::DataType::Int8 => js_sys::Int8Array::new_with_length(numel).into(),
113 binding::DataType::Uint16 => js_sys::Uint16Array::new_with_length(numel).into(),
114 binding::DataType::Int16 => js_sys::Int16Array::new_with_length(numel).into(),
115 binding::DataType::Uint32 => js_sys::Uint32Array::new_with_length(numel).into(),
116 binding::DataType::Int32 => js_sys::Int32Array::new_with_length(numel).into(),
117 binding::DataType::Uint64 => js_sys::BigUint64Array::new_with_length(numel).into(),
118 binding::DataType::Int64 => js_sys::BigInt64Array::new_with_length(numel).into(),
119 binding::DataType::Float32 => js_sys::Float32Array::new_with_length(numel).into(),
120 binding::DataType::Float64 => js_sys::Float64Array::new_with_length(numel).into(),
121 binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(),
122 binding::DataType::__Invalid => unreachable!()
123 }
124}
125
126pub unsafe fn buffer_from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize) -> JsValue {
127 match dtype {
128 binding::DataType::Bool | binding::DataType::Uint8 => unsafe { js_sys::Uint8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(),
129 binding::DataType::Int8 => unsafe { js_sys::Int8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(),
130 binding::DataType::Uint16 => unsafe { js_sys::Uint16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(),
131 binding::DataType::Int16 => unsafe { js_sys::Int16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(),
132 binding::DataType::Uint32 => unsafe { js_sys::Uint32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
133 binding::DataType::Int32 => unsafe { js_sys::Int32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
134 binding::DataType::Uint64 => unsafe { js_sys::BigUint64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
135 binding::DataType::Int64 => unsafe { js_sys::BigInt64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
136 binding::DataType::Float32 => unsafe { js_sys::Float32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
137 binding::DataType::Float64 => unsafe { js_sys::Float64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
138 binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(),
139 binding::DataType::__Invalid => unreachable!()
140 }
141}
142
143pub fn dtype_to_onnx(dtype: binding::DataType) -> ort_sys::ONNXTensorElementDataType {
144 match dtype {
145 binding::DataType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
146 binding::DataType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
147 binding::DataType::Uint8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
148 binding::DataType::Int8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
149 binding::DataType::Uint16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
150 binding::DataType::Int16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
151 binding::DataType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
152 binding::DataType::Int32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
153 binding::DataType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
154 binding::DataType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
155 binding::DataType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
156 binding::DataType::Float32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
157 binding::DataType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
158 binding::DataType::Int4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4,
159 binding::DataType::Uint4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4,
160 binding::DataType::__Invalid => unreachable!()
161 }
162}
163
164pub fn onnx_to_dtype(dtype: ort_sys::ONNXTensorElementDataType) -> Option<binding::DataType> {
165 match dtype {
166 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => Some(binding::DataType::String),
167 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => Some(binding::DataType::Bool),
168 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Some(binding::DataType::Uint8),
169 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => Some(binding::DataType::Int8),
170 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => Some(binding::DataType::Uint16),
171 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => Some(binding::DataType::Int16),
172 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Some(binding::DataType::Uint32),
173 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => Some(binding::DataType::Int32),
174 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => Some(binding::DataType::Uint64),
175 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Some(binding::DataType::Int64),
176 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => Some(binding::DataType::Float16),
177 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Some(binding::DataType::Float32),
178 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Some(binding::DataType::Float64),
179 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 => Some(binding::DataType::Int4),
180 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4 => Some(binding::DataType::Uint4),
181 _ => None
182 }
183}
184
185pub struct TypeInfo {
186 pub dtype: ort_sys::ONNXTensorElementDataType,
187 pub shape: Vec<i32>
188}
189
190impl TypeInfo {
191 pub fn new_sys_from_tensor(tensor: &Tensor) -> *mut ort_sys::OrtTypeInfo {
192 Self::new_sys(tensor.js.dtype(), tensor.js.dims())
193 }
194
195 pub fn new_sys_from_value_metadata(metadata: &binding::ValueMetadata) -> *mut ort_sys::OrtTypeInfo {
196 Self::new_sys(
197 metadata.r#type.unwrap(),
198 metadata
199 .shape
200 .as_ref()
201 .unwrap()
202 .iter()
203 .map(|el| match el {
204 binding::ShapeElement::Value(v) => *v as i32,
205 binding::ShapeElement::Named(_) => -1
206 })
207 .collect()
208 )
209 }
210
211 pub fn new_sys(dtype: DataType, shape: Vec<i32>) -> *mut ort_sys::OrtTypeInfo {
212 (Box::leak(Box::new(Self { dtype: dtype_to_onnx(dtype), shape })) as *mut TypeInfo).cast()
213 }
214
215 pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtTypeInfo) -> Box<TypeInfo> {
216 unsafe { Box::from_raw(ptr.cast::<TypeInfo>()) }
217 }
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
221pub enum SyncDirection {
222 Rust,
224 Runtime
226}
227
228pub trait ValueExt {
229 crate::private_trait!();
230
231 #[allow(async_fn_in_trait)]
235 async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()>;
236}
237
238impl<T: ValueTypeMarker> ValueExt for ort::value::Value<T> {
239 crate::private_impl!();
240
241 async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> {
242 let ptr = self.ptr_mut();
243 let sentinel: [u8; 4] = unsafe { core::ptr::read(ptr.cast()) };
246 if sentinel != TENSOR_SENTINEL {
247 return Err(Error::new("Cannot synchronize Value that was not created by ort-web"));
248 }
249
250 let tensor: &mut Tensor = unsafe { &mut *ptr.cast() };
251 tensor.sync(direction).await
252 }
253}