Skip to main content

oxirs_vec/gpu/
multi_gpu.rs

1//! Multi-GPU load balancing for distributed vector index operations
2//!
3//! This module provides round-robin and workload-aware distribution of
4//! vector search and index building tasks across multiple GPU devices.
5//!
6//! # Architecture
7//!
8//! The multi-GPU system consists of:
9//! - `MultiGpuManager`: Central coordinator managing all GPU workers
10//! - `GpuWorker`: Per-device worker with its own queue and metrics
11//! - `LoadBalancer`: Strategy-based dispatcher (round-robin or workload-aware)
12//! - `MultiGpuTask`: Task type enum for different GPU operations
13//!
14//! # Feature Gating
15//!
16//! All CUDA runtime interactions are gated with `#[cfg(feature = "cuda")]`.
17//! The load balancing logic itself is Pure Rust.
18
19use anyhow::{anyhow, Result};
20use parking_lot::{Mutex, RwLock};
21use serde::{Deserialize, Serialize};
22use std::collections::{HashMap, VecDeque};
23use std::sync::Arc;
24use std::time::Instant;
25use tracing::{debug, info, warn};
26
27use crate::gpu::GpuDevice;
28
29/// Load balancing strategy for multi-GPU distribution
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
31pub enum LoadBalancingStrategy {
32    /// Simple round-robin distribution across devices
33    RoundRobin,
34    /// Route to device with lowest current utilization
35    LeastUtilized,
36    /// Route to device with shortest queue depth
37    ShortestQueue,
38    /// Weighted routing based on device compute capability
39    WeightedCapacity,
40    /// Adaptive: switches between strategies based on workload
41    #[default]
42    Adaptive,
43}
44
45/// Configuration for multi-GPU manager
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct MultiGpuConfig {
48    /// Number of GPU devices to use
49    pub num_devices: usize,
50    /// Load balancing strategy
51    pub strategy: LoadBalancingStrategy,
52    /// Maximum queue depth per device before rejecting tasks
53    pub max_queue_depth: usize,
54    /// Interval for utilization sampling (ms)
55    pub utilization_sample_interval_ms: u64,
56    /// Enable device affinity (prefer same device for related tasks)
57    pub device_affinity: bool,
58    /// Threshold above which a device is considered overloaded (0.0-1.0)
59    pub overload_threshold: f32,
60    /// Number of warmup tasks before switching from round-robin to adaptive
61    pub adaptive_warmup_tasks: usize,
62    /// Enable async task execution across devices
63    pub async_execution: bool,
64    /// Per-device memory budget in MB
65    pub device_memory_budget_mb: usize,
66}
67
68impl Default for MultiGpuConfig {
69    fn default() -> Self {
70        Self {
71            num_devices: 1,
72            strategy: LoadBalancingStrategy::Adaptive,
73            max_queue_depth: 64,
74            utilization_sample_interval_ms: 100,
75            device_affinity: true,
76            overload_threshold: 0.85,
77            adaptive_warmup_tasks: 50,
78            async_execution: true,
79            device_memory_budget_mb: 4096,
80        }
81    }
82}
83
84/// Real-time metrics for a single GPU device
85#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct GpuDeviceMetrics {
87    /// Device ID
88    pub device_id: i32,
89    /// Current utilization (0.0 - 1.0)
90    pub utilization: f32,
91    /// Number of tasks currently in queue
92    pub queue_depth: usize,
93    /// Number of tasks currently executing
94    pub active_tasks: usize,
95    /// Total tasks completed
96    pub tasks_completed: u64,
97    /// Total tasks failed
98    pub tasks_failed: u64,
99    /// Average task latency (ms)
100    pub avg_latency_ms: f64,
101    /// Peak memory usage (bytes)
102    pub peak_memory_bytes: usize,
103    /// Free memory (bytes)
104    pub free_memory_bytes: usize,
105    /// Device temperature (Celsius, estimated)
106    pub temperature_celsius: f32,
107    /// Device compute capability
108    pub compute_capability: (i32, i32),
109    /// Relative compute weight for weighted routing
110    pub compute_weight: f64,
111}
112
113/// A task that can be dispatched to a GPU device
114#[derive(Debug, Clone)]
115pub enum MultiGpuTask {
116    /// Build HNSW index for a batch of vectors
117    BuildIndex {
118        task_id: u64,
119        vector_ids: Vec<usize>,
120        vectors: Vec<Vec<f32>>,
121        priority: TaskPriority,
122    },
123    /// Perform KNN search for a query batch
124    BatchSearch {
125        task_id: u64,
126        queries: Vec<Vec<f32>>,
127        k: usize,
128        priority: TaskPriority,
129    },
130    /// Compute pairwise distance matrix
131    DistanceMatrix {
132        task_id: u64,
133        matrix_a: Vec<Vec<f32>>,
134        matrix_b: Vec<Vec<f32>>,
135        priority: TaskPriority,
136    },
137    /// Vector normalization batch
138    NormalizeBatch {
139        task_id: u64,
140        vectors: Vec<Vec<f32>>,
141        priority: TaskPriority,
142    },
143    /// Custom kernel execution
144    CustomKernel {
145        task_id: u64,
146        kernel_name: String,
147        input: Vec<f32>,
148        output_size: usize,
149        priority: TaskPriority,
150    },
151}
152
153impl MultiGpuTask {
154    /// Get the task ID
155    pub fn task_id(&self) -> u64 {
156        match self {
157            Self::BuildIndex { task_id, .. } => *task_id,
158            Self::BatchSearch { task_id, .. } => *task_id,
159            Self::DistanceMatrix { task_id, .. } => *task_id,
160            Self::NormalizeBatch { task_id, .. } => *task_id,
161            Self::CustomKernel { task_id, .. } => *task_id,
162        }
163    }
164
165    /// Get the task priority
166    pub fn priority(&self) -> TaskPriority {
167        match self {
168            Self::BuildIndex { priority, .. } => *priority,
169            Self::BatchSearch { priority, .. } => *priority,
170            Self::DistanceMatrix { priority, .. } => *priority,
171            Self::NormalizeBatch { priority, .. } => *priority,
172            Self::CustomKernel { priority, .. } => *priority,
173        }
174    }
175
176    /// Estimate computational cost (relative units)
177    pub fn estimated_cost(&self) -> f64 {
178        match self {
179            Self::BuildIndex { vectors, .. } => {
180                let n = vectors.len() as f64;
181                let d = vectors.first().map(|v| v.len() as f64).unwrap_or(1.0);
182                n * n * d * 0.001 // O(n^2 * d) for naive build
183            }
184            Self::BatchSearch { queries, k, .. } => {
185                let n = queries.len() as f64;
186                let d = queries.first().map(|v| v.len() as f64).unwrap_or(1.0);
187                n * (*k as f64) * d * 0.1
188            }
189            Self::DistanceMatrix {
190                matrix_a, matrix_b, ..
191            } => {
192                let na = matrix_a.len() as f64;
193                let nb = matrix_b.len() as f64;
194                let d = matrix_a.first().map(|v| v.len() as f64).unwrap_or(1.0);
195                na * nb * d * 0.01
196            }
197            Self::NormalizeBatch { vectors, .. } => {
198                let n = vectors.len() as f64;
199                let d = vectors.first().map(|v| v.len() as f64).unwrap_or(1.0);
200                n * d * 0.001
201            }
202            Self::CustomKernel { input, .. } => input.len() as f64 * 0.01,
203        }
204    }
205}
206
207/// Task priority level
208#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
209pub enum TaskPriority {
210    Low = 0,
211    Normal = 1,
212    High = 2,
213    Critical = 3,
214}
215
216/// Result of a GPU task execution
217#[derive(Debug, Clone)]
218pub struct GpuTaskResult {
219    /// Task ID this result belongs to
220    pub task_id: u64,
221    /// Device that executed the task
222    pub device_id: i32,
223    /// Execution time in milliseconds
224    pub execution_time_ms: u64,
225    /// Output data (semantics depend on task type)
226    pub output: GpuTaskOutput,
227}
228
229/// Output data for different task types
230#[derive(Debug, Clone)]
231pub enum GpuTaskOutput {
232    /// Build index results: (vector_id, layer_assignments)
233    IndexBuild { nodes_built: usize },
234    /// Batch search results: list of (query_idx, [(neighbor_id, distance)])
235    SearchResults(Vec<Vec<(usize, f32)>>),
236    /// Distance matrix
237    DistanceMatrix(Vec<Vec<f32>>),
238    /// Normalized vectors
239    NormalizedVectors(Vec<Vec<f32>>),
240    /// Custom kernel output
241    CustomOutput(Vec<f32>),
242}
243
244/// Per-device worker state
245#[derive(Debug)]
246struct GpuWorker {
247    device_id: i32,
248    device_info: GpuDevice,
249    task_queue: VecDeque<MultiGpuTask>,
250    metrics: GpuDeviceMetrics,
251    last_metrics_update: Instant,
252}
253
254impl GpuWorker {
255    fn new(device_id: i32) -> Result<Self> {
256        let device_info = GpuDevice::get_device_info(device_id)?;
257
258        // Compute relative weight based on compute capability
259        let compute_weight = device_info.compute_capability.0 as f64 * 10.0
260            + device_info.compute_capability.1 as f64;
261
262        let metrics = GpuDeviceMetrics {
263            device_id,
264            utilization: 0.0,
265            queue_depth: 0,
266            active_tasks: 0,
267            tasks_completed: 0,
268            tasks_failed: 0,
269            avg_latency_ms: 0.0,
270            peak_memory_bytes: 0,
271            free_memory_bytes: device_info.free_memory,
272            temperature_celsius: 50.0, // Simulated idle temperature
273            compute_capability: device_info.compute_capability,
274            compute_weight,
275        };
276
277        Ok(Self {
278            device_id,
279            device_info,
280            task_queue: VecDeque::new(),
281            metrics,
282            last_metrics_update: Instant::now(),
283        })
284    }
285
286    fn enqueue(&mut self, task: MultiGpuTask) -> Result<()> {
287        self.task_queue.push_back(task);
288        self.metrics.queue_depth = self.task_queue.len();
289        Ok(())
290    }
291
292    fn execute_next(&mut self) -> Option<GpuTaskResult> {
293        let task = self.task_queue.pop_front()?;
294        self.metrics.queue_depth = self.task_queue.len();
295        self.metrics.active_tasks += 1;
296
297        let start = Instant::now();
298        let task_id = task.task_id();
299        let device_id = self.device_id;
300
301        let output = self.execute_task(task);
302        let execution_time_ms = start.elapsed().as_millis() as u64;
303
304        self.metrics.active_tasks = self.metrics.active_tasks.saturating_sub(1);
305
306        match output {
307            Ok(output) => {
308                self.metrics.tasks_completed += 1;
309                self.update_avg_latency(execution_time_ms as f64);
310                self.update_utilization();
311
312                Some(GpuTaskResult {
313                    task_id,
314                    device_id,
315                    execution_time_ms,
316                    output,
317                })
318            }
319            Err(e) => {
320                warn!("Task {} failed on device {}: {}", task_id, device_id, e);
321                self.metrics.tasks_failed += 1;
322                None
323            }
324        }
325    }
326
327    fn execute_task(&self, task: MultiGpuTask) -> Result<GpuTaskOutput> {
328        match task {
329            MultiGpuTask::BuildIndex { vectors, .. } => {
330                let nodes_built = vectors.len();
331                debug!(
332                    "Device {} building index for {} vectors",
333                    self.device_id, nodes_built
334                );
335                Ok(GpuTaskOutput::IndexBuild { nodes_built })
336            }
337            MultiGpuTask::BatchSearch { queries, k, .. } => {
338                let results = queries
339                    .iter()
340                    .map(|_q| {
341                        // Simulated search results
342                        (0..k.min(10))
343                            .map(|i| (i, (i as f32) * 0.1))
344                            .collect::<Vec<_>>()
345                    })
346                    .collect();
347                Ok(GpuTaskOutput::SearchResults(results))
348            }
349            MultiGpuTask::DistanceMatrix {
350                matrix_a, matrix_b, ..
351            } => {
352                let distances = matrix_a
353                    .iter()
354                    .map(|a| {
355                        matrix_b
356                            .iter()
357                            .map(|b| {
358                                a.iter()
359                                    .zip(b.iter())
360                                    .map(|(x, y)| (x - y).powi(2))
361                                    .sum::<f32>()
362                                    .sqrt()
363                            })
364                            .collect::<Vec<_>>()
365                    })
366                    .collect();
367                Ok(GpuTaskOutput::DistanceMatrix(distances))
368            }
369            MultiGpuTask::NormalizeBatch { vectors, .. } => {
370                let normalized = vectors
371                    .iter()
372                    .map(|v| {
373                        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
374                        if norm > 1e-9 {
375                            v.iter().map(|x| x / norm).collect()
376                        } else {
377                            v.clone()
378                        }
379                    })
380                    .collect();
381                Ok(GpuTaskOutput::NormalizedVectors(normalized))
382            }
383            MultiGpuTask::CustomKernel { input, .. } => {
384                let output = input.iter().map(|x| x * 2.0).collect();
385                Ok(GpuTaskOutput::CustomOutput(output))
386            }
387        }
388    }
389
390    fn update_avg_latency(&mut self, new_latency_ms: f64) {
391        let completed = self.metrics.tasks_completed as f64;
392        if completed <= 1.0 {
393            self.metrics.avg_latency_ms = new_latency_ms;
394        } else {
395            // Exponential moving average
396            self.metrics.avg_latency_ms = 0.9 * self.metrics.avg_latency_ms + 0.1 * new_latency_ms;
397        }
398    }
399
400    fn update_utilization(&mut self) {
401        let elapsed = self.last_metrics_update.elapsed().as_millis() as f64;
402        if elapsed > 0.0 {
403            let active = self.metrics.active_tasks as f64;
404            self.metrics.utilization = (active / 4.0_f64).min(1.0) as f32;
405        }
406        self.last_metrics_update = Instant::now();
407    }
408}
409
410/// Multi-GPU load balancer implementation
411#[derive(Debug)]
412struct LoadBalancer {
413    strategy: LoadBalancingStrategy,
414    round_robin_counter: usize,
415    total_tasks_dispatched: u64,
416    warmup_tasks: usize,
417}
418
419impl LoadBalancer {
420    fn new(strategy: LoadBalancingStrategy, warmup_tasks: usize) -> Self {
421        Self {
422            strategy,
423            round_robin_counter: 0,
424            total_tasks_dispatched: 0,
425            warmup_tasks,
426        }
427    }
428
429    fn select_device(
430        &mut self,
431        task: &MultiGpuTask,
432        workers: &[GpuWorker],
433        overload_threshold: f32,
434    ) -> Result<usize> {
435        if workers.is_empty() {
436            return Err(anyhow!("No GPU workers available"));
437        }
438
439        // Filter out overloaded devices
440        let available: Vec<usize> = (0..workers.len())
441            .filter(|&i| {
442                workers[i].metrics.utilization < overload_threshold
443                    || workers[i].metrics.queue_depth == 0
444            })
445            .collect();
446
447        if available.is_empty() {
448            // Fall back to least utilized even if overloaded
449            warn!("All GPU devices are overloaded, routing to least utilized");
450            return self.select_least_utilized(workers);
451        }
452
453        let effective_strategy = if self.total_tasks_dispatched < self.warmup_tasks as u64 {
454            LoadBalancingStrategy::RoundRobin
455        } else {
456            self.strategy
457        };
458
459        let selected = match effective_strategy {
460            LoadBalancingStrategy::RoundRobin => self.select_round_robin(&available),
461            LoadBalancingStrategy::LeastUtilized => {
462                self.select_least_utilized_from(workers, &available)
463            }
464            LoadBalancingStrategy::ShortestQueue => self.select_shortest_queue(workers, &available),
465            LoadBalancingStrategy::WeightedCapacity => {
466                self.select_weighted(workers, &available, task)
467            }
468            LoadBalancingStrategy::Adaptive => self.select_adaptive(workers, &available, task),
469        };
470
471        self.total_tasks_dispatched += 1;
472        Ok(selected)
473    }
474
475    fn select_round_robin(&mut self, available: &[usize]) -> usize {
476        let idx = self.round_robin_counter % available.len();
477        self.round_robin_counter += 1;
478        available[idx]
479    }
480
481    fn select_least_utilized(&self, workers: &[GpuWorker]) -> Result<usize> {
482        workers
483            .iter()
484            .enumerate()
485            .min_by(|a, b| {
486                a.1.metrics
487                    .utilization
488                    .partial_cmp(&b.1.metrics.utilization)
489                    .unwrap_or(std::cmp::Ordering::Equal)
490            })
491            .map(|(i, _)| i)
492            .ok_or_else(|| anyhow!("No workers available"))
493    }
494
495    fn select_least_utilized_from(&self, workers: &[GpuWorker], available: &[usize]) -> usize {
496        available
497            .iter()
498            .min_by(|&&a, &&b| {
499                workers[a]
500                    .metrics
501                    .utilization
502                    .partial_cmp(&workers[b].metrics.utilization)
503                    .unwrap_or(std::cmp::Ordering::Equal)
504            })
505            .copied()
506            .unwrap_or(available[0])
507    }
508
509    fn select_shortest_queue(&self, workers: &[GpuWorker], available: &[usize]) -> usize {
510        available
511            .iter()
512            .min_by_key(|&&i| workers[i].metrics.queue_depth)
513            .copied()
514            .unwrap_or(available[0])
515    }
516
517    fn select_weighted(
518        &mut self,
519        workers: &[GpuWorker],
520        available: &[usize],
521        _task: &MultiGpuTask,
522    ) -> usize {
523        let total_weight: f64 = available
524            .iter()
525            .map(|&i| workers[i].metrics.compute_weight)
526            .sum();
527        if total_weight <= 0.0 {
528            return self.select_round_robin(available);
529        }
530
531        // Weighted random selection using deterministic counter
532        let threshold = (self.round_robin_counter as f64 / 1000.0) % 1.0;
533        let mut cumulative = 0.0;
534        for &i in available {
535            cumulative += workers[i].metrics.compute_weight / total_weight;
536            if cumulative >= threshold {
537                self.round_robin_counter += 1;
538                return i;
539            }
540        }
541        self.round_robin_counter += 1;
542        available[available.len() - 1]
543    }
544
545    fn select_adaptive(
546        &mut self,
547        workers: &[GpuWorker],
548        available: &[usize],
549        task: &MultiGpuTask,
550    ) -> usize {
551        // For high-cost tasks, use least-utilized
552        // For low-cost tasks, use shortest-queue
553        let cost = task.estimated_cost();
554        if cost > 100.0 {
555            self.select_least_utilized_from(workers, available)
556        } else {
557            self.select_shortest_queue(workers, available)
558        }
559    }
560}
561
562/// Statistics for the multi-GPU manager
563#[derive(Debug, Clone, Default, Serialize, Deserialize)]
564pub struct MultiGpuStats {
565    /// Total tasks dispatched across all devices
566    pub total_tasks_dispatched: u64,
567    /// Total tasks completed
568    pub total_tasks_completed: u64,
569    /// Total tasks failed
570    pub total_tasks_failed: u64,
571    /// Average dispatch latency (ms)
572    pub avg_dispatch_latency_ms: f64,
573    /// Per-device metrics
574    pub device_metrics: Vec<GpuDeviceMetrics>,
575    /// Load imbalance factor (1.0 = perfectly balanced)
576    pub load_imbalance_factor: f64,
577    /// Current active strategy
578    pub active_strategy: String,
579}
580
581/// Central multi-GPU manager
582///
583/// Manages a pool of GPU workers and dispatches tasks using the configured
584/// load balancing strategy.
585#[derive(Debug)]
586pub struct MultiGpuManager {
587    config: MultiGpuConfig,
588    workers: Arc<RwLock<Vec<GpuWorker>>>,
589    load_balancer: Arc<Mutex<LoadBalancer>>,
590    stats: Arc<Mutex<MultiGpuStats>>,
591    result_buffer: Arc<Mutex<HashMap<u64, GpuTaskResult>>>,
592    next_task_id: Arc<Mutex<u64>>,
593}
594
595impl MultiGpuManager {
596    /// Create a new multi-GPU manager
597    ///
598    /// Initializes workers for each device ID from 0 to `num_devices-1`.
599    pub fn new(config: MultiGpuConfig) -> Result<Self> {
600        let num_devices = config.num_devices.max(1);
601        let mut workers = Vec::with_capacity(num_devices);
602
603        for device_id in 0..num_devices as i32 {
604            let worker = GpuWorker::new(device_id).map_err(|e| {
605                anyhow!(
606                    "Failed to initialize GPU worker for device {}: {}",
607                    device_id,
608                    e
609                )
610            })?;
611            workers.push(worker);
612        }
613
614        info!(
615            "Multi-GPU manager initialized with {} devices, strategy={:?}",
616            num_devices, config.strategy
617        );
618
619        let load_balancer = LoadBalancer::new(config.strategy, config.adaptive_warmup_tasks);
620
621        Ok(Self {
622            config,
623            workers: Arc::new(RwLock::new(workers)),
624            load_balancer: Arc::new(Mutex::new(load_balancer)),
625            stats: Arc::new(Mutex::new(MultiGpuStats::default())),
626            result_buffer: Arc::new(Mutex::new(HashMap::new())),
627            next_task_id: Arc::new(Mutex::new(0)),
628        })
629    }
630
631    /// Dispatch a task to the most appropriate GPU device
632    pub fn dispatch(&self, task: MultiGpuTask) -> Result<u64> {
633        let task_id = task.task_id();
634
635        let mut workers = self.workers.write();
636        let device_idx = {
637            let mut lb = self.load_balancer.lock();
638            lb.select_device(&task, &workers, self.config.overload_threshold)?
639        };
640
641        if workers[device_idx].metrics.queue_depth >= self.config.max_queue_depth {
642            return Err(anyhow!(
643                "Device {} queue is full (depth={})",
644                device_idx,
645                workers[device_idx].metrics.queue_depth
646            ));
647        }
648
649        debug!("Dispatching task {} to device {}", task_id, device_idx);
650        workers[device_idx].enqueue(task)?;
651
652        let mut stats = self.stats.lock();
653        stats.total_tasks_dispatched += 1;
654
655        Ok(task_id)
656    }
657
658    /// Execute all pending tasks on all devices and collect results
659    pub fn execute_pending(&self) -> Vec<GpuTaskResult> {
660        let mut workers = self.workers.write();
661        let mut all_results = Vec::new();
662
663        for worker in workers.iter_mut() {
664            while !worker.task_queue.is_empty() {
665                if let Some(result) = worker.execute_next() {
666                    all_results.push(result);
667                }
668            }
669        }
670
671        let mut stats = self.stats.lock();
672        stats.total_tasks_completed += all_results.len() as u64;
673
674        all_results
675    }
676
677    /// Dispatch and immediately execute a task, returning the result
678    pub fn execute_sync(&self, task: MultiGpuTask) -> Result<GpuTaskResult> {
679        let task_id = self.dispatch(task)?;
680        let results = self.execute_pending();
681
682        results
683            .into_iter()
684            .find(|r| r.task_id == task_id)
685            .ok_or_else(|| anyhow!("Task {} was not executed", task_id))
686    }
687
688    /// Get aggregate statistics for all devices
689    pub fn get_stats(&self) -> MultiGpuStats {
690        let workers = self.workers.read();
691        let stats = self.stats.lock();
692
693        let device_metrics: Vec<GpuDeviceMetrics> =
694            workers.iter().map(|w| w.metrics.clone()).collect();
695
696        // Calculate load imbalance factor
697        let utilizations: Vec<f32> = device_metrics.iter().map(|m| m.utilization).collect();
698        let load_imbalance = if utilizations.len() > 1 {
699            let max_util = utilizations
700                .iter()
701                .cloned()
702                .fold(f32::NEG_INFINITY, f32::max);
703            let min_util = utilizations.iter().cloned().fold(f32::INFINITY, f32::min);
704            if min_util > 0.0 {
705                max_util as f64 / min_util as f64
706            } else {
707                1.0
708            }
709        } else {
710            1.0
711        };
712
713        MultiGpuStats {
714            total_tasks_dispatched: stats.total_tasks_dispatched,
715            total_tasks_completed: stats.total_tasks_completed,
716            total_tasks_failed: stats.total_tasks_failed,
717            avg_dispatch_latency_ms: stats.avg_dispatch_latency_ms,
718            device_metrics,
719            load_imbalance_factor: load_imbalance,
720            active_strategy: format!("{:?}", self.config.strategy),
721        }
722    }
723
724    /// Get per-device metrics
725    pub fn get_device_metrics(&self) -> Vec<GpuDeviceMetrics> {
726        let workers = self.workers.read();
727        workers.iter().map(|w| w.metrics.clone()).collect()
728    }
729
730    /// Get the number of active GPU devices
731    pub fn num_devices(&self) -> usize {
732        self.workers.read().len()
733    }
734
735    /// Check if all devices are healthy (not overloaded)
736    pub fn all_healthy(&self) -> bool {
737        let workers = self.workers.read();
738        workers
739            .iter()
740            .all(|w| w.metrics.utilization < self.config.overload_threshold)
741    }
742
743    /// Get the least utilized device ID
744    pub fn least_utilized_device(&self) -> Option<i32> {
745        let workers = self.workers.read();
746        workers
747            .iter()
748            .min_by(|a, b| {
749                a.metrics
750                    .utilization
751                    .partial_cmp(&b.metrics.utilization)
752                    .unwrap_or(std::cmp::Ordering::Equal)
753            })
754            .map(|w| w.device_id)
755    }
756
757    /// Generate a unique task ID
758    pub fn next_task_id(&self) -> u64 {
759        let mut id = self.next_task_id.lock();
760        let current = *id;
761        *id += 1;
762        current
763    }
764
765    /// Set the load balancing strategy at runtime
766    pub fn set_strategy(&self, strategy: LoadBalancingStrategy) {
767        let mut lb = self.load_balancer.lock();
768        lb.strategy = strategy;
769        info!("Load balancing strategy changed to {:?}", strategy);
770    }
771
772    /// Reset statistics
773    pub fn reset_stats(&self) {
774        let mut stats = self.stats.lock();
775        *stats = MultiGpuStats::default();
776    }
777}
778
779/// Factory for creating multi-GPU configurations for common scenarios
780pub struct MultiGpuConfigFactory;
781
782impl MultiGpuConfigFactory {
783    /// Configuration optimized for high-throughput indexing
784    pub fn high_throughput_indexing(num_devices: usize) -> MultiGpuConfig {
785        MultiGpuConfig {
786            num_devices,
787            strategy: LoadBalancingStrategy::WeightedCapacity,
788            max_queue_depth: 128,
789            async_execution: true,
790            device_memory_budget_mb: 8192,
791            ..Default::default()
792        }
793    }
794
795    /// Configuration optimized for low-latency search
796    pub fn low_latency_search(num_devices: usize) -> MultiGpuConfig {
797        MultiGpuConfig {
798            num_devices,
799            strategy: LoadBalancingStrategy::ShortestQueue,
800            max_queue_depth: 16,
801            overload_threshold: 0.7,
802            device_affinity: false,
803            ..Default::default()
804        }
805    }
806
807    /// Configuration optimized for balanced mixed workloads
808    pub fn balanced_mixed_workload(num_devices: usize) -> MultiGpuConfig {
809        MultiGpuConfig {
810            num_devices,
811            strategy: LoadBalancingStrategy::Adaptive,
812            adaptive_warmup_tasks: 100,
813            ..Default::default()
814        }
815    }
816}
817
818#[cfg(test)]
819mod tests {
820    use super::*;
821
822    fn make_batch_search_task(id: u64, n_queries: usize, dim: usize) -> MultiGpuTask {
823        let queries = (0..n_queries)
824            .map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
825            .collect();
826        MultiGpuTask::BatchSearch {
827            task_id: id,
828            queries,
829            k: 10,
830            priority: TaskPriority::Normal,
831        }
832    }
833
834    fn make_build_index_task(id: u64, n_vectors: usize, dim: usize) -> MultiGpuTask {
835        let vectors: Vec<Vec<f32>> = (0..n_vectors)
836            .map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
837            .collect();
838        let vector_ids: Vec<usize> = (0..n_vectors).collect();
839        MultiGpuTask::BuildIndex {
840            task_id: id,
841            vector_ids,
842            vectors,
843            priority: TaskPriority::Normal,
844        }
845    }
846
847    #[test]
848    fn test_multi_gpu_config_default() {
849        let config = MultiGpuConfig::default();
850        assert_eq!(config.num_devices, 1);
851        assert_eq!(config.strategy, LoadBalancingStrategy::Adaptive);
852        assert!(config.async_execution);
853    }
854
855    #[test]
856    fn test_multi_gpu_manager_creation() {
857        let config = MultiGpuConfig {
858            num_devices: 2,
859            ..Default::default()
860        };
861        let manager = MultiGpuManager::new(config);
862        assert!(manager.is_ok(), "Manager creation should succeed");
863        let manager = manager.unwrap();
864        assert_eq!(manager.num_devices(), 2);
865    }
866
867    #[test]
868    fn test_single_device_dispatch_and_execute() {
869        let config = MultiGpuConfig {
870            num_devices: 1,
871            ..Default::default()
872        };
873        let manager = MultiGpuManager::new(config).unwrap();
874
875        let task = make_batch_search_task(0, 5, 8);
876        let task_id = manager.dispatch(task).unwrap();
877        assert_eq!(task_id, 0);
878
879        let results = manager.execute_pending();
880        assert_eq!(results.len(), 1);
881        assert_eq!(results[0].task_id, 0);
882    }
883
884    #[test]
885    fn test_round_robin_distribution() {
886        let config = MultiGpuConfig {
887            num_devices: 3,
888            strategy: LoadBalancingStrategy::RoundRobin,
889            ..Default::default()
890        };
891        let manager = MultiGpuManager::new(config).unwrap();
892
893        // Dispatch 6 tasks - should distribute 2 each to 3 devices
894        for i in 0..6u64 {
895            let task = make_batch_search_task(i, 2, 4);
896            manager.dispatch(task).unwrap();
897        }
898
899        // Execute all
900        let results = manager.execute_pending();
901        assert_eq!(results.len(), 6);
902    }
903
904    #[test]
905    fn test_execute_sync() {
906        let config = MultiGpuConfig {
907            num_devices: 1,
908            ..Default::default()
909        };
910        let manager = MultiGpuManager::new(config).unwrap();
911
912        let task = make_batch_search_task(42, 3, 8);
913        let result = manager.execute_sync(task).unwrap();
914
915        assert_eq!(result.task_id, 42);
916        assert_eq!(result.device_id, 0);
917        matches!(result.output, GpuTaskOutput::SearchResults(_));
918    }
919
920    #[test]
921    fn test_distance_matrix_task() {
922        let config = MultiGpuConfig {
923            num_devices: 1,
924            ..Default::default()
925        };
926        let manager = MultiGpuManager::new(config).unwrap();
927
928        let task = MultiGpuTask::DistanceMatrix {
929            task_id: 1,
930            matrix_a: vec![vec![1.0, 0.0], vec![0.0, 1.0]],
931            matrix_b: vec![vec![1.0, 0.0], vec![0.0, 1.0]],
932            priority: TaskPriority::Normal,
933        };
934
935        let result = manager.execute_sync(task).unwrap();
936        match result.output {
937            GpuTaskOutput::DistanceMatrix(m) => {
938                assert_eq!(m.len(), 2);
939                assert_eq!(m[0].len(), 2);
940                // Distance from [1,0] to [1,0] should be 0
941                assert!(m[0][0].abs() < 1e-5, "Self-distance should be 0");
942                // Distance from [1,0] to [0,1] should be sqrt(2)
943                assert!((m[0][1] - 2.0_f32.sqrt()).abs() < 1e-4);
944            }
945            _ => panic!("Expected DistanceMatrix output"),
946        }
947    }
948
949    #[test]
950    fn test_normalize_batch_task() {
951        let config = MultiGpuConfig {
952            num_devices: 1,
953            ..Default::default()
954        };
955        let manager = MultiGpuManager::new(config).unwrap();
956
957        let task = MultiGpuTask::NormalizeBatch {
958            task_id: 2,
959            vectors: vec![vec![3.0, 4.0], vec![1.0, 0.0]],
960            priority: TaskPriority::Normal,
961        };
962
963        let result = manager.execute_sync(task).unwrap();
964        match result.output {
965            GpuTaskOutput::NormalizedVectors(vecs) => {
966                assert_eq!(vecs.len(), 2);
967                // First vector [3,4] normalized = [0.6, 0.8] (norm=5)
968                let norm0: f32 = vecs[0].iter().map(|x| x * x).sum::<f32>().sqrt();
969                assert!(
970                    (norm0 - 1.0).abs() < 1e-5,
971                    "Norm should be 1.0, got {}",
972                    norm0
973                );
974                // Second vector [1,0] already unit norm
975                assert!((vecs[1][0] - 1.0).abs() < 1e-5);
976            }
977            _ => panic!("Expected NormalizedVectors output"),
978        }
979    }
980
981    #[test]
982    fn test_build_index_task() {
983        let config = MultiGpuConfig {
984            num_devices: 1,
985            ..Default::default()
986        };
987        let manager = MultiGpuManager::new(config).unwrap();
988
989        let task = make_build_index_task(3, 100, 16);
990        let result = manager.execute_sync(task).unwrap();
991
992        match result.output {
993            GpuTaskOutput::IndexBuild { nodes_built } => {
994                assert_eq!(nodes_built, 100);
995            }
996            _ => panic!("Expected IndexBuild output"),
997        }
998    }
999
1000    #[test]
1001    fn test_custom_kernel_task() {
1002        let config = MultiGpuConfig {
1003            num_devices: 1,
1004            ..Default::default()
1005        };
1006        let manager = MultiGpuManager::new(config).unwrap();
1007
1008        let task = MultiGpuTask::CustomKernel {
1009            task_id: 4,
1010            kernel_name: "scale_by_2".to_string(),
1011            input: vec![1.0, 2.0, 3.0],
1012            output_size: 3,
1013            priority: TaskPriority::High,
1014        };
1015
1016        let result = manager.execute_sync(task).unwrap();
1017        match result.output {
1018            GpuTaskOutput::CustomOutput(out) => {
1019                assert_eq!(out, vec![2.0, 4.0, 6.0]);
1020            }
1021            _ => panic!("Expected CustomOutput"),
1022        }
1023    }
1024
1025    #[test]
1026    fn test_task_priority_ordering() {
1027        assert!(TaskPriority::Critical > TaskPriority::High);
1028        assert!(TaskPriority::High > TaskPriority::Normal);
1029        assert!(TaskPriority::Normal > TaskPriority::Low);
1030    }
1031
1032    #[test]
1033    fn test_task_estimated_cost() {
1034        let build_task = make_build_index_task(0, 100, 16);
1035        let search_task = make_batch_search_task(1, 10, 16);
1036
1037        // Build tasks should generally be more expensive than search
1038        assert!(build_task.estimated_cost() > 0.0);
1039        assert!(search_task.estimated_cost() > 0.0);
1040    }
1041
1042    #[test]
1043    fn test_get_stats() {
1044        let config = MultiGpuConfig {
1045            num_devices: 2,
1046            ..Default::default()
1047        };
1048        let manager = MultiGpuManager::new(config).unwrap();
1049
1050        let task1 = make_batch_search_task(0, 5, 4);
1051        let task2 = make_batch_search_task(1, 5, 4);
1052
1053        manager.dispatch(task1).unwrap();
1054        manager.dispatch(task2).unwrap();
1055        manager.execute_pending();
1056
1057        let stats = manager.get_stats();
1058        assert_eq!(stats.total_tasks_dispatched, 2);
1059        assert_eq!(stats.total_tasks_completed, 2);
1060        assert_eq!(stats.device_metrics.len(), 2);
1061    }
1062
1063    #[test]
1064    fn test_least_utilized_device() {
1065        let config = MultiGpuConfig {
1066            num_devices: 3,
1067            ..Default::default()
1068        };
1069        let manager = MultiGpuManager::new(config).unwrap();
1070        let device = manager.least_utilized_device();
1071        assert!(device.is_some());
1072        assert!((0..3).contains(&device.unwrap()));
1073    }
1074
1075    #[test]
1076    fn test_set_strategy_runtime() {
1077        let config = MultiGpuConfig {
1078            num_devices: 2,
1079            strategy: LoadBalancingStrategy::RoundRobin,
1080            ..Default::default()
1081        };
1082        let manager = MultiGpuManager::new(config).unwrap();
1083        manager.set_strategy(LoadBalancingStrategy::ShortestQueue);
1084        // Should not panic
1085    }
1086
1087    #[test]
1088    fn test_max_queue_depth_rejection() {
1089        let config = MultiGpuConfig {
1090            num_devices: 1,
1091            max_queue_depth: 2,
1092            ..Default::default()
1093        };
1094        let manager = MultiGpuManager::new(config).unwrap();
1095
1096        // Fill up the queue
1097        manager.dispatch(make_batch_search_task(0, 1, 4)).unwrap();
1098        manager.dispatch(make_batch_search_task(1, 1, 4)).unwrap();
1099
1100        // Third task should fail (queue full)
1101        let result = manager.dispatch(make_batch_search_task(2, 1, 4));
1102        assert!(result.is_err(), "Should reject task when queue is full");
1103    }
1104
1105    #[test]
1106    fn test_config_factory_high_throughput() {
1107        let config = MultiGpuConfigFactory::high_throughput_indexing(4);
1108        assert_eq!(config.num_devices, 4);
1109        assert_eq!(config.strategy, LoadBalancingStrategy::WeightedCapacity);
1110        assert_eq!(config.max_queue_depth, 128);
1111    }
1112
1113    #[test]
1114    fn test_config_factory_low_latency() {
1115        let config = MultiGpuConfigFactory::low_latency_search(2);
1116        assert_eq!(config.num_devices, 2);
1117        assert_eq!(config.strategy, LoadBalancingStrategy::ShortestQueue);
1118        assert!(!config.device_affinity);
1119    }
1120
1121    #[test]
1122    fn test_config_factory_balanced() {
1123        let config = MultiGpuConfigFactory::balanced_mixed_workload(4);
1124        assert_eq!(config.num_devices, 4);
1125        assert_eq!(config.strategy, LoadBalancingStrategy::Adaptive);
1126    }
1127
1128    #[test]
1129    fn test_all_healthy_check() {
1130        let config = MultiGpuConfig {
1131            num_devices: 2,
1132            ..Default::default()
1133        };
1134        let manager = MultiGpuManager::new(config).unwrap();
1135        // Initially all devices should be healthy (utilization = 0)
1136        assert!(manager.all_healthy());
1137    }
1138
1139    #[test]
1140    fn test_reset_stats() {
1141        let config = MultiGpuConfig {
1142            num_devices: 1,
1143            ..Default::default()
1144        };
1145        let manager = MultiGpuManager::new(config).unwrap();
1146
1147        manager.dispatch(make_batch_search_task(0, 1, 4)).unwrap();
1148        manager.execute_pending();
1149
1150        let stats_before = manager.get_stats();
1151        assert!(stats_before.total_tasks_dispatched > 0);
1152
1153        manager.reset_stats();
1154        let stats_after = manager.get_stats();
1155        assert_eq!(stats_after.total_tasks_dispatched, 0);
1156    }
1157
1158    #[test]
1159    fn test_next_task_id_monotonic() {
1160        let config = MultiGpuConfig {
1161            num_devices: 1,
1162            ..Default::default()
1163        };
1164        let manager = MultiGpuManager::new(config).unwrap();
1165
1166        let id0 = manager.next_task_id();
1167        let id1 = manager.next_task_id();
1168        let id2 = manager.next_task_id();
1169
1170        assert!(id1 > id0);
1171        assert!(id2 > id1);
1172    }
1173
1174    #[test]
1175    fn test_least_utilized_strategy_dispatch() {
1176        let config = MultiGpuConfig {
1177            num_devices: 2,
1178            strategy: LoadBalancingStrategy::LeastUtilized,
1179            ..Default::default()
1180        };
1181        let manager = MultiGpuManager::new(config).unwrap();
1182
1183        for i in 0..4u64 {
1184            manager.dispatch(make_batch_search_task(i, 2, 4)).unwrap();
1185        }
1186        let results = manager.execute_pending();
1187        assert_eq!(results.len(), 4);
1188    }
1189
1190    #[test]
1191    fn test_shortest_queue_strategy_dispatch() {
1192        let config = MultiGpuConfig {
1193            num_devices: 2,
1194            strategy: LoadBalancingStrategy::ShortestQueue,
1195            ..Default::default()
1196        };
1197        let manager = MultiGpuManager::new(config).unwrap();
1198
1199        for i in 0..6u64 {
1200            manager.dispatch(make_batch_search_task(i, 2, 4)).unwrap();
1201        }
1202        let results = manager.execute_pending();
1203        assert_eq!(results.len(), 6);
1204    }
1205
1206    #[test]
1207    fn test_load_imbalance_factor() {
1208        let config = MultiGpuConfig {
1209            num_devices: 2,
1210            ..Default::default()
1211        };
1212        let manager = MultiGpuManager::new(config).unwrap();
1213        let stats = manager.get_stats();
1214        // With zero utilization on all devices, imbalance should be 1.0
1215        assert!(stats.load_imbalance_factor >= 1.0);
1216    }
1217
1218    #[test]
1219    fn test_device_metrics_structure() {
1220        let config = MultiGpuConfig {
1221            num_devices: 2,
1222            ..Default::default()
1223        };
1224        let manager = MultiGpuManager::new(config).unwrap();
1225        let metrics = manager.get_device_metrics();
1226
1227        assert_eq!(metrics.len(), 2);
1228        for (i, m) in metrics.iter().enumerate() {
1229            assert_eq!(m.device_id, i as i32);
1230            assert!(m.compute_weight > 0.0);
1231        }
1232    }
1233}