1use crate::{GpuDevice, GpuError, Result};
39use parking_lot::Mutex;
40use std::sync::Arc;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum LoadBalanceStrategy {
49 RoundRobin,
51 LeastLoaded,
53 WeightedCapacity,
55 AdaptiveThroughput,
57}
58
59impl Default for LoadBalanceStrategy {
60 fn default() -> Self {
61 Self::LeastLoaded
62 }
63}
64
65#[derive(Debug, Clone, Default)]
71pub struct DeviceStats {
72 pub frames_dispatched: u64,
74 pub frames_completed: u64,
76 pub frames_failed: u64,
78 pub ema_throughput_fps: f64,
80 pub queue_depth: u64,
82}
83
84impl DeviceStats {
85 pub fn update_ema(&mut self, fps: f64) {
87 const ALPHA: f64 = 0.1;
88 if self.ema_throughput_fps == 0.0 {
89 self.ema_throughput_fps = fps;
90 } else {
91 self.ema_throughput_fps = ALPHA * fps + (1.0 - ALPHA) * self.ema_throughput_fps;
92 }
93 }
94}
95
96pub struct DeviceSlot {
102 pub device: Arc<GpuDevice>,
104 pub weight: f32,
106 pub stats: Mutex<DeviceStats>,
108 pub index: usize,
110}
111
112impl DeviceSlot {
113 #[must_use]
115 pub fn new(device: Arc<GpuDevice>, index: usize, weight: f32) -> Self {
116 Self {
117 device,
118 weight: weight.max(0.01),
119 stats: Mutex::new(DeviceStats::default()),
120 index,
121 }
122 }
123
124 pub fn on_dispatch(&self) {
126 let mut s = self.stats.lock();
127 s.frames_dispatched += 1;
128 s.queue_depth += 1;
129 }
130
131 pub fn on_complete(&self, latency_secs: f64) {
133 let mut s = self.stats.lock();
134 s.frames_completed += 1;
135 s.queue_depth = s.queue_depth.saturating_sub(1);
136 if latency_secs > 0.0 {
137 s.update_ema(1.0 / latency_secs);
138 }
139 }
140
141 pub fn on_failure(&self) {
143 let mut s = self.stats.lock();
144 s.frames_failed += 1;
145 s.queue_depth = s.queue_depth.saturating_sub(1);
146 }
147
148 #[must_use]
150 pub fn queue_depth(&self) -> u64 {
151 self.stats.lock().queue_depth
152 }
153
154 #[must_use]
156 pub fn ema_throughput(&self) -> f64 {
157 self.stats.lock().ema_throughput_fps
158 }
159}
160
161pub struct MultiGpuScheduler {
172 slots: Vec<DeviceSlot>,
173 strategy: LoadBalanceStrategy,
174 rr_counter: Mutex<usize>,
175}
176
177impl MultiGpuScheduler {
178 pub fn new(devices: Vec<(Arc<GpuDevice>, f32)>, strategy: LoadBalanceStrategy) -> Result<Self> {
184 if devices.is_empty() {
185 return Err(GpuError::NotSupported(
186 "MultiGpuScheduler requires at least one device".to_string(),
187 ));
188 }
189 let slots = devices
190 .into_iter()
191 .enumerate()
192 .map(|(i, (dev, w))| DeviceSlot::new(dev, i, w))
193 .collect();
194 Ok(Self {
195 slots,
196 strategy,
197 rr_counter: Mutex::new(0),
198 })
199 }
200
201 pub fn equal_weight(devices: Vec<Arc<GpuDevice>>) -> Result<Self> {
208 Self::new(
209 devices.into_iter().map(|d| (d, 1.0)).collect(),
210 LoadBalanceStrategy::default(),
211 )
212 }
213
214 #[must_use]
216 pub fn device_count(&self) -> usize {
217 self.slots.len()
218 }
219
220 #[must_use]
223 pub fn select_device(&self) -> usize {
224 match self.strategy {
225 LoadBalanceStrategy::RoundRobin => self.select_round_robin(),
226 LoadBalanceStrategy::LeastLoaded => self.select_least_loaded(),
227 LoadBalanceStrategy::WeightedCapacity => self.select_weighted(),
228 LoadBalanceStrategy::AdaptiveThroughput => self.select_adaptive(),
229 }
230 }
231
232 pub fn dispatch<F, T>(&self, work_fn: F) -> Result<(T, usize)>
240 where
241 F: FnOnce(&GpuDevice) -> Result<T>,
242 {
243 let slot_idx = self.select_device();
244 let slot = &self.slots[slot_idx];
245
246 slot.on_dispatch();
247
248 let start = std::time::Instant::now();
249 match work_fn(&slot.device) {
250 Ok(result) => {
251 let elapsed = start.elapsed().as_secs_f64();
252 slot.on_complete(elapsed);
253 Ok((result, slot_idx))
254 }
255 Err(e) => {
256 slot.on_failure();
257 Err(e)
258 }
259 }
260 }
261
262 #[must_use]
264 pub fn device_stats(&self) -> Vec<DeviceStats> {
265 self.slots.iter().map(|s| s.stats.lock().clone()).collect()
266 }
267
268 #[must_use]
270 pub fn total_dispatched(&self) -> u64 {
271 self.slots
272 .iter()
273 .map(|s| s.stats.lock().frames_dispatched)
274 .sum()
275 }
276
277 #[must_use]
279 pub fn total_completed(&self) -> u64 {
280 self.slots
281 .iter()
282 .map(|s| s.stats.lock().frames_completed)
283 .sum()
284 }
285
286 #[must_use]
290 pub fn slot(&self, index: usize) -> Option<&DeviceSlot> {
291 self.slots.get(index)
292 }
293
294 fn select_round_robin(&self) -> usize {
297 let mut counter = self.rr_counter.lock();
298 let idx = *counter % self.slots.len();
299 *counter = counter.wrapping_add(1);
300 idx
301 }
302
303 fn select_least_loaded(&self) -> usize {
304 self.slots
305 .iter()
306 .enumerate()
307 .min_by_key(|(_, s)| s.queue_depth())
308 .map(|(i, _)| i)
309 .unwrap_or(0)
310 }
311
312 fn select_weighted(&self) -> usize {
313 let total_weight: f32 = self.slots.iter().map(|s| s.weight).sum();
316 if total_weight <= 0.0 {
317 return 0;
318 }
319
320 let mut best_idx = 0;
324 let mut best_score = f32::NEG_INFINITY;
325 for (i, slot) in self.slots.iter().enumerate() {
326 let depth = slot.queue_depth() as f32 + 1.0;
327 let score = slot.weight / (total_weight * depth);
328 if score > best_score {
329 best_score = score;
330 best_idx = i;
331 }
332 }
333 best_idx
334 }
335
336 fn select_adaptive(&self) -> usize {
337 self.slots
339 .iter()
340 .enumerate()
341 .max_by(|(_, a), (_, b)| {
342 let score_a = a.ema_throughput() / (a.queue_depth() as f64 + 1.0);
343 let score_b = b.ema_throughput() / (b.queue_depth() as f64 + 1.0);
344 score_a
345 .partial_cmp(&score_b)
346 .unwrap_or(std::cmp::Ordering::Equal)
347 })
348 .map(|(i, _)| i)
349 .unwrap_or(0)
350 }
351}
352
353pub fn distribute_frames<P, T, F>(
364 scheduler: &MultiGpuScheduler,
365 frames: &[P],
366 work_fn: F,
367) -> Vec<Result<T>>
368where
369 P: Send + Sync,
370 T: Send,
371 F: Fn(&GpuDevice, &P) -> Result<T> + Send + Sync,
372{
373 frames
374 .iter()
375 .map(|frame| {
376 scheduler
377 .dispatch(|dev| work_fn(dev, frame))
378 .map(|(result, _)| result)
379 })
380 .collect()
381}
382
383#[cfg(test)]
388mod tests {
389 use super::*;
390
391 fn make_scheduler(n: usize, strategy: LoadBalanceStrategy) -> MultiGpuScheduler {
393 let devices: Vec<(Arc<GpuDevice>, f32)> = (0..n)
394 .map(|_| {
395 let dev =
396 GpuDevice::new_fallback().expect("CPU fallback device unavailable in test");
397 (Arc::new(dev), 1.0)
398 })
399 .collect();
400 MultiGpuScheduler::new(devices, strategy).expect("scheduler creation failed")
401 }
402
403 #[test]
404 fn test_empty_device_list_is_error() {
405 let result = MultiGpuScheduler::new(vec![], LoadBalanceStrategy::RoundRobin);
406 assert!(result.is_err());
407 }
408
409 #[test]
410 fn test_single_device_always_selected() {
411 let sched = make_scheduler(1, LoadBalanceStrategy::RoundRobin);
412 for _ in 0..5 {
413 assert_eq!(sched.select_device(), 0);
414 }
415 }
416
417 #[test]
418 fn test_round_robin_cycles() {
419 let sched = make_scheduler(3, LoadBalanceStrategy::RoundRobin);
420 let selected: Vec<usize> = (0..6).map(|_| sched.select_device()).collect();
421 assert_eq!(selected, vec![0, 1, 2, 0, 1, 2]);
422 }
423
424 #[test]
425 fn test_least_loaded_prefers_idle() {
426 let sched = make_scheduler(3, LoadBalanceStrategy::LeastLoaded);
427 sched.slots[0].on_dispatch();
429 sched.slots[0].on_dispatch();
430 sched.slots[1].on_dispatch();
431 assert_eq!(sched.select_device(), 2);
433 }
434
435 #[test]
436 fn test_dispatch_records_stats() {
437 let sched = make_scheduler(1, LoadBalanceStrategy::RoundRobin);
438 let _ = sched.dispatch(|_dev| Ok::<u32, crate::GpuError>(42));
439 assert_eq!(sched.total_dispatched(), 1);
440 assert_eq!(sched.total_completed(), 1);
441 }
442
443 #[test]
444 fn test_dispatch_failure_recorded() {
445 let sched = make_scheduler(1, LoadBalanceStrategy::RoundRobin);
446 let _ = sched.dispatch(|_dev| {
447 Err::<u32, crate::GpuError>(GpuError::NotSupported("test".to_string()))
448 });
449 let stats = sched.device_stats();
450 assert_eq!(stats[0].frames_failed, 1);
451 assert_eq!(stats[0].queue_depth, 0);
452 }
453
454 #[test]
455 fn test_device_count() {
456 let sched = make_scheduler(4, LoadBalanceStrategy::LeastLoaded);
457 assert_eq!(sched.device_count(), 4);
458 }
459
460 #[test]
461 fn test_total_dispatched_sum() {
462 let sched = make_scheduler(3, LoadBalanceStrategy::RoundRobin);
463 for _ in 0..9 {
464 let _ = sched.dispatch(|_| Ok::<(), _>(()));
465 }
466 assert_eq!(sched.total_dispatched(), 9);
467 }
468
469 #[test]
470 fn test_weighted_selects_highest_weight() {
471 let mk = || Arc::new(GpuDevice::new_fallback().expect("CPU fallback unavailable in test"));
473 let devices: Vec<(Arc<GpuDevice>, f32)> = vec![(mk(), 1.0), (mk(), 1.0), (mk(), 10.0)];
474 let sched = MultiGpuScheduler::new(devices, LoadBalanceStrategy::WeightedCapacity)
475 .expect("create weighted scheduler");
476 assert_eq!(sched.select_device(), 2);
478 }
479
480 #[test]
481 fn test_adaptive_prefers_high_throughput() {
482 let sched = make_scheduler(3, LoadBalanceStrategy::AdaptiveThroughput);
483 sched.slots[1].on_dispatch();
485 sched.slots[1].on_complete(0.001); sched.slots[0].on_dispatch();
487 sched.slots[0].on_complete(0.1); assert_eq!(sched.select_device(), 1);
490 }
491
492 #[test]
493 fn test_distribute_frames_returns_results_in_order() {
494 let sched = make_scheduler(2, LoadBalanceStrategy::RoundRobin);
495 let frames = vec![1u32, 2, 3, 4, 5, 6];
496 let results = distribute_frames(&sched, &frames, |_dev, &frame| Ok(frame * 2));
497 let values: Vec<u32> = results
498 .into_iter()
499 .map(|r| r.expect("frame result"))
500 .collect();
501 assert_eq!(values, vec![2, 4, 6, 8, 10, 12]);
502 }
503
504 #[test]
505 fn test_device_stats_snapshot() {
506 let sched = make_scheduler(2, LoadBalanceStrategy::RoundRobin);
507 let _ = sched.dispatch(|_| Ok::<(), _>(()));
508 let _ = sched.dispatch(|_| Ok::<(), _>(()));
509 let stats = sched.device_stats();
510 assert_eq!(stats.len(), 2);
511 assert_eq!(stats[0].frames_dispatched, 1);
513 assert_eq!(stats[1].frames_dispatched, 1);
514 }
515
516 #[test]
517 fn test_device_ema_update() {
518 let mut s = DeviceStats::default();
519 s.update_ema(100.0);
520 assert!((s.ema_throughput_fps - 100.0).abs() < 1e-6);
521 s.update_ema(50.0);
522 assert!((s.ema_throughput_fps - 95.0).abs() < 1e-6);
524 }
525}