ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
use alloc::string::{String, ToString};

use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PreferredLayout {
	NCHW,
	NHWC
}

impl PreferredLayout {
	pub(crate) fn as_str(&self) -> &'static str {
		match self {
			PreferredLayout::NCHW => "NCHW",
			PreferredLayout::NHWC => "NHWC"
		}
	}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DawnBackendType {
	Vulkan,
	D3D12
}

impl DawnBackendType {
	pub(crate) fn as_str(&self) -> &'static str {
		match self {
			DawnBackendType::Vulkan => "Vulkan",
			DawnBackendType::D3D12 => "D3D12"
		}
	}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BufferCacheMode {
	Disabled,
	LazyRelease,
	Simple,
	Bucket
}

impl BufferCacheMode {
	pub(crate) fn as_str(&self) -> &'static str {
		match self {
			BufferCacheMode::Disabled => "disabled",
			BufferCacheMode::LazyRelease => "lazyRelease",
			BufferCacheMode::Simple => "simple",
			BufferCacheMode::Bucket => "bucket"
		}
	}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ValidationMode {
	Disabled,
	WgpuOnly,
	Basic,
	Full
}

impl ValidationMode {
	#[must_use]
	pub fn as_str(&self) -> &'static str {
		match self {
			ValidationMode::Disabled => "disabled",
			ValidationMode::WgpuOnly => "wgpuOnly",
			ValidationMode::Basic => "basic",
			ValidationMode::Full => "full"
		}
	}
}

#[derive(Debug, Default, Clone)]
pub struct WebGPU {
	options: ExecutionProviderOptions
}

super::impl_ep!(arbitrary; WebGPU);

impl WebGPU {
	#[must_use]
	pub fn with_preferred_layout(mut self, layout: PreferredLayout) -> Self {
		self.options.set("ep.webgpuexecutionprovider.preferredLayout", layout.as_str());
		self
	}

	#[must_use]
	pub fn with_enable_graph_capture(mut self, enable: bool) -> Self {
		self.options
			.set("ep.webgpuexecutionprovider.enableGraphCapture", if enable { "1" } else { "0" });
		self
	}

	#[must_use]
	pub fn with_dawn_proc_table(mut self, table: String) -> Self {
		self.options.set("ep.webgpuexecutionprovider.dawnProcTable", table);
		self
	}

	#[must_use]
	pub fn with_dawn_backend_type(mut self, backend_type: DawnBackendType) -> Self {
		self.options.set("ep.webgpuexecutionprovider.dawnBackendType", backend_type.as_str());
		self
	}

	#[must_use]
	pub fn with_device_id(mut self, id: i32) -> Self {
		self.options.set("ep.webgpuexecutionprovider.deviceId", id.to_string());
		self
	}

	#[must_use]
	pub fn with_storage_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
		self.options.set("ep.webgpuexecutionprovider.storageBufferCacheMode", mode.as_str());
		self
	}

	#[must_use]
	pub fn with_uniform_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
		self.options.set("ep.webgpuexecutionprovider.uniformBufferCacheMode", mode.as_str());
		self
	}

	#[must_use]
	pub fn with_query_resolve_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
		self.options.set("ep.webgpuexecutionprovider.queryResolveBufferCacheMode", mode.as_str());
		self
	}

	#[must_use]
	pub fn with_default_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
		self.options.set("ep.webgpuexecutionprovider.defaultBufferCacheMode", mode.as_str());
		self
	}

	#[must_use]
	pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self {
		self.options.set("ep.webgpuexecutionprovider.validationMode", mode.as_str());
		self
	}

	#[must_use]
	pub fn with_force_cpu_node_names(mut self, names: String) -> Self {
		self.options.set("ep.webgpuexecutionprovider.forceCpuNodeNames", names);
		self
	}

	#[must_use]
	pub fn with_enable_pix_capture(mut self, enable: bool) -> Self {
		self.options
			.set("ep.webgpuexecutionprovider.enablePIXCapture", if enable { "1" } else { "0" });
		self
	}
}

impl ExecutionProvider for WebGPU {
	fn name(&self) -> &'static str {
		"WebGpuExecutionProvider"
	}

	fn supported_by_platform(&self) -> bool {
		cfg!(any(target_os = "windows", target_os = "linux", target_arch = "wasm32"))
	}

	#[allow(unused, unreachable_code)]
	fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
		#[cfg(any(target_arch = "wasm32", feature = "load-dynamic", feature = "webgpu"))]
		{
			use crate::{AsPointer, ortsys};

			let ffi_options = self.options.to_ffi();
			ortsys![unsafe SessionOptionsAppendExecutionProvider(
				session_builder.ptr_mut(),
				c"WebGPU".as_ptr().cast::<core::ffi::c_char>(), // much consistency
				ffi_options.key_ptrs(),
				ffi_options.value_ptrs(),
				ffi_options.len(),
			)?];
			return Ok(());
		}

		Err(RegisterError::MissingFeature)
	}
}