ort_web/
tensor.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::{ffi::c_void, slice};
3
4use js_sys::Uint8Array;
5use ort::{AsPointer, value::ValueTypeMarker};
6use wasm_bindgen::{JsCast, JsValue};
7
8use crate::{
9	Error,
10	binding::{self, DataType},
11	memory::MemoryInfo,
12	util::num_elements
13};
14
15pub const TENSOR_SENTINEL: [u8; 4] = [0xFC, 0x86, 0xA5, 0x39];
16
17pub enum TensorData {
18	/// Data is stored in WASM linear memory and can be immediately accessed.
19	RustView { ptr: *mut c_void, byte_len: usize },
20	/// Data is stored outside of WASM linear memory (i.e. session output, or a tensor created from anything other than
21	/// a Rust slice) and would need to be retrieved if we try to extract this tensor.
22	External { buffer: Option<Box<[u8]>> }
23}
24
25#[repr(C)]
26pub struct Tensor {
27	sentinel: [u8; 4],
28	pub js: binding::Tensor,
29	pub data: TensorData,
30	pub memory_info: MemoryInfo
31}
32
33impl Tensor {
34	pub unsafe fn from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize, dims: &[i32]) -> Result<Self, JsValue> {
35		let tensor = binding::Tensor::new_from_buffer(dtype, unsafe { buffer_from_ptr(dtype, ptr, byte_len) }, dims)?;
36		Ok(Self {
37			sentinel: TENSOR_SENTINEL,
38			memory_info: MemoryInfo { location: tensor.location() },
39			js: tensor,
40			data: TensorData::RustView { ptr, byte_len }
41		})
42	}
43
44	pub fn from_tensor(tensor: binding::Tensor) -> Self {
45		Self {
46			sentinel: TENSOR_SENTINEL,
47			memory_info: MemoryInfo { location: tensor.location() },
48			js: tensor,
49			data: TensorData::External { buffer: None }
50		}
51	}
52
53	pub async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> {
54		match direction {
55			SyncDirection::Rust => {
56				let data = self.js.get_data().await?;
57
58				// cast to some kind of typed array first, then convert to uint8array so we can properly copy
59				let generic_typed_array = Uint8Array::unchecked_from_js(data);
60				let bytes = Uint8Array::new_with_byte_offset_and_length(
61					&generic_typed_array.buffer(),
62					generic_typed_array.byte_offset(),
63					generic_typed_array.byte_length()
64				);
65				match &mut self.data {
66					TensorData::RustView { ptr, byte_len } => {
67						bytes.copy_to(unsafe { core::slice::from_raw_parts_mut(ptr.cast(), *byte_len) });
68					}
69					TensorData::External { buffer } => {
70						let buffer = match buffer {
71							Some(buffer) => buffer,
72							None => {
73								*buffer = Some(vec![0; generic_typed_array.byte_length() as usize].into_boxed_slice());
74								unsafe { buffer.as_mut().unwrap_unchecked() }
75							}
76						};
77						bytes.copy_to(buffer);
78					}
79				}
80			}
81			SyncDirection::Runtime => {
82				let Ok(generic_typed_array) = self.js.data().map(Uint8Array::unchecked_from_js) else {
83					// we have a download function, but no upload...
84					return Err(Error::new(
85						"Cannot synchronize Rust data to a runtime tensor that is not on the CPU; modify the WebGPU/WebGL buffer directly."
86					));
87				};
88				let bytes = Uint8Array::new_with_byte_offset_and_length(
89					&generic_typed_array.buffer(),
90					generic_typed_array.byte_offset(),
91					generic_typed_array.byte_length()
92				);
93				bytes.copy_from(match &self.data {
94					TensorData::RustView { ptr, byte_len } => unsafe { core::slice::from_raw_parts(ptr.cast(), *byte_len) },
95					TensorData::External { buffer } => {
96						let Some(buffer) = buffer else {
97							return Ok(());
98						};
99						&*buffer
100					}
101				});
102			}
103		}
104		Ok(())
105	}
106}
107
108pub fn create_buffer(dtype: binding::DataType, shape: &[i32]) -> JsValue {
109	let numel = num_elements(shape) as u32;
110	match dtype {
111		binding::DataType::Bool | binding::DataType::Uint8 => js_sys::Uint8Array::new_with_length(numel).into(),
112		binding::DataType::Int8 => js_sys::Int8Array::new_with_length(numel).into(),
113		binding::DataType::Uint16 => js_sys::Uint16Array::new_with_length(numel).into(),
114		binding::DataType::Int16 => js_sys::Int16Array::new_with_length(numel).into(),
115		binding::DataType::Uint32 => js_sys::Uint32Array::new_with_length(numel).into(),
116		binding::DataType::Int32 => js_sys::Int32Array::new_with_length(numel).into(),
117		binding::DataType::Uint64 => js_sys::BigUint64Array::new_with_length(numel).into(),
118		binding::DataType::Int64 => js_sys::BigInt64Array::new_with_length(numel).into(),
119		binding::DataType::Float32 => js_sys::Float32Array::new_with_length(numel).into(),
120		binding::DataType::Float64 => js_sys::Float64Array::new_with_length(numel).into(),
121		binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(),
122		binding::DataType::__Invalid => unreachable!()
123	}
124}
125
126pub unsafe fn buffer_from_ptr(dtype: binding::DataType, ptr: *mut c_void, byte_len: usize) -> JsValue {
127	match dtype {
128		binding::DataType::Bool | binding::DataType::Uint8 => unsafe { js_sys::Uint8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(),
129		binding::DataType::Int8 => unsafe { js_sys::Int8Array::view(slice::from_raw_parts(ptr.cast(), byte_len)) }.into(),
130		binding::DataType::Uint16 => unsafe { js_sys::Uint16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(),
131		binding::DataType::Int16 => unsafe { js_sys::Int16Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 2)) }.into(),
132		binding::DataType::Uint32 => unsafe { js_sys::Uint32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
133		binding::DataType::Int32 => unsafe { js_sys::Int32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
134		binding::DataType::Uint64 => unsafe { js_sys::BigUint64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
135		binding::DataType::Int64 => unsafe { js_sys::BigInt64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
136		binding::DataType::Float32 => unsafe { js_sys::Float32Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 4)) }.into(),
137		binding::DataType::Float64 => unsafe { js_sys::Float64Array::view(slice::from_raw_parts(ptr.cast(), byte_len / 8)) }.into(),
138		binding::DataType::Int4 | binding::DataType::Uint4 | binding::DataType::Float16 | binding::DataType::String => unimplemented!(),
139		binding::DataType::__Invalid => unreachable!()
140	}
141}
142
143pub fn dtype_to_onnx(dtype: binding::DataType) -> ort_sys::ONNXTensorElementDataType {
144	match dtype {
145		binding::DataType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
146		binding::DataType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
147		binding::DataType::Uint8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
148		binding::DataType::Int8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
149		binding::DataType::Uint16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
150		binding::DataType::Int16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
151		binding::DataType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
152		binding::DataType::Int32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
153		binding::DataType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
154		binding::DataType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
155		binding::DataType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
156		binding::DataType::Float32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
157		binding::DataType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
158		binding::DataType::Int4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4,
159		binding::DataType::Uint4 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4,
160		binding::DataType::__Invalid => unreachable!()
161	}
162}
163
164pub fn onnx_to_dtype(dtype: ort_sys::ONNXTensorElementDataType) -> Option<binding::DataType> {
165	match dtype {
166		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => Some(binding::DataType::String),
167		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => Some(binding::DataType::Bool),
168		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Some(binding::DataType::Uint8),
169		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => Some(binding::DataType::Int8),
170		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => Some(binding::DataType::Uint16),
171		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => Some(binding::DataType::Int16),
172		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Some(binding::DataType::Uint32),
173		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => Some(binding::DataType::Int32),
174		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => Some(binding::DataType::Uint64),
175		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Some(binding::DataType::Int64),
176		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => Some(binding::DataType::Float16),
177		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Some(binding::DataType::Float32),
178		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Some(binding::DataType::Float64),
179		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 => Some(binding::DataType::Int4),
180		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4 => Some(binding::DataType::Uint4),
181		_ => None
182	}
183}
184
185pub struct TypeInfo {
186	pub dtype: ort_sys::ONNXTensorElementDataType,
187	pub shape: Vec<i32>
188}
189
190impl TypeInfo {
191	pub fn new_sys_from_tensor(tensor: &Tensor) -> *mut ort_sys::OrtTypeInfo {
192		Self::new_sys(tensor.js.dtype(), tensor.js.dims())
193	}
194
195	pub fn new_sys_from_value_metadata(metadata: &binding::ValueMetadata) -> *mut ort_sys::OrtTypeInfo {
196		Self::new_sys(
197			metadata.r#type.unwrap(),
198			metadata
199				.shape
200				.as_ref()
201				.unwrap()
202				.iter()
203				.map(|el| match el {
204					binding::ShapeElement::Value(v) => *v as i32,
205					binding::ShapeElement::Named(_) => -1
206				})
207				.collect()
208		)
209	}
210
211	pub fn new_sys(dtype: DataType, shape: Vec<i32>) -> *mut ort_sys::OrtTypeInfo {
212		(Box::leak(Box::new(Self { dtype: dtype_to_onnx(dtype), shape })) as *mut TypeInfo).cast()
213	}
214
215	pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtTypeInfo) -> Box<TypeInfo> {
216		unsafe { Box::from_raw(ptr.cast::<TypeInfo>()) }
217	}
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
221pub enum SyncDirection {
222	/// Synchronize tensor data from the device/runtime so that it is accessible to Rust code.
223	Rust,
224	/// Synchronize tensor data from Rust code so that it is accessible to the runtime.
225	Runtime
226}
227
228pub trait ValueExt {
229	crate::private_trait!();
230
231	/// Synchronize data between Rust & the runtime.
232	///
233	/// See the [top-level documentation][crate] for more information on synchronization.
234	#[allow(async_fn_in_trait)]
235	async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()>;
236}
237
238impl<T: ValueTypeMarker> ValueExt for ort::value::Value<T> {
239	crate::private_impl!();
240
241	async fn sync(&mut self, direction: SyncDirection) -> crate::Result<()> {
242		let ptr = self.ptr_mut();
243		// definitely safe regardless of what backend is used since it's highly improbable that a backend's tensor would be
244		// smaller than 4 bytes (which is pointer size on wasm32)
245		let sentinel: [u8; 4] = unsafe { core::ptr::read(ptr.cast()) };
246		if sentinel != TENSOR_SENTINEL {
247			return Err(Error::new("Cannot synchronize Value that was not created by ort-web"));
248		}
249
250		let tensor: &mut Tensor = unsafe { &mut *ptr.cast() };
251		tensor.sync(direction).await
252	}
253}