Skip to main content

oxigdal_gpu_advanced/multi_gpu/
device_manager.rs

1//! GPU device management and capabilities detection.
2
3use crate::error::{GpuAdvancedError, Result};
4use wgpu::{Adapter, Backend, DeviceType};
5
6/// Device capabilities and features
7#[derive(Debug, Clone)]
8pub struct DeviceCapabilities {
9    /// Supports compute shaders
10    pub compute: bool,
11    /// Supports timestamp queries
12    pub timestamp_query: bool,
13    /// Supports pipeline statistics
14    pub pipeline_statistics: bool,
15    /// Supports texture compression BC
16    pub texture_compression_bc: bool,
17    /// Supports texture compression ETC2
18    pub texture_compression_etc2: bool,
19    /// Supports texture compression ASTC
20    pub texture_compression_astc: bool,
21    /// Supports indirect first instance
22    pub indirect_first_instance: bool,
23    /// Supports shader f16
24    pub shader_f16: bool,
25    /// Supports push constants
26    pub push_constants: bool,
27    /// Supports multi draw indirect
28    pub multi_draw_indirect: bool,
29    /// Supports multi draw indirect count
30    pub multi_draw_indirect_count: bool,
31}
32
33impl DeviceCapabilities {
34    /// Detect capabilities from adapter
35    pub fn from_adapter(adapter: &Adapter) -> Self {
36        let features = adapter.features();
37
38        Self {
39            compute: true, // Always available in WGPU
40            timestamp_query: features.contains(wgpu::Features::TIMESTAMP_QUERY),
41            pipeline_statistics: features.contains(wgpu::Features::PIPELINE_STATISTICS_QUERY),
42            texture_compression_bc: features.contains(wgpu::Features::TEXTURE_COMPRESSION_BC),
43            texture_compression_etc2: features.contains(wgpu::Features::TEXTURE_COMPRESSION_ETC2),
44            texture_compression_astc: features.contains(wgpu::Features::TEXTURE_COMPRESSION_ASTC),
45            indirect_first_instance: features.contains(wgpu::Features::INDIRECT_FIRST_INSTANCE),
46            shader_f16: features.contains(wgpu::Features::SHADER_F16),
47            // Note: PUSH_CONSTANTS feature was removed in newer WGPU versions
48            // Push constants are now part of the core API
49            push_constants: true,
50            // MULTI_DRAW_INDIRECT was replaced with MULTI_DRAW_INDIRECT_COUNT
51            multi_draw_indirect: features.contains(wgpu::Features::MULTI_DRAW_INDIRECT_COUNT),
52            multi_draw_indirect_count: features.contains(wgpu::Features::MULTI_DRAW_INDIRECT_COUNT),
53        }
54    }
55
56    /// Check if device supports all required features
57    pub fn supports_required_features(&self) -> bool {
58        self.compute
59    }
60
61    /// Check if device supports optimal features
62    pub fn supports_optimal_features(&self) -> bool {
63        self.compute && self.timestamp_query && self.push_constants
64    }
65}
66
67/// Device performance class
68#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
69pub enum DevicePerformanceClass {
70    /// Low-end device (integrated GPU, mobile)
71    Low,
72    /// Mid-range device (mainstream discrete GPU)
73    Medium,
74    /// High-end device (high-performance discrete GPU)
75    High,
76    /// Extreme performance (datacenter GPU)
77    Extreme,
78}
79
80impl DevicePerformanceClass {
81    /// Classify device based on characteristics
82    pub fn classify(device_type: DeviceType, _backend: Backend, limits: &wgpu::Limits) -> Self {
83        // Simple heuristic based on device type and limits
84        match device_type {
85            DeviceType::DiscreteGpu => {
86                // Check memory and compute capabilities
87                if limits.max_compute_workgroup_size_x >= 1024
88                    && limits.max_buffer_size >= 4_000_000_000
89                {
90                    Self::Extreme
91                } else if limits.max_compute_workgroup_size_x >= 512 {
92                    Self::High
93                } else {
94                    Self::Medium
95                }
96            }
97            DeviceType::IntegratedGpu => {
98                if limits.max_compute_workgroup_size_x >= 512 {
99                    Self::Medium
100                } else {
101                    Self::Low
102                }
103            }
104            DeviceType::VirtualGpu => Self::Medium,
105            DeviceType::Cpu | DeviceType::Other => Self::Low,
106        }
107    }
108
109    /// Get workload multiplier for this performance class
110    pub fn workload_multiplier(&self) -> f32 {
111        match self {
112            Self::Extreme => 4.0,
113            Self::High => 2.0,
114            Self::Medium => 1.0,
115            Self::Low => 0.5,
116        }
117    }
118}
119
120/// Device filter for selection
121#[derive(Debug, Clone)]
122pub struct DeviceFilter {
123    /// Minimum performance class
124    pub min_performance: Option<DevicePerformanceClass>,
125    /// Preferred backend
126    pub preferred_backend: Option<Backend>,
127    /// Required device type
128    pub required_type: Option<DeviceType>,
129    /// Minimum memory
130    pub min_memory: Option<u64>,
131    /// Required features
132    pub required_features: wgpu::Features,
133}
134
135impl Default for DeviceFilter {
136    fn default() -> Self {
137        Self {
138            min_performance: None,
139            preferred_backend: None,
140            required_type: None,
141            min_memory: None,
142            required_features: wgpu::Features::empty(),
143        }
144    }
145}
146
147impl DeviceFilter {
148    /// Create a new device filter
149    pub fn new() -> Self {
150        Self::default()
151    }
152
153    /// Set minimum performance class
154    pub fn with_min_performance(mut self, perf: DevicePerformanceClass) -> Self {
155        self.min_performance = Some(perf);
156        self
157    }
158
159    /// Set preferred backend
160    pub fn with_preferred_backend(mut self, backend: Backend) -> Self {
161        self.preferred_backend = Some(backend);
162        self
163    }
164
165    /// Set required device type
166    pub fn with_required_type(mut self, device_type: DeviceType) -> Self {
167        self.required_type = Some(device_type);
168        self
169    }
170
171    /// Set minimum memory
172    pub fn with_min_memory(mut self, memory: u64) -> Self {
173        self.min_memory = Some(memory);
174        self
175    }
176
177    /// Set required features
178    pub fn with_required_features(mut self, features: wgpu::Features) -> Self {
179        self.required_features = features;
180        self
181    }
182
183    /// Check if adapter matches filter
184    pub fn matches(&self, adapter: &Adapter) -> bool {
185        let info = adapter.get_info();
186        let limits = adapter.limits();
187
188        // Check device type
189        if let Some(req_type) = self.required_type {
190            if info.device_type != req_type {
191                return false;
192            }
193        }
194
195        // Check backend
196        if let Some(pref_backend) = self.preferred_backend {
197            if info.backend != pref_backend {
198                return false;
199            }
200        }
201
202        // Check performance class
203        if let Some(min_perf) = self.min_performance {
204            let perf_class =
205                DevicePerformanceClass::classify(info.device_type, info.backend, &limits);
206            if perf_class < min_perf {
207                return false;
208            }
209        }
210
211        // Check minimum memory
212        if let Some(min_mem) = self.min_memory {
213            if limits.max_buffer_size < min_mem {
214                return false;
215            }
216        }
217
218        // Check features
219        let features = adapter.features();
220        if !features.contains(self.required_features) {
221            return false;
222        }
223
224        true
225    }
226}
227
228/// Device manager for GPU enumeration and selection
229pub struct DeviceManager {
230    /// WGPU instance
231    instance: wgpu::Instance,
232}
233
234impl DeviceManager {
235    /// Create a new device manager
236    pub fn new() -> Self {
237        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
238            backends: wgpu::Backends::all(),
239            ..wgpu::InstanceDescriptor::new_without_display_handle()
240        });
241
242        Self { instance }
243    }
244
245    /// Enumerate all adapters
246    pub async fn enumerate_adapters(&self) -> Vec<Adapter> {
247        self.instance
248            .enumerate_adapters(wgpu::Backends::all())
249            .await
250    }
251
252    /// Enumerate adapters matching filter
253    pub async fn enumerate_filtered(&self, filter: &DeviceFilter) -> Vec<Adapter> {
254        self.enumerate_adapters()
255            .await
256            .into_iter()
257            .filter(|adapter| filter.matches(adapter))
258            .collect()
259    }
260
261    /// Get best adapter matching filter
262    pub async fn get_best_adapter(&self, filter: &DeviceFilter) -> Result<Adapter> {
263        let mut adapters = self.enumerate_filtered(filter).await;
264
265        if adapters.is_empty() {
266            return Err(GpuAdvancedError::GpuNotFound(
267                "No adapter matching filter".to_string(),
268            ));
269        }
270
271        // Sort by performance class and backend
272        adapters.sort_by(|a, b| {
273            let info_a = a.get_info();
274            let info_b = b.get_info();
275            let limits_a = a.limits();
276            let limits_b = b.limits();
277
278            let perf_a =
279                DevicePerformanceClass::classify(info_a.device_type, info_a.backend, &limits_a);
280            let perf_b =
281                DevicePerformanceClass::classify(info_b.device_type, info_b.backend, &limits_b);
282
283            perf_b.cmp(&perf_a)
284        });
285
286        adapters
287            .into_iter()
288            .next()
289            .ok_or_else(|| GpuAdvancedError::GpuNotFound("No adapter available".to_string()))
290    }
291
292    /// Print device information
293    pub async fn print_device_info(&self) {
294        let adapters = self.enumerate_adapters().await;
295        println!("Found {} GPU adapter(s):", adapters.len());
296
297        for (i, adapter) in adapters.iter().enumerate() {
298            let info = adapter.get_info();
299            let limits = adapter.limits();
300            let features = adapter.features();
301            let caps = DeviceCapabilities::from_adapter(adapter);
302            let perf_class =
303                DevicePerformanceClass::classify(info.device_type, info.backend, &limits);
304
305            println!("\n  Adapter {}:", i);
306            println!("    Name: {}", info.name);
307            println!("    Backend: {:?}", info.backend);
308            println!("    Device Type: {:?}", info.device_type);
309            println!("    Performance Class: {:?}", perf_class);
310            println!("    Max Buffer Size: {} bytes", limits.max_buffer_size);
311            println!(
312                "    Max Texture 2D: {}x{}",
313                limits.max_texture_dimension_2d, limits.max_texture_dimension_2d
314            );
315            println!(
316                "    Max Workgroup Size: {}x{}x{}",
317                limits.max_compute_workgroup_size_x,
318                limits.max_compute_workgroup_size_y,
319                limits.max_compute_workgroup_size_z
320            );
321            println!("    Features: {:?}", features.bits());
322            println!("    Timestamp Query: {}", caps.timestamp_query);
323            println!("    Push Constants: {}", caps.push_constants);
324        }
325    }
326}
327
328impl Default for DeviceManager {
329    fn default() -> Self {
330        Self::new()
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[tokio::test]
339    async fn test_device_manager_creation() {
340        let manager = DeviceManager::new();
341        let adapters = manager.enumerate_adapters().await;
342        println!("Found {} adapters", adapters.len());
343    }
344
345    #[tokio::test]
346    async fn test_device_filter() {
347        let filter = DeviceFilter::new().with_min_performance(DevicePerformanceClass::Low);
348
349        let manager = DeviceManager::new();
350        let adapters = manager.enumerate_filtered(&filter).await;
351        println!("Found {} filtered adapters", adapters.len());
352    }
353
354    #[tokio::test]
355    async fn test_performance_classification() {
356        let manager = DeviceManager::new();
357        let adapters = manager.enumerate_adapters().await;
358
359        for adapter in adapters {
360            let info = adapter.get_info();
361            let limits = adapter.limits();
362            let perf_class =
363                DevicePerformanceClass::classify(info.device_type, info.backend, &limits);
364            println!("{}: {:?}", info.name, perf_class);
365        }
366    }
367}