Skip to main content

ort/ep/
webgpu.rs

1use alloc::string::{String, ToString};
2
3use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
4use crate::{error::Result, session::builder::SessionBuilder};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum PreferredLayout {
8	NCHW,
9	NHWC
10}
11
12impl PreferredLayout {
13	pub(crate) fn as_str(&self) -> &'static str {
14		match self {
15			PreferredLayout::NCHW => "NCHW",
16			PreferredLayout::NHWC => "NHWC"
17		}
18	}
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum DawnBackendType {
23	Vulkan,
24	D3D12
25}
26
27impl DawnBackendType {
28	pub(crate) fn as_str(&self) -> &'static str {
29		match self {
30			DawnBackendType::Vulkan => "Vulkan",
31			DawnBackendType::D3D12 => "D3D12"
32		}
33	}
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum BufferCacheMode {
38	Disabled,
39	LazyRelease,
40	Simple,
41	Bucket
42}
43
44impl BufferCacheMode {
45	pub(crate) fn as_str(&self) -> &'static str {
46		match self {
47			BufferCacheMode::Disabled => "disabled",
48			BufferCacheMode::LazyRelease => "lazyRelease",
49			BufferCacheMode::Simple => "simple",
50			BufferCacheMode::Bucket => "bucket"
51		}
52	}
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum ValidationMode {
57	Disabled,
58	WgpuOnly,
59	Basic,
60	Full
61}
62
63impl ValidationMode {
64	#[must_use]
65	pub fn as_str(&self) -> &'static str {
66		match self {
67			ValidationMode::Disabled => "disabled",
68			ValidationMode::WgpuOnly => "wgpuOnly",
69			ValidationMode::Basic => "basic",
70			ValidationMode::Full => "full"
71		}
72	}
73}
74
75#[derive(Debug, Default, Clone)]
76pub struct WebGPU {
77	options: ExecutionProviderOptions
78}
79
80super::impl_ep!(arbitrary; WebGPU);
81
82impl WebGPU {
83	#[must_use]
84	pub fn with_preferred_layout(mut self, layout: PreferredLayout) -> Self {
85		self.options.set("ep.webgpuexecutionprovider.preferredLayout", layout.as_str());
86		self
87	}
88
89	#[must_use]
90	pub fn with_enable_graph_capture(mut self, enable: bool) -> Self {
91		self.options
92			.set("ep.webgpuexecutionprovider.enableGraphCapture", if enable { "1" } else { "0" });
93		self
94	}
95
96	#[must_use]
97	pub fn with_dawn_proc_table(mut self, table: String) -> Self {
98		self.options.set("ep.webgpuexecutionprovider.dawnProcTable", table);
99		self
100	}
101
102	#[must_use]
103	pub fn with_dawn_backend_type(mut self, backend_type: DawnBackendType) -> Self {
104		self.options.set("ep.webgpuexecutionprovider.dawnBackendType", backend_type.as_str());
105		self
106	}
107
108	#[must_use]
109	pub fn with_device_id(mut self, id: i32) -> Self {
110		self.options.set("ep.webgpuexecutionprovider.deviceId", id.to_string());
111		self
112	}
113
114	#[must_use]
115	pub fn with_storage_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
116		self.options.set("ep.webgpuexecutionprovider.storageBufferCacheMode", mode.as_str());
117		self
118	}
119
120	#[must_use]
121	pub fn with_uniform_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
122		self.options.set("ep.webgpuexecutionprovider.uniformBufferCacheMode", mode.as_str());
123		self
124	}
125
126	#[must_use]
127	pub fn with_query_resolve_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
128		self.options.set("ep.webgpuexecutionprovider.queryResolveBufferCacheMode", mode.as_str());
129		self
130	}
131
132	#[must_use]
133	pub fn with_default_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
134		self.options.set("ep.webgpuexecutionprovider.defaultBufferCacheMode", mode.as_str());
135		self
136	}
137
138	#[must_use]
139	pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self {
140		self.options.set("ep.webgpuexecutionprovider.validationMode", mode.as_str());
141		self
142	}
143
144	#[must_use]
145	pub fn with_force_cpu_node_names(mut self, names: String) -> Self {
146		self.options.set("ep.webgpuexecutionprovider.forceCpuNodeNames", names);
147		self
148	}
149
150	#[must_use]
151	pub fn with_enable_pix_capture(mut self, enable: bool) -> Self {
152		self.options
153			.set("ep.webgpuexecutionprovider.enablePIXCapture", if enable { "1" } else { "0" });
154		self
155	}
156}
157
158impl ExecutionProvider for WebGPU {
159	fn name(&self) -> &'static str {
160		"WebGpuExecutionProvider"
161	}
162
163	fn supported_by_platform(&self) -> bool {
164		cfg!(any(target_os = "windows", target_os = "linux", target_arch = "wasm32"))
165	}
166
167	#[allow(unused, unreachable_code)]
168	fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
169		#[cfg(any(target_arch = "wasm32", feature = "load-dynamic", feature = "webgpu"))]
170		{
171			use crate::{AsPointer, ortsys};
172
173			let ffi_options = self.options.to_ffi();
174			ortsys![unsafe SessionOptionsAppendExecutionProvider(
175				session_builder.ptr_mut(),
176				c"WebGPU".as_ptr().cast::<core::ffi::c_char>(), // much consistency
177				ffi_options.key_ptrs(),
178				ffi_options.value_ptrs(),
179				ffi_options.len(),
180			)?];
181			return Ok(());
182		}
183
184		Err(RegisterError::MissingFeature)
185	}
186}