1use anyhow::{anyhow, Result};
20use parking_lot::{Mutex, RwLock};
21use serde::{Deserialize, Serialize};
22use std::collections::{HashMap, VecDeque};
23use std::sync::Arc;
24use std::time::Instant;
25use tracing::{debug, info, warn};
26
27use crate::gpu::GpuDevice;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
31pub enum LoadBalancingStrategy {
32 RoundRobin,
34 LeastUtilized,
36 ShortestQueue,
38 WeightedCapacity,
40 #[default]
42 Adaptive,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct MultiGpuConfig {
48 pub num_devices: usize,
50 pub strategy: LoadBalancingStrategy,
52 pub max_queue_depth: usize,
54 pub utilization_sample_interval_ms: u64,
56 pub device_affinity: bool,
58 pub overload_threshold: f32,
60 pub adaptive_warmup_tasks: usize,
62 pub async_execution: bool,
64 pub device_memory_budget_mb: usize,
66}
67
68impl Default for MultiGpuConfig {
69 fn default() -> Self {
70 Self {
71 num_devices: 1,
72 strategy: LoadBalancingStrategy::Adaptive,
73 max_queue_depth: 64,
74 utilization_sample_interval_ms: 100,
75 device_affinity: true,
76 overload_threshold: 0.85,
77 adaptive_warmup_tasks: 50,
78 async_execution: true,
79 device_memory_budget_mb: 4096,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct GpuDeviceMetrics {
87 pub device_id: i32,
89 pub utilization: f32,
91 pub queue_depth: usize,
93 pub active_tasks: usize,
95 pub tasks_completed: u64,
97 pub tasks_failed: u64,
99 pub avg_latency_ms: f64,
101 pub peak_memory_bytes: usize,
103 pub free_memory_bytes: usize,
105 pub temperature_celsius: f32,
107 pub compute_capability: (i32, i32),
109 pub compute_weight: f64,
111}
112
113#[derive(Debug, Clone)]
115pub enum MultiGpuTask {
116 BuildIndex {
118 task_id: u64,
119 vector_ids: Vec<usize>,
120 vectors: Vec<Vec<f32>>,
121 priority: TaskPriority,
122 },
123 BatchSearch {
125 task_id: u64,
126 queries: Vec<Vec<f32>>,
127 k: usize,
128 priority: TaskPriority,
129 },
130 DistanceMatrix {
132 task_id: u64,
133 matrix_a: Vec<Vec<f32>>,
134 matrix_b: Vec<Vec<f32>>,
135 priority: TaskPriority,
136 },
137 NormalizeBatch {
139 task_id: u64,
140 vectors: Vec<Vec<f32>>,
141 priority: TaskPriority,
142 },
143 CustomKernel {
145 task_id: u64,
146 kernel_name: String,
147 input: Vec<f32>,
148 output_size: usize,
149 priority: TaskPriority,
150 },
151}
152
153impl MultiGpuTask {
154 pub fn task_id(&self) -> u64 {
156 match self {
157 Self::BuildIndex { task_id, .. } => *task_id,
158 Self::BatchSearch { task_id, .. } => *task_id,
159 Self::DistanceMatrix { task_id, .. } => *task_id,
160 Self::NormalizeBatch { task_id, .. } => *task_id,
161 Self::CustomKernel { task_id, .. } => *task_id,
162 }
163 }
164
165 pub fn priority(&self) -> TaskPriority {
167 match self {
168 Self::BuildIndex { priority, .. } => *priority,
169 Self::BatchSearch { priority, .. } => *priority,
170 Self::DistanceMatrix { priority, .. } => *priority,
171 Self::NormalizeBatch { priority, .. } => *priority,
172 Self::CustomKernel { priority, .. } => *priority,
173 }
174 }
175
176 pub fn estimated_cost(&self) -> f64 {
178 match self {
179 Self::BuildIndex { vectors, .. } => {
180 let n = vectors.len() as f64;
181 let d = vectors.first().map(|v| v.len() as f64).unwrap_or(1.0);
182 n * n * d * 0.001 }
184 Self::BatchSearch { queries, k, .. } => {
185 let n = queries.len() as f64;
186 let d = queries.first().map(|v| v.len() as f64).unwrap_or(1.0);
187 n * (*k as f64) * d * 0.1
188 }
189 Self::DistanceMatrix {
190 matrix_a, matrix_b, ..
191 } => {
192 let na = matrix_a.len() as f64;
193 let nb = matrix_b.len() as f64;
194 let d = matrix_a.first().map(|v| v.len() as f64).unwrap_or(1.0);
195 na * nb * d * 0.01
196 }
197 Self::NormalizeBatch { vectors, .. } => {
198 let n = vectors.len() as f64;
199 let d = vectors.first().map(|v| v.len() as f64).unwrap_or(1.0);
200 n * d * 0.001
201 }
202 Self::CustomKernel { input, .. } => input.len() as f64 * 0.01,
203 }
204 }
205}
206
207#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
209pub enum TaskPriority {
210 Low = 0,
211 Normal = 1,
212 High = 2,
213 Critical = 3,
214}
215
216#[derive(Debug, Clone)]
218pub struct GpuTaskResult {
219 pub task_id: u64,
221 pub device_id: i32,
223 pub execution_time_ms: u64,
225 pub output: GpuTaskOutput,
227}
228
229#[derive(Debug, Clone)]
231pub enum GpuTaskOutput {
232 IndexBuild { nodes_built: usize },
234 SearchResults(Vec<Vec<(usize, f32)>>),
236 DistanceMatrix(Vec<Vec<f32>>),
238 NormalizedVectors(Vec<Vec<f32>>),
240 CustomOutput(Vec<f32>),
242}
243
244#[derive(Debug)]
246struct GpuWorker {
247 device_id: i32,
248 device_info: GpuDevice,
249 task_queue: VecDeque<MultiGpuTask>,
250 metrics: GpuDeviceMetrics,
251 last_metrics_update: Instant,
252}
253
254impl GpuWorker {
255 fn new(device_id: i32) -> Result<Self> {
256 let device_info = GpuDevice::get_device_info(device_id)?;
257
258 let compute_weight = device_info.compute_capability.0 as f64 * 10.0
260 + device_info.compute_capability.1 as f64;
261
262 let metrics = GpuDeviceMetrics {
263 device_id,
264 utilization: 0.0,
265 queue_depth: 0,
266 active_tasks: 0,
267 tasks_completed: 0,
268 tasks_failed: 0,
269 avg_latency_ms: 0.0,
270 peak_memory_bytes: 0,
271 free_memory_bytes: device_info.free_memory,
272 temperature_celsius: 50.0, compute_capability: device_info.compute_capability,
274 compute_weight,
275 };
276
277 Ok(Self {
278 device_id,
279 device_info,
280 task_queue: VecDeque::new(),
281 metrics,
282 last_metrics_update: Instant::now(),
283 })
284 }
285
286 fn enqueue(&mut self, task: MultiGpuTask) -> Result<()> {
287 self.task_queue.push_back(task);
288 self.metrics.queue_depth = self.task_queue.len();
289 Ok(())
290 }
291
292 fn execute_next(&mut self) -> Option<GpuTaskResult> {
293 let task = self.task_queue.pop_front()?;
294 self.metrics.queue_depth = self.task_queue.len();
295 self.metrics.active_tasks += 1;
296
297 let start = Instant::now();
298 let task_id = task.task_id();
299 let device_id = self.device_id;
300
301 let output = self.execute_task(task);
302 let execution_time_ms = start.elapsed().as_millis() as u64;
303
304 self.metrics.active_tasks = self.metrics.active_tasks.saturating_sub(1);
305
306 match output {
307 Ok(output) => {
308 self.metrics.tasks_completed += 1;
309 self.update_avg_latency(execution_time_ms as f64);
310 self.update_utilization();
311
312 Some(GpuTaskResult {
313 task_id,
314 device_id,
315 execution_time_ms,
316 output,
317 })
318 }
319 Err(e) => {
320 warn!("Task {} failed on device {}: {}", task_id, device_id, e);
321 self.metrics.tasks_failed += 1;
322 None
323 }
324 }
325 }
326
327 fn execute_task(&self, task: MultiGpuTask) -> Result<GpuTaskOutput> {
328 match task {
329 MultiGpuTask::BuildIndex { vectors, .. } => {
330 let nodes_built = vectors.len();
331 debug!(
332 "Device {} building index for {} vectors",
333 self.device_id, nodes_built
334 );
335 Ok(GpuTaskOutput::IndexBuild { nodes_built })
336 }
337 MultiGpuTask::BatchSearch { queries, k, .. } => {
338 let results = queries
339 .iter()
340 .map(|_q| {
341 (0..k.min(10))
343 .map(|i| (i, (i as f32) * 0.1))
344 .collect::<Vec<_>>()
345 })
346 .collect();
347 Ok(GpuTaskOutput::SearchResults(results))
348 }
349 MultiGpuTask::DistanceMatrix {
350 matrix_a, matrix_b, ..
351 } => {
352 let distances = matrix_a
353 .iter()
354 .map(|a| {
355 matrix_b
356 .iter()
357 .map(|b| {
358 a.iter()
359 .zip(b.iter())
360 .map(|(x, y)| (x - y).powi(2))
361 .sum::<f32>()
362 .sqrt()
363 })
364 .collect::<Vec<_>>()
365 })
366 .collect();
367 Ok(GpuTaskOutput::DistanceMatrix(distances))
368 }
369 MultiGpuTask::NormalizeBatch { vectors, .. } => {
370 let normalized = vectors
371 .iter()
372 .map(|v| {
373 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
374 if norm > 1e-9 {
375 v.iter().map(|x| x / norm).collect()
376 } else {
377 v.clone()
378 }
379 })
380 .collect();
381 Ok(GpuTaskOutput::NormalizedVectors(normalized))
382 }
383 MultiGpuTask::CustomKernel { input, .. } => {
384 let output = input.iter().map(|x| x * 2.0).collect();
385 Ok(GpuTaskOutput::CustomOutput(output))
386 }
387 }
388 }
389
390 fn update_avg_latency(&mut self, new_latency_ms: f64) {
391 let completed = self.metrics.tasks_completed as f64;
392 if completed <= 1.0 {
393 self.metrics.avg_latency_ms = new_latency_ms;
394 } else {
395 self.metrics.avg_latency_ms = 0.9 * self.metrics.avg_latency_ms + 0.1 * new_latency_ms;
397 }
398 }
399
400 fn update_utilization(&mut self) {
401 let elapsed = self.last_metrics_update.elapsed().as_millis() as f64;
402 if elapsed > 0.0 {
403 let active = self.metrics.active_tasks as f64;
404 self.metrics.utilization = (active / 4.0_f64).min(1.0) as f32;
405 }
406 self.last_metrics_update = Instant::now();
407 }
408}
409
410#[derive(Debug)]
412struct LoadBalancer {
413 strategy: LoadBalancingStrategy,
414 round_robin_counter: usize,
415 total_tasks_dispatched: u64,
416 warmup_tasks: usize,
417}
418
419impl LoadBalancer {
420 fn new(strategy: LoadBalancingStrategy, warmup_tasks: usize) -> Self {
421 Self {
422 strategy,
423 round_robin_counter: 0,
424 total_tasks_dispatched: 0,
425 warmup_tasks,
426 }
427 }
428
429 fn select_device(
430 &mut self,
431 task: &MultiGpuTask,
432 workers: &[GpuWorker],
433 overload_threshold: f32,
434 ) -> Result<usize> {
435 if workers.is_empty() {
436 return Err(anyhow!("No GPU workers available"));
437 }
438
439 let available: Vec<usize> = (0..workers.len())
441 .filter(|&i| {
442 workers[i].metrics.utilization < overload_threshold
443 || workers[i].metrics.queue_depth == 0
444 })
445 .collect();
446
447 if available.is_empty() {
448 warn!("All GPU devices are overloaded, routing to least utilized");
450 return self.select_least_utilized(workers);
451 }
452
453 let effective_strategy = if self.total_tasks_dispatched < self.warmup_tasks as u64 {
454 LoadBalancingStrategy::RoundRobin
455 } else {
456 self.strategy
457 };
458
459 let selected = match effective_strategy {
460 LoadBalancingStrategy::RoundRobin => self.select_round_robin(&available),
461 LoadBalancingStrategy::LeastUtilized => {
462 self.select_least_utilized_from(workers, &available)
463 }
464 LoadBalancingStrategy::ShortestQueue => self.select_shortest_queue(workers, &available),
465 LoadBalancingStrategy::WeightedCapacity => {
466 self.select_weighted(workers, &available, task)
467 }
468 LoadBalancingStrategy::Adaptive => self.select_adaptive(workers, &available, task),
469 };
470
471 self.total_tasks_dispatched += 1;
472 Ok(selected)
473 }
474
475 fn select_round_robin(&mut self, available: &[usize]) -> usize {
476 let idx = self.round_robin_counter % available.len();
477 self.round_robin_counter += 1;
478 available[idx]
479 }
480
481 fn select_least_utilized(&self, workers: &[GpuWorker]) -> Result<usize> {
482 workers
483 .iter()
484 .enumerate()
485 .min_by(|a, b| {
486 a.1.metrics
487 .utilization
488 .partial_cmp(&b.1.metrics.utilization)
489 .unwrap_or(std::cmp::Ordering::Equal)
490 })
491 .map(|(i, _)| i)
492 .ok_or_else(|| anyhow!("No workers available"))
493 }
494
495 fn select_least_utilized_from(&self, workers: &[GpuWorker], available: &[usize]) -> usize {
496 available
497 .iter()
498 .min_by(|&&a, &&b| {
499 workers[a]
500 .metrics
501 .utilization
502 .partial_cmp(&workers[b].metrics.utilization)
503 .unwrap_or(std::cmp::Ordering::Equal)
504 })
505 .copied()
506 .unwrap_or(available[0])
507 }
508
509 fn select_shortest_queue(&self, workers: &[GpuWorker], available: &[usize]) -> usize {
510 available
511 .iter()
512 .min_by_key(|&&i| workers[i].metrics.queue_depth)
513 .copied()
514 .unwrap_or(available[0])
515 }
516
517 fn select_weighted(
518 &mut self,
519 workers: &[GpuWorker],
520 available: &[usize],
521 _task: &MultiGpuTask,
522 ) -> usize {
523 let total_weight: f64 = available
524 .iter()
525 .map(|&i| workers[i].metrics.compute_weight)
526 .sum();
527 if total_weight <= 0.0 {
528 return self.select_round_robin(available);
529 }
530
531 let threshold = (self.round_robin_counter as f64 / 1000.0) % 1.0;
533 let mut cumulative = 0.0;
534 for &i in available {
535 cumulative += workers[i].metrics.compute_weight / total_weight;
536 if cumulative >= threshold {
537 self.round_robin_counter += 1;
538 return i;
539 }
540 }
541 self.round_robin_counter += 1;
542 available[available.len() - 1]
543 }
544
545 fn select_adaptive(
546 &mut self,
547 workers: &[GpuWorker],
548 available: &[usize],
549 task: &MultiGpuTask,
550 ) -> usize {
551 let cost = task.estimated_cost();
554 if cost > 100.0 {
555 self.select_least_utilized_from(workers, available)
556 } else {
557 self.select_shortest_queue(workers, available)
558 }
559 }
560}
561
562#[derive(Debug, Clone, Default, Serialize, Deserialize)]
564pub struct MultiGpuStats {
565 pub total_tasks_dispatched: u64,
567 pub total_tasks_completed: u64,
569 pub total_tasks_failed: u64,
571 pub avg_dispatch_latency_ms: f64,
573 pub device_metrics: Vec<GpuDeviceMetrics>,
575 pub load_imbalance_factor: f64,
577 pub active_strategy: String,
579}
580
581#[derive(Debug)]
586pub struct MultiGpuManager {
587 config: MultiGpuConfig,
588 workers: Arc<RwLock<Vec<GpuWorker>>>,
589 load_balancer: Arc<Mutex<LoadBalancer>>,
590 stats: Arc<Mutex<MultiGpuStats>>,
591 result_buffer: Arc<Mutex<HashMap<u64, GpuTaskResult>>>,
592 next_task_id: Arc<Mutex<u64>>,
593}
594
595impl MultiGpuManager {
596 pub fn new(config: MultiGpuConfig) -> Result<Self> {
600 let num_devices = config.num_devices.max(1);
601 let mut workers = Vec::with_capacity(num_devices);
602
603 for device_id in 0..num_devices as i32 {
604 let worker = GpuWorker::new(device_id).map_err(|e| {
605 anyhow!(
606 "Failed to initialize GPU worker for device {}: {}",
607 device_id,
608 e
609 )
610 })?;
611 workers.push(worker);
612 }
613
614 info!(
615 "Multi-GPU manager initialized with {} devices, strategy={:?}",
616 num_devices, config.strategy
617 );
618
619 let load_balancer = LoadBalancer::new(config.strategy, config.adaptive_warmup_tasks);
620
621 Ok(Self {
622 config,
623 workers: Arc::new(RwLock::new(workers)),
624 load_balancer: Arc::new(Mutex::new(load_balancer)),
625 stats: Arc::new(Mutex::new(MultiGpuStats::default())),
626 result_buffer: Arc::new(Mutex::new(HashMap::new())),
627 next_task_id: Arc::new(Mutex::new(0)),
628 })
629 }
630
631 pub fn dispatch(&self, task: MultiGpuTask) -> Result<u64> {
633 let task_id = task.task_id();
634
635 let mut workers = self.workers.write();
636 let device_idx = {
637 let mut lb = self.load_balancer.lock();
638 lb.select_device(&task, &workers, self.config.overload_threshold)?
639 };
640
641 if workers[device_idx].metrics.queue_depth >= self.config.max_queue_depth {
642 return Err(anyhow!(
643 "Device {} queue is full (depth={})",
644 device_idx,
645 workers[device_idx].metrics.queue_depth
646 ));
647 }
648
649 debug!("Dispatching task {} to device {}", task_id, device_idx);
650 workers[device_idx].enqueue(task)?;
651
652 let mut stats = self.stats.lock();
653 stats.total_tasks_dispatched += 1;
654
655 Ok(task_id)
656 }
657
658 pub fn execute_pending(&self) -> Vec<GpuTaskResult> {
660 let mut workers = self.workers.write();
661 let mut all_results = Vec::new();
662
663 for worker in workers.iter_mut() {
664 while !worker.task_queue.is_empty() {
665 if let Some(result) = worker.execute_next() {
666 all_results.push(result);
667 }
668 }
669 }
670
671 let mut stats = self.stats.lock();
672 stats.total_tasks_completed += all_results.len() as u64;
673
674 all_results
675 }
676
677 pub fn execute_sync(&self, task: MultiGpuTask) -> Result<GpuTaskResult> {
679 let task_id = self.dispatch(task)?;
680 let results = self.execute_pending();
681
682 results
683 .into_iter()
684 .find(|r| r.task_id == task_id)
685 .ok_or_else(|| anyhow!("Task {} was not executed", task_id))
686 }
687
688 pub fn get_stats(&self) -> MultiGpuStats {
690 let workers = self.workers.read();
691 let stats = self.stats.lock();
692
693 let device_metrics: Vec<GpuDeviceMetrics> =
694 workers.iter().map(|w| w.metrics.clone()).collect();
695
696 let utilizations: Vec<f32> = device_metrics.iter().map(|m| m.utilization).collect();
698 let load_imbalance = if utilizations.len() > 1 {
699 let max_util = utilizations
700 .iter()
701 .cloned()
702 .fold(f32::NEG_INFINITY, f32::max);
703 let min_util = utilizations.iter().cloned().fold(f32::INFINITY, f32::min);
704 if min_util > 0.0 {
705 max_util as f64 / min_util as f64
706 } else {
707 1.0
708 }
709 } else {
710 1.0
711 };
712
713 MultiGpuStats {
714 total_tasks_dispatched: stats.total_tasks_dispatched,
715 total_tasks_completed: stats.total_tasks_completed,
716 total_tasks_failed: stats.total_tasks_failed,
717 avg_dispatch_latency_ms: stats.avg_dispatch_latency_ms,
718 device_metrics,
719 load_imbalance_factor: load_imbalance,
720 active_strategy: format!("{:?}", self.config.strategy),
721 }
722 }
723
724 pub fn get_device_metrics(&self) -> Vec<GpuDeviceMetrics> {
726 let workers = self.workers.read();
727 workers.iter().map(|w| w.metrics.clone()).collect()
728 }
729
730 pub fn num_devices(&self) -> usize {
732 self.workers.read().len()
733 }
734
735 pub fn all_healthy(&self) -> bool {
737 let workers = self.workers.read();
738 workers
739 .iter()
740 .all(|w| w.metrics.utilization < self.config.overload_threshold)
741 }
742
743 pub fn least_utilized_device(&self) -> Option<i32> {
745 let workers = self.workers.read();
746 workers
747 .iter()
748 .min_by(|a, b| {
749 a.metrics
750 .utilization
751 .partial_cmp(&b.metrics.utilization)
752 .unwrap_or(std::cmp::Ordering::Equal)
753 })
754 .map(|w| w.device_id)
755 }
756
757 pub fn next_task_id(&self) -> u64 {
759 let mut id = self.next_task_id.lock();
760 let current = *id;
761 *id += 1;
762 current
763 }
764
765 pub fn set_strategy(&self, strategy: LoadBalancingStrategy) {
767 let mut lb = self.load_balancer.lock();
768 lb.strategy = strategy;
769 info!("Load balancing strategy changed to {:?}", strategy);
770 }
771
772 pub fn reset_stats(&self) {
774 let mut stats = self.stats.lock();
775 *stats = MultiGpuStats::default();
776 }
777}
778
779pub struct MultiGpuConfigFactory;
781
782impl MultiGpuConfigFactory {
783 pub fn high_throughput_indexing(num_devices: usize) -> MultiGpuConfig {
785 MultiGpuConfig {
786 num_devices,
787 strategy: LoadBalancingStrategy::WeightedCapacity,
788 max_queue_depth: 128,
789 async_execution: true,
790 device_memory_budget_mb: 8192,
791 ..Default::default()
792 }
793 }
794
795 pub fn low_latency_search(num_devices: usize) -> MultiGpuConfig {
797 MultiGpuConfig {
798 num_devices,
799 strategy: LoadBalancingStrategy::ShortestQueue,
800 max_queue_depth: 16,
801 overload_threshold: 0.7,
802 device_affinity: false,
803 ..Default::default()
804 }
805 }
806
807 pub fn balanced_mixed_workload(num_devices: usize) -> MultiGpuConfig {
809 MultiGpuConfig {
810 num_devices,
811 strategy: LoadBalancingStrategy::Adaptive,
812 adaptive_warmup_tasks: 100,
813 ..Default::default()
814 }
815 }
816}
817
818#[cfg(test)]
819mod tests {
820 use super::*;
821
822 fn make_batch_search_task(id: u64, n_queries: usize, dim: usize) -> MultiGpuTask {
823 let queries = (0..n_queries)
824 .map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
825 .collect();
826 MultiGpuTask::BatchSearch {
827 task_id: id,
828 queries,
829 k: 10,
830 priority: TaskPriority::Normal,
831 }
832 }
833
834 fn make_build_index_task(id: u64, n_vectors: usize, dim: usize) -> MultiGpuTask {
835 let vectors: Vec<Vec<f32>> = (0..n_vectors)
836 .map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
837 .collect();
838 let vector_ids: Vec<usize> = (0..n_vectors).collect();
839 MultiGpuTask::BuildIndex {
840 task_id: id,
841 vector_ids,
842 vectors,
843 priority: TaskPriority::Normal,
844 }
845 }
846
847 #[test]
848 fn test_multi_gpu_config_default() {
849 let config = MultiGpuConfig::default();
850 assert_eq!(config.num_devices, 1);
851 assert_eq!(config.strategy, LoadBalancingStrategy::Adaptive);
852 assert!(config.async_execution);
853 }
854
855 #[test]
856 fn test_multi_gpu_manager_creation() {
857 let config = MultiGpuConfig {
858 num_devices: 2,
859 ..Default::default()
860 };
861 let manager = MultiGpuManager::new(config);
862 assert!(manager.is_ok(), "Manager creation should succeed");
863 let manager = manager.unwrap();
864 assert_eq!(manager.num_devices(), 2);
865 }
866
867 #[test]
868 fn test_single_device_dispatch_and_execute() {
869 let config = MultiGpuConfig {
870 num_devices: 1,
871 ..Default::default()
872 };
873 let manager = MultiGpuManager::new(config).unwrap();
874
875 let task = make_batch_search_task(0, 5, 8);
876 let task_id = manager.dispatch(task).unwrap();
877 assert_eq!(task_id, 0);
878
879 let results = manager.execute_pending();
880 assert_eq!(results.len(), 1);
881 assert_eq!(results[0].task_id, 0);
882 }
883
884 #[test]
885 fn test_round_robin_distribution() {
886 let config = MultiGpuConfig {
887 num_devices: 3,
888 strategy: LoadBalancingStrategy::RoundRobin,
889 ..Default::default()
890 };
891 let manager = MultiGpuManager::new(config).unwrap();
892
893 for i in 0..6u64 {
895 let task = make_batch_search_task(i, 2, 4);
896 manager.dispatch(task).unwrap();
897 }
898
899 let results = manager.execute_pending();
901 assert_eq!(results.len(), 6);
902 }
903
904 #[test]
905 fn test_execute_sync() {
906 let config = MultiGpuConfig {
907 num_devices: 1,
908 ..Default::default()
909 };
910 let manager = MultiGpuManager::new(config).unwrap();
911
912 let task = make_batch_search_task(42, 3, 8);
913 let result = manager.execute_sync(task).unwrap();
914
915 assert_eq!(result.task_id, 42);
916 assert_eq!(result.device_id, 0);
917 matches!(result.output, GpuTaskOutput::SearchResults(_));
918 }
919
920 #[test]
921 fn test_distance_matrix_task() {
922 let config = MultiGpuConfig {
923 num_devices: 1,
924 ..Default::default()
925 };
926 let manager = MultiGpuManager::new(config).unwrap();
927
928 let task = MultiGpuTask::DistanceMatrix {
929 task_id: 1,
930 matrix_a: vec![vec![1.0, 0.0], vec![0.0, 1.0]],
931 matrix_b: vec![vec![1.0, 0.0], vec![0.0, 1.0]],
932 priority: TaskPriority::Normal,
933 };
934
935 let result = manager.execute_sync(task).unwrap();
936 match result.output {
937 GpuTaskOutput::DistanceMatrix(m) => {
938 assert_eq!(m.len(), 2);
939 assert_eq!(m[0].len(), 2);
940 assert!(m[0][0].abs() < 1e-5, "Self-distance should be 0");
942 assert!((m[0][1] - 2.0_f32.sqrt()).abs() < 1e-4);
944 }
945 _ => panic!("Expected DistanceMatrix output"),
946 }
947 }
948
949 #[test]
950 fn test_normalize_batch_task() {
951 let config = MultiGpuConfig {
952 num_devices: 1,
953 ..Default::default()
954 };
955 let manager = MultiGpuManager::new(config).unwrap();
956
957 let task = MultiGpuTask::NormalizeBatch {
958 task_id: 2,
959 vectors: vec![vec![3.0, 4.0], vec![1.0, 0.0]],
960 priority: TaskPriority::Normal,
961 };
962
963 let result = manager.execute_sync(task).unwrap();
964 match result.output {
965 GpuTaskOutput::NormalizedVectors(vecs) => {
966 assert_eq!(vecs.len(), 2);
967 let norm0: f32 = vecs[0].iter().map(|x| x * x).sum::<f32>().sqrt();
969 assert!(
970 (norm0 - 1.0).abs() < 1e-5,
971 "Norm should be 1.0, got {}",
972 norm0
973 );
974 assert!((vecs[1][0] - 1.0).abs() < 1e-5);
976 }
977 _ => panic!("Expected NormalizedVectors output"),
978 }
979 }
980
981 #[test]
982 fn test_build_index_task() {
983 let config = MultiGpuConfig {
984 num_devices: 1,
985 ..Default::default()
986 };
987 let manager = MultiGpuManager::new(config).unwrap();
988
989 let task = make_build_index_task(3, 100, 16);
990 let result = manager.execute_sync(task).unwrap();
991
992 match result.output {
993 GpuTaskOutput::IndexBuild { nodes_built } => {
994 assert_eq!(nodes_built, 100);
995 }
996 _ => panic!("Expected IndexBuild output"),
997 }
998 }
999
1000 #[test]
1001 fn test_custom_kernel_task() {
1002 let config = MultiGpuConfig {
1003 num_devices: 1,
1004 ..Default::default()
1005 };
1006 let manager = MultiGpuManager::new(config).unwrap();
1007
1008 let task = MultiGpuTask::CustomKernel {
1009 task_id: 4,
1010 kernel_name: "scale_by_2".to_string(),
1011 input: vec![1.0, 2.0, 3.0],
1012 output_size: 3,
1013 priority: TaskPriority::High,
1014 };
1015
1016 let result = manager.execute_sync(task).unwrap();
1017 match result.output {
1018 GpuTaskOutput::CustomOutput(out) => {
1019 assert_eq!(out, vec![2.0, 4.0, 6.0]);
1020 }
1021 _ => panic!("Expected CustomOutput"),
1022 }
1023 }
1024
1025 #[test]
1026 fn test_task_priority_ordering() {
1027 assert!(TaskPriority::Critical > TaskPriority::High);
1028 assert!(TaskPriority::High > TaskPriority::Normal);
1029 assert!(TaskPriority::Normal > TaskPriority::Low);
1030 }
1031
1032 #[test]
1033 fn test_task_estimated_cost() {
1034 let build_task = make_build_index_task(0, 100, 16);
1035 let search_task = make_batch_search_task(1, 10, 16);
1036
1037 assert!(build_task.estimated_cost() > 0.0);
1039 assert!(search_task.estimated_cost() > 0.0);
1040 }
1041
1042 #[test]
1043 fn test_get_stats() {
1044 let config = MultiGpuConfig {
1045 num_devices: 2,
1046 ..Default::default()
1047 };
1048 let manager = MultiGpuManager::new(config).unwrap();
1049
1050 let task1 = make_batch_search_task(0, 5, 4);
1051 let task2 = make_batch_search_task(1, 5, 4);
1052
1053 manager.dispatch(task1).unwrap();
1054 manager.dispatch(task2).unwrap();
1055 manager.execute_pending();
1056
1057 let stats = manager.get_stats();
1058 assert_eq!(stats.total_tasks_dispatched, 2);
1059 assert_eq!(stats.total_tasks_completed, 2);
1060 assert_eq!(stats.device_metrics.len(), 2);
1061 }
1062
1063 #[test]
1064 fn test_least_utilized_device() {
1065 let config = MultiGpuConfig {
1066 num_devices: 3,
1067 ..Default::default()
1068 };
1069 let manager = MultiGpuManager::new(config).unwrap();
1070 let device = manager.least_utilized_device();
1071 assert!(device.is_some());
1072 assert!((0..3).contains(&device.unwrap()));
1073 }
1074
1075 #[test]
1076 fn test_set_strategy_runtime() {
1077 let config = MultiGpuConfig {
1078 num_devices: 2,
1079 strategy: LoadBalancingStrategy::RoundRobin,
1080 ..Default::default()
1081 };
1082 let manager = MultiGpuManager::new(config).unwrap();
1083 manager.set_strategy(LoadBalancingStrategy::ShortestQueue);
1084 }
1086
1087 #[test]
1088 fn test_max_queue_depth_rejection() {
1089 let config = MultiGpuConfig {
1090 num_devices: 1,
1091 max_queue_depth: 2,
1092 ..Default::default()
1093 };
1094 let manager = MultiGpuManager::new(config).unwrap();
1095
1096 manager.dispatch(make_batch_search_task(0, 1, 4)).unwrap();
1098 manager.dispatch(make_batch_search_task(1, 1, 4)).unwrap();
1099
1100 let result = manager.dispatch(make_batch_search_task(2, 1, 4));
1102 assert!(result.is_err(), "Should reject task when queue is full");
1103 }
1104
1105 #[test]
1106 fn test_config_factory_high_throughput() {
1107 let config = MultiGpuConfigFactory::high_throughput_indexing(4);
1108 assert_eq!(config.num_devices, 4);
1109 assert_eq!(config.strategy, LoadBalancingStrategy::WeightedCapacity);
1110 assert_eq!(config.max_queue_depth, 128);
1111 }
1112
1113 #[test]
1114 fn test_config_factory_low_latency() {
1115 let config = MultiGpuConfigFactory::low_latency_search(2);
1116 assert_eq!(config.num_devices, 2);
1117 assert_eq!(config.strategy, LoadBalancingStrategy::ShortestQueue);
1118 assert!(!config.device_affinity);
1119 }
1120
1121 #[test]
1122 fn test_config_factory_balanced() {
1123 let config = MultiGpuConfigFactory::balanced_mixed_workload(4);
1124 assert_eq!(config.num_devices, 4);
1125 assert_eq!(config.strategy, LoadBalancingStrategy::Adaptive);
1126 }
1127
1128 #[test]
1129 fn test_all_healthy_check() {
1130 let config = MultiGpuConfig {
1131 num_devices: 2,
1132 ..Default::default()
1133 };
1134 let manager = MultiGpuManager::new(config).unwrap();
1135 assert!(manager.all_healthy());
1137 }
1138
1139 #[test]
1140 fn test_reset_stats() {
1141 let config = MultiGpuConfig {
1142 num_devices: 1,
1143 ..Default::default()
1144 };
1145 let manager = MultiGpuManager::new(config).unwrap();
1146
1147 manager.dispatch(make_batch_search_task(0, 1, 4)).unwrap();
1148 manager.execute_pending();
1149
1150 let stats_before = manager.get_stats();
1151 assert!(stats_before.total_tasks_dispatched > 0);
1152
1153 manager.reset_stats();
1154 let stats_after = manager.get_stats();
1155 assert_eq!(stats_after.total_tasks_dispatched, 0);
1156 }
1157
1158 #[test]
1159 fn test_next_task_id_monotonic() {
1160 let config = MultiGpuConfig {
1161 num_devices: 1,
1162 ..Default::default()
1163 };
1164 let manager = MultiGpuManager::new(config).unwrap();
1165
1166 let id0 = manager.next_task_id();
1167 let id1 = manager.next_task_id();
1168 let id2 = manager.next_task_id();
1169
1170 assert!(id1 > id0);
1171 assert!(id2 > id1);
1172 }
1173
1174 #[test]
1175 fn test_least_utilized_strategy_dispatch() {
1176 let config = MultiGpuConfig {
1177 num_devices: 2,
1178 strategy: LoadBalancingStrategy::LeastUtilized,
1179 ..Default::default()
1180 };
1181 let manager = MultiGpuManager::new(config).unwrap();
1182
1183 for i in 0..4u64 {
1184 manager.dispatch(make_batch_search_task(i, 2, 4)).unwrap();
1185 }
1186 let results = manager.execute_pending();
1187 assert_eq!(results.len(), 4);
1188 }
1189
1190 #[test]
1191 fn test_shortest_queue_strategy_dispatch() {
1192 let config = MultiGpuConfig {
1193 num_devices: 2,
1194 strategy: LoadBalancingStrategy::ShortestQueue,
1195 ..Default::default()
1196 };
1197 let manager = MultiGpuManager::new(config).unwrap();
1198
1199 for i in 0..6u64 {
1200 manager.dispatch(make_batch_search_task(i, 2, 4)).unwrap();
1201 }
1202 let results = manager.execute_pending();
1203 assert_eq!(results.len(), 6);
1204 }
1205
1206 #[test]
1207 fn test_load_imbalance_factor() {
1208 let config = MultiGpuConfig {
1209 num_devices: 2,
1210 ..Default::default()
1211 };
1212 let manager = MultiGpuManager::new(config).unwrap();
1213 let stats = manager.get_stats();
1214 assert!(stats.load_imbalance_factor >= 1.0);
1216 }
1217
1218 #[test]
1219 fn test_device_metrics_structure() {
1220 let config = MultiGpuConfig {
1221 num_devices: 2,
1222 ..Default::default()
1223 };
1224 let manager = MultiGpuManager::new(config).unwrap();
1225 let metrics = manager.get_device_metrics();
1226
1227 assert_eq!(metrics.len(), 2);
1228 for (i, m) in metrics.iter().enumerate() {
1229 assert_eq!(m.device_id, i as i32);
1230 assert!(m.compute_weight > 0.0);
1231 }
1232 }
1233}