1use anyhow::{anyhow, Result};
15use parking_lot::Mutex;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tracing::{debug, info};
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30pub struct SimpleGpuDevice {
31 pub id: u32,
33 pub name: String,
35 pub memory_mb: u64,
37 pub compute_units: u32,
39}
40
41impl SimpleGpuDevice {
42 pub fn new(id: u32, name: impl Into<String>, memory_mb: u64, compute_units: u32) -> Self {
44 Self {
45 id,
46 name: name.into(),
47 memory_mb,
48 compute_units,
49 }
50 }
51}
52
53#[derive(Debug)]
58struct DeviceState {
59 device: SimpleGpuDevice,
60 current_workload_mb: u64,
62}
63
64impl DeviceState {
65 fn new(device: SimpleGpuDevice) -> Self {
66 Self {
67 device,
68 current_workload_mb: 0,
69 }
70 }
71
72 fn utilization(&self) -> f64 {
74 if self.device.memory_mb == 0 {
75 return 0.0;
76 }
77 (self.current_workload_mb as f64 / self.device.memory_mb as f64).min(1.0)
78 }
79}
80
81#[derive(Debug, Clone)]
104pub struct GpuLoadBalancer {
105 inner: Arc<Mutex<GpuLoadBalancerInner>>,
106}
107
108#[derive(Debug)]
109struct GpuLoadBalancerInner {
110 device_order: Vec<u32>,
112 states: HashMap<u32, DeviceState>,
114}
115
116impl GpuLoadBalancerInner {
117 fn new() -> Self {
118 Self {
119 device_order: Vec::new(),
120 states: HashMap::new(),
121 }
122 }
123}
124
125impl Default for GpuLoadBalancer {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl GpuLoadBalancer {
132 pub fn new() -> Self {
134 Self {
135 inner: Arc::new(Mutex::new(GpuLoadBalancerInner::new())),
136 }
137 }
138
139 pub fn register_device(&self, device: SimpleGpuDevice) {
142 let mut g = self.inner.lock();
143 let id = device.id;
144 info!("Registering GPU device {} ({})", id, device.name);
145 if !g.device_order.contains(&id) {
146 g.device_order.push(id);
147 }
148 g.states.insert(id, DeviceState::new(device));
149 }
150
151 pub fn unregister_device(&self, device_id: u32) {
153 let mut g = self.inner.lock();
154 g.device_order.retain(|&x| x != device_id);
155 g.states.remove(&device_id);
156 debug!("Unregistered GPU device {}", device_id);
157 }
158
159 pub fn select_device(&self, workload_mb: u64) -> Option<u32> {
165 let g = self.inner.lock();
166 g.device_order
167 .iter()
168 .filter_map(|&id| g.states.get(&id).map(|s| (id, s)))
169 .filter(|(_, s)| {
170 s.device.memory_mb.saturating_sub(s.current_workload_mb) >= workload_mb
171 })
172 .min_by(|(_, a), (_, b)| {
173 a.utilization()
174 .partial_cmp(&b.utilization())
175 .unwrap_or(std::cmp::Ordering::Equal)
176 })
177 .map(|(id, _)| id)
178 }
179
180 pub fn record_workload(&self, device_id: u32, mb: u64) -> Result<()> {
184 let mut g = self.inner.lock();
185 let state = g
186 .states
187 .get_mut(&device_id)
188 .ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
189 state.current_workload_mb += mb;
190 debug!(
191 "Device {}: workload {} MB (util {:.1}%)",
192 device_id,
193 state.current_workload_mb,
194 state.utilization() * 100.0
195 );
196 Ok(())
197 }
198
199 pub fn release_workload(&self, device_id: u32, mb: u64) -> Result<()> {
204 let mut g = self.inner.lock();
205 let state = g
206 .states
207 .get_mut(&device_id)
208 .ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
209 state.current_workload_mb = state.current_workload_mb.saturating_sub(mb);
210 debug!(
211 "Device {}: released {} MB, now {} MB",
212 device_id, mb, state.current_workload_mb
213 );
214 Ok(())
215 }
216
217 pub fn utilization(&self, device_id: u32) -> Option<f64> {
221 let g = self.inner.lock();
222 g.states.get(&device_id).map(|s| s.utilization())
223 }
224
225 pub fn total_capacity_mb(&self) -> u64 {
227 let g = self.inner.lock();
228 g.states.values().map(|s| s.device.memory_mb).sum()
229 }
230
231 pub fn device_count(&self) -> usize {
233 self.inner.lock().states.len()
234 }
235
236 pub fn utilization_snapshot(&self) -> Vec<(u32, f64)> {
238 let g = self.inner.lock();
239 g.device_order
240 .iter()
241 .filter_map(|&id| g.states.get(&id).map(|s| (id, s.utilization())))
242 .collect()
243 }
244}
245
246#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
252pub struct WorkloadChunk {
253 pub device_id: u32,
255 pub start_idx: usize,
257 pub end_idx: usize,
259}
260
261impl WorkloadChunk {
262 pub fn len(&self) -> usize {
264 self.end_idx.saturating_sub(self.start_idx)
265 }
266
267 pub fn is_empty(&self) -> bool {
269 self.len() == 0
270 }
271}
272
273#[derive(Debug, Clone, Default)]
283pub struct WorkloadDistributor;
284
285impl WorkloadDistributor {
286 pub fn new() -> Self {
288 Self
289 }
290
291 pub fn distribute(
301 &self,
302 total_vectors: usize,
303 devices: &[SimpleGpuDevice],
304 ) -> Result<Vec<WorkloadChunk>> {
305 let eligible: Vec<&SimpleGpuDevice> = devices.iter().filter(|d| d.memory_mb > 0).collect();
306
307 if eligible.is_empty() {
308 return Err(anyhow!(
309 "No eligible GPU devices (all have zero memory or list is empty)"
310 ));
311 }
312
313 let total_mem: u64 = eligible.iter().map(|d| d.memory_mb).sum();
314
315 let mut chunks: Vec<WorkloadChunk> = Vec::with_capacity(eligible.len());
316 let mut assigned = 0usize;
317
318 for (i, device) in eligible.iter().enumerate() {
319 let start_idx = assigned;
320 let end_idx = if i == eligible.len() - 1 {
321 total_vectors
323 } else {
324 let fraction = device.memory_mb as f64 / total_mem as f64;
325 let count = (total_vectors as f64 * fraction).round() as usize;
326 (assigned + count).min(total_vectors)
327 };
328
329 chunks.push(WorkloadChunk {
330 device_id: device.id,
331 start_idx,
332 end_idx,
333 });
334 assigned = end_idx;
335
336 if assigned >= total_vectors {
337 break;
338 }
339 }
340
341 Ok(chunks)
342 }
343
344 pub fn distribute_even(
349 &self,
350 total_vectors: usize,
351 devices: &[SimpleGpuDevice],
352 ) -> Result<Vec<WorkloadChunk>> {
353 if devices.is_empty() {
354 return Err(anyhow!("Cannot distribute across zero devices"));
355 }
356
357 let n = devices.len();
358 let base = total_vectors / n;
359 let remainder = total_vectors % n;
360
361 let mut chunks = Vec::with_capacity(n);
362 let mut start = 0;
363
364 for (i, device) in devices.iter().enumerate() {
365 let extra = if i < remainder { 1 } else { 0 };
366 let end = start + base + extra;
367 chunks.push(WorkloadChunk {
368 device_id: device.id,
369 start_idx: start,
370 end_idx: end,
371 });
372 start = end;
373 }
374
375 Ok(chunks)
376 }
377}
378
379#[cfg(test)]
384mod tests {
385 use super::*;
386
387 fn make_device(id: u32, mem_mb: u64) -> SimpleGpuDevice {
388 SimpleGpuDevice::new(id, format!("GPU-{}", id), mem_mb, 128)
389 }
390
391 #[test]
394 fn test_simple_gpu_device_fields() {
395 let d = SimpleGpuDevice::new(0, "TestGPU", 8192, 128);
396 assert_eq!(d.id, 0);
397 assert_eq!(d.name, "TestGPU");
398 assert_eq!(d.memory_mb, 8192);
399 assert_eq!(d.compute_units, 128);
400 }
401
402 #[test]
405 fn test_register_device_count() {
406 let lb = GpuLoadBalancer::new();
407 lb.register_device(make_device(0, 8192));
408 lb.register_device(make_device(1, 16384));
409 assert_eq!(lb.device_count(), 2);
410 }
411
412 #[test]
413 fn test_total_capacity_mb() {
414 let lb = GpuLoadBalancer::new();
415 lb.register_device(make_device(0, 4096));
416 lb.register_device(make_device(1, 8192));
417 assert_eq!(lb.total_capacity_mb(), 12288);
418 }
419
420 #[test]
421 fn test_select_device_empty_returns_none() {
422 let lb = GpuLoadBalancer::new();
423 assert!(lb.select_device(100).is_none());
424 }
425
426 #[test]
427 fn test_select_device_single() {
428 let lb = GpuLoadBalancer::new();
429 lb.register_device(make_device(0, 8192));
430 let sel = lb.select_device(100);
431 assert_eq!(sel, Some(0));
432 }
433
434 #[test]
435 fn test_select_device_insufficient_memory() {
436 let lb = GpuLoadBalancer::new();
437 lb.register_device(make_device(0, 100)); assert!(lb.select_device(200).is_none());
440 }
441
442 #[test]
443 fn test_select_device_prefers_least_loaded() {
444 let lb = GpuLoadBalancer::new();
445 lb.register_device(make_device(0, 8192));
446 lb.register_device(make_device(1, 8192));
447
448 lb.record_workload(0, 7000).unwrap();
450
451 let sel = lb.select_device(500);
453 assert_eq!(sel, Some(1), "Should prefer the less-loaded device");
454 }
455
456 #[test]
457 fn test_record_and_release_workload() {
458 let lb = GpuLoadBalancer::new();
459 lb.register_device(make_device(0, 8192));
460
461 lb.record_workload(0, 2048).unwrap();
462 let u1 = lb.utilization(0).unwrap();
463 assert!(
464 (u1 - 0.25).abs() < 1e-6,
465 "Expected 25% utilisation, got {}",
466 u1
467 );
468
469 lb.release_workload(0, 2048).unwrap();
470 let u2 = lb.utilization(0).unwrap();
471 assert!(u2 < 1e-9, "Expected 0% after release, got {}", u2);
472 }
473
474 #[test]
475 fn test_release_clamps_to_zero() {
476 let lb = GpuLoadBalancer::new();
477 lb.register_device(make_device(0, 8192));
478 lb.record_workload(0, 100).unwrap();
479 lb.release_workload(0, 9999).unwrap();
481 assert_eq!(lb.utilization(0).unwrap(), 0.0);
482 }
483
484 #[test]
485 fn test_record_unknown_device_errors() {
486 let lb = GpuLoadBalancer::new();
487 assert!(lb.record_workload(99, 100).is_err());
488 }
489
490 #[test]
491 fn test_release_unknown_device_errors() {
492 let lb = GpuLoadBalancer::new();
493 assert!(lb.release_workload(99, 100).is_err());
494 }
495
496 #[test]
497 fn test_utilization_unknown_device_none() {
498 let lb = GpuLoadBalancer::new();
499 assert!(lb.utilization(42).is_none());
500 }
501
502 #[test]
503 fn test_utilization_snapshot() {
504 let lb = GpuLoadBalancer::new();
505 lb.register_device(make_device(0, 8192));
506 lb.register_device(make_device(1, 4096));
507 lb.record_workload(0, 4096).unwrap();
508 let snap = lb.utilization_snapshot();
509 assert_eq!(snap.len(), 2);
510 let u0 = snap
511 .iter()
512 .find(|(id, _)| *id == 0)
513 .map(|(_, u)| *u)
514 .unwrap();
515 assert!((u0 - 0.5).abs() < 1e-6);
516 }
517
518 #[test]
519 fn test_unregister_device() {
520 let lb = GpuLoadBalancer::new();
521 lb.register_device(make_device(0, 8192));
522 lb.register_device(make_device(1, 8192));
523 lb.unregister_device(0);
524 assert_eq!(lb.device_count(), 1);
525 assert!(lb.utilization(0).is_none());
526 }
527
528 #[test]
529 fn test_reregister_device_resets_workload() {
530 let lb = GpuLoadBalancer::new();
531 lb.register_device(make_device(0, 8192));
532 lb.record_workload(0, 4096).unwrap();
533 lb.register_device(make_device(0, 8192));
535 assert_eq!(lb.utilization(0).unwrap(), 0.0);
536 }
537
538 #[test]
541 fn test_workload_chunk_len() {
542 let chunk = WorkloadChunk {
543 device_id: 0,
544 start_idx: 10,
545 end_idx: 50,
546 };
547 assert_eq!(chunk.len(), 40);
548 }
549
550 #[test]
551 fn test_workload_chunk_is_empty() {
552 let chunk = WorkloadChunk {
553 device_id: 0,
554 start_idx: 5,
555 end_idx: 5,
556 };
557 assert!(chunk.is_empty());
558 }
559
560 #[test]
563 fn test_distribute_empty_devices_error() {
564 let dist = WorkloadDistributor::new();
565 assert!(dist.distribute(1000, &[]).is_err());
566 }
567
568 #[test]
569 fn test_distribute_single_device() {
570 let dist = WorkloadDistributor::new();
571 let devices = vec![make_device(0, 8192)];
572 let chunks = dist.distribute(1000, &devices).unwrap();
573 assert_eq!(chunks.len(), 1);
574 assert_eq!(chunks[0].start_idx, 0);
575 assert_eq!(chunks[0].end_idx, 1000);
576 }
577
578 #[test]
579 fn test_distribute_covers_all_vectors() {
580 let dist = WorkloadDistributor::new();
581 let devices = vec![make_device(0, 4096), make_device(1, 8192)];
582 let chunks = dist.distribute(900, &devices).unwrap();
583 let covered: usize = chunks.iter().map(|c| c.len()).sum();
584 assert_eq!(covered, 900, "All vectors must be covered");
585 }
586
587 #[test]
588 fn test_distribute_proportional_to_memory() {
589 let dist = WorkloadDistributor::new();
590 let devices = vec![make_device(0, 1024), make_device(1, 3072)];
592 let chunks = dist.distribute(1000, &devices).unwrap();
593 assert_eq!(chunks.len(), 2);
594 let c0 = &chunks[0];
596 let c1 = &chunks[1];
597 assert!(
598 c0.len() <= 300,
599 "Device 0 should get ~25%, got {}",
600 c0.len()
601 );
602 assert!(
603 c1.len() >= 700,
604 "Device 1 should get ~75%, got {}",
605 c1.len()
606 );
607 assert_eq!(c0.start_idx, 0);
608 assert_eq!(c1.end_idx, 1000);
609 }
610
611 #[test]
612 fn test_distribute_skips_zero_memory_device() {
613 let dist = WorkloadDistributor::new();
614 let devices = vec![make_device(0, 0), make_device(1, 8192)];
615 let chunks = dist.distribute(100, &devices).unwrap();
616 assert_eq!(chunks.len(), 1);
618 assert_eq!(chunks[0].device_id, 1);
619 }
620
621 #[test]
622 fn test_distribute_even_basic() {
623 let dist = WorkloadDistributor::new();
624 let devices = vec![
625 make_device(0, 4096),
626 make_device(1, 4096),
627 make_device(2, 4096),
628 ];
629 let chunks = dist.distribute_even(9, &devices).unwrap();
630 assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 9);
631 for chunk in &chunks {
632 assert_eq!(chunk.len(), 3);
633 }
634 }
635
636 #[test]
637 fn test_distribute_even_with_remainder() {
638 let dist = WorkloadDistributor::new();
639 let devices = vec![make_device(0, 4096), make_device(1, 4096)];
640 let chunks = dist.distribute_even(7, &devices).unwrap();
641 assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 7);
642 assert_eq!(chunks[0].len(), 4);
644 assert_eq!(chunks[1].len(), 3);
645 }
646
647 #[test]
648 fn test_distribute_even_empty_devices_error() {
649 let dist = WorkloadDistributor::new();
650 assert!(dist.distribute_even(100, &[]).is_err());
651 }
652}