1use super::{GpuDevice, SelectionStrategy};
10use crate::error::{GpuAdvancedError, Result};
11use parking_lot::RwLock;
12use std::cmp::Ordering;
13use std::collections::VecDeque;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering as AtomicOrdering};
16use std::time::{Duration, Instant};
17
18pub struct LoadBalancer {
20 devices: Vec<Arc<GpuDevice>>,
22 strategy: SelectionStrategy,
24 rr_counter: AtomicUsize,
26 stats: Arc<RwLock<LoadStats>>,
28 migration_config: Arc<RwLock<MigrationConfig>>,
30 migration_history: Arc<RwLock<MigrationHistory>>,
32 workload_tracker: Arc<RwLock<WorkloadTracker>>,
34}
35
36#[derive(Debug, Clone, Default)]
38pub struct LoadStats {
39 pub tasks_per_device: Vec<usize>,
41 pub time_per_device: Vec<u64>,
43 pub active_tasks: Vec<usize>,
45 pub memory_per_device: Vec<u64>,
47 pub migrations_from: Vec<usize>,
49 pub migrations_to: Vec<usize>,
51}
52
53#[derive(Debug, Clone)]
55pub struct MigrationConfig {
56 pub overload_threshold: f32,
58 pub underutilization_threshold: f32,
60 pub min_imbalance_threshold: f32,
62 pub transfer_cost_base: f64,
64 pub transfer_cost_per_byte: f64,
66 pub min_migration_size: u64,
68 pub max_pending_migrations: usize,
70 pub migration_cooldown_secs: u64,
72 pub enable_predictive_migration: bool,
74 pub history_window_size: usize,
76 pub memory_weight: f32,
78 pub compute_weight: f32,
80}
81
82impl Default for MigrationConfig {
83 fn default() -> Self {
84 Self {
85 overload_threshold: 0.8,
86 underutilization_threshold: 0.3,
87 min_imbalance_threshold: 0.2,
88 transfer_cost_base: 1.0,
89 transfer_cost_per_byte: 0.000001, min_migration_size: 1024, max_pending_migrations: 4,
92 migration_cooldown_secs: 5,
93 enable_predictive_migration: true,
94 history_window_size: 100,
95 memory_weight: 0.4,
96 compute_weight: 0.6,
97 }
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct MigratableWorkload {
104 pub id: u64,
106 pub source_device: usize,
108 pub memory_size: u64,
110 pub compute_intensity: f32,
112 pub priority: u32,
114 pub created_at: Instant,
116 pub migrating: bool,
118 pub dependencies: Vec<u64>,
120}
121
122impl MigratableWorkload {
123 pub fn new(
125 id: u64,
126 source_device: usize,
127 memory_size: u64,
128 compute_intensity: f32,
129 priority: u32,
130 ) -> Self {
131 Self {
132 id,
133 source_device,
134 memory_size,
135 compute_intensity,
136 priority,
137 created_at: Instant::now(),
138 migrating: false,
139 dependencies: Vec::new(),
140 }
141 }
142
143 pub fn with_dependency(mut self, dep_id: u64) -> Self {
145 self.dependencies.push(dep_id);
146 self
147 }
148
149 pub fn calculate_migration_cost(&self, config: &MigrationConfig) -> f64 {
151 config.transfer_cost_base
152 + (self.memory_size as f64 * config.transfer_cost_per_byte)
153 + (self.compute_intensity as f64 * 0.1) }
155}
156
157#[derive(Debug, Clone)]
159pub struct MigrationPlan {
160 pub workload: MigratableWorkload,
162 pub source_device: usize,
164 pub target_device: usize,
166 pub estimated_cost: f64,
168 pub expected_benefit: f64,
170 pub net_benefit: f64,
172 pub created_at: Instant,
174 pub approved: bool,
176}
177
178impl MigrationPlan {
179 pub fn new(
181 workload: MigratableWorkload,
182 target_device: usize,
183 config: &MigrationConfig,
184 ) -> Self {
185 let source_device = workload.source_device;
186 let estimated_cost = workload.calculate_migration_cost(config);
187 let expected_benefit = workload.compute_intensity as f64 * 10.0; let net_benefit = expected_benefit - estimated_cost;
189
190 Self {
191 workload,
192 source_device,
193 target_device,
194 estimated_cost,
195 expected_benefit,
196 net_benefit,
197 created_at: Instant::now(),
198 approved: net_benefit > 0.0,
199 }
200 }
201
202 pub fn should_migrate(&self) -> bool {
204 self.approved && self.net_benefit > 0.0
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct MigrationResult {
211 pub success: bool,
213 pub source_device: usize,
215 pub target_device: usize,
217 pub workload_id: u64,
219 pub transfer_time: Duration,
221 pub bytes_transferred: u64,
223 pub error_message: Option<String>,
225}
226
227#[derive(Debug, Default)]
229pub struct MigrationHistory {
230 entries: VecDeque<MigrationHistoryEntry>,
232 max_size: usize,
234 total_successful: usize,
236 total_failed: usize,
238}
239
240#[derive(Debug, Clone)]
242pub struct MigrationHistoryEntry {
243 pub timestamp: Instant,
245 pub source_device: usize,
247 pub target_device: usize,
249 pub success: bool,
251 pub transfer_time: Duration,
253 pub bytes_transferred: u64,
255}
256
257impl MigrationHistory {
258 pub fn new(max_size: usize) -> Self {
260 Self {
261 entries: VecDeque::with_capacity(max_size),
262 max_size,
263 total_successful: 0,
264 total_failed: 0,
265 }
266 }
267
268 pub fn add_entry(&mut self, entry: MigrationHistoryEntry) {
270 if entry.success {
271 self.total_successful += 1;
272 } else {
273 self.total_failed += 1;
274 }
275
276 if self.entries.len() >= self.max_size {
277 self.entries.pop_front();
278 }
279 self.entries.push_back(entry);
280 }
281
282 pub fn success_rate(&self, source: usize, target: usize) -> f64 {
284 let filtered: Vec<_> = self
285 .entries
286 .iter()
287 .filter(|e| e.source_device == source && e.target_device == target)
288 .collect();
289
290 if filtered.is_empty() {
291 return 1.0; }
293
294 let successful = filtered.iter().filter(|e| e.success).count();
295 successful as f64 / filtered.len() as f64
296 }
297
298 pub fn average_transfer_time(&self, source: usize, target: usize) -> Option<Duration> {
300 let filtered: Vec<_> = self
301 .entries
302 .iter()
303 .filter(|e| e.source_device == source && e.target_device == target && e.success)
304 .collect();
305
306 if filtered.is_empty() {
307 return None;
308 }
309
310 let total: Duration = filtered.iter().map(|e| e.transfer_time).sum();
311 Some(total / filtered.len() as u32)
312 }
313
314 pub fn total_bytes_transferred(&self) -> u64 {
316 self.entries.iter().map(|e| e.bytes_transferred).sum()
317 }
318
319 pub fn overall_success_rate(&self) -> f64 {
321 let total = self.total_successful + self.total_failed;
322 if total == 0 {
323 return 1.0;
324 }
325 self.total_successful as f64 / total as f64
326 }
327}
328
329#[derive(Debug)]
331pub struct WorkloadTracker {
332 utilization_samples: Vec<VecDeque<UtilizationSample>>,
334 pending_workloads: Vec<Vec<MigratableWorkload>>,
336 next_workload_id: AtomicU64,
338 last_rebalance: Vec<Option<Instant>>,
340}
341
342#[derive(Debug, Clone)]
344pub struct UtilizationSample {
345 pub timestamp: Instant,
347 pub compute: f32,
349 pub memory: f32,
351 pub active_tasks: usize,
353}
354
355impl WorkloadTracker {
356 pub fn new(device_count: usize, history_size: usize) -> Self {
358 let mut utilization_samples = Vec::with_capacity(device_count);
359 let mut pending_workloads = Vec::with_capacity(device_count);
360 let mut last_rebalance = Vec::with_capacity(device_count);
361
362 for _ in 0..device_count {
363 utilization_samples.push(VecDeque::with_capacity(history_size));
364 pending_workloads.push(Vec::new());
365 last_rebalance.push(None);
367 }
368
369 Self {
370 utilization_samples,
371 pending_workloads,
372 next_workload_id: AtomicU64::new(0),
373 last_rebalance,
374 }
375 }
376
377 pub fn next_workload_id(&self) -> u64 {
379 self.next_workload_id.fetch_add(1, AtomicOrdering::Relaxed)
380 }
381
382 pub fn record_sample(&mut self, device_index: usize, sample: UtilizationSample) {
384 if let Some(samples) = self.utilization_samples.get_mut(device_index) {
385 if samples.len() >= samples.capacity() {
386 samples.pop_front();
387 }
388 samples.push_back(sample);
389 }
390 }
391
392 pub fn average_utilization(&self, device_index: usize, window: usize) -> Option<(f32, f32)> {
394 let samples = self.utilization_samples.get(device_index)?;
395 if samples.is_empty() {
396 return None;
397 }
398
399 let take_count = window.min(samples.len());
400 let recent: Vec<_> = samples.iter().rev().take(take_count).collect();
401
402 let avg_compute = recent.iter().map(|s| s.compute).sum::<f32>() / take_count as f32;
403 let avg_memory = recent.iter().map(|s| s.memory).sum::<f32>() / take_count as f32;
404
405 Some((avg_compute, avg_memory))
406 }
407
408 pub fn utilization_trend(&self, device_index: usize, window: usize) -> Option<f32> {
410 let samples = self.utilization_samples.get(device_index)?;
411 if samples.len() < 2 {
412 return None;
413 }
414
415 let take_count = window.min(samples.len());
416 let skip_count = samples.len().saturating_sub(take_count);
419 let recent: Vec<_> = samples.iter().skip(skip_count).collect();
420
421 if recent.len() < 2 {
422 return None;
423 }
424
425 let n = recent.len() as f32;
429 let mut sum_x = 0.0f32;
430 let mut sum_y = 0.0f32;
431 let mut sum_xy = 0.0f32;
432 let mut sum_xx = 0.0f32;
433
434 for (i, sample) in recent.iter().enumerate() {
435 let x = i as f32;
436 let y = sample.compute;
437 sum_x += x;
438 sum_y += y;
439 sum_xy += x * y;
440 sum_xx += x * x;
441 }
442
443 let denominator = n * sum_xx - sum_x * sum_x;
444 if denominator.abs() < f32::EPSILON {
445 return Some(0.0);
446 }
447
448 Some((n * sum_xy - sum_x * sum_y) / denominator)
449 }
450
451 pub fn add_workload(&mut self, device_index: usize, workload: MigratableWorkload) {
453 if let Some(workloads) = self.pending_workloads.get_mut(device_index) {
454 workloads.push(workload);
455 }
456 }
457
458 pub fn remove_workload(
460 &mut self,
461 device_index: usize,
462 workload_id: u64,
463 ) -> Option<MigratableWorkload> {
464 if let Some(workloads) = self.pending_workloads.get_mut(device_index) {
465 if let Some(pos) = workloads.iter().position(|w| w.id == workload_id) {
466 return Some(workloads.remove(pos));
467 }
468 }
469 None
470 }
471
472 pub fn get_migratable_workloads(&self, device_index: usize) -> Vec<&MigratableWorkload> {
474 self.pending_workloads
475 .get(device_index)
476 .map(|workloads| {
477 workloads
478 .iter()
479 .filter(|w| !w.migrating && w.dependencies.is_empty())
480 .collect()
481 })
482 .unwrap_or_default()
483 }
484
485 pub fn update_rebalance_time(&mut self, device_index: usize) {
487 if let Some(time) = self.last_rebalance.get_mut(device_index) {
488 *time = Some(Instant::now());
489 }
490 }
491
492 pub fn is_in_cooldown(&self, device_index: usize, cooldown_secs: u64) -> bool {
496 self.last_rebalance
497 .get(device_index)
498 .and_then(|opt| opt.as_ref())
499 .map(|t| t.elapsed().as_secs() < cooldown_secs)
500 .unwrap_or(false)
501 }
502
503 pub fn pending_count(&self, device_index: usize) -> usize {
505 self.pending_workloads
506 .get(device_index)
507 .map(|w| w.len())
508 .unwrap_or(0)
509 }
510}
511
512#[derive(Debug, Clone)]
514pub struct DeviceLoad {
515 pub device_index: usize,
517 pub compute_utilization: f32,
519 pub memory_utilization: f32,
521 pub combined_load: f32,
523 pub active_tasks: usize,
525 pub pending_workloads: usize,
527 pub score: f32,
529 pub trend: f32,
531}
532
533impl DeviceLoad {
534 pub fn is_overloaded(&self, config: &MigrationConfig) -> bool {
536 self.combined_load > config.overload_threshold
537 }
538
539 pub fn is_underutilized(&self, config: &MigrationConfig) -> bool {
541 self.combined_load < config.underutilization_threshold
542 }
543}
544
545impl LoadBalancer {
546 pub fn new(devices: Vec<Arc<GpuDevice>>, strategy: SelectionStrategy) -> Self {
548 let device_count = devices.len();
549 let stats = LoadStats {
550 tasks_per_device: vec![0; device_count],
551 time_per_device: vec![0; device_count],
552 active_tasks: vec![0; device_count],
553 memory_per_device: vec![0; device_count],
554 migrations_from: vec![0; device_count],
555 migrations_to: vec![0; device_count],
556 };
557
558 let config = MigrationConfig::default();
559 let tracker = WorkloadTracker::new(device_count, config.history_window_size);
560 let history = MigrationHistory::new(config.history_window_size);
561
562 Self {
563 devices,
564 strategy,
565 rr_counter: AtomicUsize::new(0),
566 stats: Arc::new(RwLock::new(stats)),
567 migration_config: Arc::new(RwLock::new(config)),
568 migration_history: Arc::new(RwLock::new(history)),
569 workload_tracker: Arc::new(RwLock::new(tracker)),
570 }
571 }
572
573 pub fn migration_config(&self) -> MigrationConfig {
575 self.migration_config.read().clone()
576 }
577
578 pub fn set_migration_config(&self, config: MigrationConfig) {
580 *self.migration_config.write() = config;
581 }
582
583 pub fn select_device(&self) -> Result<Arc<GpuDevice>> {
585 if self.devices.is_empty() {
586 return Err(GpuAdvancedError::GpuNotFound(
587 "No devices available".to_string(),
588 ));
589 }
590
591 match self.strategy {
592 SelectionStrategy::RoundRobin => self.select_round_robin(),
593 SelectionStrategy::LeastLoaded => self.select_least_loaded(),
594 SelectionStrategy::BestScore => self.select_best_score(),
595 SelectionStrategy::Affinity => self.select_affinity(),
596 }
597 }
598
599 fn select_round_robin(&self) -> Result<Arc<GpuDevice>> {
601 let index = self.rr_counter.fetch_add(1, AtomicOrdering::Relaxed) % self.devices.len();
602 self.devices
603 .get(index)
604 .cloned()
605 .ok_or(GpuAdvancedError::InvalidGpuIndex {
606 index,
607 total: self.devices.len(),
608 })
609 }
610
611 fn select_least_loaded(&self) -> Result<Arc<GpuDevice>> {
613 let stats = self.stats.read();
614
615 let (index, _) = self
616 .devices
617 .iter()
618 .enumerate()
619 .map(|(i, device)| {
620 let active_tasks = stats.active_tasks.get(i).copied().unwrap_or(0);
621 let workload = device.get_workload();
622 let load = (active_tasks as f32) + workload;
623 (i, load)
624 })
625 .min_by(|(_, load_a), (_, load_b)| {
626 load_a.partial_cmp(load_b).unwrap_or(Ordering::Equal)
627 })
628 .ok_or_else(|| {
629 GpuAdvancedError::LoadBalancingError("No device available".to_string())
630 })?;
631
632 self.devices
633 .get(index)
634 .cloned()
635 .ok_or(GpuAdvancedError::InvalidGpuIndex {
636 index,
637 total: self.devices.len(),
638 })
639 }
640
641 fn select_best_score(&self) -> Result<Arc<GpuDevice>> {
643 let (index, _) = self
644 .devices
645 .iter()
646 .enumerate()
647 .map(|(i, device)| (i, device.get_score()))
648 .max_by(|(_, score_a), (_, score_b)| {
649 score_a.partial_cmp(score_b).unwrap_or(Ordering::Equal)
650 })
651 .ok_or_else(|| {
652 GpuAdvancedError::LoadBalancingError("No device available".to_string())
653 })?;
654
655 self.devices
656 .get(index)
657 .cloned()
658 .ok_or(GpuAdvancedError::InvalidGpuIndex {
659 index,
660 total: self.devices.len(),
661 })
662 }
663
664 fn select_affinity(&self) -> Result<Arc<GpuDevice>> {
666 let thread_id = std::thread::current().id();
668 let hash = {
669 use std::collections::hash_map::DefaultHasher;
670 use std::hash::{Hash, Hasher};
671 let mut hasher = DefaultHasher::new();
672 thread_id.hash(&mut hasher);
673 hasher.finish()
674 };
675
676 let index = (hash as usize) % self.devices.len();
677 self.devices
678 .get(index)
679 .cloned()
680 .ok_or(GpuAdvancedError::InvalidGpuIndex {
681 index,
682 total: self.devices.len(),
683 })
684 }
685
686 pub fn select_weighted(&self) -> Result<Arc<GpuDevice>> {
688 if self.devices.is_empty() {
689 return Err(GpuAdvancedError::GpuNotFound(
690 "No devices available".to_string(),
691 ));
692 }
693
694 let config = self.migration_config.read();
695
696 let mut best_index = 0;
698 let mut best_score = f32::MIN;
699
700 for (i, device) in self.devices.iter().enumerate() {
701 let compute_util = device.get_workload();
702 let memory_usage = device.get_memory_usage();
703 let max_memory = device.info.max_buffer_size;
704 let memory_util = if max_memory > 0 {
705 memory_usage as f32 / max_memory as f32
706 } else {
707 0.0
708 };
709
710 let availability =
712 1.0 - (compute_util * config.compute_weight + memory_util * config.memory_weight);
713 let type_bonus = device.get_score();
714 let score = availability * type_bonus;
715
716 if score > best_score {
717 best_score = score;
718 best_index = i;
719 }
720 }
721
722 self.devices
723 .get(best_index)
724 .cloned()
725 .ok_or(GpuAdvancedError::InvalidGpuIndex {
726 index: best_index,
727 total: self.devices.len(),
728 })
729 }
730
731 pub fn get_device_loads(&self) -> Vec<DeviceLoad> {
733 let config = self.migration_config.read();
734 let tracker = self.workload_tracker.read();
735 let stats = self.stats.read();
736
737 self.devices
738 .iter()
739 .enumerate()
740 .map(|(i, device)| {
741 let compute_utilization = device.get_workload();
742 let memory_usage = device.get_memory_usage();
743 let max_memory = device.info.max_buffer_size;
744 let memory_utilization = if max_memory > 0 {
745 memory_usage as f32 / max_memory as f32
746 } else {
747 0.0
748 };
749
750 let combined_load = compute_utilization * config.compute_weight
751 + memory_utilization * config.memory_weight;
752
753 let trend = tracker.utilization_trend(i, 10).unwrap_or(0.0);
754
755 DeviceLoad {
756 device_index: i,
757 compute_utilization,
758 memory_utilization,
759 combined_load,
760 active_tasks: stats.active_tasks.get(i).copied().unwrap_or(0),
761 pending_workloads: tracker.pending_count(i),
762 score: device.get_score(),
763 trend,
764 }
765 })
766 .collect()
767 }
768
769 pub fn identify_overloaded_devices(&self) -> Vec<DeviceLoad> {
771 let config = self.migration_config.read();
772 self.get_device_loads()
773 .into_iter()
774 .filter(|load| load.is_overloaded(&config))
775 .collect()
776 }
777
778 pub fn identify_underutilized_devices(&self) -> Vec<DeviceLoad> {
780 let config = self.migration_config.read();
781 self.get_device_loads()
782 .into_iter()
783 .filter(|load| load.is_underutilized(&config))
784 .collect()
785 }
786
787 pub fn is_imbalanced(&self) -> bool {
789 let loads = self.get_device_loads();
790 if loads.len() < 2 {
791 return false;
792 }
793
794 let config = self.migration_config.read();
795
796 let max_load = loads
798 .iter()
799 .map(|l| l.combined_load)
800 .fold(f32::MIN, f32::max);
801 let min_load = loads
802 .iter()
803 .map(|l| l.combined_load)
804 .fold(f32::MAX, f32::min);
805
806 (max_load - min_load) > config.min_imbalance_threshold
807 }
808
809 pub fn calculate_transfer_cost(
811 &self,
812 source_device: usize,
813 target_device: usize,
814 data_size: u64,
815 ) -> Result<f64> {
816 if source_device >= self.devices.len() || target_device >= self.devices.len() {
817 return Err(GpuAdvancedError::InvalidGpuIndex {
818 index: source_device.max(target_device),
819 total: self.devices.len(),
820 });
821 }
822
823 let config = self.migration_config.read();
824 let history = self.migration_history.read();
825
826 let mut cost =
828 config.transfer_cost_base + (data_size as f64 * config.transfer_cost_per_byte);
829
830 if let Some(avg_time) = history.average_transfer_time(source_device, target_device) {
832 let time_factor = avg_time.as_secs_f64();
834 cost *= 1.0 + time_factor;
835 }
836
837 let success_rate = history.success_rate(source_device, target_device);
839 if success_rate < 1.0 {
840 cost *= 1.0 + (1.0 - success_rate) * 0.5;
842 }
843
844 Ok(cost)
845 }
846
847 pub fn create_migration_plan(
849 &self,
850 workload: MigratableWorkload,
851 target_device: usize,
852 ) -> Result<MigrationPlan> {
853 if target_device >= self.devices.len() {
854 return Err(GpuAdvancedError::InvalidGpuIndex {
855 index: target_device,
856 total: self.devices.len(),
857 });
858 }
859
860 let config = self.migration_config.read();
861 let plan = MigrationPlan::new(workload, target_device, &config);
862
863 Ok(plan)
864 }
865
866 pub fn find_migration_target(&self, source_device: usize) -> Result<Option<usize>> {
868 let loads = self.get_device_loads();
869 let config = self.migration_config.read();
870 let tracker = self.workload_tracker.read();
871
872 let source_load = loads
874 .iter()
875 .find(|l| l.device_index == source_device)
876 .ok_or(GpuAdvancedError::InvalidGpuIndex {
877 index: source_device,
878 total: self.devices.len(),
879 })?;
880
881 let mut candidates: Vec<_> = loads
883 .iter()
884 .filter(|l| {
885 l.device_index != source_device
886 && l.is_underutilized(&config)
887 && !tracker.is_in_cooldown(l.device_index, config.migration_cooldown_secs)
888 })
889 .collect();
890
891 if candidates.is_empty() {
892 return Ok(None);
893 }
894
895 candidates.sort_by(|a, b| match a.combined_load.partial_cmp(&b.combined_load) {
897 Some(Ordering::Equal) | None => {
898 b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
899 }
900 Some(ordering) => ordering,
901 });
902
903 if let Some(best) = candidates.first() {
905 let load_diff = source_load.combined_load - best.combined_load;
906 if load_diff > config.min_imbalance_threshold {
907 return Ok(Some(best.device_index));
908 }
909 }
910
911 Ok(None)
912 }
913
914 pub fn select_workload_for_migration(
916 &self,
917 source_device: usize,
918 ) -> Option<MigratableWorkload> {
919 let config = self.migration_config.read();
920 let tracker = self.workload_tracker.read();
921
922 let migratable = tracker.get_migratable_workloads(source_device);
923
924 let mut candidates: Vec<_> = migratable
926 .into_iter()
927 .filter(|w| w.memory_size >= config.min_migration_size)
928 .collect();
929
930 candidates.sort_by(|a, b| {
931 match b.priority.cmp(&a.priority) {
933 Ordering::Equal => b
934 .compute_intensity
935 .partial_cmp(&a.compute_intensity)
936 .unwrap_or(Ordering::Equal),
937 other => other,
938 }
939 });
940
941 candidates.first().map(|w| (*w).clone())
942 }
943
944 pub fn execute_migration(&self, plan: &MigrationPlan) -> Result<MigrationResult> {
946 if !plan.should_migrate() {
947 return Ok(MigrationResult {
948 success: false,
949 source_device: plan.source_device,
950 target_device: plan.target_device,
951 workload_id: plan.workload.id,
952 transfer_time: Duration::ZERO,
953 bytes_transferred: 0,
954 error_message: Some("Migration not approved".to_string()),
955 });
956 }
957
958 let start = Instant::now();
959
960 {
962 let mut tracker = self.workload_tracker.write();
963
964 if tracker
966 .remove_workload(plan.source_device, plan.workload.id)
967 .is_none()
968 {
969 return Ok(MigrationResult {
970 success: false,
971 source_device: plan.source_device,
972 target_device: plan.target_device,
973 workload_id: plan.workload.id,
974 transfer_time: Duration::ZERO,
975 bytes_transferred: 0,
976 error_message: Some("Workload not found on source device".to_string()),
977 });
978 }
979
980 let mut migrated = plan.workload.clone();
982 migrated.source_device = plan.target_device;
983 tracker.add_workload(plan.target_device, migrated);
984
985 tracker.update_rebalance_time(plan.source_device);
987 tracker.update_rebalance_time(plan.target_device);
988 }
989
990 {
992 let mut stats = self.stats.write();
993 if let Some(from) = stats.migrations_from.get_mut(plan.source_device) {
994 *from = from.saturating_add(1);
995 }
996 if let Some(to) = stats.migrations_to.get_mut(plan.target_device) {
997 *to = to.saturating_add(1);
998 }
999 }
1000
1001 let transfer_time = start.elapsed();
1002
1003 {
1005 let mut history = self.migration_history.write();
1006 history.add_entry(MigrationHistoryEntry {
1007 timestamp: Instant::now(),
1008 source_device: plan.source_device,
1009 target_device: plan.target_device,
1010 success: true,
1011 transfer_time,
1012 bytes_transferred: plan.workload.memory_size,
1013 });
1014 }
1015
1016 Ok(MigrationResult {
1017 success: true,
1018 source_device: plan.source_device,
1019 target_device: plan.target_device,
1020 workload_id: plan.workload.id,
1021 transfer_time,
1022 bytes_transferred: plan.workload.memory_size,
1023 error_message: None,
1024 })
1025 }
1026
1027 pub fn rebalance(&self) -> Result<Vec<MigrationResult>> {
1035 if !self.is_imbalanced() {
1037 return Ok(Vec::new());
1038 }
1039
1040 let mut results = Vec::new();
1041
1042 self.sample_utilization();
1044
1045 let overloaded = self.identify_overloaded_devices();
1047 if overloaded.is_empty() {
1048 return Ok(results);
1049 }
1050
1051 for source_load in overloaded {
1053 if results.len() >= self.migration_config.read().max_pending_migrations {
1055 break;
1056 }
1057
1058 let target = match self.find_migration_target(source_load.device_index)? {
1060 Some(t) => t,
1061 None => continue,
1062 };
1063
1064 let workload = match self.select_workload_for_migration(source_load.device_index) {
1066 Some(w) => w,
1067 None => continue,
1068 };
1069
1070 let plan = self.create_migration_plan(workload, target)?;
1072 if plan.should_migrate() {
1073 let result = self.execute_migration(&plan)?;
1074 results.push(result);
1075 }
1076 }
1077
1078 let config = self.migration_config.read();
1080 if config.enable_predictive_migration {
1081 drop(config);
1082 self.handle_predictive_migrations(&mut results)?;
1083 }
1084
1085 Ok(results)
1086 }
1087
1088 fn sample_utilization(&self) {
1090 let stats = self.stats.read();
1091 let mut tracker = self.workload_tracker.write();
1092
1093 for (i, device) in self.devices.iter().enumerate() {
1094 let compute = device.get_workload();
1095 let memory_usage = device.get_memory_usage();
1096 let max_memory = device.info.max_buffer_size;
1097 let memory = if max_memory > 0 {
1098 memory_usage as f32 / max_memory as f32
1099 } else {
1100 0.0
1101 };
1102
1103 tracker.record_sample(
1104 i,
1105 UtilizationSample {
1106 timestamp: Instant::now(),
1107 compute,
1108 memory,
1109 active_tasks: stats.active_tasks.get(i).copied().unwrap_or(0),
1110 },
1111 );
1112 }
1113 }
1114
1115 fn handle_predictive_migrations(&self, results: &mut Vec<MigrationResult>) -> Result<()> {
1117 let candidates: Vec<(usize, f32, f32)> = {
1119 let config = self.migration_config.read();
1120 let tracker = self.workload_tracker.read();
1121 let device_loads = self.get_device_loads();
1122
1123 self.devices
1124 .iter()
1125 .enumerate()
1126 .filter_map(|(i, _device)| {
1127 let trend = tracker.utilization_trend(i, 20)?;
1129
1130 let load = device_loads.iter().find(|l| l.device_index == i)?;
1132
1133 if trend > 0.05
1135 && load.combined_load > 0.5
1136 && load.combined_load < config.overload_threshold
1137 {
1138 Some((i, trend, load.combined_load))
1139 } else {
1140 None
1141 }
1142 })
1143 .collect()
1144 }; let max_migrations = self.migration_config.read().max_pending_migrations;
1148 for (device_index, _trend, _combined_load) in candidates {
1149 if results.len() >= max_migrations {
1151 break;
1152 }
1153
1154 if let Some(target) = self.find_migration_target(device_index)? {
1156 if let Some(workload) = self.select_workload_for_migration(device_index) {
1157 let plan = self.create_migration_plan(workload, target)?;
1158 if plan.should_migrate() {
1159 let result = self.execute_migration(&plan)?;
1160 results.push(result);
1161 }
1162 }
1163 }
1164 }
1165
1166 Ok(())
1167 }
1168
1169 pub fn register_workload(
1171 &self,
1172 device_index: usize,
1173 memory_size: u64,
1174 compute_intensity: f32,
1175 priority: u32,
1176 ) -> Result<u64> {
1177 if device_index >= self.devices.len() {
1178 return Err(GpuAdvancedError::InvalidGpuIndex {
1179 index: device_index,
1180 total: self.devices.len(),
1181 });
1182 }
1183
1184 let mut tracker = self.workload_tracker.write();
1185 let workload_id = tracker.next_workload_id();
1186
1187 let workload = MigratableWorkload::new(
1188 workload_id,
1189 device_index,
1190 memory_size,
1191 compute_intensity,
1192 priority,
1193 );
1194
1195 tracker.add_workload(device_index, workload);
1196
1197 Ok(workload_id)
1198 }
1199
1200 pub fn unregister_workload(&self, device_index: usize, workload_id: u64) -> Result<()> {
1202 if device_index >= self.devices.len() {
1203 return Err(GpuAdvancedError::InvalidGpuIndex {
1204 index: device_index,
1205 total: self.devices.len(),
1206 });
1207 }
1208
1209 let mut tracker = self.workload_tracker.write();
1210 tracker.remove_workload(device_index, workload_id);
1211
1212 Ok(())
1213 }
1214
1215 pub fn task_started(&self, device_index: usize) {
1217 let mut stats = self.stats.write();
1218 if let Some(count) = stats.tasks_per_device.get_mut(device_index) {
1219 *count = count.saturating_add(1);
1220 }
1221 if let Some(active) = stats.active_tasks.get_mut(device_index) {
1222 *active = active.saturating_add(1);
1223 }
1224 }
1225
1226 pub fn task_completed(&self, device_index: usize, duration_us: u64) {
1228 let mut stats = self.stats.write();
1229 if let Some(active) = stats.active_tasks.get_mut(device_index) {
1230 *active = active.saturating_sub(1);
1231 }
1232 if let Some(time) = stats.time_per_device.get_mut(device_index) {
1233 *time = time.saturating_add(duration_us);
1234 }
1235 }
1236
1237 pub fn get_stats(&self) -> LoadStats {
1239 self.stats.read().clone()
1240 }
1241
1242 pub fn print_stats(&self) {
1244 let stats = self.stats.read();
1245 println!("\nLoad Balancer Statistics:");
1246 println!(" Strategy: {:?}", self.strategy);
1247
1248 for (i, device) in self.devices.iter().enumerate() {
1249 let tasks = stats.tasks_per_device.get(i).copied().unwrap_or(0);
1250 let time_us = stats.time_per_device.get(i).copied().unwrap_or(0);
1251 let active = stats.active_tasks.get(i).copied().unwrap_or(0);
1252 let avg_time_us = if tasks > 0 {
1253 time_us / (tasks as u64)
1254 } else {
1255 0
1256 };
1257
1258 let migrations_from = stats.migrations_from.get(i).copied().unwrap_or(0);
1259 let migrations_to = stats.migrations_to.get(i).copied().unwrap_or(0);
1260
1261 println!("\n GPU {}: {}", i, device.info.name);
1262 println!(" Total tasks: {}", tasks);
1263 println!(" Active tasks: {}", active);
1264 println!(" Total time: {} ms", time_us / 1000);
1265 println!(" Avg task time: {} us", avg_time_us);
1266 println!(
1267 " Current workload: {:.1}%",
1268 device.get_workload() * 100.0
1269 );
1270 println!(" Migrations from: {}", migrations_from);
1271 println!(" Migrations to: {}", migrations_to);
1272 }
1273 }
1274
1275 pub fn reset_stats(&self) {
1277 let mut stats = self.stats.write();
1278 let device_count = self.devices.len();
1279 stats.tasks_per_device = vec![0; device_count];
1280 stats.time_per_device = vec![0; device_count];
1281 stats.active_tasks = vec![0; device_count];
1282 stats.memory_per_device = vec![0; device_count];
1283 stats.migrations_from = vec![0; device_count];
1284 stats.migrations_to = vec![0; device_count];
1285 }
1286
1287 pub fn get_device_utilization(&self, device_index: usize) -> f32 {
1289 self.devices
1290 .get(device_index)
1291 .map(|device| device.get_workload())
1292 .unwrap_or(0.0)
1293 }
1294
1295 pub fn get_cluster_utilization(&self) -> f32 {
1297 if self.devices.is_empty() {
1298 return 0.0;
1299 }
1300
1301 let total_utilization: f32 = self
1302 .devices
1303 .iter()
1304 .map(|device| device.get_workload())
1305 .sum();
1306
1307 total_utilization / (self.devices.len() as f32)
1308 }
1309
1310 pub fn suggest_device(&self, estimated_memory: u64) -> Result<Arc<GpuDevice>> {
1312 let candidates: Vec<_> = self
1314 .devices
1315 .iter()
1316 .filter(|device| {
1317 let memory_usage = device.get_memory_usage();
1318 let max_memory = device.info.max_buffer_size;
1319 (max_memory - memory_usage) >= estimated_memory
1320 })
1321 .collect();
1322
1323 if candidates.is_empty() {
1324 return Err(GpuAdvancedError::GpuNotFound(
1325 "No device with enough memory".to_string(),
1326 ));
1327 }
1328
1329 self.select_device()
1331 }
1332
1333 pub fn get_migration_stats(&self) -> (usize, usize, f64) {
1335 let history = self.migration_history.read();
1336 (
1337 history.total_successful,
1338 history.total_failed,
1339 history.overall_success_rate(),
1340 )
1341 }
1342
1343 pub fn device_count(&self) -> usize {
1345 self.devices.len()
1346 }
1347}
1348
1349#[cfg(test)]
1350mod tests {
1351 use super::*;
1352
1353 #[test]
1354 fn test_load_stats() {
1355 let stats = LoadStats::default();
1356 assert_eq!(stats.tasks_per_device.len(), 0);
1357 }
1358
1359 #[test]
1360 fn test_selection_strategy() {
1361 let strategy = SelectionStrategy::RoundRobin;
1363 let _strategy2 = strategy;
1364 }
1366
1367 #[test]
1368 fn test_migration_config_default() {
1369 let config = MigrationConfig::default();
1370 assert!(config.overload_threshold > 0.0);
1371 assert!(config.overload_threshold <= 1.0);
1372 assert!(config.underutilization_threshold >= 0.0);
1373 assert!(config.underutilization_threshold < config.overload_threshold);
1374 }
1375
1376 #[test]
1377 fn test_migratable_workload() {
1378 let workload = MigratableWorkload::new(1, 0, 1024 * 1024, 0.5, 10);
1379 assert_eq!(workload.id, 1);
1380 assert_eq!(workload.source_device, 0);
1381 assert_eq!(workload.memory_size, 1024 * 1024);
1382 assert!(!workload.migrating);
1383
1384 let workload_with_dep = workload.with_dependency(0);
1385 assert_eq!(workload_with_dep.dependencies.len(), 1);
1386 }
1387
1388 #[test]
1389 fn test_migration_cost_calculation() {
1390 let config = MigrationConfig::default();
1391 let workload = MigratableWorkload::new(1, 0, 1024 * 1024, 0.5, 10);
1392
1393 let cost = workload.calculate_migration_cost(&config);
1394 assert!(cost > config.transfer_cost_base);
1395 }
1396
1397 #[test]
1398 fn test_migration_plan() {
1399 let config = MigrationConfig::default();
1400 let workload = MigratableWorkload::new(1, 0, 1024 * 1024, 0.8, 10);
1401 let plan = MigrationPlan::new(workload, 1, &config);
1402
1403 assert_eq!(plan.source_device, 0);
1404 assert_eq!(plan.target_device, 1);
1405 assert!(plan.estimated_cost > 0.0);
1406 }
1407
1408 #[test]
1409 fn test_migration_history() {
1410 let mut history = MigrationHistory::new(10);
1411
1412 history.add_entry(MigrationHistoryEntry {
1413 timestamp: Instant::now(),
1414 source_device: 0,
1415 target_device: 1,
1416 success: true,
1417 transfer_time: Duration::from_millis(10),
1418 bytes_transferred: 1024,
1419 });
1420
1421 assert_eq!(history.total_successful, 1);
1422 assert_eq!(history.total_failed, 0);
1423 assert!((history.overall_success_rate() - 1.0).abs() < f64::EPSILON);
1424
1425 history.add_entry(MigrationHistoryEntry {
1426 timestamp: Instant::now(),
1427 source_device: 0,
1428 target_device: 1,
1429 success: false,
1430 transfer_time: Duration::from_millis(5),
1431 bytes_transferred: 0,
1432 });
1433
1434 assert_eq!(history.total_failed, 1);
1435 assert!((history.overall_success_rate() - 0.5).abs() < f64::EPSILON);
1436 }
1437
1438 #[test]
1439 fn test_workload_tracker() {
1440 let mut tracker = WorkloadTracker::new(2, 100);
1441
1442 let id1 = tracker.next_workload_id();
1443 let id2 = tracker.next_workload_id();
1444 assert_ne!(id1, id2);
1445
1446 let workload = MigratableWorkload::new(id1, 0, 1024, 0.5, 10);
1447 tracker.add_workload(0, workload);
1448 assert_eq!(tracker.pending_count(0), 1);
1449
1450 let removed = tracker.remove_workload(0, id1);
1451 assert!(removed.is_some());
1452 assert_eq!(tracker.pending_count(0), 0);
1453 }
1454
1455 #[test]
1456 fn test_utilization_sample() {
1457 let mut tracker = WorkloadTracker::new(2, 100);
1458
1459 for i in 0..10 {
1460 tracker.record_sample(
1461 0,
1462 UtilizationSample {
1463 timestamp: Instant::now(),
1464 compute: 0.1 * (i as f32),
1465 memory: 0.05 * (i as f32),
1466 active_tasks: i,
1467 },
1468 );
1469 }
1470
1471 let (avg_compute, avg_memory) = tracker
1472 .average_utilization(0, 5)
1473 .expect("Should have samples");
1474 assert!(avg_compute > 0.0);
1475 assert!(avg_memory > 0.0);
1476
1477 let trend = tracker.utilization_trend(0, 10).expect("Should have trend");
1478 assert!(trend > 0.0); }
1480
1481 #[test]
1482 fn test_device_load() {
1483 let config = MigrationConfig::default();
1484
1485 let load = DeviceLoad {
1486 device_index: 0,
1487 compute_utilization: 0.9,
1488 memory_utilization: 0.5,
1489 combined_load: 0.85,
1490 active_tasks: 5,
1491 pending_workloads: 3,
1492 score: 0.7,
1493 trend: 0.1,
1494 };
1495
1496 assert!(load.is_overloaded(&config));
1497 assert!(!load.is_underutilized(&config));
1498
1499 let underutilized_load = DeviceLoad {
1500 device_index: 1,
1501 compute_utilization: 0.1,
1502 memory_utilization: 0.1,
1503 combined_load: 0.1,
1504 active_tasks: 0,
1505 pending_workloads: 0,
1506 score: 0.9,
1507 trend: -0.05,
1508 };
1509
1510 assert!(!underutilized_load.is_overloaded(&config));
1511 assert!(underutilized_load.is_underutilized(&config));
1512 }
1513
1514 #[test]
1515 fn test_migration_history_average_time() {
1516 let mut history = MigrationHistory::new(10);
1517
1518 history.add_entry(MigrationHistoryEntry {
1519 timestamp: Instant::now(),
1520 source_device: 0,
1521 target_device: 1,
1522 success: true,
1523 transfer_time: Duration::from_millis(10),
1524 bytes_transferred: 1024,
1525 });
1526
1527 history.add_entry(MigrationHistoryEntry {
1528 timestamp: Instant::now(),
1529 source_device: 0,
1530 target_device: 1,
1531 success: true,
1532 transfer_time: Duration::from_millis(20),
1533 bytes_transferred: 2048,
1534 });
1535
1536 let avg = history
1537 .average_transfer_time(0, 1)
1538 .expect("Should have average");
1539 assert_eq!(avg, Duration::from_millis(15));
1540
1541 assert!(history.average_transfer_time(1, 0).is_none());
1542 }
1543
1544 #[test]
1545 fn test_migration_history_success_rate() {
1546 let mut history = MigrationHistory::new(10);
1547
1548 assert!((history.success_rate(0, 1) - 1.0).abs() < f64::EPSILON);
1550
1551 for _ in 0..3 {
1553 history.add_entry(MigrationHistoryEntry {
1554 timestamp: Instant::now(),
1555 source_device: 0,
1556 target_device: 1,
1557 success: true,
1558 transfer_time: Duration::from_millis(10),
1559 bytes_transferred: 1024,
1560 });
1561 }
1562
1563 history.add_entry(MigrationHistoryEntry {
1564 timestamp: Instant::now(),
1565 source_device: 0,
1566 target_device: 1,
1567 success: false,
1568 transfer_time: Duration::from_millis(5),
1569 bytes_transferred: 0,
1570 });
1571
1572 let rate = history.success_rate(0, 1);
1573 assert!((rate - 0.75).abs() < f64::EPSILON);
1574 }
1575
1576 #[test]
1577 fn test_workload_tracker_cooldown() {
1578 let mut tracker = WorkloadTracker::new(2, 100);
1579
1580 assert!(!tracker.is_in_cooldown(0, 1));
1582
1583 tracker.update_rebalance_time(0);
1585
1586 assert!(tracker.is_in_cooldown(0, 1));
1588
1589 assert!(!tracker.is_in_cooldown(0, 0));
1591 }
1592
1593 #[test]
1594 fn test_workload_tracker_migratable() {
1595 let mut tracker = WorkloadTracker::new(2, 100);
1596
1597 let workload1 = MigratableWorkload::new(0, 0, 1024, 0.5, 10);
1598 let mut workload2 = MigratableWorkload::new(1, 0, 2048, 0.7, 5);
1599 workload2.migrating = true;
1600 let workload3 = MigratableWorkload::new(2, 0, 4096, 0.3, 15).with_dependency(0);
1601
1602 tracker.add_workload(0, workload1);
1603 tracker.add_workload(0, workload2);
1604 tracker.add_workload(0, workload3);
1605
1606 let migratable = tracker.get_migratable_workloads(0);
1607
1608 assert_eq!(migratable.len(), 1);
1610 assert_eq!(migratable[0].id, 0);
1611 }
1612
1613 #[test]
1614 fn test_utilization_trend_calculation() {
1615 let mut tracker = WorkloadTracker::new(1, 100);
1616
1617 for i in 0..20 {
1619 tracker.record_sample(
1620 0,
1621 UtilizationSample {
1622 timestamp: Instant::now(),
1623 compute: 0.05 * (i as f32),
1624 memory: 0.02 * (i as f32),
1625 active_tasks: i,
1626 },
1627 );
1628 }
1629
1630 let trend = tracker
1631 .utilization_trend(0, 20)
1632 .expect("Should compute trend");
1633 assert!(
1634 trend > 0.0,
1635 "Trend should be positive for increasing samples"
1636 );
1637
1638 let mut tracker2 = WorkloadTracker::new(1, 100);
1640 for i in 0..20 {
1641 tracker2.record_sample(
1642 0,
1643 UtilizationSample {
1644 timestamp: Instant::now(),
1645 compute: 1.0 - 0.05 * (i as f32),
1646 memory: 0.5 - 0.02 * (i as f32),
1647 active_tasks: 20 - i,
1648 },
1649 );
1650 }
1651
1652 let trend2 = tracker2
1653 .utilization_trend(0, 20)
1654 .expect("Should compute trend");
1655 assert!(
1656 trend2 < 0.0,
1657 "Trend should be negative for decreasing samples"
1658 );
1659 }
1660}