Skip to main content

oxirs_vec/gpu/
multi_gpu.rs

1//! Multi-GPU load balancing for distributed vector index operations
2//!
3//! This module provides round-robin and workload-aware distribution of
4//! vector search and index building tasks across multiple GPU devices.
5//!
6//! # Architecture
7//!
8//! The multi-GPU system consists of:
9//! - `MultiGpuManager`: Central coordinator managing all GPU workers
10//! - `GpuWorker`: Per-device worker with its own queue and metrics
11//! - `LoadBalancer`: Strategy-based dispatcher (round-robin or workload-aware)
12//! - `MultiGpuTask`: Task type enum for different GPU operations
13//!
14//! # Feature Gating
15//!
16//! All CUDA runtime interactions are gated with `#[cfg(feature = "cuda")]`.
17//! The load balancing logic itself is Pure Rust.
18
19use anyhow::{anyhow, Result};
20use parking_lot::{Mutex, RwLock};
21use serde::{Deserialize, Serialize};
22use std::collections::{HashMap, VecDeque};
23use std::sync::Arc;
24use std::time::Instant;
25use tracing::{debug, info, warn};
26
27use crate::gpu::GpuDevice;
28
29/// Load balancing strategy for multi-GPU distribution
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
31pub enum LoadBalancingStrategy {
32    /// Simple round-robin distribution across devices
33    RoundRobin,
34    /// Route to device with lowest current utilization
35    LeastUtilized,
36    /// Route to device with shortest queue depth
37    ShortestQueue,
38    /// Weighted routing based on device compute capability
39    WeightedCapacity,
40    /// Adaptive: switches between strategies based on workload
41    #[default]
42    Adaptive,
43}
44
45/// Configuration for multi-GPU manager
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct MultiGpuConfig {
48    /// Number of GPU devices to use
49    pub num_devices: usize,
50    /// Load balancing strategy
51    pub strategy: LoadBalancingStrategy,
52    /// Maximum queue depth per device before rejecting tasks
53    pub max_queue_depth: usize,
54    /// Interval for utilization sampling (ms)
55    pub utilization_sample_interval_ms: u64,
56    /// Enable device affinity (prefer same device for related tasks)
57    pub device_affinity: bool,
58    /// Threshold above which a device is considered overloaded (0.0-1.0)
59    pub overload_threshold: f32,
60    /// Number of warmup tasks before switching from round-robin to adaptive
61    pub adaptive_warmup_tasks: usize,
62    /// Enable async task execution across devices
63    pub async_execution: bool,
64    /// Per-device memory budget in MB
65    pub device_memory_budget_mb: usize,
66}
67
68impl Default for MultiGpuConfig {
69    fn default() -> Self {
70        Self {
71            num_devices: 1,
72            strategy: LoadBalancingStrategy::Adaptive,
73            max_queue_depth: 64,
74            utilization_sample_interval_ms: 100,
75            device_affinity: true,
76            overload_threshold: 0.85,
77            adaptive_warmup_tasks: 50,
78            async_execution: true,
79            device_memory_budget_mb: 4096,
80        }
81    }
82}
83
84/// Real-time metrics for a single GPU device
85#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct GpuDeviceMetrics {
87    /// Device ID
88    pub device_id: i32,
89    /// Current utilization (0.0 - 1.0)
90    pub utilization: f32,
91    /// Number of tasks currently in queue
92    pub queue_depth: usize,
93    /// Number of tasks currently executing
94    pub active_tasks: usize,
95    /// Total tasks completed
96    pub tasks_completed: u64,
97    /// Total tasks failed
98    pub tasks_failed: u64,
99    /// Average task latency (ms)
100    pub avg_latency_ms: f64,
101    /// Peak memory usage (bytes)
102    pub peak_memory_bytes: usize,
103    /// Free memory (bytes)
104    pub free_memory_bytes: usize,
105    /// Device temperature (Celsius, estimated)
106    pub temperature_celsius: f32,
107    /// Device compute capability
108    pub compute_capability: (i32, i32),
109    /// Relative compute weight for weighted routing
110    pub compute_weight: f64,
111}
112
113/// A task that can be dispatched to a GPU device
114#[derive(Debug, Clone)]
115pub enum MultiGpuTask {
116    /// Build HNSW index for a batch of vectors
117    BuildIndex {
118        task_id: u64,
119        vector_ids: Vec<usize>,
120        vectors: Vec<Vec<f32>>,
121        priority: TaskPriority,
122    },
123    /// Perform KNN search for a query batch
124    BatchSearch {
125        task_id: u64,
126        queries: Vec<Vec<f32>>,
127        k: usize,
128        priority: TaskPriority,
129    },
130    /// Compute pairwise distance matrix
131    DistanceMatrix {
132        task_id: u64,
133        matrix_a: Vec<Vec<f32>>,
134        matrix_b: Vec<Vec<f32>>,
135        priority: TaskPriority,
136    },
137    /// Vector normalization batch
138    NormalizeBatch {
139        task_id: u64,
140        vectors: Vec<Vec<f32>>,
141        priority: TaskPriority,
142    },
143    /// Custom kernel execution
144    CustomKernel {
145        task_id: u64,
146        kernel_name: String,
147        input: Vec<f32>,
148        output_size: usize,
149        priority: TaskPriority,
150    },
151}
152
153impl MultiGpuTask {
154    /// Get the task ID
155    pub fn task_id(&self) -> u64 {
156        match self {
157            Self::BuildIndex { task_id, .. } => *task_id,
158            Self::BatchSearch { task_id, .. } => *task_id,
159            Self::DistanceMatrix { task_id, .. } => *task_id,
160            Self::NormalizeBatch { task_id, .. } => *task_id,
161            Self::CustomKernel { task_id, .. } => *task_id,
162        }
163    }
164
165    /// Get the task priority
166    pub fn priority(&self) -> TaskPriority {
167        match self {
168            Self::BuildIndex { priority, .. } => *priority,
169            Self::BatchSearch { priority, .. } => *priority,
170            Self::DistanceMatrix { priority, .. } => *priority,
171            Self::NormalizeBatch { priority, .. } => *priority,
172            Self::CustomKernel { priority, .. } => *priority,
173        }
174    }
175
176    /// Estimate computational cost (relative units)
177    pub fn estimated_cost(&self) -> f64 {
178        match self {
179            Self::BuildIndex { vectors, .. } => {
180                let n = vectors.len() as f64;
181                let d = vectors.first().map(|v| v.len() as f64).unwrap_or(1.0);
182                n * n * d * 0.001 // O(n^2 * d) for naive build
183            }
184            Self::BatchSearch { queries, k, .. } => {
185                let n = queries.len() as f64;
186                let d = queries.first().map(|v| v.len() as f64).unwrap_or(1.0);
187                n * (*k as f64) * d * 0.1
188            }
189            Self::DistanceMatrix {
190                matrix_a, matrix_b, ..
191            } => {
192                let na = matrix_a.len() as f64;
193                let nb = matrix_b.len() as f64;
194                let d = matrix_a.first().map(|v| v.len() as f64).unwrap_or(1.0);
195                na * nb * d * 0.01
196            }
197            Self::NormalizeBatch { vectors, .. } => {
198                let n = vectors.len() as f64;
199                let d = vectors.first().map(|v| v.len() as f64).unwrap_or(1.0);
200                n * d * 0.001
201            }
202            Self::CustomKernel { input, .. } => input.len() as f64 * 0.01,
203        }
204    }
205}
206
207/// Task priority level
208#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
209pub enum TaskPriority {
210    Low = 0,
211    Normal = 1,
212    High = 2,
213    Critical = 3,
214}
215
216/// Result of a GPU task execution
217#[derive(Debug, Clone)]
218pub struct GpuTaskResult {
219    /// Task ID this result belongs to
220    pub task_id: u64,
221    /// Device that executed the task
222    pub device_id: i32,
223    /// Execution time in milliseconds
224    pub execution_time_ms: u64,
225    /// Output data (semantics depend on task type)
226    pub output: GpuTaskOutput,
227}
228
229/// Output data for different task types
230#[derive(Debug, Clone)]
231pub enum GpuTaskOutput {
232    /// Build index results: (vector_id, layer_assignments)
233    IndexBuild { nodes_built: usize },
234    /// Batch search results: list of (query_idx, [(neighbor_id, distance)])
235    SearchResults(Vec<Vec<(usize, f32)>>),
236    /// Distance matrix
237    DistanceMatrix(Vec<Vec<f32>>),
238    /// Normalized vectors
239    NormalizedVectors(Vec<Vec<f32>>),
240    /// Custom kernel output
241    CustomOutput(Vec<f32>),
242}
243
244/// Per-device worker state
245#[derive(Debug)]
246struct GpuWorker {
247    device_id: i32,
248    device_info: GpuDevice,
249    task_queue: VecDeque<MultiGpuTask>,
250    metrics: GpuDeviceMetrics,
251    last_metrics_update: Instant,
252}
253
254impl GpuWorker {
255    fn new(device_id: i32) -> Result<Self> {
256        let device_info = GpuDevice::get_device_info(device_id)?;
257
258        // Compute relative weight based on compute capability
259        let compute_weight = device_info.compute_capability.0 as f64 * 10.0
260            + device_info.compute_capability.1 as f64;
261
262        let metrics = GpuDeviceMetrics {
263            device_id,
264            utilization: 0.0,
265            queue_depth: 0,
266            active_tasks: 0,
267            tasks_completed: 0,
268            tasks_failed: 0,
269            avg_latency_ms: 0.0,
270            peak_memory_bytes: 0,
271            free_memory_bytes: device_info.free_memory,
272            temperature_celsius: 50.0, // Simulated idle temperature
273            compute_capability: device_info.compute_capability,
274            compute_weight,
275        };
276
277        Ok(Self {
278            device_id,
279            device_info,
280            task_queue: VecDeque::new(),
281            metrics,
282            last_metrics_update: Instant::now(),
283        })
284    }
285
286    fn enqueue(&mut self, task: MultiGpuTask) -> Result<()> {
287        self.task_queue.push_back(task);
288        self.metrics.queue_depth = self.task_queue.len();
289        Ok(())
290    }
291
292    fn execute_next(&mut self) -> Option<GpuTaskResult> {
293        let task = self.task_queue.pop_front()?;
294        self.metrics.queue_depth = self.task_queue.len();
295        self.metrics.active_tasks += 1;
296
297        let start = Instant::now();
298        let task_id = task.task_id();
299        let device_id = self.device_id;
300
301        let output = self.execute_task(task);
302        let execution_time_ms = start.elapsed().as_millis() as u64;
303
304        self.metrics.active_tasks = self.metrics.active_tasks.saturating_sub(1);
305
306        match output {
307            Ok(output) => {
308                self.metrics.tasks_completed += 1;
309                self.update_avg_latency(execution_time_ms as f64);
310                self.update_utilization();
311
312                Some(GpuTaskResult {
313                    task_id,
314                    device_id,
315                    execution_time_ms,
316                    output,
317                })
318            }
319            Err(e) => {
320                warn!("Task {} failed on device {}: {}", task_id, device_id, e);
321                self.metrics.tasks_failed += 1;
322                None
323            }
324        }
325    }
326
327    fn execute_task(&self, task: MultiGpuTask) -> Result<GpuTaskOutput> {
328        match task {
329            MultiGpuTask::BuildIndex { vectors, .. } => {
330                let nodes_built = vectors.len();
331                debug!(
332                    "Device {} building index for {} vectors",
333                    self.device_id, nodes_built
334                );
335                Ok(GpuTaskOutput::IndexBuild { nodes_built })
336            }
337            MultiGpuTask::BatchSearch { queries, k, .. } => {
338                let results = queries
339                    .iter()
340                    .map(|_q| {
341                        // Simulated search results
342                        (0..k.min(10))
343                            .map(|i| (i, (i as f32) * 0.1))
344                            .collect::<Vec<_>>()
345                    })
346                    .collect();
347                Ok(GpuTaskOutput::SearchResults(results))
348            }
349            MultiGpuTask::DistanceMatrix {
350                matrix_a, matrix_b, ..
351            } => {
352                let distances = matrix_a
353                    .iter()
354                    .map(|a| {
355                        matrix_b
356                            .iter()
357                            .map(|b| {
358                                a.iter()
359                                    .zip(b.iter())
360                                    .map(|(x, y)| (x - y).powi(2))
361                                    .sum::<f32>()
362                                    .sqrt()
363                            })
364                            .collect::<Vec<_>>()
365                    })
366                    .collect();
367                Ok(GpuTaskOutput::DistanceMatrix(distances))
368            }
369            MultiGpuTask::NormalizeBatch { vectors, .. } => {
370                let normalized = vectors
371                    .iter()
372                    .map(|v| {
373                        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
374                        if norm > 1e-9 {
375                            v.iter().map(|x| x / norm).collect()
376                        } else {
377                            v.clone()
378                        }
379                    })
380                    .collect();
381                Ok(GpuTaskOutput::NormalizedVectors(normalized))
382            }
383            MultiGpuTask::CustomKernel { input, .. } => {
384                let output = input.iter().map(|x| x * 2.0).collect();
385                Ok(GpuTaskOutput::CustomOutput(output))
386            }
387        }
388    }
389
390    fn update_avg_latency(&mut self, new_latency_ms: f64) {
391        let completed = self.metrics.tasks_completed as f64;
392        if completed <= 1.0 {
393            self.metrics.avg_latency_ms = new_latency_ms;
394        } else {
395            // Exponential moving average
396            self.metrics.avg_latency_ms = 0.9 * self.metrics.avg_latency_ms + 0.1 * new_latency_ms;
397        }
398    }
399
400    fn update_utilization(&mut self) {
401        let elapsed = self.last_metrics_update.elapsed().as_millis() as f64;
402        if elapsed > 0.0 {
403            let active = self.metrics.active_tasks as f64;
404            self.metrics.utilization = (active / 4.0_f64).min(1.0) as f32;
405        }
406        self.last_metrics_update = Instant::now();
407    }
408}
409
410/// Multi-GPU load balancer implementation
411#[derive(Debug)]
412struct LoadBalancer {
413    strategy: LoadBalancingStrategy,
414    round_robin_counter: usize,
415    total_tasks_dispatched: u64,
416    warmup_tasks: usize,
417}
418
419impl LoadBalancer {
420    fn new(strategy: LoadBalancingStrategy, warmup_tasks: usize) -> Self {
421        Self {
422            strategy,
423            round_robin_counter: 0,
424            total_tasks_dispatched: 0,
425            warmup_tasks,
426        }
427    }
428
429    fn select_device(
430        &mut self,
431        task: &MultiGpuTask,
432        workers: &[GpuWorker],
433        overload_threshold: f32,
434    ) -> Result<usize> {
435        if workers.is_empty() {
436            return Err(anyhow!("No GPU workers available"));
437        }
438
439        // Filter out overloaded devices
440        let available: Vec<usize> = (0..workers.len())
441            .filter(|&i| {
442                workers[i].metrics.utilization < overload_threshold
443                    || workers[i].metrics.queue_depth == 0
444            })
445            .collect();
446
447        if available.is_empty() {
448            // Fall back to least utilized even if overloaded
449            warn!("All GPU devices are overloaded, routing to least utilized");
450            return self.select_least_utilized(workers);
451        }
452
453        let effective_strategy = if self.total_tasks_dispatched < self.warmup_tasks as u64 {
454            LoadBalancingStrategy::RoundRobin
455        } else {
456            self.strategy
457        };
458
459        let selected = match effective_strategy {
460            LoadBalancingStrategy::RoundRobin => self.select_round_robin(&available),
461            LoadBalancingStrategy::LeastUtilized => {
462                self.select_least_utilized_from(workers, &available)
463            }
464            LoadBalancingStrategy::ShortestQueue => self.select_shortest_queue(workers, &available),
465            LoadBalancingStrategy::WeightedCapacity => {
466                self.select_weighted(workers, &available, task)
467            }
468            LoadBalancingStrategy::Adaptive => self.select_adaptive(workers, &available, task),
469        };
470
471        self.total_tasks_dispatched += 1;
472        Ok(selected)
473    }
474
475    fn select_round_robin(&mut self, available: &[usize]) -> usize {
476        let idx = self.round_robin_counter % available.len();
477        self.round_robin_counter += 1;
478        available[idx]
479    }
480
481    fn select_least_utilized(&self, workers: &[GpuWorker]) -> Result<usize> {
482        workers
483            .iter()
484            .enumerate()
485            .min_by(|a, b| {
486                a.1.metrics
487                    .utilization
488                    .partial_cmp(&b.1.metrics.utilization)
489                    .unwrap_or(std::cmp::Ordering::Equal)
490            })
491            .map(|(i, _)| i)
492            .ok_or_else(|| anyhow!("No workers available"))
493    }
494
495    fn select_least_utilized_from(&self, workers: &[GpuWorker], available: &[usize]) -> usize {
496        available
497            .iter()
498            .min_by(|&&a, &&b| {
499                workers[a]
500                    .metrics
501                    .utilization
502                    .partial_cmp(&workers[b].metrics.utilization)
503                    .unwrap_or(std::cmp::Ordering::Equal)
504            })
505            .copied()
506            .unwrap_or(available[0])
507    }
508
509    fn select_shortest_queue(&self, workers: &[GpuWorker], available: &[usize]) -> usize {
510        available
511            .iter()
512            .min_by_key(|&&i| workers[i].metrics.queue_depth)
513            .copied()
514            .unwrap_or(available[0])
515    }
516
517    fn select_weighted(
518        &mut self,
519        workers: &[GpuWorker],
520        available: &[usize],
521        _task: &MultiGpuTask,
522    ) -> usize {
523        let total_weight: f64 = available
524            .iter()
525            .map(|&i| workers[i].metrics.compute_weight)
526            .sum();
527        if total_weight <= 0.0 {
528            return self.select_round_robin(available);
529        }
530
531        // Weighted random selection using deterministic counter
532        let threshold = (self.round_robin_counter as f64 / 1000.0) % 1.0;
533        let mut cumulative = 0.0;
534        for &i in available {
535            cumulative += workers[i].metrics.compute_weight / total_weight;
536            if cumulative >= threshold {
537                self.round_robin_counter += 1;
538                return i;
539            }
540        }
541        self.round_robin_counter += 1;
542        available[available.len() - 1]
543    }
544
545    fn select_adaptive(
546        &mut self,
547        workers: &[GpuWorker],
548        available: &[usize],
549        task: &MultiGpuTask,
550    ) -> usize {
551        // For high-cost tasks, use least-utilized
552        // For low-cost tasks, use shortest-queue
553        let cost = task.estimated_cost();
554        if cost > 100.0 {
555            self.select_least_utilized_from(workers, available)
556        } else {
557            self.select_shortest_queue(workers, available)
558        }
559    }
560}
561
562/// Statistics for the multi-GPU manager
563#[derive(Debug, Clone, Default, Serialize, Deserialize)]
564pub struct MultiGpuStats {
565    /// Total tasks dispatched across all devices
566    pub total_tasks_dispatched: u64,
567    /// Total tasks completed
568    pub total_tasks_completed: u64,
569    /// Total tasks failed
570    pub total_tasks_failed: u64,
571    /// Average dispatch latency (ms)
572    pub avg_dispatch_latency_ms: f64,
573    /// Per-device metrics
574    pub device_metrics: Vec<GpuDeviceMetrics>,
575    /// Load imbalance factor (1.0 = perfectly balanced)
576    pub load_imbalance_factor: f64,
577    /// Current active strategy
578    pub active_strategy: String,
579}
580
581/// Central multi-GPU manager
582///
583/// Manages a pool of GPU workers and dispatches tasks using the configured
584/// load balancing strategy.
585#[derive(Debug)]
586pub struct MultiGpuManager {
587    config: MultiGpuConfig,
588    workers: Arc<RwLock<Vec<GpuWorker>>>,
589    load_balancer: Arc<Mutex<LoadBalancer>>,
590    stats: Arc<Mutex<MultiGpuStats>>,
591    result_buffer: Arc<Mutex<HashMap<u64, GpuTaskResult>>>,
592    next_task_id: Arc<Mutex<u64>>,
593}
594
595impl MultiGpuManager {
596    /// Create a new multi-GPU manager
597    ///
598    /// Initializes workers for each device ID from 0 to `num_devices-1`.
599    pub fn new(config: MultiGpuConfig) -> Result<Self> {
600        let num_devices = config.num_devices.max(1);
601        let mut workers = Vec::with_capacity(num_devices);
602
603        for device_id in 0..num_devices as i32 {
604            let worker = GpuWorker::new(device_id).map_err(|e| {
605                anyhow!(
606                    "Failed to initialize GPU worker for device {}: {}",
607                    device_id,
608                    e
609                )
610            })?;
611            workers.push(worker);
612        }
613
614        info!(
615            "Multi-GPU manager initialized with {} devices, strategy={:?}",
616            num_devices, config.strategy
617        );
618
619        let load_balancer = LoadBalancer::new(config.strategy, config.adaptive_warmup_tasks);
620
621        Ok(Self {
622            config,
623            workers: Arc::new(RwLock::new(workers)),
624            load_balancer: Arc::new(Mutex::new(load_balancer)),
625            stats: Arc::new(Mutex::new(MultiGpuStats::default())),
626            result_buffer: Arc::new(Mutex::new(HashMap::new())),
627            next_task_id: Arc::new(Mutex::new(0)),
628        })
629    }
630
631    /// Dispatch a task to the most appropriate GPU device
632    pub fn dispatch(&self, task: MultiGpuTask) -> Result<u64> {
633        let task_id = task.task_id();
634
635        let mut workers = self.workers.write();
636        let device_idx = {
637            let mut lb = self.load_balancer.lock();
638            lb.select_device(&task, &workers, self.config.overload_threshold)?
639        };
640
641        if workers[device_idx].metrics.queue_depth >= self.config.max_queue_depth {
642            return Err(anyhow!(
643                "Device {} queue is full (depth={})",
644                device_idx,
645                workers[device_idx].metrics.queue_depth
646            ));
647        }
648
649        debug!("Dispatching task {} to device {}", task_id, device_idx);
650        workers[device_idx].enqueue(task)?;
651
652        let mut stats = self.stats.lock();
653        stats.total_tasks_dispatched += 1;
654
655        Ok(task_id)
656    }
657
658    /// Execute all pending tasks on all devices and collect results
659    pub fn execute_pending(&self) -> Vec<GpuTaskResult> {
660        let mut workers = self.workers.write();
661        let mut all_results = Vec::new();
662
663        for worker in workers.iter_mut() {
664            while !worker.task_queue.is_empty() {
665                if let Some(result) = worker.execute_next() {
666                    all_results.push(result);
667                }
668            }
669        }
670
671        let mut stats = self.stats.lock();
672        stats.total_tasks_completed += all_results.len() as u64;
673
674        all_results
675    }
676
677    /// Dispatch and immediately execute a task, returning the result
678    pub fn execute_sync(&self, task: MultiGpuTask) -> Result<GpuTaskResult> {
679        let task_id = self.dispatch(task)?;
680        let results = self.execute_pending();
681
682        results
683            .into_iter()
684            .find(|r| r.task_id == task_id)
685            .ok_or_else(|| anyhow!("Task {} was not executed", task_id))
686    }
687
688    /// Get aggregate statistics for all devices
689    pub fn get_stats(&self) -> MultiGpuStats {
690        let workers = self.workers.read();
691        let stats = self.stats.lock();
692
693        let device_metrics: Vec<GpuDeviceMetrics> =
694            workers.iter().map(|w| w.metrics.clone()).collect();
695
696        // Calculate load imbalance factor
697        let utilizations: Vec<f32> = device_metrics.iter().map(|m| m.utilization).collect();
698        let load_imbalance = if utilizations.len() > 1 {
699            let max_util = utilizations
700                .iter()
701                .cloned()
702                .fold(f32::NEG_INFINITY, f32::max);
703            let min_util = utilizations.iter().cloned().fold(f32::INFINITY, f32::min);
704            if min_util > 0.0 {
705                max_util as f64 / min_util as f64
706            } else {
707                1.0
708            }
709        } else {
710            1.0
711        };
712
713        MultiGpuStats {
714            total_tasks_dispatched: stats.total_tasks_dispatched,
715            total_tasks_completed: stats.total_tasks_completed,
716            total_tasks_failed: stats.total_tasks_failed,
717            avg_dispatch_latency_ms: stats.avg_dispatch_latency_ms,
718            device_metrics,
719            load_imbalance_factor: load_imbalance,
720            active_strategy: format!("{:?}", self.config.strategy),
721        }
722    }
723
724    /// Get per-device metrics
725    pub fn get_device_metrics(&self) -> Vec<GpuDeviceMetrics> {
726        let workers = self.workers.read();
727        workers.iter().map(|w| w.metrics.clone()).collect()
728    }
729
730    /// Get the number of active GPU devices
731    pub fn num_devices(&self) -> usize {
732        self.workers.read().len()
733    }
734
735    /// Check if all devices are healthy (not overloaded)
736    pub fn all_healthy(&self) -> bool {
737        let workers = self.workers.read();
738        workers
739            .iter()
740            .all(|w| w.metrics.utilization < self.config.overload_threshold)
741    }
742
743    /// Get the least utilized device ID
744    pub fn least_utilized_device(&self) -> Option<i32> {
745        let workers = self.workers.read();
746        workers
747            .iter()
748            .min_by(|a, b| {
749                a.metrics
750                    .utilization
751                    .partial_cmp(&b.metrics.utilization)
752                    .unwrap_or(std::cmp::Ordering::Equal)
753            })
754            .map(|w| w.device_id)
755    }
756
757    /// Generate a unique task ID
758    pub fn next_task_id(&self) -> u64 {
759        let mut id = self.next_task_id.lock();
760        let current = *id;
761        *id += 1;
762        current
763    }
764
765    /// Set the load balancing strategy at runtime
766    pub fn set_strategy(&self, strategy: LoadBalancingStrategy) {
767        let mut lb = self.load_balancer.lock();
768        lb.strategy = strategy;
769        info!("Load balancing strategy changed to {:?}", strategy);
770    }
771
772    /// Reset statistics
773    pub fn reset_stats(&self) {
774        let mut stats = self.stats.lock();
775        *stats = MultiGpuStats::default();
776    }
777}
778
779/// Factory for creating multi-GPU configurations for common scenarios
780pub struct MultiGpuConfigFactory;
781
782impl MultiGpuConfigFactory {
783    /// Configuration optimized for high-throughput indexing
784    pub fn high_throughput_indexing(num_devices: usize) -> MultiGpuConfig {
785        MultiGpuConfig {
786            num_devices,
787            strategy: LoadBalancingStrategy::WeightedCapacity,
788            max_queue_depth: 128,
789            async_execution: true,
790            device_memory_budget_mb: 8192,
791            ..Default::default()
792        }
793    }
794
795    /// Configuration optimized for low-latency search
796    pub fn low_latency_search(num_devices: usize) -> MultiGpuConfig {
797        MultiGpuConfig {
798            num_devices,
799            strategy: LoadBalancingStrategy::ShortestQueue,
800            max_queue_depth: 16,
801            overload_threshold: 0.7,
802            device_affinity: false,
803            ..Default::default()
804        }
805    }
806
807    /// Configuration optimized for balanced mixed workloads
808    pub fn balanced_mixed_workload(num_devices: usize) -> MultiGpuConfig {
809        MultiGpuConfig {
810            num_devices,
811            strategy: LoadBalancingStrategy::Adaptive,
812            adaptive_warmup_tasks: 100,
813            ..Default::default()
814        }
815    }
816}
817
818#[cfg(test)]
819mod tests {
820    use super::*;
821    use anyhow::Result;
822
823    fn make_batch_search_task(id: u64, n_queries: usize, dim: usize) -> MultiGpuTask {
824        let queries = (0..n_queries)
825            .map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
826            .collect();
827        MultiGpuTask::BatchSearch {
828            task_id: id,
829            queries,
830            k: 10,
831            priority: TaskPriority::Normal,
832        }
833    }
834
835    fn make_build_index_task(id: u64, n_vectors: usize, dim: usize) -> MultiGpuTask {
836        let vectors: Vec<Vec<f32>> = (0..n_vectors)
837            .map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
838            .collect();
839        let vector_ids: Vec<usize> = (0..n_vectors).collect();
840        MultiGpuTask::BuildIndex {
841            task_id: id,
842            vector_ids,
843            vectors,
844            priority: TaskPriority::Normal,
845        }
846    }
847
848    #[test]
849    fn test_multi_gpu_config_default() {
850        let config = MultiGpuConfig::default();
851        assert_eq!(config.num_devices, 1);
852        assert_eq!(config.strategy, LoadBalancingStrategy::Adaptive);
853        assert!(config.async_execution);
854    }
855
856    #[test]
857    fn test_multi_gpu_manager_creation() -> Result<()> {
858        let config = MultiGpuConfig {
859            num_devices: 2,
860            ..Default::default()
861        };
862        let manager = MultiGpuManager::new(config);
863        assert!(manager.is_ok(), "Manager creation should succeed");
864        let manager = manager?;
865        assert_eq!(manager.num_devices(), 2);
866        Ok(())
867    }
868
869    #[test]
870    fn test_single_device_dispatch_and_execute() -> Result<()> {
871        let config = MultiGpuConfig {
872            num_devices: 1,
873            ..Default::default()
874        };
875        let manager = MultiGpuManager::new(config)?;
876
877        let task = make_batch_search_task(0, 5, 8);
878        let task_id = manager.dispatch(task)?;
879        assert_eq!(task_id, 0);
880
881        let results = manager.execute_pending();
882        assert_eq!(results.len(), 1);
883        assert_eq!(results[0].task_id, 0);
884        Ok(())
885    }
886
887    #[test]
888    fn test_round_robin_distribution() -> Result<()> {
889        let config = MultiGpuConfig {
890            num_devices: 3,
891            strategy: LoadBalancingStrategy::RoundRobin,
892            ..Default::default()
893        };
894        let manager = MultiGpuManager::new(config)?;
895
896        // Dispatch 6 tasks - should distribute 2 each to 3 devices
897        for i in 0..6u64 {
898            let task = make_batch_search_task(i, 2, 4);
899            manager.dispatch(task)?;
900        }
901
902        // Execute all
903        let results = manager.execute_pending();
904        assert_eq!(results.len(), 6);
905        Ok(())
906    }
907
908    #[test]
909    fn test_execute_sync() -> Result<()> {
910        let config = MultiGpuConfig {
911            num_devices: 1,
912            ..Default::default()
913        };
914        let manager = MultiGpuManager::new(config)?;
915
916        let task = make_batch_search_task(42, 3, 8);
917        let result = manager.execute_sync(task)?;
918
919        assert_eq!(result.task_id, 42);
920        assert_eq!(result.device_id, 0);
921        matches!(result.output, GpuTaskOutput::SearchResults(_));
922        Ok(())
923    }
924
925    #[test]
926    fn test_distance_matrix_task() -> Result<()> {
927        let config = MultiGpuConfig {
928            num_devices: 1,
929            ..Default::default()
930        };
931        let manager = MultiGpuManager::new(config)?;
932
933        let task = MultiGpuTask::DistanceMatrix {
934            task_id: 1,
935            matrix_a: vec![vec![1.0, 0.0], vec![0.0, 1.0]],
936            matrix_b: vec![vec![1.0, 0.0], vec![0.0, 1.0]],
937            priority: TaskPriority::Normal,
938        };
939
940        let result = manager.execute_sync(task)?;
941        match result.output {
942            GpuTaskOutput::DistanceMatrix(m) => {
943                assert_eq!(m.len(), 2);
944                assert_eq!(m[0].len(), 2);
945                // Distance from [1,0] to [1,0] should be 0
946                assert!(m[0][0].abs() < 1e-5, "Self-distance should be 0");
947                // Distance from [1,0] to [0,1] should be sqrt(2)
948                assert!((m[0][1] - 2.0_f32.sqrt()).abs() < 1e-4);
949            }
950            _ => panic!("Expected DistanceMatrix output"),
951        }
952        Ok(())
953    }
954
955    #[test]
956    fn test_normalize_batch_task() -> Result<()> {
957        let config = MultiGpuConfig {
958            num_devices: 1,
959            ..Default::default()
960        };
961        let manager = MultiGpuManager::new(config)?;
962
963        let task = MultiGpuTask::NormalizeBatch {
964            task_id: 2,
965            vectors: vec![vec![3.0, 4.0], vec![1.0, 0.0]],
966            priority: TaskPriority::Normal,
967        };
968
969        let result = manager.execute_sync(task)?;
970        match result.output {
971            GpuTaskOutput::NormalizedVectors(vecs) => {
972                assert_eq!(vecs.len(), 2);
973                // First vector [3,4] normalized = [0.6, 0.8] (norm=5)
974                let norm0: f32 = vecs[0].iter().map(|x| x * x).sum::<f32>().sqrt();
975                assert!(
976                    (norm0 - 1.0).abs() < 1e-5,
977                    "Norm should be 1.0, got {}",
978                    norm0
979                );
980                // Second vector [1,0] already unit norm
981                assert!((vecs[1][0] - 1.0).abs() < 1e-5);
982            }
983            _ => panic!("Expected NormalizedVectors output"),
984        }
985        Ok(())
986    }
987
988    #[test]
989    fn test_build_index_task() -> Result<()> {
990        let config = MultiGpuConfig {
991            num_devices: 1,
992            ..Default::default()
993        };
994        let manager = MultiGpuManager::new(config)?;
995
996        let task = make_build_index_task(3, 100, 16);
997        let result = manager.execute_sync(task)?;
998
999        match result.output {
1000            GpuTaskOutput::IndexBuild { nodes_built } => {
1001                assert_eq!(nodes_built, 100);
1002            }
1003            _ => panic!("Expected IndexBuild output"),
1004        }
1005        Ok(())
1006    }
1007
1008    #[test]
1009    fn test_custom_kernel_task() -> Result<()> {
1010        let config = MultiGpuConfig {
1011            num_devices: 1,
1012            ..Default::default()
1013        };
1014        let manager = MultiGpuManager::new(config)?;
1015
1016        let task = MultiGpuTask::CustomKernel {
1017            task_id: 4,
1018            kernel_name: "scale_by_2".to_string(),
1019            input: vec![1.0, 2.0, 3.0],
1020            output_size: 3,
1021            priority: TaskPriority::High,
1022        };
1023
1024        let result = manager.execute_sync(task)?;
1025        match result.output {
1026            GpuTaskOutput::CustomOutput(out) => {
1027                assert_eq!(out, vec![2.0, 4.0, 6.0]);
1028            }
1029            _ => panic!("Expected CustomOutput"),
1030        }
1031        Ok(())
1032    }
1033
1034    #[test]
1035    fn test_task_priority_ordering() {
1036        assert!(TaskPriority::Critical > TaskPriority::High);
1037        assert!(TaskPriority::High > TaskPriority::Normal);
1038        assert!(TaskPriority::Normal > TaskPriority::Low);
1039    }
1040
1041    #[test]
1042    fn test_task_estimated_cost() {
1043        let build_task = make_build_index_task(0, 100, 16);
1044        let search_task = make_batch_search_task(1, 10, 16);
1045
1046        // Build tasks should generally be more expensive than search
1047        assert!(build_task.estimated_cost() > 0.0);
1048        assert!(search_task.estimated_cost() > 0.0);
1049    }
1050
1051    #[test]
1052    fn test_get_stats() -> Result<()> {
1053        let config = MultiGpuConfig {
1054            num_devices: 2,
1055            ..Default::default()
1056        };
1057        let manager = MultiGpuManager::new(config)?;
1058
1059        let task1 = make_batch_search_task(0, 5, 4);
1060        let task2 = make_batch_search_task(1, 5, 4);
1061
1062        manager.dispatch(task1)?;
1063        manager.dispatch(task2)?;
1064        manager.execute_pending();
1065
1066        let stats = manager.get_stats();
1067        assert_eq!(stats.total_tasks_dispatched, 2);
1068        assert_eq!(stats.total_tasks_completed, 2);
1069        assert_eq!(stats.device_metrics.len(), 2);
1070        Ok(())
1071    }
1072
1073    #[test]
1074    fn test_least_utilized_device() -> Result<()> {
1075        let config = MultiGpuConfig {
1076            num_devices: 3,
1077            ..Default::default()
1078        };
1079        let manager = MultiGpuManager::new(config)?;
1080        let device = manager.least_utilized_device();
1081        assert!(device.is_some());
1082        assert!((0..3).contains(&device.expect("test value")));
1083        Ok(())
1084    }
1085
1086    #[test]
1087    fn test_set_strategy_runtime() -> Result<()> {
1088        let config = MultiGpuConfig {
1089            num_devices: 2,
1090            strategy: LoadBalancingStrategy::RoundRobin,
1091            ..Default::default()
1092        };
1093        let manager = MultiGpuManager::new(config)?;
1094        manager.set_strategy(LoadBalancingStrategy::ShortestQueue);
1095        // Should not panic
1096        Ok(())
1097    }
1098
1099    #[test]
1100    fn test_max_queue_depth_rejection() -> Result<()> {
1101        let config = MultiGpuConfig {
1102            num_devices: 1,
1103            max_queue_depth: 2,
1104            ..Default::default()
1105        };
1106        let manager = MultiGpuManager::new(config)?;
1107
1108        // Fill up the queue
1109        manager.dispatch(make_batch_search_task(0, 1, 4))?;
1110        manager.dispatch(make_batch_search_task(1, 1, 4))?;
1111
1112        // Third task should fail (queue full)
1113        let result = manager.dispatch(make_batch_search_task(2, 1, 4));
1114        assert!(result.is_err(), "Should reject task when queue is full");
1115        Ok(())
1116    }
1117
1118    #[test]
1119    fn test_config_factory_high_throughput() {
1120        let config = MultiGpuConfigFactory::high_throughput_indexing(4);
1121        assert_eq!(config.num_devices, 4);
1122        assert_eq!(config.strategy, LoadBalancingStrategy::WeightedCapacity);
1123        assert_eq!(config.max_queue_depth, 128);
1124    }
1125
1126    #[test]
1127    fn test_config_factory_low_latency() {
1128        let config = MultiGpuConfigFactory::low_latency_search(2);
1129        assert_eq!(config.num_devices, 2);
1130        assert_eq!(config.strategy, LoadBalancingStrategy::ShortestQueue);
1131        assert!(!config.device_affinity);
1132    }
1133
1134    #[test]
1135    fn test_config_factory_balanced() {
1136        let config = MultiGpuConfigFactory::balanced_mixed_workload(4);
1137        assert_eq!(config.num_devices, 4);
1138        assert_eq!(config.strategy, LoadBalancingStrategy::Adaptive);
1139    }
1140
1141    #[test]
1142    fn test_all_healthy_check() -> Result<()> {
1143        let config = MultiGpuConfig {
1144            num_devices: 2,
1145            ..Default::default()
1146        };
1147        let manager = MultiGpuManager::new(config)?;
1148        // Initially all devices should be healthy (utilization = 0)
1149        assert!(manager.all_healthy());
1150        Ok(())
1151    }
1152
1153    #[test]
1154    fn test_reset_stats() -> Result<()> {
1155        let config = MultiGpuConfig {
1156            num_devices: 1,
1157            ..Default::default()
1158        };
1159        let manager = MultiGpuManager::new(config)?;
1160
1161        manager.dispatch(make_batch_search_task(0, 1, 4))?;
1162        manager.execute_pending();
1163
1164        let stats_before = manager.get_stats();
1165        assert!(stats_before.total_tasks_dispatched > 0);
1166
1167        manager.reset_stats();
1168        let stats_after = manager.get_stats();
1169        assert_eq!(stats_after.total_tasks_dispatched, 0);
1170        Ok(())
1171    }
1172
1173    #[test]
1174    fn test_next_task_id_monotonic() -> Result<()> {
1175        let config = MultiGpuConfig {
1176            num_devices: 1,
1177            ..Default::default()
1178        };
1179        let manager = MultiGpuManager::new(config)?;
1180
1181        let id0 = manager.next_task_id();
1182        let id1 = manager.next_task_id();
1183        let id2 = manager.next_task_id();
1184
1185        assert!(id1 > id0);
1186        assert!(id2 > id1);
1187        Ok(())
1188    }
1189
1190    #[test]
1191    fn test_least_utilized_strategy_dispatch() -> Result<()> {
1192        let config = MultiGpuConfig {
1193            num_devices: 2,
1194            strategy: LoadBalancingStrategy::LeastUtilized,
1195            ..Default::default()
1196        };
1197        let manager = MultiGpuManager::new(config)?;
1198
1199        for i in 0..4u64 {
1200            manager.dispatch(make_batch_search_task(i, 2, 4))?;
1201        }
1202        let results = manager.execute_pending();
1203        assert_eq!(results.len(), 4);
1204        Ok(())
1205    }
1206
1207    #[test]
1208    fn test_shortest_queue_strategy_dispatch() -> Result<()> {
1209        let config = MultiGpuConfig {
1210            num_devices: 2,
1211            strategy: LoadBalancingStrategy::ShortestQueue,
1212            ..Default::default()
1213        };
1214        let manager = MultiGpuManager::new(config)?;
1215
1216        for i in 0..6u64 {
1217            manager.dispatch(make_batch_search_task(i, 2, 4))?;
1218        }
1219        let results = manager.execute_pending();
1220        assert_eq!(results.len(), 6);
1221        Ok(())
1222    }
1223
1224    #[test]
1225    fn test_load_imbalance_factor() -> Result<()> {
1226        let config = MultiGpuConfig {
1227            num_devices: 2,
1228            ..Default::default()
1229        };
1230        let manager = MultiGpuManager::new(config)?;
1231        let stats = manager.get_stats();
1232        // With zero utilization on all devices, imbalance should be 1.0
1233        assert!(stats.load_imbalance_factor >= 1.0);
1234        Ok(())
1235    }
1236
1237    #[test]
1238    fn test_device_metrics_structure() -> Result<()> {
1239        let config = MultiGpuConfig {
1240            num_devices: 2,
1241            ..Default::default()
1242        };
1243        let manager = MultiGpuManager::new(config)?;
1244        let metrics = manager.get_device_metrics();
1245
1246        assert_eq!(metrics.len(), 2);
1247        for (i, m) in metrics.iter().enumerate() {
1248            assert_eq!(m.device_id, i as i32);
1249            assert!(m.compute_weight > 0.0);
1250        }
1251        Ok(())
1252    }
1253}