1use std::sync::OnceLock;
22use std::sync::atomic::{AtomicU8, Ordering};
23
24const PREF_DEFAULT: u8 = 0;
25const PREF_VULKAN: u8 = 1;
26
27static BACKEND_PREF: AtomicU8 = AtomicU8::new(PREF_DEFAULT);
28
29pub struct WgpuDevice {
33 pub instance: wgpu::Instance,
34 pub adapter: wgpu::Adapter,
35 pub device: wgpu::Device,
36 pub queue: wgpu::Queue,
37 pub name: String,
38 pub backend: wgpu::Backend,
39}
40
41impl WgpuDevice {
42 fn new_with_backends(backends: wgpu::Backends) -> Option<Self> {
43 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
44 backends,
45 flags: wgpu::InstanceFlags::default(),
46 backend_options: wgpu::BackendOptions::default(),
47 memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
48 display: None,
49 });
50 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
51 power_preference: wgpu::PowerPreference::HighPerformance,
52 compatible_surface: None,
53 force_fallback_adapter: false,
54 }))
55 .ok()?;
56
57 let info = adapter.get_info();
58 let limits = adapter.limits();
59 let adapter_feats = adapter.features();
60 let mut required_features = wgpu::Features::empty();
61 if adapter_feats.contains(wgpu::Features::SHADER_F16) {
62 required_features |= wgpu::Features::SHADER_F16;
63 }
64 if adapter_feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
65 required_features |= wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX;
66 }
67 if adapter_feats.contains(wgpu::Features::SUBGROUP) {
68 required_features |= wgpu::Features::SUBGROUP;
69 }
70
71 let (device, queue) =
72 match pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
73 label: Some("rlx-wgpu device"),
74 required_features,
75 required_limits: limits,
76 memory_hints: wgpu::MemoryHints::Performance,
77 experimental_features: unsafe { wgpu::ExperimentalFeatures::enabled() },
78 trace: wgpu::Trace::Off,
79 })) {
80 Ok(p) => p,
81 Err(e) => {
82 eprintln!("rlx-wgpu request_device failed: {e}");
83 return None;
84 }
85 };
86
87 Some(Self {
88 instance,
89 adapter,
90 device,
91 queue,
92 name: info.name,
93 backend: info.backend,
94 })
95 }
96
97 fn new_default() -> Option<Self> {
98 Self::new_with_backends(default_backends())
99 }
100}
101
102fn default_backends() -> wgpu::Backends {
103 if let Some(b) = wgpu::Backends::from_env() {
104 return b;
105 }
106 #[cfg(target_os = "windows")]
107 {
108 wgpu::Backends::DX12 | wgpu::Backends::VULKAN
110 }
111 #[cfg(target_os = "linux")]
112 {
113 wgpu::Backends::VULKAN
115 }
116 #[cfg(target_os = "macos")]
117 {
118 wgpu::Backends::METAL | wgpu::Backends::VULKAN
119 }
120 #[cfg(not(any(target_os = "windows", target_os = "linux", target_os = "macos")))]
121 {
122 wgpu::Backends::all()
123 }
124}
125
126unsafe impl Send for WgpuDevice {}
128unsafe impl Sync for WgpuDevice {}
129
130fn default_device() -> Option<&'static WgpuDevice> {
131 static DEVICE: OnceLock<Option<WgpuDevice>> = OnceLock::new();
132 DEVICE.get_or_init(WgpuDevice::new_default).as_ref()
133}
134
135fn vulkan_device() -> Option<&'static WgpuDevice> {
136 static DEVICE: OnceLock<Option<WgpuDevice>> = OnceLock::new();
137 DEVICE
138 .get_or_init(|| WgpuDevice::new_with_backends(wgpu::Backends::VULKAN))
139 .as_ref()
140}
141
142pub fn select_vulkan_backend() {
146 BACKEND_PREF.store(PREF_VULKAN, Ordering::SeqCst);
147}
148
149pub fn is_vulkan_available() -> bool {
151 vulkan_device().is_some()
152}
153
154pub fn wgpu_device() -> Option<&'static WgpuDevice> {
157 if BACKEND_PREF.load(Ordering::SeqCst) == PREF_VULKAN {
158 vulkan_device()
159 } else {
160 default_device()
161 }
162}
163
164pub fn adapter_name() -> Option<String> {
166 wgpu_device().map(|d| d.name.clone())
167}
168
169pub fn coop_discrete_backend() -> bool {
171 wgpu_device()
172 .map(|d| matches!(d.backend, wgpu::Backend::Vulkan | wgpu::Backend::Dx12))
173 .unwrap_or(false)
174}
175
176pub fn coop_f32_8x8_supported() -> bool {
179 let dev = match wgpu_device() {
180 Some(d) => d,
181 None => return false,
182 };
183 dev.adapter.cooperative_matrix_properties().iter().any(|p| {
184 p.m_size == 8
185 && p.n_size == 8
186 && p.k_size == 8
187 && p.ab_type == wgpu::CooperativeScalarType::F32
188 && p.cr_type == wgpu::CooperativeScalarType::F32
189 })
190}
191
192pub fn coop_f16_16x16_supported() -> bool {
195 let dev = match wgpu_device() {
196 Some(d) => d,
197 None => return false,
198 };
199 dev.adapter.cooperative_matrix_properties().iter().any(|p| {
200 p.m_size == 16
201 && p.n_size == 16
202 && p.k_size == 16
203 && p.ab_type == wgpu::CooperativeScalarType::F16
204 && (p.cr_type == wgpu::CooperativeScalarType::F16
205 || p.cr_type == wgpu::CooperativeScalarType::F32)
206 })
207}
208
209pub fn coop_f16_16x16_f32_acc_supported() -> bool {
211 let dev = match wgpu_device() {
212 Some(d) => d,
213 None => return false,
214 };
215 dev.adapter.cooperative_matrix_properties().iter().any(|p| {
216 p.m_size == 16
217 && p.n_size == 16
218 && p.k_size == 16
219 && p.ab_type == wgpu::CooperativeScalarType::F16
220 && p.cr_type == wgpu::CooperativeScalarType::F32
221 })
222}