Skip to main content

oxirs_vec/gpu/
load_balancer.rs

1//! GPU load balancing for distributing index-building work across multiple devices.
2//!
3//! This module provides:
4//! - `GpuLoadBalancer`: runtime tracking of per-device workloads and selection of the
5//!   least-loaded device for a new task.
6//! - `WorkloadDistributor`: static splitting of a large index job into per-device
7//!   contiguous chunks.
8//!
9//! # Pure Rust Policy
10//!
11//! No CUDA runtime calls are made here.  All load-balancing logic is Pure Rust and
12//! operates on abstract device descriptors (`SimpleGpuDevice`).
13
14use anyhow::{anyhow, Result};
15use parking_lot::Mutex;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tracing::{debug, info};
20
21// ============================================================
22// Device descriptor
23// ============================================================
24
25/// Lightweight descriptor of a GPU device used for load balancing decisions.
26///
27/// This is intentionally separate from `crate::gpu::GpuDevice` (which carries
28/// CUDA-specific fields) so that the load balancer remains 100% Pure Rust.
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30pub struct SimpleGpuDevice {
31    /// Numeric device identifier (matches CUDA device ordinal when CUDA is enabled)
32    pub id: u32,
33    /// Human-readable name, e.g. "NVIDIA A100-80GB"
34    pub name: String,
35    /// Total GPU memory in megabytes
36    pub memory_mb: u64,
37    /// Number of CUDA streaming multiprocessors / compute units
38    pub compute_units: u32,
39}
40
41impl SimpleGpuDevice {
42    /// Create a new device descriptor.
43    pub fn new(id: u32, name: impl Into<String>, memory_mb: u64, compute_units: u32) -> Self {
44        Self {
45            id,
46            name: name.into(),
47            memory_mb,
48            compute_units,
49        }
50    }
51}
52
53// ============================================================
54// Per-device state (internal)
55// ============================================================
56
57#[derive(Debug)]
58struct DeviceState {
59    device: SimpleGpuDevice,
60    /// Currently allocated workload in megabytes
61    current_workload_mb: u64,
62}
63
64impl DeviceState {
65    fn new(device: SimpleGpuDevice) -> Self {
66        Self {
67            device,
68            current_workload_mb: 0,
69        }
70    }
71
72    /// Utilisation as a fraction [0.0, 1.0] of total device memory.
73    fn utilization(&self) -> f64 {
74        if self.device.memory_mb == 0 {
75            return 0.0;
76        }
77        (self.current_workload_mb as f64 / self.device.memory_mb as f64).min(1.0)
78    }
79}
80
81// ============================================================
82// GpuLoadBalancer
83// ============================================================
84
85/// Distributes GPU work across multiple devices using a least-loaded strategy.
86///
87/// All mutating operations are thread-safe via an internal `Mutex`.
88///
89/// # Example
90/// ```
91/// use oxirs_vec::gpu::{GpuLoadBalancer, SimpleGpuDevice};
92///
93/// let balancer = GpuLoadBalancer::new();
94/// balancer.register_device(SimpleGpuDevice::new(0, "GPU-0", 8192, 128));
95/// balancer.register_device(SimpleGpuDevice::new(1, "GPU-1", 16384, 256));
96///
97/// if let Some(id) = balancer.select_device(512) {
98///     balancer.record_workload(id, 512);
99///     // ... do GPU work ...
100///     balancer.release_workload(id, 512);
101/// }
102/// ```
103#[derive(Debug, Clone)]
104pub struct GpuLoadBalancer {
105    inner: Arc<Mutex<GpuLoadBalancerInner>>,
106}
107
108#[derive(Debug)]
109struct GpuLoadBalancerInner {
110    /// Ordered list of registered device IDs (insertion order)
111    device_order: Vec<u32>,
112    /// Per-device state keyed by device ID
113    states: HashMap<u32, DeviceState>,
114}
115
116impl GpuLoadBalancerInner {
117    fn new() -> Self {
118        Self {
119            device_order: Vec::new(),
120            states: HashMap::new(),
121        }
122    }
123}
124
125impl Default for GpuLoadBalancer {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl GpuLoadBalancer {
132    /// Create an empty load balancer with no registered devices.
133    pub fn new() -> Self {
134        Self {
135            inner: Arc::new(Mutex::new(GpuLoadBalancerInner::new())),
136        }
137    }
138
139    /// Register a GPU device.  If a device with the same `id` already exists it is
140    /// replaced (workload is reset to zero).
141    pub fn register_device(&self, device: SimpleGpuDevice) {
142        let mut g = self.inner.lock();
143        let id = device.id;
144        info!("Registering GPU device {} ({})", id, device.name);
145        if !g.device_order.contains(&id) {
146            g.device_order.push(id);
147        }
148        g.states.insert(id, DeviceState::new(device));
149    }
150
151    /// Remove a device from the balancer.
152    pub fn unregister_device(&self, device_id: u32) {
153        let mut g = self.inner.lock();
154        g.device_order.retain(|&x| x != device_id);
155        g.states.remove(&device_id);
156        debug!("Unregistered GPU device {}", device_id);
157    }
158
159    /// Select the device best suited to handle `workload_mb` megabytes of new work.
160    ///
161    /// Returns the `id` of the device with the lowest current utilisation that has
162    /// enough free memory to accept the workload, or `None` if no suitable device
163    /// exists or no devices are registered.
164    pub fn select_device(&self, workload_mb: u64) -> Option<u32> {
165        let g = self.inner.lock();
166        g.device_order
167            .iter()
168            .filter_map(|&id| g.states.get(&id).map(|s| (id, s)))
169            .filter(|(_, s)| {
170                s.device.memory_mb.saturating_sub(s.current_workload_mb) >= workload_mb
171            })
172            .min_by(|(_, a), (_, b)| {
173                a.utilization()
174                    .partial_cmp(&b.utilization())
175                    .unwrap_or(std::cmp::Ordering::Equal)
176            })
177            .map(|(id, _)| id)
178    }
179
180    /// Record `mb` megabytes of additional workload on `device_id`.
181    ///
182    /// Returns an error if `device_id` is not registered.
183    pub fn record_workload(&self, device_id: u32, mb: u64) -> Result<()> {
184        let mut g = self.inner.lock();
185        let state = g
186            .states
187            .get_mut(&device_id)
188            .ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
189        state.current_workload_mb += mb;
190        debug!(
191            "Device {}: workload {} MB (util {:.1}%)",
192            device_id,
193            state.current_workload_mb,
194            state.utilization() * 100.0
195        );
196        Ok(())
197    }
198
199    /// Release `mb` megabytes of workload from `device_id`.
200    ///
201    /// Clamps to zero to prevent underflow.  Returns an error if the device is
202    /// not registered.
203    pub fn release_workload(&self, device_id: u32, mb: u64) -> Result<()> {
204        let mut g = self.inner.lock();
205        let state = g
206            .states
207            .get_mut(&device_id)
208            .ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
209        state.current_workload_mb = state.current_workload_mb.saturating_sub(mb);
210        debug!(
211            "Device {}: released {} MB, now {} MB",
212            device_id, mb, state.current_workload_mb
213        );
214        Ok(())
215    }
216
217    /// Utilisation of `device_id` as a fraction in `[0.0, 1.0]`.
218    ///
219    /// Returns `None` if the device is not registered.
220    pub fn utilization(&self, device_id: u32) -> Option<f64> {
221        let g = self.inner.lock();
222        g.states.get(&device_id).map(|s| s.utilization())
223    }
224
225    /// Sum of memory across all registered devices in MB.
226    pub fn total_capacity_mb(&self) -> u64 {
227        let g = self.inner.lock();
228        g.states.values().map(|s| s.device.memory_mb).sum()
229    }
230
231    /// Number of registered devices.
232    pub fn device_count(&self) -> usize {
233        self.inner.lock().states.len()
234    }
235
236    /// Returns a snapshot of device IDs and their current utilisation.
237    pub fn utilization_snapshot(&self) -> Vec<(u32, f64)> {
238        let g = self.inner.lock();
239        g.device_order
240            .iter()
241            .filter_map(|&id| g.states.get(&id).map(|s| (id, s.utilization())))
242            .collect()
243    }
244}
245
246// ============================================================
247// WorkloadChunk
248// ============================================================
249
250/// A contiguous slice of a vector dataset assigned to a specific GPU.
251#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
252pub struct WorkloadChunk {
253    /// ID of the GPU device responsible for this chunk
254    pub device_id: u32,
255    /// Start index (inclusive) in the source vector array
256    pub start_idx: usize,
257    /// End index (exclusive) in the source vector array
258    pub end_idx: usize,
259}
260
261impl WorkloadChunk {
262    /// Number of vectors in this chunk.
263    pub fn len(&self) -> usize {
264        self.end_idx.saturating_sub(self.start_idx)
265    }
266
267    /// Returns `true` if the chunk covers no vectors.
268    pub fn is_empty(&self) -> bool {
269        self.len() == 0
270    }
271}
272
273// ============================================================
274// WorkloadDistributor
275// ============================================================
276
277/// Splits a large vector index job across multiple GPU devices proportionally to
278/// their memory capacity.
279///
280/// The distributor is stateless: call `distribute` as many times as needed
281/// without side effects.
282#[derive(Debug, Clone, Default)]
283pub struct WorkloadDistributor;
284
285impl WorkloadDistributor {
286    /// Create a new distributor.
287    pub fn new() -> Self {
288        Self
289    }
290
291    /// Distribute `total_vectors` vectors across `devices` proportionally to each
292    /// device's `memory_mb`.
293    ///
294    /// Returns one [`WorkloadChunk`] per device (in device order).  Devices with
295    /// zero memory are skipped.  Returns an error if `devices` is empty or all
296    /// devices have zero memory.
297    ///
298    /// The last chunk absorbs any rounding remainder so that every vector is
299    /// covered exactly once.
300    pub fn distribute(
301        &self,
302        total_vectors: usize,
303        devices: &[SimpleGpuDevice],
304    ) -> Result<Vec<WorkloadChunk>> {
305        let eligible: Vec<&SimpleGpuDevice> = devices.iter().filter(|d| d.memory_mb > 0).collect();
306
307        if eligible.is_empty() {
308            return Err(anyhow!(
309                "No eligible GPU devices (all have zero memory or list is empty)"
310            ));
311        }
312
313        let total_mem: u64 = eligible.iter().map(|d| d.memory_mb).sum();
314
315        let mut chunks: Vec<WorkloadChunk> = Vec::with_capacity(eligible.len());
316        let mut assigned = 0usize;
317
318        for (i, device) in eligible.iter().enumerate() {
319            let start_idx = assigned;
320            let end_idx = if i == eligible.len() - 1 {
321                // Last device gets remaining vectors (absorbs rounding error)
322                total_vectors
323            } else {
324                let fraction = device.memory_mb as f64 / total_mem as f64;
325                let count = (total_vectors as f64 * fraction).round() as usize;
326                (assigned + count).min(total_vectors)
327            };
328
329            chunks.push(WorkloadChunk {
330                device_id: device.id,
331                start_idx,
332                end_idx,
333            });
334            assigned = end_idx;
335
336            if assigned >= total_vectors {
337                break;
338            }
339        }
340
341        Ok(chunks)
342    }
343
344    /// Distribute evenly (round-robin, ignoring memory ratios).
345    ///
346    /// Useful when all devices are homogeneous.  Returns an error if `devices` is
347    /// empty.
348    pub fn distribute_even(
349        &self,
350        total_vectors: usize,
351        devices: &[SimpleGpuDevice],
352    ) -> Result<Vec<WorkloadChunk>> {
353        if devices.is_empty() {
354            return Err(anyhow!("Cannot distribute across zero devices"));
355        }
356
357        let n = devices.len();
358        let base = total_vectors / n;
359        let remainder = total_vectors % n;
360
361        let mut chunks = Vec::with_capacity(n);
362        let mut start = 0;
363
364        for (i, device) in devices.iter().enumerate() {
365            let extra = if i < remainder { 1 } else { 0 };
366            let end = start + base + extra;
367            chunks.push(WorkloadChunk {
368                device_id: device.id,
369                start_idx: start,
370                end_idx: end,
371            });
372            start = end;
373        }
374
375        Ok(chunks)
376    }
377}
378
379// ============================================================
380// Tests
381// ============================================================
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use anyhow::Result;
387
388    fn make_device(id: u32, mem_mb: u64) -> SimpleGpuDevice {
389        SimpleGpuDevice::new(id, format!("GPU-{}", id), mem_mb, 128)
390    }
391
392    // ---- SimpleGpuDevice ----
393
394    #[test]
395    fn test_simple_gpu_device_fields() {
396        let d = SimpleGpuDevice::new(0, "TestGPU", 8192, 128);
397        assert_eq!(d.id, 0);
398        assert_eq!(d.name, "TestGPU");
399        assert_eq!(d.memory_mb, 8192);
400        assert_eq!(d.compute_units, 128);
401    }
402
403    // ---- GpuLoadBalancer ----
404
405    #[test]
406    fn test_register_device_count() {
407        let lb = GpuLoadBalancer::new();
408        lb.register_device(make_device(0, 8192));
409        lb.register_device(make_device(1, 16384));
410        assert_eq!(lb.device_count(), 2);
411    }
412
413    #[test]
414    fn test_total_capacity_mb() {
415        let lb = GpuLoadBalancer::new();
416        lb.register_device(make_device(0, 4096));
417        lb.register_device(make_device(1, 8192));
418        assert_eq!(lb.total_capacity_mb(), 12288);
419    }
420
421    #[test]
422    fn test_select_device_empty_returns_none() {
423        let lb = GpuLoadBalancer::new();
424        assert!(lb.select_device(100).is_none());
425    }
426
427    #[test]
428    fn test_select_device_single() {
429        let lb = GpuLoadBalancer::new();
430        lb.register_device(make_device(0, 8192));
431        let sel = lb.select_device(100);
432        assert_eq!(sel, Some(0));
433    }
434
435    #[test]
436    fn test_select_device_insufficient_memory() {
437        let lb = GpuLoadBalancer::new();
438        lb.register_device(make_device(0, 100)); // only 100 MB
439                                                 // Requesting 200 MB should yield None
440        assert!(lb.select_device(200).is_none());
441    }
442
443    #[test]
444    fn test_select_device_prefers_least_loaded() -> Result<()> {
445        let lb = GpuLoadBalancer::new();
446        lb.register_device(make_device(0, 8192));
447        lb.register_device(make_device(1, 8192));
448
449        // Load device 0 heavily
450        lb.record_workload(0, 7000)?;
451
452        // Device 1 should be selected
453        let sel = lb.select_device(500);
454        assert_eq!(sel, Some(1), "Should prefer the less-loaded device");
455        Ok(())
456    }
457
458    #[test]
459    fn test_record_and_release_workload() -> Result<()> {
460        let lb = GpuLoadBalancer::new();
461        lb.register_device(make_device(0, 8192));
462
463        lb.record_workload(0, 2048)?;
464        let u1 = lb.utilization(0).expect("utilization(0) was None");
465        assert!(
466            (u1 - 0.25).abs() < 1e-6,
467            "Expected 25% utilisation, got {}",
468            u1
469        );
470
471        lb.release_workload(0, 2048)?;
472        let u2 = lb.utilization(0).expect("utilization(0) was None");
473        assert!(u2 < 1e-9, "Expected 0% after release, got {}", u2);
474        Ok(())
475    }
476
477    #[test]
478    fn test_release_clamps_to_zero() -> Result<()> {
479        let lb = GpuLoadBalancer::new();
480        lb.register_device(make_device(0, 8192));
481        lb.record_workload(0, 100)?;
482        // Release more than recorded — should not underflow
483        lb.release_workload(0, 9999)?;
484        let __val = lb.utilization(0).expect("utilization(0) was None");
485        assert_eq!(__val, 0.0);
486        Ok(())
487    }
488
489    #[test]
490    fn test_record_unknown_device_errors() {
491        let lb = GpuLoadBalancer::new();
492        assert!(lb.record_workload(99, 100).is_err());
493    }
494
495    #[test]
496    fn test_release_unknown_device_errors() {
497        let lb = GpuLoadBalancer::new();
498        assert!(lb.release_workload(99, 100).is_err());
499    }
500
501    #[test]
502    fn test_utilization_unknown_device_none() {
503        let lb = GpuLoadBalancer::new();
504        assert!(lb.utilization(42).is_none());
505    }
506
507    #[test]
508    fn test_utilization_snapshot() -> Result<()> {
509        let lb = GpuLoadBalancer::new();
510        lb.register_device(make_device(0, 8192));
511        lb.register_device(make_device(1, 4096));
512        lb.record_workload(0, 4096)?;
513        let snap = lb.utilization_snapshot();
514        assert_eq!(snap.len(), 2);
515        let u0 = snap
516            .iter()
517            .find(|(id, _)| *id == 0)
518            .map(|(_, u)| *u)
519            .expect("device 0 not in snapshot");
520        assert!((u0 - 0.5).abs() < 1e-6);
521        Ok(())
522    }
523
524    #[test]
525    fn test_unregister_device() {
526        let lb = GpuLoadBalancer::new();
527        lb.register_device(make_device(0, 8192));
528        lb.register_device(make_device(1, 8192));
529        lb.unregister_device(0);
530        assert_eq!(lb.device_count(), 1);
531        assert!(lb.utilization(0).is_none());
532    }
533
534    #[test]
535    fn test_reregister_device_resets_workload() -> Result<()> {
536        let lb = GpuLoadBalancer::new();
537        lb.register_device(make_device(0, 8192));
538        lb.record_workload(0, 4096)?;
539        // Re-register same device — workload should reset
540        lb.register_device(make_device(0, 8192));
541        let __val = lb.utilization(0).expect("utilization(0) should be present");
542        assert_eq!(__val, 0.0);
543        Ok(())
544    }
545
546    // ---- WorkloadChunk ----
547
548    #[test]
549    fn test_workload_chunk_len() {
550        let chunk = WorkloadChunk {
551            device_id: 0,
552            start_idx: 10,
553            end_idx: 50,
554        };
555        assert_eq!(chunk.len(), 40);
556    }
557
558    #[test]
559    fn test_workload_chunk_is_empty() {
560        let chunk = WorkloadChunk {
561            device_id: 0,
562            start_idx: 5,
563            end_idx: 5,
564        };
565        assert!(chunk.is_empty());
566    }
567
568    // ---- WorkloadDistributor ----
569
570    #[test]
571    fn test_distribute_empty_devices_error() {
572        let dist = WorkloadDistributor::new();
573        assert!(dist.distribute(1000, &[]).is_err());
574    }
575
576    #[test]
577    fn test_distribute_single_device() -> Result<()> {
578        let dist = WorkloadDistributor::new();
579        let devices = vec![make_device(0, 8192)];
580        let chunks = dist.distribute(1000, &devices)?;
581        assert_eq!(chunks.len(), 1);
582        assert_eq!(chunks[0].start_idx, 0);
583        assert_eq!(chunks[0].end_idx, 1000);
584        Ok(())
585    }
586
587    #[test]
588    fn test_distribute_covers_all_vectors() -> Result<()> {
589        let dist = WorkloadDistributor::new();
590        let devices = vec![make_device(0, 4096), make_device(1, 8192)];
591        let chunks = dist.distribute(900, &devices)?;
592        let covered: usize = chunks.iter().map(|c| c.len()).sum();
593        assert_eq!(covered, 900, "All vectors must be covered");
594        Ok(())
595    }
596
597    #[test]
598    fn test_distribute_proportional_to_memory() -> Result<()> {
599        let dist = WorkloadDistributor::new();
600        // Device 0: 1 GB, device 1: 3 GB => 25 / 75 split
601        let devices = vec![make_device(0, 1024), make_device(1, 3072)];
602        let chunks = dist.distribute(1000, &devices)?;
603        assert_eq!(chunks.len(), 2);
604        // Device 0 should get ~250 vectors
605        let c0 = &chunks[0];
606        let c1 = &chunks[1];
607        assert!(
608            c0.len() <= 300,
609            "Device 0 should get ~25%, got {}",
610            c0.len()
611        );
612        assert!(
613            c1.len() >= 700,
614            "Device 1 should get ~75%, got {}",
615            c1.len()
616        );
617        assert_eq!(c0.start_idx, 0);
618        assert_eq!(c1.end_idx, 1000);
619        Ok(())
620    }
621
622    #[test]
623    fn test_distribute_skips_zero_memory_device() -> Result<()> {
624        let dist = WorkloadDistributor::new();
625        let devices = vec![make_device(0, 0), make_device(1, 8192)];
626        let chunks = dist.distribute(100, &devices)?;
627        // Device 0 is skipped; only device 1
628        assert_eq!(chunks.len(), 1);
629        assert_eq!(chunks[0].device_id, 1);
630        Ok(())
631    }
632
633    #[test]
634    fn test_distribute_even_basic() -> Result<()> {
635        let dist = WorkloadDistributor::new();
636        let devices = vec![
637            make_device(0, 4096),
638            make_device(1, 4096),
639            make_device(2, 4096),
640        ];
641        let chunks = dist.distribute_even(9, &devices)?;
642        assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 9);
643        for chunk in &chunks {
644            assert_eq!(chunk.len(), 3);
645        }
646        Ok(())
647    }
648
649    #[test]
650    fn test_distribute_even_with_remainder() -> Result<()> {
651        let dist = WorkloadDistributor::new();
652        let devices = vec![make_device(0, 4096), make_device(1, 4096)];
653        let chunks = dist.distribute_even(7, &devices)?;
654        assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 7);
655        // First device gets 4, second gets 3
656        assert_eq!(chunks[0].len(), 4);
657        assert_eq!(chunks[1].len(), 3);
658        Ok(())
659    }
660
661    #[test]
662    fn test_distribute_even_empty_devices_error() {
663        let dist = WorkloadDistributor::new();
664        assert!(dist.distribute_even(100, &[]).is_err());
665    }
666}