1use crate::error::{GpuError, GpuResult};
7use std::sync::Arc;
8use tracing::{debug, info};
9use wgpu::{
10 Adapter, AdapterInfo, Backend, Backends, Device, DeviceDescriptor, Features, Instance,
11 InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum BackendPreference {
17 Vulkan,
19 Metal,
21 DX12,
23 WebGPU,
25 Auto,
27 All,
29}
30
31impl BackendPreference {
32 pub fn to_backends(&self) -> Backends {
34 match self {
35 Self::Vulkan => Backends::VULKAN,
36 Self::Metal => Backends::METAL,
37 Self::DX12 => Backends::DX12,
38 Self::WebGPU => Backends::BROWSER_WEBGPU,
39 Self::Auto => Backends::PRIMARY,
40 Self::All => Backends::all(),
41 }
42 }
43
44 pub fn platform_default() -> Self {
46 #[cfg(target_os = "macos")]
47 return Self::Metal;
48
49 #[cfg(target_os = "windows")]
50 return Self::DX12;
51
52 #[cfg(target_os = "linux")]
53 return Self::Vulkan;
54
55 #[cfg(target_arch = "wasm32")]
56 return Self::WebGPU;
57
58 #[cfg(not(any(
59 target_os = "macos",
60 target_os = "windows",
61 target_os = "linux",
62 target_arch = "wasm32"
63 )))]
64 return Self::Auto;
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum GpuPowerPreference {
71 LowPower,
73 HighPerformance,
75 Default,
77}
78
79impl From<GpuPowerPreference> for PowerPreference {
80 fn from(pref: GpuPowerPreference) -> Self {
81 match pref {
82 GpuPowerPreference::LowPower => PowerPreference::LowPower,
83 GpuPowerPreference::HighPerformance => PowerPreference::HighPerformance,
84 GpuPowerPreference::Default => PowerPreference::None,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct GpuContextConfig {
92 pub backend: BackendPreference,
94 pub power_preference: GpuPowerPreference,
96 pub required_features: Features,
98 pub required_limits: Option<Limits>,
100 pub debug: bool,
102 pub label: Option<String>,
104}
105
106impl Default for GpuContextConfig {
107 fn default() -> Self {
108 Self {
109 backend: BackendPreference::platform_default(),
110 power_preference: GpuPowerPreference::HighPerformance,
111 required_features: Features::empty(),
112 required_limits: None,
113 debug: cfg!(debug_assertions),
114 label: Some("OxiGDAL GPU Context".to_string()),
115 }
116 }
117}
118
119impl GpuContextConfig {
120 pub fn new() -> Self {
122 Self::default()
123 }
124
125 pub fn with_backend(mut self, backend: BackendPreference) -> Self {
127 self.backend = backend;
128 self
129 }
130
131 pub fn with_power_preference(mut self, power: GpuPowerPreference) -> Self {
133 self.power_preference = power;
134 self
135 }
136
137 pub fn with_features(mut self, features: Features) -> Self {
139 self.required_features = features;
140 self
141 }
142
143 pub fn with_limits(mut self, limits: Limits) -> Self {
145 self.required_limits = Some(limits);
146 self
147 }
148
149 pub fn with_debug(mut self, debug: bool) -> Self {
151 self.debug = debug;
152 self
153 }
154
155 pub fn with_label(mut self, label: impl Into<String>) -> Self {
157 self.label = Some(label.into());
158 self
159 }
160}
161
162#[derive(Clone)]
168pub struct GpuContext {
169 instance: Arc<Instance>,
171 adapter: Arc<Adapter>,
173 device: Arc<Device>,
175 queue: Arc<Queue>,
177 adapter_info: AdapterInfo,
179 limits: Limits,
181}
182
183impl GpuContext {
184 pub async fn new() -> GpuResult<Self> {
191 Self::with_config(GpuContextConfig::default()).await
192 }
193
194 pub async fn with_config(config: GpuContextConfig) -> GpuResult<Self> {
201 info!(
202 "Initializing GPU context with backend: {:?}",
203 config.backend
204 );
205
206 let instance = Instance::new(InstanceDescriptor {
208 backends: config.backend.to_backends(),
209 ..InstanceDescriptor::new_without_display_handle()
210 });
211
212 let adapter = Self::request_adapter(&instance, &config).await?;
214 let adapter_info = adapter.get_info();
215
216 info!(
217 "Selected GPU adapter: {} ({:?})",
218 adapter_info.name, adapter_info.backend
219 );
220 debug!("Adapter info: {:?}", adapter_info);
221
222 let adapter_limits = adapter.limits();
224 let limits = config
225 .required_limits
226 .unwrap_or_else(|| Self::default_limits(&adapter_limits));
227
228 if !Self::validate_limits(&limits, &adapter_limits) {
230 return Err(GpuError::device_request(format!(
231 "Requested limits exceed adapter capabilities: \
232 max_compute_workgroup_size_x: {} (adapter: {})",
233 limits.max_compute_workgroup_size_x, adapter_limits.max_compute_workgroup_size_x
234 )));
235 }
236
237 let (device, queue) = adapter
239 .request_device(&DeviceDescriptor {
240 label: config.label.as_deref(),
241 required_features: config.required_features,
242 required_limits: limits.clone(),
243 memory_hints: Default::default(),
244 experimental_features: Default::default(),
245 trace: Default::default(),
246 })
247 .await
248 .map_err(|e| GpuError::device_request(e.to_string()))?;
249
250 info!("GPU device created successfully");
251 debug!("Device limits: {:?}", limits);
252
253 Ok(Self {
254 instance: Arc::new(instance),
255 adapter: Arc::new(adapter),
256 device: Arc::new(device),
257 queue: Arc::new(queue),
258 adapter_info,
259 limits,
260 })
261 }
262
263 async fn request_adapter(instance: &Instance, config: &GpuContextConfig) -> GpuResult<Adapter> {
265 let adapter = instance
266 .request_adapter(&RequestAdapterOptions {
267 power_preference: config.power_preference.into(),
268 force_fallback_adapter: false,
269 compatible_surface: None,
270 })
271 .await;
272
273 adapter.map_err(|_| {
274 let backends = match config.backend {
275 BackendPreference::Auto => "Auto (PRIMARY)".to_string(),
276 BackendPreference::All => "All".to_string(),
277 backend => format!("{backend:?}"),
278 };
279 GpuError::no_adapter(backends)
280 })
281 }
282
283 fn default_limits(adapter_limits: &Limits) -> Limits {
285 Limits {
286 max_compute_workgroup_size_x: adapter_limits.max_compute_workgroup_size_x.min(256),
287 max_compute_workgroup_size_y: adapter_limits.max_compute_workgroup_size_y.min(256),
288 max_compute_workgroup_size_z: adapter_limits.max_compute_workgroup_size_z.min(64),
289 max_compute_invocations_per_workgroup: adapter_limits
290 .max_compute_invocations_per_workgroup
291 .min(256),
292 max_compute_workgroups_per_dimension: adapter_limits
293 .max_compute_workgroups_per_dimension,
294 ..Default::default()
295 }
296 }
297
298 fn validate_limits(requested: &Limits, adapter: &Limits) -> bool {
300 requested.max_compute_workgroup_size_x <= adapter.max_compute_workgroup_size_x
301 && requested.max_compute_workgroup_size_y <= adapter.max_compute_workgroup_size_y
302 && requested.max_compute_workgroup_size_z <= adapter.max_compute_workgroup_size_z
303 && requested.max_compute_invocations_per_workgroup
304 <= adapter.max_compute_invocations_per_workgroup
305 }
306
307 pub fn device(&self) -> &Device {
309 &self.device
310 }
311
312 pub fn queue(&self) -> &Queue {
314 &self.queue
315 }
316
317 pub fn adapter(&self) -> &Adapter {
319 &self.adapter
320 }
321
322 pub fn instance(&self) -> &Instance {
324 &self.instance
325 }
326
327 pub fn adapter_info(&self) -> &AdapterInfo {
329 &self.adapter_info
330 }
331
332 pub fn limits(&self) -> &Limits {
334 &self.limits
335 }
336
337 pub fn backend(&self) -> Backend {
339 self.adapter_info.backend
340 }
341
342 pub fn supports_feature(&self, feature: Features) -> bool {
344 self.device.features().contains(feature)
345 }
346
347 pub fn max_workgroup_size(&self) -> (u32, u32, u32) {
349 (
350 self.limits.max_compute_workgroup_size_x,
351 self.limits.max_compute_workgroup_size_y,
352 self.limits.max_compute_workgroup_size_z,
353 )
354 }
355
356 pub fn poll(&self, _wait: bool) {
360 }
363
364 pub fn is_valid(&self) -> bool {
366 self.device.create_buffer(&wgpu::BufferDescriptor {
368 label: Some("health_check"),
369 size: 4,
370 usage: wgpu::BufferUsages::UNIFORM,
371 mapped_at_creation: false,
372 });
373 true
374 }
375}
376
377impl std::fmt::Debug for GpuContext {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 f.debug_struct("GpuContext")
380 .field("adapter", &self.adapter_info.name)
381 .field("backend", &self.adapter_info.backend)
382 .field("device_type", &self.adapter_info.device_type)
383 .field("limits", &self.limits)
384 .finish()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[tokio::test]
393 async fn test_gpu_context_creation() {
394 match GpuContext::new().await {
396 Ok(ctx) => {
397 println!("GPU Context created: {:?}", ctx);
398 assert!(ctx.is_valid());
399 }
400 Err(e) => {
401 println!("GPU not available (expected in CI): {}", e);
402 }
403 }
404 }
405
406 #[tokio::test]
407 async fn test_backend_preference() {
408 let config = GpuContextConfig::new()
409 .with_backend(BackendPreference::platform_default())
410 .with_power_preference(GpuPowerPreference::HighPerformance);
411
412 match GpuContext::with_config(config).await {
413 Ok(ctx) => {
414 println!("Backend: {:?}", ctx.backend());
415 }
416 Err(e) => {
417 println!("GPU not available: {}", e);
418 }
419 }
420 }
421
422 #[test]
423 fn test_backend_conversion() {
424 assert_eq!(BackendPreference::Vulkan.to_backends(), Backends::VULKAN);
425 assert_eq!(BackendPreference::Metal.to_backends(), Backends::METAL);
426 assert_eq!(BackendPreference::DX12.to_backends(), Backends::DX12);
427 }
428
429 #[test]
430 fn test_platform_default() {
431 let default = BackendPreference::platform_default();
432 println!("Platform default backend: {:?}", default);
433
434 #[cfg(target_os = "macos")]
435 assert_eq!(default, BackendPreference::Metal);
436
437 #[cfg(target_os = "windows")]
438 assert_eq!(default, BackendPreference::DX12);
439
440 #[cfg(target_os = "linux")]
441 assert_eq!(default, BackendPreference::Vulkan);
442 }
443}