Skip to main content

oxigdal_gpu_advanced/multi_gpu/
mod.rs

1//! Multi-GPU orchestration and management.
2//!
3//! This module provides advanced multi-GPU support including:
4//! - Automatic GPU detection and selection
5//! - Load balancing across GPUs
6//! - Work stealing between GPUs
7//! - GPU affinity and pinning
8
9pub mod affinity;
10pub mod device_manager;
11pub mod load_balancer;
12pub mod sync;
13pub mod work_queue;
14
15use crate::error::{GpuAdvancedError, Result};
16use dashmap::DashMap;
17use parking_lot::RwLock;
18use std::sync::Arc;
19use wgpu::{Adapter, Device, Queue};
20
21/// GPU device information
22#[derive(Debug, Clone)]
23pub struct GpuDeviceInfo {
24    /// Device index
25    pub index: usize,
26    /// Device name
27    pub name: String,
28    /// Backend type (Vulkan, Metal, DX12, etc.)
29    pub backend: wgpu::Backend,
30    /// Device type (DiscreteGpu, IntegratedGpu, VirtualGpu, Cpu)
31    pub device_type: wgpu::DeviceType,
32    /// Maximum buffer size
33    pub max_buffer_size: u64,
34    /// Maximum texture dimension 1D
35    pub max_texture_dimension_1d: u32,
36    /// Maximum texture dimension 2D
37    pub max_texture_dimension_2d: u32,
38    /// Maximum texture dimension 3D
39    pub max_texture_dimension_3d: u32,
40    /// Maximum compute workgroup size X
41    pub max_compute_workgroup_size_x: u32,
42    /// Maximum compute workgroup size Y
43    pub max_compute_workgroup_size_y: u32,
44    /// Maximum compute workgroup size Z
45    pub max_compute_workgroup_size_z: u32,
46    /// Maximum compute workgroups per dimension
47    pub max_compute_workgroups_per_dimension: u32,
48    /// Maximum bind groups
49    pub max_bind_groups: u32,
50    /// Memory size (estimated)
51    pub memory_size: Option<u64>,
52}
53
54/// GPU device with associated resources
55pub struct GpuDevice {
56    /// Device info
57    pub info: GpuDeviceInfo,
58    /// WGPU adapter
59    pub adapter: Arc<Adapter>,
60    /// WGPU device
61    pub device: Arc<Device>,
62    /// WGPU queue
63    pub queue: Arc<Queue>,
64    /// Current memory usage
65    pub memory_usage: Arc<RwLock<u64>>,
66    /// Current workload (0.0 to 1.0)
67    pub workload: Arc<RwLock<f32>>,
68}
69
70impl GpuDevice {
71    /// Create a new GPU device
72    pub fn new(index: usize, adapter: Adapter, device: Device, queue: Queue) -> Result<Self> {
73        let info = adapter.get_info();
74        let limits = device.limits();
75
76        let device_info = GpuDeviceInfo {
77            index,
78            name: info.name.clone(),
79            backend: info.backend,
80            device_type: info.device_type,
81            max_buffer_size: limits.max_buffer_size,
82            max_texture_dimension_1d: limits.max_texture_dimension_1d,
83            max_texture_dimension_2d: limits.max_texture_dimension_2d,
84            max_texture_dimension_3d: limits.max_texture_dimension_3d,
85            max_compute_workgroup_size_x: limits.max_compute_workgroup_size_x,
86            max_compute_workgroup_size_y: limits.max_compute_workgroup_size_y,
87            max_compute_workgroup_size_z: limits.max_compute_workgroup_size_z,
88            max_compute_workgroups_per_dimension: limits.max_compute_workgroups_per_dimension,
89            max_bind_groups: limits.max_bind_groups,
90            memory_size: None, // Could be estimated from limits
91        };
92
93        Ok(Self {
94            info: device_info,
95            adapter: Arc::new(adapter),
96            device: Arc::new(device),
97            queue: Arc::new(queue),
98            memory_usage: Arc::new(RwLock::new(0)),
99            workload: Arc::new(RwLock::new(0.0)),
100        })
101    }
102
103    /// Get current memory usage
104    pub fn get_memory_usage(&self) -> u64 {
105        *self.memory_usage.read()
106    }
107
108    /// Update memory usage
109    pub fn update_memory_usage(&self, delta: i64) {
110        let mut usage = self.memory_usage.write();
111        if delta >= 0 {
112            *usage = usage.saturating_add(delta as u64);
113        } else {
114            *usage = usage.saturating_sub((-delta) as u64);
115        }
116    }
117
118    /// Get current workload
119    pub fn get_workload(&self) -> f32 {
120        *self.workload.read()
121    }
122
123    /// Set workload
124    pub fn set_workload(&self, workload: f32) {
125        *self.workload.write() = workload.clamp(0.0, 1.0);
126    }
127
128    /// Check if device is available (low workload)
129    pub fn is_available(&self) -> bool {
130        self.get_workload() < 0.8
131    }
132
133    /// Get device score for selection (higher is better)
134    pub fn get_score(&self) -> f32 {
135        // Score based on device type and current workload
136        let type_score = match self.info.device_type {
137            wgpu::DeviceType::DiscreteGpu => 1.0,
138            wgpu::DeviceType::IntegratedGpu => 0.7,
139            wgpu::DeviceType::VirtualGpu => 0.5,
140            wgpu::DeviceType::Cpu => 0.3,
141            wgpu::DeviceType::Other => 0.1,
142        };
143
144        let workload = self.get_workload();
145        type_score * (1.0 - workload)
146    }
147}
148
149/// Multi-GPU manager
150pub struct MultiGpuManager {
151    /// Available GPU devices
152    devices: Vec<Arc<GpuDevice>>,
153    /// Device selection strategy (reserved for future dynamic strategy switching)
154    #[allow(dead_code)]
155    strategy: SelectionStrategy,
156    /// Work queues per device
157    work_queues: DashMap<usize, Arc<work_queue::WorkQueue>>,
158    /// Load balancer
159    load_balancer: Arc<load_balancer::LoadBalancer>,
160}
161
162/// Device selection strategy
163#[derive(Debug, Clone, Copy)]
164pub enum SelectionStrategy {
165    /// Round-robin selection
166    RoundRobin,
167    /// Select least loaded device
168    LeastLoaded,
169    /// Select device with highest score
170    BestScore,
171    /// Affinity-based selection
172    Affinity,
173}
174
175impl MultiGpuManager {
176    /// Create a new multi-GPU manager
177    pub async fn new(strategy: SelectionStrategy) -> Result<Self> {
178        let devices = Self::enumerate_devices().await?;
179
180        if devices.is_empty() {
181            return Err(GpuAdvancedError::GpuNotFound(
182                "No compatible GPU devices found".to_string(),
183            ));
184        }
185
186        let work_queues = DashMap::new();
187        for device in &devices {
188            work_queues.insert(
189                device.info.index,
190                Arc::new(work_queue::WorkQueue::new(device.clone())),
191            );
192        }
193
194        let load_balancer = Arc::new(load_balancer::LoadBalancer::new(devices.clone(), strategy));
195
196        Ok(Self {
197            devices,
198            strategy,
199            work_queues,
200            load_balancer,
201        })
202    }
203
204    /// Enumerate all available GPU devices
205    async fn enumerate_devices() -> Result<Vec<Arc<GpuDevice>>> {
206        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
207            backends: wgpu::Backends::all(),
208            ..wgpu::InstanceDescriptor::new_without_display_handle()
209        });
210
211        let mut devices = Vec::new();
212        let mut index = 0;
213
214        // Try each backend separately to avoid potential hangs
215        for _backend in &[
216            wgpu::Backends::VULKAN,
217            wgpu::Backends::METAL,
218            wgpu::Backends::DX12,
219            wgpu::Backends::GL,
220        ] {
221            if let Ok(adapter) = instance
222                .request_adapter(&wgpu::RequestAdapterOptions {
223                    power_preference: wgpu::PowerPreference::HighPerformance,
224                    force_fallback_adapter: false,
225                    compatible_surface: None,
226                })
227                .await
228            {
229                let info = adapter.get_info();
230
231                // Skip CPU adapters by default
232                if info.device_type == wgpu::DeviceType::Cpu {
233                    continue;
234                }
235
236                // Skip if we already have this adapter (avoid duplicates)
237                if devices.iter().any(|d: &Arc<GpuDevice>| {
238                    let d_info = &d.info;
239                    d_info.name == info.name && d_info.backend == info.backend
240                }) {
241                    continue;
242                }
243
244                // Request device
245                let (device, queue) = match adapter
246                    .request_device(&wgpu::DeviceDescriptor {
247                        label: Some(&format!("GPU Device {}", index)),
248                        required_features: wgpu::Features::empty(),
249                        required_limits: wgpu::Limits::default(),
250                        memory_hints: wgpu::MemoryHints::Performance,
251                        experimental_features: wgpu::ExperimentalFeatures::disabled(),
252                        trace: wgpu::Trace::Off,
253                    })
254                    .await
255                {
256                    Ok((device, queue)) => (device, queue),
257                    Err(e) => {
258                        tracing::warn!("Failed to request device {}: {}", index, e);
259                        continue;
260                    }
261                };
262
263                let gpu_device = GpuDevice::new(index, adapter, device, queue)?;
264                devices.push(Arc::new(gpu_device));
265                index += 1;
266            }
267        }
268
269        Ok(devices)
270    }
271
272    /// Get total number of GPUs
273    pub fn gpu_count(&self) -> usize {
274        self.devices.len()
275    }
276
277    /// Get GPU by index
278    pub fn get_gpu(&self, index: usize) -> Result<Arc<GpuDevice>> {
279        self.devices
280            .get(index)
281            .cloned()
282            .ok_or(GpuAdvancedError::InvalidGpuIndex {
283                index,
284                total: self.devices.len(),
285            })
286    }
287
288    /// Get all GPUs
289    pub fn get_all_gpus(&self) -> &[Arc<GpuDevice>] {
290        &self.devices
291    }
292
293    /// Select best GPU for a task
294    pub fn select_gpu(&self) -> Result<Arc<GpuDevice>> {
295        self.load_balancer.select_device()
296    }
297
298    /// Select GPU with specific requirements
299    pub fn select_gpu_with_requirements(
300        &self,
301        min_memory: Option<u64>,
302        preferred_type: Option<wgpu::DeviceType>,
303    ) -> Result<Arc<GpuDevice>> {
304        let mut candidates: Vec<_> = self
305            .devices
306            .iter()
307            .filter(|device| {
308                if let Some(min_mem) = min_memory {
309                    if let Some(mem_size) = device.info.memory_size {
310                        if mem_size < min_mem {
311                            return false;
312                        }
313                    }
314                }
315
316                if let Some(pref_type) = preferred_type {
317                    if device.info.device_type != pref_type {
318                        return false;
319                    }
320                }
321
322                device.is_available()
323            })
324            .collect();
325
326        if candidates.is_empty() {
327            return Err(GpuAdvancedError::GpuNotFound(
328                "No GPU matching requirements".to_string(),
329            ));
330        }
331
332        // Sort by score
333        candidates.sort_by(|a, b| {
334            b.get_score()
335                .partial_cmp(&a.get_score())
336                .unwrap_or(std::cmp::Ordering::Equal)
337        });
338
339        Ok(candidates[0].clone())
340    }
341
342    /// Get work queue for a GPU
343    pub fn get_work_queue(&self, index: usize) -> Result<Arc<work_queue::WorkQueue>> {
344        self.work_queues
345            .get(&index)
346            .map(|q| q.clone())
347            .ok_or(GpuAdvancedError::InvalidGpuIndex {
348                index,
349                total: self.devices.len(),
350            })
351    }
352
353    /// Submit work to best available GPU
354    pub async fn submit_work<F, T>(&self, work: F) -> Result<T>
355    where
356        F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
357        T: Send + 'static,
358    {
359        let device = self.select_gpu()?;
360        let queue = self.get_work_queue(device.info.index)?;
361        queue.submit_work(work).await
362    }
363
364    /// Get load balancer
365    pub fn get_load_balancer(&self) -> Arc<load_balancer::LoadBalancer> {
366        self.load_balancer.clone()
367    }
368
369    /// Print GPU information
370    pub fn print_gpu_info(&self) {
371        println!("Multi-GPU Manager - {} devices found", self.devices.len());
372        for device in &self.devices {
373            println!(
374                "  GPU {}: {} ({:?}, {:?})",
375                device.info.index, device.info.name, device.info.backend, device.info.device_type
376            );
377            println!("    Max buffer size: {} bytes", device.info.max_buffer_size);
378            println!(
379                "    Max texture 2D: {}x{}",
380                device.info.max_texture_dimension_2d, device.info.max_texture_dimension_2d
381            );
382            println!(
383                "    Max workgroup size: {}x{}x{}",
384                device.info.max_compute_workgroup_size_x,
385                device.info.max_compute_workgroup_size_y,
386                device.info.max_compute_workgroup_size_z
387            );
388            println!(
389                "    Current workload: {:.1}%",
390                device.get_workload() * 100.0
391            );
392            println!("    Memory usage: {} bytes", device.get_memory_usage());
393        }
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[tokio::test]
402    async fn test_multi_gpu_manager_creation() {
403        let result = MultiGpuManager::new(SelectionStrategy::LeastLoaded).await;
404
405        // This might fail if no GPU is available, which is ok in CI
406        match result {
407            Ok(manager) => {
408                assert!(manager.gpu_count() > 0);
409                manager.print_gpu_info();
410            }
411            Err(e) => {
412                println!("No GPU available: {}", e);
413            }
414        }
415    }
416
417    #[tokio::test]
418    async fn test_gpu_selection() {
419        let result = MultiGpuManager::new(SelectionStrategy::BestScore).await;
420
421        if let Ok(manager) = result {
422            let gpu = manager.select_gpu();
423            assert!(gpu.is_ok());
424
425            if let Ok(gpu) = gpu {
426                println!("Selected GPU: {}", gpu.info.name);
427                assert!(gpu.get_score() >= 0.0);
428            }
429        }
430    }
431}