ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
use alloc::string::ToString;

use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{AsPointer, error::Result, ortsys, session::builder::SessionBuilder};

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum PowerPreference {
	#[default]
	Default,
	HighPerformance,
	LowPower
}

impl PowerPreference {
	#[must_use]
	pub fn as_str(&self) -> &'static str {
		match self {
			PowerPreference::Default => "default",
			PowerPreference::HighPerformance => "high-performance",
			PowerPreference::LowPower => "low-power"
		}
	}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeviceType {
	CPU,
	GPU,
	NPU
}

impl DeviceType {
	#[must_use]
	pub fn as_str(&self) -> &'static str {
		match self {
			DeviceType::CPU => "cpu",
			DeviceType::GPU => "gpu",
			DeviceType::NPU => "npu"
		}
	}
}

#[derive(Debug, Default, Clone)]
pub struct WebNN {
	options: ExecutionProviderOptions
}

impl WebNN {
	#[must_use]
	pub fn with_device_type(mut self, device_type: DeviceType) -> Self {
		self.options.set("deviceType", device_type.as_str());
		self
	}

	#[must_use]
	pub fn with_power_preference(mut self, pref: PowerPreference) -> Self {
		self.options.set("powerPreference", pref.as_str());
		self
	}

	#[must_use]
	pub fn with_threads(mut self, threads: u32) -> Self {
		self.options.set("numThreads", threads.to_string());
		self
	}
}

super::impl_ep!(arbitrary; WebNN);

impl ExecutionProvider for WebNN {
	fn name(&self) -> &'static str {
		"WebNNExecutionProvider"
	}

	fn supported_by_platform(&self) -> bool {
		cfg!(target_arch = "wasm32")
	}

	#[allow(unused)]
	fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
		let ffi_options = self.options.to_ffi();
		ortsys![unsafe SessionOptionsAppendExecutionProvider(
			session_builder.ptr_mut(),
			c"WebNN".as_ptr().cast::<core::ffi::c_char>(),
			ffi_options.key_ptrs(),
			ffi_options.value_ptrs(),
			ffi_options.len(),
		)?];
		Ok(())
	}
}