1use anyhow::{anyhow, Result};
15use parking_lot::Mutex;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tracing::{debug, info};
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30pub struct SimpleGpuDevice {
31 pub id: u32,
33 pub name: String,
35 pub memory_mb: u64,
37 pub compute_units: u32,
39}
40
41impl SimpleGpuDevice {
42 pub fn new(id: u32, name: impl Into<String>, memory_mb: u64, compute_units: u32) -> Self {
44 Self {
45 id,
46 name: name.into(),
47 memory_mb,
48 compute_units,
49 }
50 }
51}
52
53#[derive(Debug)]
58struct DeviceState {
59 device: SimpleGpuDevice,
60 current_workload_mb: u64,
62}
63
64impl DeviceState {
65 fn new(device: SimpleGpuDevice) -> Self {
66 Self {
67 device,
68 current_workload_mb: 0,
69 }
70 }
71
72 fn utilization(&self) -> f64 {
74 if self.device.memory_mb == 0 {
75 return 0.0;
76 }
77 (self.current_workload_mb as f64 / self.device.memory_mb as f64).min(1.0)
78 }
79}
80
81#[derive(Debug, Clone)]
104pub struct GpuLoadBalancer {
105 inner: Arc<Mutex<GpuLoadBalancerInner>>,
106}
107
108#[derive(Debug)]
109struct GpuLoadBalancerInner {
110 device_order: Vec<u32>,
112 states: HashMap<u32, DeviceState>,
114}
115
116impl GpuLoadBalancerInner {
117 fn new() -> Self {
118 Self {
119 device_order: Vec::new(),
120 states: HashMap::new(),
121 }
122 }
123}
124
125impl Default for GpuLoadBalancer {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl GpuLoadBalancer {
132 pub fn new() -> Self {
134 Self {
135 inner: Arc::new(Mutex::new(GpuLoadBalancerInner::new())),
136 }
137 }
138
139 pub fn register_device(&self, device: SimpleGpuDevice) {
142 let mut g = self.inner.lock();
143 let id = device.id;
144 info!("Registering GPU device {} ({})", id, device.name);
145 if !g.device_order.contains(&id) {
146 g.device_order.push(id);
147 }
148 g.states.insert(id, DeviceState::new(device));
149 }
150
151 pub fn unregister_device(&self, device_id: u32) {
153 let mut g = self.inner.lock();
154 g.device_order.retain(|&x| x != device_id);
155 g.states.remove(&device_id);
156 debug!("Unregistered GPU device {}", device_id);
157 }
158
159 pub fn select_device(&self, workload_mb: u64) -> Option<u32> {
165 let g = self.inner.lock();
166 g.device_order
167 .iter()
168 .filter_map(|&id| g.states.get(&id).map(|s| (id, s)))
169 .filter(|(_, s)| {
170 s.device.memory_mb.saturating_sub(s.current_workload_mb) >= workload_mb
171 })
172 .min_by(|(_, a), (_, b)| {
173 a.utilization()
174 .partial_cmp(&b.utilization())
175 .unwrap_or(std::cmp::Ordering::Equal)
176 })
177 .map(|(id, _)| id)
178 }
179
180 pub fn record_workload(&self, device_id: u32, mb: u64) -> Result<()> {
184 let mut g = self.inner.lock();
185 let state = g
186 .states
187 .get_mut(&device_id)
188 .ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
189 state.current_workload_mb += mb;
190 debug!(
191 "Device {}: workload {} MB (util {:.1}%)",
192 device_id,
193 state.current_workload_mb,
194 state.utilization() * 100.0
195 );
196 Ok(())
197 }
198
199 pub fn release_workload(&self, device_id: u32, mb: u64) -> Result<()> {
204 let mut g = self.inner.lock();
205 let state = g
206 .states
207 .get_mut(&device_id)
208 .ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
209 state.current_workload_mb = state.current_workload_mb.saturating_sub(mb);
210 debug!(
211 "Device {}: released {} MB, now {} MB",
212 device_id, mb, state.current_workload_mb
213 );
214 Ok(())
215 }
216
217 pub fn utilization(&self, device_id: u32) -> Option<f64> {
221 let g = self.inner.lock();
222 g.states.get(&device_id).map(|s| s.utilization())
223 }
224
225 pub fn total_capacity_mb(&self) -> u64 {
227 let g = self.inner.lock();
228 g.states.values().map(|s| s.device.memory_mb).sum()
229 }
230
231 pub fn device_count(&self) -> usize {
233 self.inner.lock().states.len()
234 }
235
236 pub fn utilization_snapshot(&self) -> Vec<(u32, f64)> {
238 let g = self.inner.lock();
239 g.device_order
240 .iter()
241 .filter_map(|&id| g.states.get(&id).map(|s| (id, s.utilization())))
242 .collect()
243 }
244}
245
246#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
252pub struct WorkloadChunk {
253 pub device_id: u32,
255 pub start_idx: usize,
257 pub end_idx: usize,
259}
260
261impl WorkloadChunk {
262 pub fn len(&self) -> usize {
264 self.end_idx.saturating_sub(self.start_idx)
265 }
266
267 pub fn is_empty(&self) -> bool {
269 self.len() == 0
270 }
271}
272
273#[derive(Debug, Clone, Default)]
283pub struct WorkloadDistributor;
284
285impl WorkloadDistributor {
286 pub fn new() -> Self {
288 Self
289 }
290
291 pub fn distribute(
301 &self,
302 total_vectors: usize,
303 devices: &[SimpleGpuDevice],
304 ) -> Result<Vec<WorkloadChunk>> {
305 let eligible: Vec<&SimpleGpuDevice> = devices.iter().filter(|d| d.memory_mb > 0).collect();
306
307 if eligible.is_empty() {
308 return Err(anyhow!(
309 "No eligible GPU devices (all have zero memory or list is empty)"
310 ));
311 }
312
313 let total_mem: u64 = eligible.iter().map(|d| d.memory_mb).sum();
314
315 let mut chunks: Vec<WorkloadChunk> = Vec::with_capacity(eligible.len());
316 let mut assigned = 0usize;
317
318 for (i, device) in eligible.iter().enumerate() {
319 let start_idx = assigned;
320 let end_idx = if i == eligible.len() - 1 {
321 total_vectors
323 } else {
324 let fraction = device.memory_mb as f64 / total_mem as f64;
325 let count = (total_vectors as f64 * fraction).round() as usize;
326 (assigned + count).min(total_vectors)
327 };
328
329 chunks.push(WorkloadChunk {
330 device_id: device.id,
331 start_idx,
332 end_idx,
333 });
334 assigned = end_idx;
335
336 if assigned >= total_vectors {
337 break;
338 }
339 }
340
341 Ok(chunks)
342 }
343
344 pub fn distribute_even(
349 &self,
350 total_vectors: usize,
351 devices: &[SimpleGpuDevice],
352 ) -> Result<Vec<WorkloadChunk>> {
353 if devices.is_empty() {
354 return Err(anyhow!("Cannot distribute across zero devices"));
355 }
356
357 let n = devices.len();
358 let base = total_vectors / n;
359 let remainder = total_vectors % n;
360
361 let mut chunks = Vec::with_capacity(n);
362 let mut start = 0;
363
364 for (i, device) in devices.iter().enumerate() {
365 let extra = if i < remainder { 1 } else { 0 };
366 let end = start + base + extra;
367 chunks.push(WorkloadChunk {
368 device_id: device.id,
369 start_idx: start,
370 end_idx: end,
371 });
372 start = end;
373 }
374
375 Ok(chunks)
376 }
377}
378
379#[cfg(test)]
384mod tests {
385 use super::*;
386 use anyhow::Result;
387
388 fn make_device(id: u32, mem_mb: u64) -> SimpleGpuDevice {
389 SimpleGpuDevice::new(id, format!("GPU-{}", id), mem_mb, 128)
390 }
391
392 #[test]
395 fn test_simple_gpu_device_fields() {
396 let d = SimpleGpuDevice::new(0, "TestGPU", 8192, 128);
397 assert_eq!(d.id, 0);
398 assert_eq!(d.name, "TestGPU");
399 assert_eq!(d.memory_mb, 8192);
400 assert_eq!(d.compute_units, 128);
401 }
402
403 #[test]
406 fn test_register_device_count() {
407 let lb = GpuLoadBalancer::new();
408 lb.register_device(make_device(0, 8192));
409 lb.register_device(make_device(1, 16384));
410 assert_eq!(lb.device_count(), 2);
411 }
412
413 #[test]
414 fn test_total_capacity_mb() {
415 let lb = GpuLoadBalancer::new();
416 lb.register_device(make_device(0, 4096));
417 lb.register_device(make_device(1, 8192));
418 assert_eq!(lb.total_capacity_mb(), 12288);
419 }
420
421 #[test]
422 fn test_select_device_empty_returns_none() {
423 let lb = GpuLoadBalancer::new();
424 assert!(lb.select_device(100).is_none());
425 }
426
427 #[test]
428 fn test_select_device_single() {
429 let lb = GpuLoadBalancer::new();
430 lb.register_device(make_device(0, 8192));
431 let sel = lb.select_device(100);
432 assert_eq!(sel, Some(0));
433 }
434
435 #[test]
436 fn test_select_device_insufficient_memory() {
437 let lb = GpuLoadBalancer::new();
438 lb.register_device(make_device(0, 100)); assert!(lb.select_device(200).is_none());
441 }
442
443 #[test]
444 fn test_select_device_prefers_least_loaded() -> Result<()> {
445 let lb = GpuLoadBalancer::new();
446 lb.register_device(make_device(0, 8192));
447 lb.register_device(make_device(1, 8192));
448
449 lb.record_workload(0, 7000)?;
451
452 let sel = lb.select_device(500);
454 assert_eq!(sel, Some(1), "Should prefer the less-loaded device");
455 Ok(())
456 }
457
458 #[test]
459 fn test_record_and_release_workload() -> Result<()> {
460 let lb = GpuLoadBalancer::new();
461 lb.register_device(make_device(0, 8192));
462
463 lb.record_workload(0, 2048)?;
464 let u1 = lb.utilization(0).expect("utilization(0) was None");
465 assert!(
466 (u1 - 0.25).abs() < 1e-6,
467 "Expected 25% utilisation, got {}",
468 u1
469 );
470
471 lb.release_workload(0, 2048)?;
472 let u2 = lb.utilization(0).expect("utilization(0) was None");
473 assert!(u2 < 1e-9, "Expected 0% after release, got {}", u2);
474 Ok(())
475 }
476
477 #[test]
478 fn test_release_clamps_to_zero() -> Result<()> {
479 let lb = GpuLoadBalancer::new();
480 lb.register_device(make_device(0, 8192));
481 lb.record_workload(0, 100)?;
482 lb.release_workload(0, 9999)?;
484 let __val = lb.utilization(0).expect("utilization(0) was None");
485 assert_eq!(__val, 0.0);
486 Ok(())
487 }
488
489 #[test]
490 fn test_record_unknown_device_errors() {
491 let lb = GpuLoadBalancer::new();
492 assert!(lb.record_workload(99, 100).is_err());
493 }
494
495 #[test]
496 fn test_release_unknown_device_errors() {
497 let lb = GpuLoadBalancer::new();
498 assert!(lb.release_workload(99, 100).is_err());
499 }
500
501 #[test]
502 fn test_utilization_unknown_device_none() {
503 let lb = GpuLoadBalancer::new();
504 assert!(lb.utilization(42).is_none());
505 }
506
507 #[test]
508 fn test_utilization_snapshot() -> Result<()> {
509 let lb = GpuLoadBalancer::new();
510 lb.register_device(make_device(0, 8192));
511 lb.register_device(make_device(1, 4096));
512 lb.record_workload(0, 4096)?;
513 let snap = lb.utilization_snapshot();
514 assert_eq!(snap.len(), 2);
515 let u0 = snap
516 .iter()
517 .find(|(id, _)| *id == 0)
518 .map(|(_, u)| *u)
519 .expect("device 0 not in snapshot");
520 assert!((u0 - 0.5).abs() < 1e-6);
521 Ok(())
522 }
523
524 #[test]
525 fn test_unregister_device() {
526 let lb = GpuLoadBalancer::new();
527 lb.register_device(make_device(0, 8192));
528 lb.register_device(make_device(1, 8192));
529 lb.unregister_device(0);
530 assert_eq!(lb.device_count(), 1);
531 assert!(lb.utilization(0).is_none());
532 }
533
534 #[test]
535 fn test_reregister_device_resets_workload() -> Result<()> {
536 let lb = GpuLoadBalancer::new();
537 lb.register_device(make_device(0, 8192));
538 lb.record_workload(0, 4096)?;
539 lb.register_device(make_device(0, 8192));
541 let __val = lb.utilization(0).expect("utilization(0) should be present");
542 assert_eq!(__val, 0.0);
543 Ok(())
544 }
545
546 #[test]
549 fn test_workload_chunk_len() {
550 let chunk = WorkloadChunk {
551 device_id: 0,
552 start_idx: 10,
553 end_idx: 50,
554 };
555 assert_eq!(chunk.len(), 40);
556 }
557
558 #[test]
559 fn test_workload_chunk_is_empty() {
560 let chunk = WorkloadChunk {
561 device_id: 0,
562 start_idx: 5,
563 end_idx: 5,
564 };
565 assert!(chunk.is_empty());
566 }
567
568 #[test]
571 fn test_distribute_empty_devices_error() {
572 let dist = WorkloadDistributor::new();
573 assert!(dist.distribute(1000, &[]).is_err());
574 }
575
576 #[test]
577 fn test_distribute_single_device() -> Result<()> {
578 let dist = WorkloadDistributor::new();
579 let devices = vec![make_device(0, 8192)];
580 let chunks = dist.distribute(1000, &devices)?;
581 assert_eq!(chunks.len(), 1);
582 assert_eq!(chunks[0].start_idx, 0);
583 assert_eq!(chunks[0].end_idx, 1000);
584 Ok(())
585 }
586
587 #[test]
588 fn test_distribute_covers_all_vectors() -> Result<()> {
589 let dist = WorkloadDistributor::new();
590 let devices = vec![make_device(0, 4096), make_device(1, 8192)];
591 let chunks = dist.distribute(900, &devices)?;
592 let covered: usize = chunks.iter().map(|c| c.len()).sum();
593 assert_eq!(covered, 900, "All vectors must be covered");
594 Ok(())
595 }
596
597 #[test]
598 fn test_distribute_proportional_to_memory() -> Result<()> {
599 let dist = WorkloadDistributor::new();
600 let devices = vec![make_device(0, 1024), make_device(1, 3072)];
602 let chunks = dist.distribute(1000, &devices)?;
603 assert_eq!(chunks.len(), 2);
604 let c0 = &chunks[0];
606 let c1 = &chunks[1];
607 assert!(
608 c0.len() <= 300,
609 "Device 0 should get ~25%, got {}",
610 c0.len()
611 );
612 assert!(
613 c1.len() >= 700,
614 "Device 1 should get ~75%, got {}",
615 c1.len()
616 );
617 assert_eq!(c0.start_idx, 0);
618 assert_eq!(c1.end_idx, 1000);
619 Ok(())
620 }
621
622 #[test]
623 fn test_distribute_skips_zero_memory_device() -> Result<()> {
624 let dist = WorkloadDistributor::new();
625 let devices = vec![make_device(0, 0), make_device(1, 8192)];
626 let chunks = dist.distribute(100, &devices)?;
627 assert_eq!(chunks.len(), 1);
629 assert_eq!(chunks[0].device_id, 1);
630 Ok(())
631 }
632
633 #[test]
634 fn test_distribute_even_basic() -> Result<()> {
635 let dist = WorkloadDistributor::new();
636 let devices = vec![
637 make_device(0, 4096),
638 make_device(1, 4096),
639 make_device(2, 4096),
640 ];
641 let chunks = dist.distribute_even(9, &devices)?;
642 assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 9);
643 for chunk in &chunks {
644 assert_eq!(chunk.len(), 3);
645 }
646 Ok(())
647 }
648
649 #[test]
650 fn test_distribute_even_with_remainder() -> Result<()> {
651 let dist = WorkloadDistributor::new();
652 let devices = vec![make_device(0, 4096), make_device(1, 4096)];
653 let chunks = dist.distribute_even(7, &devices)?;
654 assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 7);
655 assert_eq!(chunks[0].len(), 4);
657 assert_eq!(chunks[1].len(), 3);
658 Ok(())
659 }
660
661 #[test]
662 fn test_distribute_even_empty_devices_error() {
663 let dist = WorkloadDistributor::new();
664 assert!(dist.distribute_even(100, &[]).is_err());
665 }
666}