oxigdal_gpu_advanced/multi_gpu/
device_manager.rs1use crate::error::{GpuAdvancedError, Result};
4use wgpu::{Adapter, Backend, DeviceType};
5
6#[derive(Debug, Clone)]
8pub struct DeviceCapabilities {
9 pub compute: bool,
11 pub timestamp_query: bool,
13 pub pipeline_statistics: bool,
15 pub texture_compression_bc: bool,
17 pub texture_compression_etc2: bool,
19 pub texture_compression_astc: bool,
21 pub indirect_first_instance: bool,
23 pub shader_f16: bool,
25 pub push_constants: bool,
27 pub multi_draw_indirect: bool,
29 pub multi_draw_indirect_count: bool,
31}
32
33impl DeviceCapabilities {
34 pub fn from_adapter(adapter: &Adapter) -> Self {
36 let features = adapter.features();
37
38 Self {
39 compute: true, timestamp_query: features.contains(wgpu::Features::TIMESTAMP_QUERY),
41 pipeline_statistics: features.contains(wgpu::Features::PIPELINE_STATISTICS_QUERY),
42 texture_compression_bc: features.contains(wgpu::Features::TEXTURE_COMPRESSION_BC),
43 texture_compression_etc2: features.contains(wgpu::Features::TEXTURE_COMPRESSION_ETC2),
44 texture_compression_astc: features.contains(wgpu::Features::TEXTURE_COMPRESSION_ASTC),
45 indirect_first_instance: features.contains(wgpu::Features::INDIRECT_FIRST_INSTANCE),
46 shader_f16: features.contains(wgpu::Features::SHADER_F16),
47 push_constants: true,
50 multi_draw_indirect: features.contains(wgpu::Features::MULTI_DRAW_INDIRECT_COUNT),
52 multi_draw_indirect_count: features.contains(wgpu::Features::MULTI_DRAW_INDIRECT_COUNT),
53 }
54 }
55
56 pub fn supports_required_features(&self) -> bool {
58 self.compute
59 }
60
61 pub fn supports_optimal_features(&self) -> bool {
63 self.compute && self.timestamp_query && self.push_constants
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
69pub enum DevicePerformanceClass {
70 Low,
72 Medium,
74 High,
76 Extreme,
78}
79
80impl DevicePerformanceClass {
81 pub fn classify(device_type: DeviceType, _backend: Backend, limits: &wgpu::Limits) -> Self {
83 match device_type {
85 DeviceType::DiscreteGpu => {
86 if limits.max_compute_workgroup_size_x >= 1024
88 && limits.max_buffer_size >= 4_000_000_000
89 {
90 Self::Extreme
91 } else if limits.max_compute_workgroup_size_x >= 512 {
92 Self::High
93 } else {
94 Self::Medium
95 }
96 }
97 DeviceType::IntegratedGpu => {
98 if limits.max_compute_workgroup_size_x >= 512 {
99 Self::Medium
100 } else {
101 Self::Low
102 }
103 }
104 DeviceType::VirtualGpu => Self::Medium,
105 DeviceType::Cpu | DeviceType::Other => Self::Low,
106 }
107 }
108
109 pub fn workload_multiplier(&self) -> f32 {
111 match self {
112 Self::Extreme => 4.0,
113 Self::High => 2.0,
114 Self::Medium => 1.0,
115 Self::Low => 0.5,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct DeviceFilter {
123 pub min_performance: Option<DevicePerformanceClass>,
125 pub preferred_backend: Option<Backend>,
127 pub required_type: Option<DeviceType>,
129 pub min_memory: Option<u64>,
131 pub required_features: wgpu::Features,
133}
134
135impl Default for DeviceFilter {
136 fn default() -> Self {
137 Self {
138 min_performance: None,
139 preferred_backend: None,
140 required_type: None,
141 min_memory: None,
142 required_features: wgpu::Features::empty(),
143 }
144 }
145}
146
147impl DeviceFilter {
148 pub fn new() -> Self {
150 Self::default()
151 }
152
153 pub fn with_min_performance(mut self, perf: DevicePerformanceClass) -> Self {
155 self.min_performance = Some(perf);
156 self
157 }
158
159 pub fn with_preferred_backend(mut self, backend: Backend) -> Self {
161 self.preferred_backend = Some(backend);
162 self
163 }
164
165 pub fn with_required_type(mut self, device_type: DeviceType) -> Self {
167 self.required_type = Some(device_type);
168 self
169 }
170
171 pub fn with_min_memory(mut self, memory: u64) -> Self {
173 self.min_memory = Some(memory);
174 self
175 }
176
177 pub fn with_required_features(mut self, features: wgpu::Features) -> Self {
179 self.required_features = features;
180 self
181 }
182
183 pub fn matches(&self, adapter: &Adapter) -> bool {
185 let info = adapter.get_info();
186 let limits = adapter.limits();
187
188 if let Some(req_type) = self.required_type {
190 if info.device_type != req_type {
191 return false;
192 }
193 }
194
195 if let Some(pref_backend) = self.preferred_backend {
197 if info.backend != pref_backend {
198 return false;
199 }
200 }
201
202 if let Some(min_perf) = self.min_performance {
204 let perf_class =
205 DevicePerformanceClass::classify(info.device_type, info.backend, &limits);
206 if perf_class < min_perf {
207 return false;
208 }
209 }
210
211 if let Some(min_mem) = self.min_memory {
213 if limits.max_buffer_size < min_mem {
214 return false;
215 }
216 }
217
218 let features = adapter.features();
220 if !features.contains(self.required_features) {
221 return false;
222 }
223
224 true
225 }
226}
227
228pub struct DeviceManager {
230 instance: wgpu::Instance,
232}
233
234impl DeviceManager {
235 pub fn new() -> Self {
237 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
238 backends: wgpu::Backends::all(),
239 ..wgpu::InstanceDescriptor::new_without_display_handle()
240 });
241
242 Self { instance }
243 }
244
245 pub async fn enumerate_adapters(&self) -> Vec<Adapter> {
247 self.instance
248 .enumerate_adapters(wgpu::Backends::all())
249 .await
250 }
251
252 pub async fn enumerate_filtered(&self, filter: &DeviceFilter) -> Vec<Adapter> {
254 self.enumerate_adapters()
255 .await
256 .into_iter()
257 .filter(|adapter| filter.matches(adapter))
258 .collect()
259 }
260
261 pub async fn get_best_adapter(&self, filter: &DeviceFilter) -> Result<Adapter> {
263 let mut adapters = self.enumerate_filtered(filter).await;
264
265 if adapters.is_empty() {
266 return Err(GpuAdvancedError::GpuNotFound(
267 "No adapter matching filter".to_string(),
268 ));
269 }
270
271 adapters.sort_by(|a, b| {
273 let info_a = a.get_info();
274 let info_b = b.get_info();
275 let limits_a = a.limits();
276 let limits_b = b.limits();
277
278 let perf_a =
279 DevicePerformanceClass::classify(info_a.device_type, info_a.backend, &limits_a);
280 let perf_b =
281 DevicePerformanceClass::classify(info_b.device_type, info_b.backend, &limits_b);
282
283 perf_b.cmp(&perf_a)
284 });
285
286 adapters
287 .into_iter()
288 .next()
289 .ok_or_else(|| GpuAdvancedError::GpuNotFound("No adapter available".to_string()))
290 }
291
292 pub async fn print_device_info(&self) {
294 let adapters = self.enumerate_adapters().await;
295 println!("Found {} GPU adapter(s):", adapters.len());
296
297 for (i, adapter) in adapters.iter().enumerate() {
298 let info = adapter.get_info();
299 let limits = adapter.limits();
300 let features = adapter.features();
301 let caps = DeviceCapabilities::from_adapter(adapter);
302 let perf_class =
303 DevicePerformanceClass::classify(info.device_type, info.backend, &limits);
304
305 println!("\n Adapter {}:", i);
306 println!(" Name: {}", info.name);
307 println!(" Backend: {:?}", info.backend);
308 println!(" Device Type: {:?}", info.device_type);
309 println!(" Performance Class: {:?}", perf_class);
310 println!(" Max Buffer Size: {} bytes", limits.max_buffer_size);
311 println!(
312 " Max Texture 2D: {}x{}",
313 limits.max_texture_dimension_2d, limits.max_texture_dimension_2d
314 );
315 println!(
316 " Max Workgroup Size: {}x{}x{}",
317 limits.max_compute_workgroup_size_x,
318 limits.max_compute_workgroup_size_y,
319 limits.max_compute_workgroup_size_z
320 );
321 println!(" Features: {:?}", features.bits());
322 println!(" Timestamp Query: {}", caps.timestamp_query);
323 println!(" Push Constants: {}", caps.push_constants);
324 }
325 }
326}
327
328impl Default for DeviceManager {
329 fn default() -> Self {
330 Self::new()
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[tokio::test]
339 async fn test_device_manager_creation() {
340 let manager = DeviceManager::new();
341 let adapters = manager.enumerate_adapters().await;
342 println!("Found {} adapters", adapters.len());
343 }
344
345 #[tokio::test]
346 async fn test_device_filter() {
347 let filter = DeviceFilter::new().with_min_performance(DevicePerformanceClass::Low);
348
349 let manager = DeviceManager::new();
350 let adapters = manager.enumerate_filtered(&filter).await;
351 println!("Found {} filtered adapters", adapters.len());
352 }
353
354 #[tokio::test]
355 async fn test_performance_classification() {
356 let manager = DeviceManager::new();
357 let adapters = manager.enumerate_adapters().await;
358
359 for adapter in adapters {
360 let info = adapter.get_info();
361 let limits = adapter.limits();
362 let perf_class =
363 DevicePerformanceClass::classify(info.device_type, info.backend, &limits);
364 println!("{}: {:?}", info.name, perf_class);
365 }
366 }
367}