1use alloc::string::{String, ToString};
2
3use super::{ExecutionProvider, ExecutionProviderOptions, RegisterError};
4use crate::{error::Result, session::builder::SessionBuilder};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum PreferredLayout {
8 NCHW,
9 NHWC
10}
11
12impl PreferredLayout {
13 pub(crate) fn as_str(&self) -> &'static str {
14 match self {
15 PreferredLayout::NCHW => "NCHW",
16 PreferredLayout::NHWC => "NHWC"
17 }
18 }
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum DawnBackendType {
23 Vulkan,
24 D3D12
25}
26
27impl DawnBackendType {
28 pub(crate) fn as_str(&self) -> &'static str {
29 match self {
30 DawnBackendType::Vulkan => "Vulkan",
31 DawnBackendType::D3D12 => "D3D12"
32 }
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum BufferCacheMode {
38 Disabled,
39 LazyRelease,
40 Simple,
41 Bucket
42}
43
44impl BufferCacheMode {
45 pub(crate) fn as_str(&self) -> &'static str {
46 match self {
47 BufferCacheMode::Disabled => "disabled",
48 BufferCacheMode::LazyRelease => "lazyRelease",
49 BufferCacheMode::Simple => "simple",
50 BufferCacheMode::Bucket => "bucket"
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum ValidationMode {
57 Disabled,
58 WgpuOnly,
59 Basic,
60 Full
61}
62
63impl ValidationMode {
64 #[must_use]
65 pub fn as_str(&self) -> &'static str {
66 match self {
67 ValidationMode::Disabled => "disabled",
68 ValidationMode::WgpuOnly => "wgpuOnly",
69 ValidationMode::Basic => "basic",
70 ValidationMode::Full => "full"
71 }
72 }
73}
74
75#[derive(Debug, Default, Clone)]
76pub struct WebGPU {
77 options: ExecutionProviderOptions
78}
79
80super::impl_ep!(arbitrary; WebGPU);
81
82impl WebGPU {
83 #[must_use]
84 pub fn with_preferred_layout(mut self, layout: PreferredLayout) -> Self {
85 self.options.set("ep.webgpuexecutionprovider.preferredLayout", layout.as_str());
86 self
87 }
88
89 #[must_use]
90 pub fn with_enable_graph_capture(mut self, enable: bool) -> Self {
91 self.options
92 .set("ep.webgpuexecutionprovider.enableGraphCapture", if enable { "1" } else { "0" });
93 self
94 }
95
96 #[must_use]
97 pub fn with_dawn_proc_table(mut self, table: String) -> Self {
98 self.options.set("ep.webgpuexecutionprovider.dawnProcTable", table);
99 self
100 }
101
102 #[must_use]
103 pub fn with_dawn_backend_type(mut self, backend_type: DawnBackendType) -> Self {
104 self.options.set("ep.webgpuexecutionprovider.dawnBackendType", backend_type.as_str());
105 self
106 }
107
108 #[must_use]
109 pub fn with_device_id(mut self, id: i32) -> Self {
110 self.options.set("ep.webgpuexecutionprovider.deviceId", id.to_string());
111 self
112 }
113
114 #[must_use]
115 pub fn with_storage_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
116 self.options.set("ep.webgpuexecutionprovider.storageBufferCacheMode", mode.as_str());
117 self
118 }
119
120 #[must_use]
121 pub fn with_uniform_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
122 self.options.set("ep.webgpuexecutionprovider.uniformBufferCacheMode", mode.as_str());
123 self
124 }
125
126 #[must_use]
127 pub fn with_query_resolve_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
128 self.options.set("ep.webgpuexecutionprovider.queryResolveBufferCacheMode", mode.as_str());
129 self
130 }
131
132 #[must_use]
133 pub fn with_default_buffer_cache_mode(mut self, mode: BufferCacheMode) -> Self {
134 self.options.set("ep.webgpuexecutionprovider.defaultBufferCacheMode", mode.as_str());
135 self
136 }
137
138 #[must_use]
139 pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self {
140 self.options.set("ep.webgpuexecutionprovider.validationMode", mode.as_str());
141 self
142 }
143
144 #[must_use]
145 pub fn with_force_cpu_node_names(mut self, names: String) -> Self {
146 self.options.set("ep.webgpuexecutionprovider.forceCpuNodeNames", names);
147 self
148 }
149
150 #[must_use]
151 pub fn with_enable_pix_capture(mut self, enable: bool) -> Self {
152 self.options
153 .set("ep.webgpuexecutionprovider.enablePIXCapture", if enable { "1" } else { "0" });
154 self
155 }
156}
157
158impl ExecutionProvider for WebGPU {
159 fn name(&self) -> &'static str {
160 "WebGpuExecutionProvider"
161 }
162
163 fn supported_by_platform(&self) -> bool {
164 cfg!(any(target_os = "windows", target_os = "linux", target_arch = "wasm32"))
165 }
166
167 #[allow(unused, unreachable_code)]
168 fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
169 #[cfg(any(target_arch = "wasm32", feature = "load-dynamic", feature = "webgpu"))]
170 {
171 use crate::{AsPointer, ortsys};
172
173 let ffi_options = self.options.to_ffi();
174 ortsys![unsafe SessionOptionsAppendExecutionProvider(
175 session_builder.ptr_mut(),
176 c"WebGPU".as_ptr().cast::<core::ffi::c_char>(), ffi_options.key_ptrs(),
178 ffi_options.value_ptrs(),
179 ffi_options.len(),
180 )?];
181 return Ok(());
182 }
183
184 Err(RegisterError::MissingFeature)
185 }
186}