ort-web 0.2.1+1.24

ONNX Runtime on the web 🌐 - An alternative backend for ort
Documentation
use alloc::{string::String, vec::Vec};
use std::collections::HashMap;

use js_sys::{JsString, Object, Reflect, Uint8Array};
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;

use crate::{binding::DataType, tensor::Tensor};

#[derive(Serialize, Debug, Clone, Copy)]
#[serde(rename_all = "lowercase")]
pub enum ExecutionMode {
	Sequential,
	Parallel
}

#[derive(Serialize, Debug, Clone, Copy)]
#[serde(rename_all = "lowercase")]
pub enum GraphOptimizationLevel {
	Disabled,
	Basic,
	Layout,
	Extended,
	All
}

#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum WebNNDeviceType {
	CPU,
	GPU,
	NPU
}

#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum WebNNPowerPreference {
	Default,
	HighPerformance,
	LowPower
}

#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "UPPERCASE")]
pub enum WebGPUPreferredLayout {
	NHWC,
	NCHW
}

#[derive(Serialize, Debug, Clone)]
#[serde(tag = "name", rename_all = "lowercase")]
pub enum ExecutionProvider {
	WASM,
	WebGL,
	#[serde(rename_all = "camelCase")]
	WebNN {
		device_type: Option<WebNNDeviceType>,
		num_threads: Option<u32>,
		power_preference: Option<WebNNPowerPreference>
	},
	#[serde(rename_all = "camelCase")]
	WebGPU {
		preferred_layout: Option<WebGPUPreferredLayout>
	}
}

#[derive(Serialize, Default, Clone)]
#[serde(rename_all = "camelCase")]
pub struct SessionOptions {
	pub enable_cpu_mem_arena: Option<bool>,
	pub enable_graph_capture: Option<bool>,
	pub enable_mem_pattern: Option<bool>,
	pub enable_profiling: Option<bool>,
	pub execution_mode: Option<ExecutionMode>,
	pub execution_providers: Option<Vec<ExecutionProvider>>,
	pub extra: Option<HashMap<String, String>>,
	pub free_dimension_override: Option<HashMap<String, i32>>,
	pub graph_optimization_level: Option<GraphOptimizationLevel>,
	pub inter_op_num_threads: Option<u32>,
	pub intra_op_num_threads: Option<u32>,
	pub log_id: Option<String>,
	pub log_severity_level: Option<u8>,
	pub log_verbosity_level: Option<u16>
}

impl SessionOptions {
	pub(crate) fn to_value(&self) -> Result<JsValue, serde_wasm_bindgen::Error> {
		self.serialize(&serde_wasm_bindgen::Serializer::new().serialize_maps_as_objects(true))
	}
}

#[wasm_bindgen]
extern "C" {
	#[wasm_bindgen(js_namespace = ort)]
	pub type InferenceSession;

	#[wasm_bindgen(catch, js_namespace = ort, static_method_of = InferenceSession, js_name = create)]
	async fn create_from_uri_raw(uri: &str, options: JsValue) -> Result<InferenceSession, JsValue>;
	#[wasm_bindgen(catch, js_namespace = ort, static_method_of = InferenceSession, js_name = create)]
	async fn create_from_bytes_raw(buffer: &Uint8Array, options: JsValue) -> Result<InferenceSession, JsValue>;

	#[wasm_bindgen(catch, structural, method, js_name = startProfiling)]
	pub fn start_profiling(this: &InferenceSession) -> Result<(), JsValue>;
	#[wasm_bindgen(catch, structural, method, js_name = endProfiling)]
	pub fn end_profiling(this: &InferenceSession) -> Result<(), JsValue>;
	#[wasm_bindgen(catch, structural, method, js_name = release)]
	pub async fn release(this: &InferenceSession) -> Result<(), JsValue>;

	#[wasm_bindgen(structural, method, getter, js_name = inputMetadata)]
	fn input_metadata_raw(this: &InferenceSession) -> Vec<JsValue>;
	#[wasm_bindgen(structural, method, getter, js_name = outputMetadata)]
	fn output_metadata_raw(this: &InferenceSession) -> Vec<JsValue>;
	#[wasm_bindgen(structural, method, getter, js_name = inputNames)]
	fn input_names_raw(this: &InferenceSession) -> Vec<JsString>;
	#[wasm_bindgen(structural, method, getter, js_name = outputNames)]
	fn output_names_raw(this: &InferenceSession) -> Vec<JsString>;

	#[wasm_bindgen(catch, structural, method, js_name = run)]
	async fn run_raw(this: &InferenceSession, feeds: JsValue) -> Result<JsValue, JsValue>;
	#[wasm_bindgen(catch, structural, method, js_name = run)]
	async fn run_with_fetches_raw(this: &InferenceSession, feeds: JsValue, fetches: JsValue) -> Result<JsValue, JsValue>;
}

