Skip to main content

oxigdal_gpu/
context.rs

1//! GPU context management for OxiGDAL.
2//!
3//! This module handles WGPU device initialization, adapter selection,
4//! and resource management for GPU-accelerated operations.
5
6use crate::error::{GpuError, GpuResult};
7use std::sync::Arc;
8use tracing::{debug, info};
9use wgpu::{
10    Adapter, AdapterInfo, Backend, Backends, Device, DeviceDescriptor, Features, Instance,
11    InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
12};
13
14/// GPU backend preference for adapter selection.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum BackendPreference {
17    /// Prefer Vulkan backend (cross-platform, best performance on Linux/Windows).
18    Vulkan,
19    /// Prefer Metal backend (best performance on macOS/iOS).
20    Metal,
21    /// Prefer DX12 backend (best performance on Windows).
22    DX12,
23    /// Prefer WebGPU backend (for browser environments).
24    WebGPU,
25    /// Auto-select the best available backend for the platform.
26    Auto,
27    /// Try all available backends in order of preference.
28    All,
29}
30
31impl BackendPreference {
32    /// Convert to WGPU backends flags.
33    pub fn to_backends(&self) -> Backends {
34        match self {
35            Self::Vulkan => Backends::VULKAN,
36            Self::Metal => Backends::METAL,
37            Self::DX12 => Backends::DX12,
38            Self::WebGPU => Backends::BROWSER_WEBGPU,
39            Self::Auto => Backends::PRIMARY,
40            Self::All => Backends::all(),
41        }
42    }
43
44    /// Get platform-specific default backend.
45    pub fn platform_default() -> Self {
46        #[cfg(target_os = "macos")]
47        return Self::Metal;
48
49        #[cfg(target_os = "windows")]
50        return Self::DX12;
51
52        #[cfg(target_os = "linux")]
53        return Self::Vulkan;
54
55        #[cfg(target_arch = "wasm32")]
56        return Self::WebGPU;
57
58        #[cfg(not(any(
59            target_os = "macos",
60            target_os = "windows",
61            target_os = "linux",
62            target_arch = "wasm32"
63        )))]
64        return Self::Auto;
65    }
66}
67
68/// Power preference for GPU selection.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum GpuPowerPreference {
71    /// Prefer low power consumption (integrated GPU).
72    LowPower,
73    /// Prefer high performance (discrete GPU).
74    HighPerformance,
75    /// No preference, use system default.
76    Default,
77}
78
79impl From<GpuPowerPreference> for PowerPreference {
80    fn from(pref: GpuPowerPreference) -> Self {
81        match pref {
82            GpuPowerPreference::LowPower => PowerPreference::LowPower,
83            GpuPowerPreference::HighPerformance => PowerPreference::HighPerformance,
84            GpuPowerPreference::Default => PowerPreference::None,
85        }
86    }
87}
88
89/// Configuration for GPU context initialization.
90#[derive(Debug, Clone)]
91pub struct GpuContextConfig {
92    /// Backend preference for adapter selection.
93    pub backend: BackendPreference,
94    /// Power preference for GPU selection.
95    pub power_preference: GpuPowerPreference,
96    /// Required GPU features.
97    pub required_features: Features,
98    /// Required GPU limits.
99    pub required_limits: Option<Limits>,
100    /// Enable debug mode (validation layers).
101    pub debug: bool,
102    /// Label for the device (for debugging).
103    pub label: Option<String>,
104}
105
106impl Default for GpuContextConfig {
107    fn default() -> Self {
108        Self {
109            backend: BackendPreference::platform_default(),
110            power_preference: GpuPowerPreference::HighPerformance,
111            required_features: Features::empty(),
112            required_limits: None,
113            debug: cfg!(debug_assertions),
114            label: Some("OxiGDAL GPU Context".to_string()),
115        }
116    }
117}
118
119impl GpuContextConfig {
120    /// Create a new GPU context configuration.
121    pub fn new() -> Self {
122        Self::default()
123    }
124
125    /// Set backend preference.
126    pub fn with_backend(mut self, backend: BackendPreference) -> Self {
127        self.backend = backend;
128        self
129    }
130
131    /// Set power preference.
132    pub fn with_power_preference(mut self, power: GpuPowerPreference) -> Self {
133        self.power_preference = power;
134        self
135    }
136
137    /// Set required features.
138    pub fn with_features(mut self, features: Features) -> Self {
139        self.required_features = features;
140        self
141    }
142
143    /// Set required limits.
144    pub fn with_limits(mut self, limits: Limits) -> Self {
145        self.required_limits = Some(limits);
146        self
147    }
148
149    /// Enable debug mode.
150    pub fn with_debug(mut self, debug: bool) -> Self {
151        self.debug = debug;
152        self
153    }
154
155    /// Set device label.
156    pub fn with_label(mut self, label: impl Into<String>) -> Self {
157        self.label = Some(label.into());
158        self
159    }
160}
161
162/// GPU context holding device and queue.
163///
164/// This is the main entry point for GPU operations. It manages the WGPU
165/// device and queue, and provides methods for creating buffers, pipelines,
166/// and executing compute shaders.
167#[derive(Clone)]
168pub struct GpuContext {
169    /// WGPU instance.
170    instance: Arc<Instance>,
171    /// WGPU adapter.
172    adapter: Arc<Adapter>,
173    /// WGPU device.
174    device: Arc<Device>,
175    /// WGPU queue.
176    queue: Arc<Queue>,
177    /// Adapter information.
178    adapter_info: AdapterInfo,
179    /// Device limits.
180    limits: Limits,
181}
182
183impl GpuContext {
184    /// Create a new GPU context with default configuration.
185    ///
186    /// # Errors
187    ///
188    /// Returns an error if no suitable GPU adapter is found or device
189    /// request fails.
190    pub async fn new() -> GpuResult<Self> {
191        Self::with_config(GpuContextConfig::default()).await
192    }
193
194    /// Create a new GPU context with custom configuration.
195    ///
196    /// # Errors
197    ///
198    /// Returns an error if no suitable GPU adapter is found or device
199    /// request fails.
200    pub async fn with_config(config: GpuContextConfig) -> GpuResult<Self> {
201        info!(
202            "Initializing GPU context with backend: {:?}",
203            config.backend
204        );
205
206        // Create WGPU instance
207        let instance = Instance::new(InstanceDescriptor {
208            backends: config.backend.to_backends(),
209            ..InstanceDescriptor::new_without_display_handle()
210        });
211
212        // Request adapter
213        let adapter = Self::request_adapter(&instance, &config).await?;
214        let adapter_info = adapter.get_info();
215
216        info!(
217            "Selected GPU adapter: {} ({:?})",
218            adapter_info.name, adapter_info.backend
219        );
220        debug!("Adapter info: {:?}", adapter_info);
221
222        // Get adapter limits
223        let adapter_limits = adapter.limits();
224        let limits = config
225            .required_limits
226            .unwrap_or_else(|| Self::default_limits(&adapter_limits));
227
228        // Validate limits
229        if !Self::validate_limits(&limits, &adapter_limits) {
230            return Err(GpuError::device_request(format!(
231                "Requested limits exceed adapter capabilities: \
232                 max_compute_workgroup_size_x: {} (adapter: {})",
233                limits.max_compute_workgroup_size_x, adapter_limits.max_compute_workgroup_size_x
234            )));
235        }
236
237        // Request device
238        let (device, queue) = adapter
239            .request_device(&DeviceDescriptor {
240                label: config.label.as_deref(),
241                required_features: config.required_features,
242                required_limits: limits.clone(),
243                memory_hints: Default::default(),
244                experimental_features: Default::default(),
245                trace: Default::default(),
246            })
247            .await
248            .map_err(|e| GpuError::device_request(e.to_string()))?;
249
250        info!("GPU device created successfully");
251        debug!("Device limits: {:?}", limits);
252
253        Ok(Self {
254            instance: Arc::new(instance),
255            adapter: Arc::new(adapter),
256            device: Arc::new(device),
257            queue: Arc::new(queue),
258            adapter_info,
259            limits,
260        })
261    }
262
263    /// Request a suitable GPU adapter.
264    async fn request_adapter(instance: &Instance, config: &GpuContextConfig) -> GpuResult<Adapter> {
265        let adapter = instance
266            .request_adapter(&RequestAdapterOptions {
267                power_preference: config.power_preference.into(),
268                force_fallback_adapter: false,
269                compatible_surface: None,
270            })
271            .await;
272
273        adapter.map_err(|_| {
274            let backends = match config.backend {
275                BackendPreference::Auto => "Auto (PRIMARY)".to_string(),
276                BackendPreference::All => "All".to_string(),
277                backend => format!("{backend:?}"),
278            };
279            GpuError::no_adapter(backends)
280        })
281    }
282
283    /// Get default limits based on adapter capabilities.
284    fn default_limits(adapter_limits: &Limits) -> Limits {
285        Limits {
286            max_compute_workgroup_size_x: adapter_limits.max_compute_workgroup_size_x.min(256),
287            max_compute_workgroup_size_y: adapter_limits.max_compute_workgroup_size_y.min(256),
288            max_compute_workgroup_size_z: adapter_limits.max_compute_workgroup_size_z.min(64),
289            max_compute_invocations_per_workgroup: adapter_limits
290                .max_compute_invocations_per_workgroup
291                .min(256),
292            max_compute_workgroups_per_dimension: adapter_limits
293                .max_compute_workgroups_per_dimension,
294            ..Default::default()
295        }
296    }
297
298    /// Validate that requested limits don't exceed adapter capabilities.
299    fn validate_limits(requested: &Limits, adapter: &Limits) -> bool {
300        requested.max_compute_workgroup_size_x <= adapter.max_compute_workgroup_size_x
301            && requested.max_compute_workgroup_size_y <= adapter.max_compute_workgroup_size_y
302            && requested.max_compute_workgroup_size_z <= adapter.max_compute_workgroup_size_z
303            && requested.max_compute_invocations_per_workgroup
304                <= adapter.max_compute_invocations_per_workgroup
305    }
306
307    /// Get the WGPU device.
308    pub fn device(&self) -> &Device {
309        &self.device
310    }
311
312    /// Get the WGPU queue.
313    pub fn queue(&self) -> &Queue {
314        &self.queue
315    }
316
317    /// Get the WGPU adapter.
318    pub fn adapter(&self) -> &Adapter {
319        &self.adapter
320    }
321
322    /// Get the WGPU instance.
323    pub fn instance(&self) -> &Instance {
324        &self.instance
325    }
326
327    /// Get adapter information.
328    pub fn adapter_info(&self) -> &AdapterInfo {
329        &self.adapter_info
330    }
331
332    /// Get device limits.
333    pub fn limits(&self) -> &Limits {
334        &self.limits
335    }
336
337    /// Get the backend being used.
338    pub fn backend(&self) -> Backend {
339        self.adapter_info.backend
340    }
341
342    /// Check if the device supports a specific feature.
343    pub fn supports_feature(&self, feature: Features) -> bool {
344        self.device.features().contains(feature)
345    }
346
347    /// Get maximum workgroup size for compute shaders.
348    pub fn max_workgroup_size(&self) -> (u32, u32, u32) {
349        (
350            self.limits.max_compute_workgroup_size_x,
351            self.limits.max_compute_workgroup_size_y,
352            self.limits.max_compute_workgroup_size_z,
353        )
354    }
355
356    /// Poll the device for completed operations.
357    ///
358    /// This should be called periodically to process GPU operations.
359    pub fn poll(&self, _wait: bool) {
360        // wgpu 28 doesn't have explicit poll control, device polls automatically
361        // This method is kept for API compatibility
362    }
363
364    /// Check if the device is still valid.
365    pub fn is_valid(&self) -> bool {
366        // Try to create a small buffer as a health check
367        self.device.create_buffer(&wgpu::BufferDescriptor {
368            label: Some("health_check"),
369            size: 4,
370            usage: wgpu::BufferUsages::UNIFORM,
371            mapped_at_creation: false,
372        });
373        true
374    }
375}
376
377impl std::fmt::Debug for GpuContext {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        f.debug_struct("GpuContext")
380            .field("adapter", &self.adapter_info.name)
381            .field("backend", &self.adapter_info.backend)
382            .field("device_type", &self.adapter_info.device_type)
383            .field("limits", &self.limits)
384            .finish()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[tokio::test]
393    async fn test_gpu_context_creation() {
394        // This test will fail if no GPU is available, which is expected
395        match GpuContext::new().await {
396            Ok(ctx) => {
397                println!("GPU Context created: {:?}", ctx);
398                assert!(ctx.is_valid());
399            }
400            Err(e) => {
401                println!("GPU not available (expected in CI): {}", e);
402            }
403        }
404    }
405
406    #[tokio::test]
407    async fn test_backend_preference() {
408        let config = GpuContextConfig::new()
409            .with_backend(BackendPreference::platform_default())
410            .with_power_preference(GpuPowerPreference::HighPerformance);
411
412        match GpuContext::with_config(config).await {
413            Ok(ctx) => {
414                println!("Backend: {:?}", ctx.backend());
415            }
416            Err(e) => {
417                println!("GPU not available: {}", e);
418            }
419        }
420    }
421
422    #[test]
423    fn test_backend_conversion() {
424        assert_eq!(BackendPreference::Vulkan.to_backends(), Backends::VULKAN);
425        assert_eq!(BackendPreference::Metal.to_backends(), Backends::METAL);
426        assert_eq!(BackendPreference::DX12.to_backends(), Backends::DX12);
427    }
428
429    #[test]
430    fn test_platform_default() {
431        let default = BackendPreference::platform_default();
432        println!("Platform default backend: {:?}", default);
433
434        #[cfg(target_os = "macos")]
435        assert_eq!(default, BackendPreference::Metal);
436
437        #[cfg(target_os = "windows")]
438        assert_eq!(default, BackendPreference::DX12);
439
440        #[cfg(target_os = "linux")]
441        assert_eq!(default, BackendPreference::Vulkan);
442    }
443}