Skip to main content

oxirs_vec/gpu/
load_balancer.rs

1//! GPU load balancing for distributing index-building work across multiple devices.
2//!
3//! This module provides:
4//! - `GpuLoadBalancer`: runtime tracking of per-device workloads and selection of the
5//!   least-loaded device for a new task.
6//! - `WorkloadDistributor`: static splitting of a large index job into per-device
7//!   contiguous chunks.
8//!
9//! # Pure Rust Policy
10//!
11//! No CUDA runtime calls are made here.  All load-balancing logic is Pure Rust and
12//! operates on abstract device descriptors (`SimpleGpuDevice`).
13
14use anyhow::{anyhow, Result};
15use parking_lot::Mutex;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tracing::{debug, info};
20
21// ============================================================
22// Device descriptor
23// ============================================================
24
25/// Lightweight descriptor of a GPU device used for load balancing decisions.
26///
27/// This is intentionally separate from `crate::gpu::GpuDevice` (which carries
28/// CUDA-specific fields) so that the load balancer remains 100% Pure Rust.
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30pub struct SimpleGpuDevice {
31    /// Numeric device identifier (matches CUDA device ordinal when CUDA is enabled)
32    pub id: u32,
33    /// Human-readable name, e.g. "NVIDIA A100-80GB"
34    pub name: String,
35    /// Total GPU memory in megabytes
36    pub memory_mb: u64,
37    /// Number of CUDA streaming multiprocessors / compute units
38    pub compute_units: u32,
39}
40
41impl SimpleGpuDevice {
42    /// Create a new device descriptor.
43    pub fn new(id: u32, name: impl Into<String>, memory_mb: u64, compute_units: u32) -> Self {
44        Self {
45            id,
46            name: name.into(),
47            memory_mb,
48            compute_units,
49        }
50    }
51}
52
53// ============================================================
54// Per-device state (internal)
55// ============================================================
56
57#[derive(Debug)]
58struct DeviceState {
59    device: SimpleGpuDevice,
60    /// Currently allocated workload in megabytes
61    current_workload_mb: u64,
62}
63
64impl DeviceState {
65    fn new(device: SimpleGpuDevice) -> Self {
66        Self {
67            device,
68            current_workload_mb: 0,
69        }
70    }
71
72    /// Utilisation as a fraction [0.0, 1.0] of total device memory.
73    fn utilization(&self) -> f64 {
74        if self.device.memory_mb == 0 {
75            return 0.0;
76        }
77        (self.current_workload_mb as f64 / self.device.memory_mb as f64).min(1.0)
78    }
79}
80
81// ============================================================
82// GpuLoadBalancer
83// ============================================================
84
85/// Distributes GPU work across multiple devices using a least-loaded strategy.
86///
87/// All mutating operations are thread-safe via an internal `Mutex`.
88///
89/// # Example
90/// ```
91/// use oxirs_vec::gpu::{GpuLoadBalancer, SimpleGpuDevice};
92///
93/// let balancer = GpuLoadBalancer::new();
94/// balancer.register_device(SimpleGpuDevice::new(0, "GPU-0", 8192, 128));
95/// balancer.register_device(SimpleGpuDevice::new(1, "GPU-1", 16384, 256));
96///
97/// if let Some(id) = balancer.select_device(512) {
98///     balancer.record_workload(id, 512);
99///     // ... do GPU work ...
100///     balancer.release_workload(id, 512);
101/// }
102/// ```
103#[derive(Debug, Clone)]
104pub struct GpuLoadBalancer {
105    inner: Arc<Mutex<GpuLoadBalancerInner>>,
106}
107
108#[derive(Debug)]
109struct GpuLoadBalancerInner {
110    /// Ordered list of registered device IDs (insertion order)
111    device_order: Vec<u32>,
112    /// Per-device state keyed by device ID
113    states: HashMap<u32, DeviceState>,
114}
115
116impl GpuLoadBalancerInner {
117    fn new() -> Self {
118        Self {
119            device_order: Vec::new(),
120            states: HashMap::new(),
121        }
122    }
123}
124
125impl Default for GpuLoadBalancer {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl GpuLoadBalancer {
132    /// Create an empty load balancer with no registered devices.
133    pub fn new() -> Self {
134        Self {
135            inner: Arc::new(Mutex::new(GpuLoadBalancerInner::new())),
136        }
137    }
138
139    /// Register a GPU device.  If a device with the same `id` already exists it is
140    /// replaced (workload is reset to zero).
141    pub fn register_device(&self, device: SimpleGpuDevice) {
142        let mut g = self.inner.lock();
143        let id = device.id;
144        info!("Registering GPU device {} ({})", id, device.name);
145        if !g.device_order.contains(&id) {
146            g.device_order.push(id);
147        }
148        g.states.insert(id, DeviceState::new(device));
149    }
150
151    /// Remove a device from the balancer.
152    pub fn unregister_device(&self, device_id: u32) {
153        let mut g = self.inner.lock();
154        g.device_order.retain(|&x| x != device_id);
155        g.states.remove(&device_id);
156        debug!("Unregistered GPU device {}", device_id);
157    }
158
159    /// Select the device best suited to handle `workload_mb` megabytes of new work.
160    ///
161    /// Returns the `id` of the device with the lowest current utilisation that has
162    /// enough free memory to accept the workload, or `None` if no suitable device
163    /// exists or no devices are registered.
164    pub fn select_device(&self, workload_mb: u64) -> Option<u32> {
165        let g = self.inner.lock();
166        g.device_order
167            .iter()
168            .filter_map(|&id| g.states.get(&id).map(|s| (id, s)))
169            .filter(|(_, s)| {
170                s.device.memory_mb.saturating_sub(s.current_workload_mb) >= workload_mb
171            })
172            .min_by(|(_, a), (_, b)| {
173                a.utilization()
174                    .partial_cmp(&b.utilization())
175                    .unwrap_or(std::cmp::Ordering::Equal)
176            })
177            .map(|(id, _)| id)
178    }
179
180    /// Record `mb` megabytes of additional workload on `device_id`.
181    ///
182    /// Returns an error if `device_id` is not registered.
183    pub fn record_workload(&self, device_id: u32, mb: u64) -> Result<()> {
184        let mut g = self.inner.lock();
185        let state = g
186            .states
187            .get_mut(&device_id)
188            .ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
189        state.current_workload_mb += mb;
190        debug!(
191            "Device {}: workload {} MB (util {:.1}%)",
192            device_id,
193            state.current_workload_mb,
194            state.utilization() * 100.0
195        );
196        Ok(())
197    }
198
199    /// Release `mb` megabytes of workload from `device_id`.
200    ///
201    /// Clamps to zero to prevent underflow.  Returns an error if the device is
202    /// not registered.
203    pub fn release_workload(&self, device_id: u32, mb: u64) -> Result<()> {
204        let mut g = self.inner.lock();
205        let state = g
206            .states
207            .get_mut(&device_id)
208            .ok_or_else(|| anyhow!("Device {} not registered", device_id))?;
209        state.current_workload_mb = state.current_workload_mb.saturating_sub(mb);
210        debug!(
211            "Device {}: released {} MB, now {} MB",
212            device_id, mb, state.current_workload_mb
213        );
214        Ok(())
215    }
216
217    /// Utilisation of `device_id` as a fraction in `[0.0, 1.0]`.
218    ///
219    /// Returns `None` if the device is not registered.
220    pub fn utilization(&self, device_id: u32) -> Option<f64> {
221        let g = self.inner.lock();
222        g.states.get(&device_id).map(|s| s.utilization())
223    }
224
225    /// Sum of memory across all registered devices in MB.
226    pub fn total_capacity_mb(&self) -> u64 {
227        let g = self.inner.lock();
228        g.states.values().map(|s| s.device.memory_mb).sum()
229    }
230
231    /// Number of registered devices.
232    pub fn device_count(&self) -> usize {
233        self.inner.lock().states.len()
234    }
235
236    /// Returns a snapshot of device IDs and their current utilisation.
237    pub fn utilization_snapshot(&self) -> Vec<(u32, f64)> {
238        let g = self.inner.lock();
239        g.device_order
240            .iter()
241            .filter_map(|&id| g.states.get(&id).map(|s| (id, s.utilization())))
242            .collect()
243    }
244}
245
246// ============================================================
247// WorkloadChunk
248// ============================================================
249
250/// A contiguous slice of a vector dataset assigned to a specific GPU.
251#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
252pub struct WorkloadChunk {
253    /// ID of the GPU device responsible for this chunk
254    pub device_id: u32,
255    /// Start index (inclusive) in the source vector array
256    pub start_idx: usize,
257    /// End index (exclusive) in the source vector array
258    pub end_idx: usize,
259}
260
261impl WorkloadChunk {
262    /// Number of vectors in this chunk.
263    pub fn len(&self) -> usize {
264        self.end_idx.saturating_sub(self.start_idx)
265    }
266
267    /// Returns `true` if the chunk covers no vectors.
268    pub fn is_empty(&self) -> bool {
269        self.len() == 0
270    }
271}
272
273// ============================================================
274// WorkloadDistributor
275// ============================================================
276
277/// Splits a large vector index job across multiple GPU devices proportionally to
278/// their memory capacity.
279///
280/// The distributor is stateless: call `distribute` as many times as needed
281/// without side effects.
282#[derive(Debug, Clone, Default)]
283pub struct WorkloadDistributor;
284
285impl WorkloadDistributor {
286    /// Create a new distributor.
287    pub fn new() -> Self {
288        Self
289    }
290
291    /// Distribute `total_vectors` vectors across `devices` proportionally to each
292    /// device's `memory_mb`.
293    ///
294    /// Returns one [`WorkloadChunk`] per device (in device order).  Devices with
295    /// zero memory are skipped.  Returns an error if `devices` is empty or all
296    /// devices have zero memory.
297    ///
298    /// The last chunk absorbs any rounding remainder so that every vector is
299    /// covered exactly once.
300    pub fn distribute(
301        &self,
302        total_vectors: usize,
303        devices: &[SimpleGpuDevice],
304    ) -> Result<Vec<WorkloadChunk>> {
305        let eligible: Vec<&SimpleGpuDevice> = devices.iter().filter(|d| d.memory_mb > 0).collect();
306
307        if eligible.is_empty() {
308            return Err(anyhow!(
309                "No eligible GPU devices (all have zero memory or list is empty)"
310            ));
311        }
312
313        let total_mem: u64 = eligible.iter().map(|d| d.memory_mb).sum();
314
315        let mut chunks: Vec<WorkloadChunk> = Vec::with_capacity(eligible.len());
316        let mut assigned = 0usize;
317
318        for (i, device) in eligible.iter().enumerate() {
319            let start_idx = assigned;
320            let end_idx = if i == eligible.len() - 1 {
321                // Last device gets remaining vectors (absorbs rounding error)
322                total_vectors
323            } else {
324                let fraction = device.memory_mb as f64 / total_mem as f64;
325                let count = (total_vectors as f64 * fraction).round() as usize;
326                (assigned + count).min(total_vectors)
327            };
328
329            chunks.push(WorkloadChunk {
330                device_id: device.id,
331                start_idx,
332                end_idx,
333            });
334            assigned = end_idx;
335
336            if assigned >= total_vectors {
337                break;
338            }
339        }
340
341        Ok(chunks)
342    }
343
344    /// Distribute evenly (round-robin, ignoring memory ratios).
345    ///
346    /// Useful when all devices are homogeneous.  Returns an error if `devices` is
347    /// empty.
348    pub fn distribute_even(
349        &self,
350        total_vectors: usize,
351        devices: &[SimpleGpuDevice],
352    ) -> Result<Vec<WorkloadChunk>> {
353        if devices.is_empty() {
354            return Err(anyhow!("Cannot distribute across zero devices"));
355        }
356
357        let n = devices.len();
358        let base = total_vectors / n;
359        let remainder = total_vectors % n;
360
361        let mut chunks = Vec::with_capacity(n);
362        let mut start = 0;
363
364        for (i, device) in devices.iter().enumerate() {
365            let extra = if i < remainder { 1 } else { 0 };
366            let end = start + base + extra;
367            chunks.push(WorkloadChunk {
368                device_id: device.id,
369                start_idx: start,
370                end_idx: end,
371            });
372            start = end;
373        }
374
375        Ok(chunks)
376    }
377}
378
379// ============================================================
380// Tests
381// ============================================================
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    fn make_device(id: u32, mem_mb: u64) -> SimpleGpuDevice {
388        SimpleGpuDevice::new(id, format!("GPU-{}", id), mem_mb, 128)
389    }
390
391    // ---- SimpleGpuDevice ----
392
393    #[test]
394    fn test_simple_gpu_device_fields() {
395        let d = SimpleGpuDevice::new(0, "TestGPU", 8192, 128);
396        assert_eq!(d.id, 0);
397        assert_eq!(d.name, "TestGPU");
398        assert_eq!(d.memory_mb, 8192);
399        assert_eq!(d.compute_units, 128);
400    }
401
402    // ---- GpuLoadBalancer ----
403
404    #[test]
405    fn test_register_device_count() {
406        let lb = GpuLoadBalancer::new();
407        lb.register_device(make_device(0, 8192));
408        lb.register_device(make_device(1, 16384));
409        assert_eq!(lb.device_count(), 2);
410    }
411
412    #[test]
413    fn test_total_capacity_mb() {
414        let lb = GpuLoadBalancer::new();
415        lb.register_device(make_device(0, 4096));
416        lb.register_device(make_device(1, 8192));
417        assert_eq!(lb.total_capacity_mb(), 12288);
418    }
419
420    #[test]
421    fn test_select_device_empty_returns_none() {
422        let lb = GpuLoadBalancer::new();
423        assert!(lb.select_device(100).is_none());
424    }
425
426    #[test]
427    fn test_select_device_single() {
428        let lb = GpuLoadBalancer::new();
429        lb.register_device(make_device(0, 8192));
430        let sel = lb.select_device(100);
431        assert_eq!(sel, Some(0));
432    }
433
434    #[test]
435    fn test_select_device_insufficient_memory() {
436        let lb = GpuLoadBalancer::new();
437        lb.register_device(make_device(0, 100)); // only 100 MB
438                                                 // Requesting 200 MB should yield None
439        assert!(lb.select_device(200).is_none());
440    }
441
442    #[test]
443    fn test_select_device_prefers_least_loaded() {
444        let lb = GpuLoadBalancer::new();
445        lb.register_device(make_device(0, 8192));
446        lb.register_device(make_device(1, 8192));
447
448        // Load device 0 heavily
449        lb.record_workload(0, 7000).unwrap();
450
451        // Device 1 should be selected
452        let sel = lb.select_device(500);
453        assert_eq!(sel, Some(1), "Should prefer the less-loaded device");
454    }
455
456    #[test]
457    fn test_record_and_release_workload() {
458        let lb = GpuLoadBalancer::new();
459        lb.register_device(make_device(0, 8192));
460
461        lb.record_workload(0, 2048).unwrap();
462        let u1 = lb.utilization(0).unwrap();
463        assert!(
464            (u1 - 0.25).abs() < 1e-6,
465            "Expected 25% utilisation, got {}",
466            u1
467        );
468
469        lb.release_workload(0, 2048).unwrap();
470        let u2 = lb.utilization(0).unwrap();
471        assert!(u2 < 1e-9, "Expected 0% after release, got {}", u2);
472    }
473
474    #[test]
475    fn test_release_clamps_to_zero() {
476        let lb = GpuLoadBalancer::new();
477        lb.register_device(make_device(0, 8192));
478        lb.record_workload(0, 100).unwrap();
479        // Release more than recorded — should not underflow
480        lb.release_workload(0, 9999).unwrap();
481        assert_eq!(lb.utilization(0).unwrap(), 0.0);
482    }
483
484    #[test]
485    fn test_record_unknown_device_errors() {
486        let lb = GpuLoadBalancer::new();
487        assert!(lb.record_workload(99, 100).is_err());
488    }
489
490    #[test]
491    fn test_release_unknown_device_errors() {
492        let lb = GpuLoadBalancer::new();
493        assert!(lb.release_workload(99, 100).is_err());
494    }
495
496    #[test]
497    fn test_utilization_unknown_device_none() {
498        let lb = GpuLoadBalancer::new();
499        assert!(lb.utilization(42).is_none());
500    }
501
502    #[test]
503    fn test_utilization_snapshot() {
504        let lb = GpuLoadBalancer::new();
505        lb.register_device(make_device(0, 8192));
506        lb.register_device(make_device(1, 4096));
507        lb.record_workload(0, 4096).unwrap();
508        let snap = lb.utilization_snapshot();
509        assert_eq!(snap.len(), 2);
510        let u0 = snap
511            .iter()
512            .find(|(id, _)| *id == 0)
513            .map(|(_, u)| *u)
514            .unwrap();
515        assert!((u0 - 0.5).abs() < 1e-6);
516    }
517
518    #[test]
519    fn test_unregister_device() {
520        let lb = GpuLoadBalancer::new();
521        lb.register_device(make_device(0, 8192));
522        lb.register_device(make_device(1, 8192));
523        lb.unregister_device(0);
524        assert_eq!(lb.device_count(), 1);
525        assert!(lb.utilization(0).is_none());
526    }
527
528    #[test]
529    fn test_reregister_device_resets_workload() {
530        let lb = GpuLoadBalancer::new();
531        lb.register_device(make_device(0, 8192));
532        lb.record_workload(0, 4096).unwrap();
533        // Re-register same device — workload should reset
534        lb.register_device(make_device(0, 8192));
535        assert_eq!(lb.utilization(0).unwrap(), 0.0);
536    }
537
538    // ---- WorkloadChunk ----
539
540    #[test]
541    fn test_workload_chunk_len() {
542        let chunk = WorkloadChunk {
543            device_id: 0,
544            start_idx: 10,
545            end_idx: 50,
546        };
547        assert_eq!(chunk.len(), 40);
548    }
549
550    #[test]
551    fn test_workload_chunk_is_empty() {
552        let chunk = WorkloadChunk {
553            device_id: 0,
554            start_idx: 5,
555            end_idx: 5,
556        };
557        assert!(chunk.is_empty());
558    }
559
560    // ---- WorkloadDistributor ----
561
562    #[test]
563    fn test_distribute_empty_devices_error() {
564        let dist = WorkloadDistributor::new();
565        assert!(dist.distribute(1000, &[]).is_err());
566    }
567
568    #[test]
569    fn test_distribute_single_device() {
570        let dist = WorkloadDistributor::new();
571        let devices = vec![make_device(0, 8192)];
572        let chunks = dist.distribute(1000, &devices).unwrap();
573        assert_eq!(chunks.len(), 1);
574        assert_eq!(chunks[0].start_idx, 0);
575        assert_eq!(chunks[0].end_idx, 1000);
576    }
577
578    #[test]
579    fn test_distribute_covers_all_vectors() {
580        let dist = WorkloadDistributor::new();
581        let devices = vec![make_device(0, 4096), make_device(1, 8192)];
582        let chunks = dist.distribute(900, &devices).unwrap();
583        let covered: usize = chunks.iter().map(|c| c.len()).sum();
584        assert_eq!(covered, 900, "All vectors must be covered");
585    }
586
587    #[test]
588    fn test_distribute_proportional_to_memory() {
589        let dist = WorkloadDistributor::new();
590        // Device 0: 1 GB, device 1: 3 GB => 25 / 75 split
591        let devices = vec![make_device(0, 1024), make_device(1, 3072)];
592        let chunks = dist.distribute(1000, &devices).unwrap();
593        assert_eq!(chunks.len(), 2);
594        // Device 0 should get ~250 vectors
595        let c0 = &chunks[0];
596        let c1 = &chunks[1];
597        assert!(
598            c0.len() <= 300,
599            "Device 0 should get ~25%, got {}",
600            c0.len()
601        );
602        assert!(
603            c1.len() >= 700,
604            "Device 1 should get ~75%, got {}",
605            c1.len()
606        );
607        assert_eq!(c0.start_idx, 0);
608        assert_eq!(c1.end_idx, 1000);
609    }
610
611    #[test]
612    fn test_distribute_skips_zero_memory_device() {
613        let dist = WorkloadDistributor::new();
614        let devices = vec![make_device(0, 0), make_device(1, 8192)];
615        let chunks = dist.distribute(100, &devices).unwrap();
616        // Device 0 is skipped; only device 1
617        assert_eq!(chunks.len(), 1);
618        assert_eq!(chunks[0].device_id, 1);
619    }
620
621    #[test]
622    fn test_distribute_even_basic() {
623        let dist = WorkloadDistributor::new();
624        let devices = vec![
625            make_device(0, 4096),
626            make_device(1, 4096),
627            make_device(2, 4096),
628        ];
629        let chunks = dist.distribute_even(9, &devices).unwrap();
630        assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 9);
631        for chunk in &chunks {
632            assert_eq!(chunk.len(), 3);
633        }
634    }
635
636    #[test]
637    fn test_distribute_even_with_remainder() {
638        let dist = WorkloadDistributor::new();
639        let devices = vec![make_device(0, 4096), make_device(1, 4096)];
640        let chunks = dist.distribute_even(7, &devices).unwrap();
641        assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), 7);
642        // First device gets 4, second gets 3
643        assert_eq!(chunks[0].len(), 4);
644        assert_eq!(chunks[1].len(), 3);
645    }
646
647    #[test]
648    fn test_distribute_even_empty_devices_error() {
649        let dist = WorkloadDistributor::new();
650        assert!(dist.distribute_even(100, &[]).is_err());
651    }
652}