impl InferenceSession {
	pub async fn create_from_uri(uri: &str, options: &SessionOptions) -> Result<InferenceSession, JsValue> {
		InferenceSession::create_from_uri_raw(uri, options.to_value()?).await
	}
	pub async fn create_from_bytes(buffer: &Uint8Array, options: &SessionOptions) -> Result<InferenceSession, JsValue> {
		InferenceSession::create_from_bytes_raw(buffer, options.to_value()?).await
	}

	pub fn input_names(&self) -> Vec<String> {
		self.input_names_raw().into_iter().map(String::from).collect()
	}
	pub fn output_names(&self) -> Vec<String> {
		self.output_names_raw().into_iter().map(String::from).collect()
	}

	pub fn input_len(&self) -> usize {
		self.input_names_raw().len()
	}
	pub fn output_len(&self) -> usize {
		self.output_names_raw().len()
	}

	pub fn input_metadata(&self) -> Vec<ValueMetadata> {
		self.input_metadata_raw()
			.into_iter()
			.map(|x| serde_wasm_bindgen::from_value(x))
			.collect::<Result<Vec<_>, serde_wasm_bindgen::Error>>()
			.unwrap()
	}

	pub fn output_metadata(&self) -> Vec<ValueMetadata> {
		self.output_metadata_raw()
			.into_iter()
			.map(|x| serde_wasm_bindgen::from_value(x))
			.collect::<Result<Vec<_>, serde_wasm_bindgen::Error>>()
			.unwrap()
	}

	pub async fn run(&self, feeds: impl Iterator<Item = (&str, &Tensor)>) -> Result<Vec<(String, Tensor)>, JsValue> {
		let feeds_value = Object::new();
		for (name, tensor) in feeds {
			Reflect::set(&feeds_value, &JsValue::from_str(name), &tensor.js)?;
		}
		Self::to_outputs(self.run_raw(feeds_value.into()).await?)
	}

	pub async fn run_with_fetches(
		&self,
		feeds: impl Iterator<Item = (&str, &Tensor)>,
		fetches: impl Iterator<Item = (&str, Option<&Tensor>)>
	) -> Result<Vec<(String, Tensor)>, JsValue> {
		let feeds_value = Object::new();
		for (name, tensor) in feeds {
			Reflect::set(&feeds_value, &JsValue::from_str(name), &tensor.js)?;
		}
		let fetches_value = Object::new();
		for (name, tensor) in fetches {
			let null = JsValue::null();
			Reflect::set(
				&fetches_value,
				&JsValue::from_str(name),
				match tensor {
					Some(tensor) => &tensor.js,
					None => &null
				}
			)?;
		}
		Self::to_outputs(self.run_with_fetches_raw(feeds_value.into(), fetches_value.into()).await?)
	}

	fn to_outputs(value: JsValue) -> Result<Vec<(String, Tensor)>, JsValue> {
		Ok(Reflect::own_keys(&value)?
			.to_vec()
			.into_iter()
			.filter_map(|c| {
				c.dyn_ref::<JsString>().map(String::from).and_then(|k| {
					Reflect::get(&value, &c)
						.map(super::Tensor::unchecked_from_js)
						.ok()
						.map(|v| (k, Tensor::from_tensor(v)))
				})
			})
			.collect())
	}
}

#[derive(Deserialize, Debug)]
#[serde(untagged)]
pub enum ShapeElement {
	Named(String),
	Value(i64)
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ValueMetadata {
	pub is_tensor: bool,
	pub name: String,
	pub shape: Option<Vec<ShapeElement>>,
	pub r#type: Option<DataType>
}