Skip to main content

oximedia_gpu/
multi_gpu.rs

1//! Multi-GPU load balancing with automatic frame distribution.
2//!
3//! This module provides a `MultiGpuScheduler` that distributes frames across
4//! all available GPU devices, performing automatic load balancing based on
5//! measured per-device throughput and real-time queue depth.
6//!
7//! # Architecture
8//!
9//! ```text
10//!                       ┌───────────────────────┐
11//!                       │  MultiGpuScheduler    │
12//!                       │  ─────────────────    │
13//!                       │  • device pool        │
14//!                       │  • load balancer      │
15//!                       │  • frame dispatcher   │
16//!                       └──────────┬────────────┘
17//!               ┌──────────────────┼────────────────────┐
18//!         ┌─────▼─────┐     ┌──────▼──────┐     ┌───────▼──────┐
19//!         │  GPU 0    │     │   GPU 1     │     │   GPU n …    │
20//!         │  worker   │     │   worker    │     │   worker     │
21//!         └───────────┘     └─────────────┘     └──────────────┘
22//! ```
23//!
24//! # Load-Balancing Strategies
25//!
26//! | Strategy           | Description                                            |
27//! |-------------------|--------------------------------------------------------|
28//! | `RoundRobin`       | Distribute frames in strict order across devices.      |
29//! | `LeastLoaded`      | Always assign to the device with the fewest queued frames. |
30//! | `WeightedCapacity` | Assign proportionally to a static device-weight table. |
31//! | `AdaptiveThroughput` | Track measured throughput and route to fastest device. |
32//!
33//! # Status
34//!
35//! GPU command dispatch is a stub (returns CPU-only results).  The scheduling
36//! logic and statistics are fully functional.
37
38use crate::{GpuDevice, GpuError, Result};
39use parking_lot::Mutex;
40use std::sync::Arc;
41
42// ─────────────────────────────────────────────────────────────────────────────
43// Load-balancing strategies
44// ─────────────────────────────────────────────────────────────────────────────
45
46/// Strategy used by [`MultiGpuScheduler`] to assign work to devices.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum LoadBalanceStrategy {
49    /// Strict round-robin assignment across all available devices.
50    RoundRobin,
51    /// Always assign to the device with the smallest pending queue depth.
52    LeastLoaded,
53    /// Assign proportionally to each device's `weight` in [`DeviceSlot`].
54    WeightedCapacity,
55    /// Dynamically measure throughput and prefer the fastest device.
56    AdaptiveThroughput,
57}
58
59impl Default for LoadBalanceStrategy {
60    fn default() -> Self {
61        Self::LeastLoaded
62    }
63}
64
65// ─────────────────────────────────────────────────────────────────────────────
66// Per-device statistics
67// ─────────────────────────────────────────────────────────────────────────────
68
69/// Runtime statistics for a single device slot.
70#[derive(Debug, Clone, Default)]
71pub struct DeviceStats {
72    /// Total frames dispatched to this device.
73    pub frames_dispatched: u64,
74    /// Total frames completed (successfully processed).
75    pub frames_completed: u64,
76    /// Total frames that failed.
77    pub frames_failed: u64,
78    /// Exponential moving-average throughput (frames / second).
79    pub ema_throughput_fps: f64,
80    /// Current pending queue depth (dispatched but not yet completed).
81    pub queue_depth: u64,
82}
83
84impl DeviceStats {
85    /// Update the EMA throughput given a new measurement (`fps`).
86    pub fn update_ema(&mut self, fps: f64) {
87        const ALPHA: f64 = 0.1;
88        if self.ema_throughput_fps == 0.0 {
89            self.ema_throughput_fps = fps;
90        } else {
91            self.ema_throughput_fps = ALPHA * fps + (1.0 - ALPHA) * self.ema_throughput_fps;
92        }
93    }
94}
95
96// ─────────────────────────────────────────────────────────────────────────────
97// Device slot
98// ─────────────────────────────────────────────────────────────────────────────
99
100/// A device slot held by the multi-GPU scheduler.
101pub struct DeviceSlot {
102    /// The GPU device.
103    pub device: Arc<GpuDevice>,
104    /// Static capacity weight (used by [`LoadBalanceStrategy::WeightedCapacity`]).
105    pub weight: f32,
106    /// Device-level statistics (protected by a mutex for multi-threaded access).
107    pub stats: Mutex<DeviceStats>,
108    /// Unique index assigned by the scheduler.
109    pub index: usize,
110}
111
112impl DeviceSlot {
113    /// Create a new device slot.
114    #[must_use]
115    pub fn new(device: Arc<GpuDevice>, index: usize, weight: f32) -> Self {
116        Self {
117            device,
118            weight: weight.max(0.01),
119            stats: Mutex::new(DeviceStats::default()),
120            index,
121        }
122    }
123
124    /// Record a dispatched frame.
125    pub fn on_dispatch(&self) {
126        let mut s = self.stats.lock();
127        s.frames_dispatched += 1;
128        s.queue_depth += 1;
129    }
130
131    /// Record a completed frame with the measured latency in seconds.
132    pub fn on_complete(&self, latency_secs: f64) {
133        let mut s = self.stats.lock();
134        s.frames_completed += 1;
135        s.queue_depth = s.queue_depth.saturating_sub(1);
136        if latency_secs > 0.0 {
137            s.update_ema(1.0 / latency_secs);
138        }
139    }
140
141    /// Record a failed frame.
142    pub fn on_failure(&self) {
143        let mut s = self.stats.lock();
144        s.frames_failed += 1;
145        s.queue_depth = s.queue_depth.saturating_sub(1);
146    }
147
148    /// Current queue depth (lock-free snapshot).
149    #[must_use]
150    pub fn queue_depth(&self) -> u64 {
151        self.stats.lock().queue_depth
152    }
153
154    /// Current EMA throughput (lock-free snapshot).
155    #[must_use]
156    pub fn ema_throughput(&self) -> f64 {
157        self.stats.lock().ema_throughput_fps
158    }
159}
160
161// ─────────────────────────────────────────────────────────────────────────────
162// MultiGpuScheduler
163// ─────────────────────────────────────────────────────────────────────────────
164
165/// Multi-GPU frame scheduler with configurable load-balancing.
166///
167/// # Thread Safety
168///
169/// `MultiGpuScheduler` is `Send + Sync` and can be shared across threads via
170/// `Arc`.  The internal round-robin counter is protected by a [`Mutex`].
171pub struct MultiGpuScheduler {
172    slots: Vec<DeviceSlot>,
173    strategy: LoadBalanceStrategy,
174    rr_counter: Mutex<usize>,
175}
176
177impl MultiGpuScheduler {
178    /// Create a new scheduler from a list of `(device, weight)` pairs.
179    ///
180    /// # Errors
181    ///
182    /// Returns `GpuError::NotSupported` if `devices` is empty.
183    pub fn new(devices: Vec<(Arc<GpuDevice>, f32)>, strategy: LoadBalanceStrategy) -> Result<Self> {
184        if devices.is_empty() {
185            return Err(GpuError::NotSupported(
186                "MultiGpuScheduler requires at least one device".to_string(),
187            ));
188        }
189        let slots = devices
190            .into_iter()
191            .enumerate()
192            .map(|(i, (dev, w))| DeviceSlot::new(dev, i, w))
193            .collect();
194        Ok(Self {
195            slots,
196            strategy,
197            rr_counter: Mutex::new(0),
198        })
199    }
200
201    /// Create a scheduler from a list of devices with equal weights using the
202    /// default `LeastLoaded` strategy.
203    ///
204    /// # Errors
205    ///
206    /// Returns an error if `devices` is empty.
207    pub fn equal_weight(devices: Vec<Arc<GpuDevice>>) -> Result<Self> {
208        Self::new(
209            devices.into_iter().map(|d| (d, 1.0)).collect(),
210            LoadBalanceStrategy::default(),
211        )
212    }
213
214    /// Number of devices managed by this scheduler.
215    #[must_use]
216    pub fn device_count(&self) -> usize {
217        self.slots.len()
218    }
219
220    /// Select the best device slot index for the next frame according to the
221    /// current strategy.
222    #[must_use]
223    pub fn select_device(&self) -> usize {
224        match self.strategy {
225            LoadBalanceStrategy::RoundRobin => self.select_round_robin(),
226            LoadBalanceStrategy::LeastLoaded => self.select_least_loaded(),
227            LoadBalanceStrategy::WeightedCapacity => self.select_weighted(),
228            LoadBalanceStrategy::AdaptiveThroughput => self.select_adaptive(),
229        }
230    }
231
232    /// Dispatch a frame to the best available device.
233    ///
234    /// Returns the index of the selected device slot.
235    ///
236    /// The `work_fn` closure receives the selected `GpuDevice` and performs
237    /// the actual GPU work.  On success the measured latency (in seconds) is
238    /// reported via `on_complete`; on failure `on_failure` is called.
239    pub fn dispatch<F, T>(&self, work_fn: F) -> Result<(T, usize)>
240    where
241        F: FnOnce(&GpuDevice) -> Result<T>,
242    {
243        let slot_idx = self.select_device();
244        let slot = &self.slots[slot_idx];
245
246        slot.on_dispatch();
247
248        let start = std::time::Instant::now();
249        match work_fn(&slot.device) {
250            Ok(result) => {
251                let elapsed = start.elapsed().as_secs_f64();
252                slot.on_complete(elapsed);
253                Ok((result, slot_idx))
254            }
255            Err(e) => {
256                slot.on_failure();
257                Err(e)
258            }
259        }
260    }
261
262    /// Get a snapshot of per-device statistics.
263    #[must_use]
264    pub fn device_stats(&self) -> Vec<DeviceStats> {
265        self.slots.iter().map(|s| s.stats.lock().clone()).collect()
266    }
267
268    /// Total frames dispatched across all devices.
269    #[must_use]
270    pub fn total_dispatched(&self) -> u64 {
271        self.slots
272            .iter()
273            .map(|s| s.stats.lock().frames_dispatched)
274            .sum()
275    }
276
277    /// Total frames completed (successfully) across all devices.
278    #[must_use]
279    pub fn total_completed(&self) -> u64 {
280        self.slots
281            .iter()
282            .map(|s| s.stats.lock().frames_completed)
283            .sum()
284    }
285
286    /// Get a reference to the device slot at `index`.
287    ///
288    /// Returns `None` if the index is out of range.
289    #[must_use]
290    pub fn slot(&self, index: usize) -> Option<&DeviceSlot> {
291        self.slots.get(index)
292    }
293
294    // ── Selection algorithms ─────────────────────────────────────────────────
295
296    fn select_round_robin(&self) -> usize {
297        let mut counter = self.rr_counter.lock();
298        let idx = *counter % self.slots.len();
299        *counter = counter.wrapping_add(1);
300        idx
301    }
302
303    fn select_least_loaded(&self) -> usize {
304        self.slots
305            .iter()
306            .enumerate()
307            .min_by_key(|(_, s)| s.queue_depth())
308            .map(|(i, _)| i)
309            .unwrap_or(0)
310    }
311
312    fn select_weighted(&self) -> usize {
313        // Weighted random selection: pick a threshold in [0, total_weight) and
314        // walk the slots accumulating weights.
315        let total_weight: f32 = self.slots.iter().map(|s| s.weight).sum();
316        if total_weight <= 0.0 {
317            return 0;
318        }
319
320        // Use a simple deterministic approximation (no randomness required for
321        // deterministic scheduling): find the slot whose cumulative weight
322        // share is largest relative to its queue depth.
323        let mut best_idx = 0;
324        let mut best_score = f32::NEG_INFINITY;
325        for (i, slot) in self.slots.iter().enumerate() {
326            let depth = slot.queue_depth() as f32 + 1.0;
327            let score = slot.weight / (total_weight * depth);
328            if score > best_score {
329                best_score = score;
330                best_idx = i;
331            }
332        }
333        best_idx
334    }
335
336    fn select_adaptive(&self) -> usize {
337        // Prefer devices with the highest EMA throughput; break ties by queue depth.
338        self.slots
339            .iter()
340            .enumerate()
341            .max_by(|(_, a), (_, b)| {
342                let score_a = a.ema_throughput() / (a.queue_depth() as f64 + 1.0);
343                let score_b = b.ema_throughput() / (b.queue_depth() as f64 + 1.0);
344                score_a
345                    .partial_cmp(&score_b)
346                    .unwrap_or(std::cmp::Ordering::Equal)
347            })
348            .map(|(i, _)| i)
349            .unwrap_or(0)
350    }
351}
352
353// ─────────────────────────────────────────────────────────────────────────────
354// Frame distribution helper
355// ─────────────────────────────────────────────────────────────────────────────
356
357/// High-level helper that distributes a batch of frames across devices.
358///
359/// `frames` is a slice of input payloads; `work_fn` is called once per frame
360/// with the selected device and the frame payload.
361///
362/// Returns a `Vec<Result<T>>` in the same order as `frames`.
363pub fn distribute_frames<P, T, F>(
364    scheduler: &MultiGpuScheduler,
365    frames: &[P],
366    work_fn: F,
367) -> Vec<Result<T>>
368where
369    P: Send + Sync,
370    T: Send,
371    F: Fn(&GpuDevice, &P) -> Result<T> + Send + Sync,
372{
373    frames
374        .iter()
375        .map(|frame| {
376            scheduler
377                .dispatch(|dev| work_fn(dev, frame))
378                .map(|(result, _)| result)
379        })
380        .collect()
381}
382
383// ─────────────────────────────────────────────────────────────────────────────
384// Tests
385// ─────────────────────────────────────────────────────────────────────────────
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    /// Create a lightweight mock scheduler that uses CPU-fallback devices.
392    fn make_scheduler(n: usize, strategy: LoadBalanceStrategy) -> MultiGpuScheduler {
393        let devices: Vec<(Arc<GpuDevice>, f32)> = (0..n)
394            .map(|_| {
395                let dev =
396                    GpuDevice::new_fallback().expect("CPU fallback device unavailable in test");
397                (Arc::new(dev), 1.0)
398            })
399            .collect();
400        MultiGpuScheduler::new(devices, strategy).expect("scheduler creation failed")
401    }
402
403    #[test]
404    fn test_empty_device_list_is_error() {
405        let result = MultiGpuScheduler::new(vec![], LoadBalanceStrategy::RoundRobin);
406        assert!(result.is_err());
407    }
408
409    #[test]
410    fn test_single_device_always_selected() {
411        let sched = make_scheduler(1, LoadBalanceStrategy::RoundRobin);
412        for _ in 0..5 {
413            assert_eq!(sched.select_device(), 0);
414        }
415    }
416
417    #[test]
418    fn test_round_robin_cycles() {
419        let sched = make_scheduler(3, LoadBalanceStrategy::RoundRobin);
420        let selected: Vec<usize> = (0..6).map(|_| sched.select_device()).collect();
421        assert_eq!(selected, vec![0, 1, 2, 0, 1, 2]);
422    }
423
424    #[test]
425    fn test_least_loaded_prefers_idle() {
426        let sched = make_scheduler(3, LoadBalanceStrategy::LeastLoaded);
427        // Manually add queue depth to slots 0 and 1.
428        sched.slots[0].on_dispatch();
429        sched.slots[0].on_dispatch();
430        sched.slots[1].on_dispatch();
431        // Slot 2 has depth 0 — should be selected.
432        assert_eq!(sched.select_device(), 2);
433    }
434
435    #[test]
436    fn test_dispatch_records_stats() {
437        let sched = make_scheduler(1, LoadBalanceStrategy::RoundRobin);
438        let _ = sched.dispatch(|_dev| Ok::<u32, crate::GpuError>(42));
439        assert_eq!(sched.total_dispatched(), 1);
440        assert_eq!(sched.total_completed(), 1);
441    }
442
443    #[test]
444    fn test_dispatch_failure_recorded() {
445        let sched = make_scheduler(1, LoadBalanceStrategy::RoundRobin);
446        let _ = sched.dispatch(|_dev| {
447            Err::<u32, crate::GpuError>(GpuError::NotSupported("test".to_string()))
448        });
449        let stats = sched.device_stats();
450        assert_eq!(stats[0].frames_failed, 1);
451        assert_eq!(stats[0].queue_depth, 0);
452    }
453
454    #[test]
455    fn test_device_count() {
456        let sched = make_scheduler(4, LoadBalanceStrategy::LeastLoaded);
457        assert_eq!(sched.device_count(), 4);
458    }
459
460    #[test]
461    fn test_total_dispatched_sum() {
462        let sched = make_scheduler(3, LoadBalanceStrategy::RoundRobin);
463        for _ in 0..9 {
464            let _ = sched.dispatch(|_| Ok::<(), _>(()));
465        }
466        assert_eq!(sched.total_dispatched(), 9);
467    }
468
469    #[test]
470    fn test_weighted_selects_highest_weight() {
471        // Give slot 2 a much higher weight.
472        let mk = || Arc::new(GpuDevice::new_fallback().expect("CPU fallback unavailable in test"));
473        let devices: Vec<(Arc<GpuDevice>, f32)> = vec![(mk(), 1.0), (mk(), 1.0), (mk(), 10.0)];
474        let sched = MultiGpuScheduler::new(devices, LoadBalanceStrategy::WeightedCapacity)
475            .expect("create weighted scheduler");
476        // Without any queue depth, the highest weight should win.
477        assert_eq!(sched.select_device(), 2);
478    }
479
480    #[test]
481    fn test_adaptive_prefers_high_throughput() {
482        let sched = make_scheduler(3, LoadBalanceStrategy::AdaptiveThroughput);
483        // Simulate device 1 completing frames quickly.
484        sched.slots[1].on_dispatch();
485        sched.slots[1].on_complete(0.001); // 1000 fps
486        sched.slots[0].on_dispatch();
487        sched.slots[0].on_complete(0.1); // 10 fps
488                                         // Device 1 should be selected next.
489        assert_eq!(sched.select_device(), 1);
490    }
491
492    #[test]
493    fn test_distribute_frames_returns_results_in_order() {
494        let sched = make_scheduler(2, LoadBalanceStrategy::RoundRobin);
495        let frames = vec![1u32, 2, 3, 4, 5, 6];
496        let results = distribute_frames(&sched, &frames, |_dev, &frame| Ok(frame * 2));
497        let values: Vec<u32> = results
498            .into_iter()
499            .map(|r| r.expect("frame result"))
500            .collect();
501        assert_eq!(values, vec![2, 4, 6, 8, 10, 12]);
502    }
503
504    #[test]
505    fn test_device_stats_snapshot() {
506        let sched = make_scheduler(2, LoadBalanceStrategy::RoundRobin);
507        let _ = sched.dispatch(|_| Ok::<(), _>(()));
508        let _ = sched.dispatch(|_| Ok::<(), _>(()));
509        let stats = sched.device_stats();
510        assert_eq!(stats.len(), 2);
511        // Round-robin: slot 0 gets frame 0, slot 1 gets frame 1.
512        assert_eq!(stats[0].frames_dispatched, 1);
513        assert_eq!(stats[1].frames_dispatched, 1);
514    }
515
516    #[test]
517    fn test_device_ema_update() {
518        let mut s = DeviceStats::default();
519        s.update_ema(100.0);
520        assert!((s.ema_throughput_fps - 100.0).abs() < 1e-6);
521        s.update_ema(50.0);
522        // EMA with alpha=0.1: 0.1*50 + 0.9*100 = 95
523        assert!((s.ema_throughput_fps - 95.0).abs() < 1e-6);
524    }
525}