Skip to main content

oxigdal_gpu_advanced/multi_gpu/
load_balancer.rs

1//! Load balancing and workload migration across multiple GPUs.
2//!
3//! This module provides advanced load balancing capabilities including:
4//! - GPU utilization monitoring
5//! - Workload migration between devices
6//! - Data transfer cost estimation
7//! - Multiple load balancing strategies
8
9use super::{GpuDevice, SelectionStrategy};
10use crate::error::{GpuAdvancedError, Result};
11use parking_lot::RwLock;
12use std::cmp::Ordering;
13use std::collections::VecDeque;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering as AtomicOrdering};
16use std::time::{Duration, Instant};
17
18/// Load balancer for distributing work across GPUs
19pub struct LoadBalancer {
20    /// Available devices
21    devices: Vec<Arc<GpuDevice>>,
22    /// Selection strategy
23    strategy: SelectionStrategy,
24    /// Round-robin counter
25    rr_counter: AtomicUsize,
26    /// Load statistics
27    stats: Arc<RwLock<LoadStats>>,
28    /// Migration configuration
29    migration_config: Arc<RwLock<MigrationConfig>>,
30    /// Migration history for adaptive decisions
31    migration_history: Arc<RwLock<MigrationHistory>>,
32    /// Workload tracker per device
33    workload_tracker: Arc<RwLock<WorkloadTracker>>,
34}
35
36/// Load balancing statistics
37#[derive(Debug, Clone, Default)]
38pub struct LoadStats {
39    /// Total tasks assigned per device
40    pub tasks_per_device: Vec<usize>,
41    /// Total execution time per device (microseconds)
42    pub time_per_device: Vec<u64>,
43    /// Current active tasks per device
44    pub active_tasks: Vec<usize>,
45    /// Memory usage per device (bytes)
46    pub memory_per_device: Vec<u64>,
47    /// Migration count per device (as source)
48    pub migrations_from: Vec<usize>,
49    /// Migration count per device (as destination)
50    pub migrations_to: Vec<usize>,
51}
52
53/// Configuration for workload migration decisions
54#[derive(Debug, Clone)]
55pub struct MigrationConfig {
56    /// Utilization threshold above which a GPU is considered overloaded (0.0 to 1.0)
57    pub overload_threshold: f32,
58    /// Utilization threshold below which a GPU is considered underutilized (0.0 to 1.0)
59    pub underutilization_threshold: f32,
60    /// Minimum utilization difference to trigger migration
61    pub min_imbalance_threshold: f32,
62    /// Base cost for data transfer (in arbitrary units representing time)
63    pub transfer_cost_base: f64,
64    /// Cost per byte transferred (in arbitrary units)
65    pub transfer_cost_per_byte: f64,
66    /// Minimum workload size to consider for migration (bytes)
67    pub min_migration_size: u64,
68    /// Maximum pending migrations per device
69    pub max_pending_migrations: usize,
70    /// Cooldown period between migrations for same device (seconds)
71    pub migration_cooldown_secs: u64,
72    /// Whether to enable predictive migration based on trends
73    pub enable_predictive_migration: bool,
74    /// History window size for trend analysis
75    pub history_window_size: usize,
76    /// Weight for memory pressure in migration decisions (0.0 to 1.0)
77    pub memory_weight: f32,
78    /// Weight for compute utilization in migration decisions (0.0 to 1.0)
79    pub compute_weight: f32,
80}
81
82impl Default for MigrationConfig {
83    fn default() -> Self {
84        Self {
85            overload_threshold: 0.8,
86            underutilization_threshold: 0.3,
87            min_imbalance_threshold: 0.2,
88            transfer_cost_base: 1.0,
89            transfer_cost_per_byte: 0.000001, // 1 microsecond per megabyte
90            min_migration_size: 1024,         // 1 KB minimum
91            max_pending_migrations: 4,
92            migration_cooldown_secs: 5,
93            enable_predictive_migration: true,
94            history_window_size: 100,
95            memory_weight: 0.4,
96            compute_weight: 0.6,
97        }
98    }
99}
100
101/// Represents a migratable workload
102#[derive(Debug, Clone)]
103pub struct MigratableWorkload {
104    /// Unique identifier for the workload
105    pub id: u64,
106    /// Source device index
107    pub source_device: usize,
108    /// Estimated memory footprint in bytes
109    pub memory_size: u64,
110    /// Estimated compute intensity (0.0 to 1.0)
111    pub compute_intensity: f32,
112    /// Priority level (higher = more important)
113    pub priority: u32,
114    /// Creation timestamp
115    pub created_at: Instant,
116    /// Whether this workload is currently being migrated
117    pub migrating: bool,
118    /// Data dependencies (other workload IDs this depends on)
119    pub dependencies: Vec<u64>,
120}
121
122impl MigratableWorkload {
123    /// Create a new migratable workload
124    pub fn new(
125        id: u64,
126        source_device: usize,
127        memory_size: u64,
128        compute_intensity: f32,
129        priority: u32,
130    ) -> Self {
131        Self {
132            id,
133            source_device,
134            memory_size,
135            compute_intensity,
136            priority,
137            created_at: Instant::now(),
138            migrating: false,
139            dependencies: Vec::new(),
140        }
141    }
142
143    /// Add a dependency to this workload
144    pub fn with_dependency(mut self, dep_id: u64) -> Self {
145        self.dependencies.push(dep_id);
146        self
147    }
148
149    /// Calculate migration cost based on configuration
150    pub fn calculate_migration_cost(&self, config: &MigrationConfig) -> f64 {
151        config.transfer_cost_base
152            + (self.memory_size as f64 * config.transfer_cost_per_byte)
153            + (self.compute_intensity as f64 * 0.1) // Compute intensity penalty
154    }
155}
156
157/// A planned migration operation
158#[derive(Debug, Clone)]
159pub struct MigrationPlan {
160    /// The workload to migrate
161    pub workload: MigratableWorkload,
162    /// Source device index
163    pub source_device: usize,
164    /// Destination device index
165    pub target_device: usize,
166    /// Estimated transfer cost
167    pub estimated_cost: f64,
168    /// Expected benefit (load reduction on source)
169    pub expected_benefit: f64,
170    /// Net benefit (benefit - cost)
171    pub net_benefit: f64,
172    /// Plan creation timestamp
173    pub created_at: Instant,
174    /// Whether the plan is approved for execution
175    pub approved: bool,
176}
177
178impl MigrationPlan {
179    /// Create a new migration plan
180    pub fn new(
181        workload: MigratableWorkload,
182        target_device: usize,
183        config: &MigrationConfig,
184    ) -> Self {
185        let source_device = workload.source_device;
186        let estimated_cost = workload.calculate_migration_cost(config);
187        let expected_benefit = workload.compute_intensity as f64 * 10.0; // Arbitrary benefit scale
188        let net_benefit = expected_benefit - estimated_cost;
189
190        Self {
191            workload,
192            source_device,
193            target_device,
194            estimated_cost,
195            expected_benefit,
196            net_benefit,
197            created_at: Instant::now(),
198            approved: net_benefit > 0.0,
199        }
200    }
201
202    /// Check if migration should proceed
203    pub fn should_migrate(&self) -> bool {
204        self.approved && self.net_benefit > 0.0
205    }
206}
207
208/// Result of a migration operation
209#[derive(Debug, Clone)]
210pub struct MigrationResult {
211    /// Whether the migration succeeded
212    pub success: bool,
213    /// Source device index
214    pub source_device: usize,
215    /// Target device index
216    pub target_device: usize,
217    /// Workload ID that was migrated
218    pub workload_id: u64,
219    /// Actual transfer time
220    pub transfer_time: Duration,
221    /// Bytes transferred
222    pub bytes_transferred: u64,
223    /// Error message if failed
224    pub error_message: Option<String>,
225}
226
227/// History of migrations for adaptive decisions
228#[derive(Debug, Default)]
229pub struct MigrationHistory {
230    /// Recent migration results
231    entries: VecDeque<MigrationHistoryEntry>,
232    /// Maximum history size
233    max_size: usize,
234    /// Total successful migrations
235    total_successful: usize,
236    /// Total failed migrations
237    total_failed: usize,
238}
239
240/// Single entry in migration history
241#[derive(Debug, Clone)]
242pub struct MigrationHistoryEntry {
243    /// Timestamp of the migration
244    pub timestamp: Instant,
245    /// Source device
246    pub source_device: usize,
247    /// Target device
248    pub target_device: usize,
249    /// Whether it succeeded
250    pub success: bool,
251    /// Transfer time
252    pub transfer_time: Duration,
253    /// Bytes transferred
254    pub bytes_transferred: u64,
255}
256
257impl MigrationHistory {
258    /// Create a new migration history
259    pub fn new(max_size: usize) -> Self {
260        Self {
261            entries: VecDeque::with_capacity(max_size),
262            max_size,
263            total_successful: 0,
264            total_failed: 0,
265        }
266    }
267
268    /// Add an entry to the history
269    pub fn add_entry(&mut self, entry: MigrationHistoryEntry) {
270        if entry.success {
271            self.total_successful += 1;
272        } else {
273            self.total_failed += 1;
274        }
275
276        if self.entries.len() >= self.max_size {
277            self.entries.pop_front();
278        }
279        self.entries.push_back(entry);
280    }
281
282    /// Get success rate for migrations between specific devices
283    pub fn success_rate(&self, source: usize, target: usize) -> f64 {
284        let filtered: Vec<_> = self
285            .entries
286            .iter()
287            .filter(|e| e.source_device == source && e.target_device == target)
288            .collect();
289
290        if filtered.is_empty() {
291            return 1.0; // Assume success if no history
292        }
293
294        let successful = filtered.iter().filter(|e| e.success).count();
295        successful as f64 / filtered.len() as f64
296    }
297
298    /// Get average transfer time for migrations between devices
299    pub fn average_transfer_time(&self, source: usize, target: usize) -> Option<Duration> {
300        let filtered: Vec<_> = self
301            .entries
302            .iter()
303            .filter(|e| e.source_device == source && e.target_device == target && e.success)
304            .collect();
305
306        if filtered.is_empty() {
307            return None;
308        }
309
310        let total: Duration = filtered.iter().map(|e| e.transfer_time).sum();
311        Some(total / filtered.len() as u32)
312    }
313
314    /// Get total bytes transferred
315    pub fn total_bytes_transferred(&self) -> u64 {
316        self.entries.iter().map(|e| e.bytes_transferred).sum()
317    }
318
319    /// Get overall success rate
320    pub fn overall_success_rate(&self) -> f64 {
321        let total = self.total_successful + self.total_failed;
322        if total == 0 {
323            return 1.0;
324        }
325        self.total_successful as f64 / total as f64
326    }
327}
328
329/// Tracks workload distribution over time
330#[derive(Debug)]
331pub struct WorkloadTracker {
332    /// Per-device utilization samples
333    utilization_samples: Vec<VecDeque<UtilizationSample>>,
334    /// Per-device pending workloads
335    pending_workloads: Vec<Vec<MigratableWorkload>>,
336    /// Global workload counter
337    next_workload_id: AtomicU64,
338    /// Last rebalance timestamp per device (None if never rebalanced)
339    last_rebalance: Vec<Option<Instant>>,
340}
341
342/// A single utilization sample
343#[derive(Debug, Clone)]
344pub struct UtilizationSample {
345    /// Sample timestamp
346    pub timestamp: Instant,
347    /// Compute utilization (0.0 to 1.0)
348    pub compute: f32,
349    /// Memory utilization (0.0 to 1.0)
350    pub memory: f32,
351    /// Active task count
352    pub active_tasks: usize,
353}
354
355impl WorkloadTracker {
356    /// Create a new workload tracker for N devices
357    pub fn new(device_count: usize, history_size: usize) -> Self {
358        let mut utilization_samples = Vec::with_capacity(device_count);
359        let mut pending_workloads = Vec::with_capacity(device_count);
360        let mut last_rebalance = Vec::with_capacity(device_count);
361
362        for _ in 0..device_count {
363            utilization_samples.push(VecDeque::with_capacity(history_size));
364            pending_workloads.push(Vec::new());
365            // Initialize to None - devices start without cooldown since no rebalancing has happened
366            last_rebalance.push(None);
367        }
368
369        Self {
370            utilization_samples,
371            pending_workloads,
372            next_workload_id: AtomicU64::new(0),
373            last_rebalance,
374        }
375    }
376
377    /// Generate a new workload ID
378    pub fn next_workload_id(&self) -> u64 {
379        self.next_workload_id.fetch_add(1, AtomicOrdering::Relaxed)
380    }
381
382    /// Record a utilization sample for a device
383    pub fn record_sample(&mut self, device_index: usize, sample: UtilizationSample) {
384        if let Some(samples) = self.utilization_samples.get_mut(device_index) {
385            if samples.len() >= samples.capacity() {
386                samples.pop_front();
387            }
388            samples.push_back(sample);
389        }
390    }
391
392    /// Get average utilization for a device over recent samples
393    pub fn average_utilization(&self, device_index: usize, window: usize) -> Option<(f32, f32)> {
394        let samples = self.utilization_samples.get(device_index)?;
395        if samples.is_empty() {
396            return None;
397        }
398
399        let take_count = window.min(samples.len());
400        let recent: Vec<_> = samples.iter().rev().take(take_count).collect();
401
402        let avg_compute = recent.iter().map(|s| s.compute).sum::<f32>() / take_count as f32;
403        let avg_memory = recent.iter().map(|s| s.memory).sum::<f32>() / take_count as f32;
404
405        Some((avg_compute, avg_memory))
406    }
407
408    /// Get utilization trend (positive = increasing, negative = decreasing)
409    pub fn utilization_trend(&self, device_index: usize, window: usize) -> Option<f32> {
410        let samples = self.utilization_samples.get(device_index)?;
411        if samples.len() < 2 {
412            return None;
413        }
414
415        let take_count = window.min(samples.len());
416        // Get recent samples in chronological order (oldest first, newest last)
417        // Skip older samples and take the most recent ones
418        let skip_count = samples.len().saturating_sub(take_count);
419        let recent: Vec<_> = samples.iter().skip(skip_count).collect();
420
421        if recent.len() < 2 {
422            return None;
423        }
424
425        // Simple linear regression slope
426        // x = 0 is oldest, x = n-1 is newest
427        // Positive slope means utilization is increasing over time
428        let n = recent.len() as f32;
429        let mut sum_x = 0.0f32;
430        let mut sum_y = 0.0f32;
431        let mut sum_xy = 0.0f32;
432        let mut sum_xx = 0.0f32;
433
434        for (i, sample) in recent.iter().enumerate() {
435            let x = i as f32;
436            let y = sample.compute;
437            sum_x += x;
438            sum_y += y;
439            sum_xy += x * y;
440            sum_xx += x * x;
441        }
442
443        let denominator = n * sum_xx - sum_x * sum_x;
444        if denominator.abs() < f32::EPSILON {
445            return Some(0.0);
446        }
447
448        Some((n * sum_xy - sum_x * sum_y) / denominator)
449    }
450
451    /// Add a pending workload to a device
452    pub fn add_workload(&mut self, device_index: usize, workload: MigratableWorkload) {
453        if let Some(workloads) = self.pending_workloads.get_mut(device_index) {
454            workloads.push(workload);
455        }
456    }
457
458    /// Remove a workload by ID
459    pub fn remove_workload(
460        &mut self,
461        device_index: usize,
462        workload_id: u64,
463    ) -> Option<MigratableWorkload> {
464        if let Some(workloads) = self.pending_workloads.get_mut(device_index) {
465            if let Some(pos) = workloads.iter().position(|w| w.id == workload_id) {
466                return Some(workloads.remove(pos));
467            }
468        }
469        None
470    }
471
472    /// Get migratable workloads from a device (not already migrating, no pending dependencies)
473    pub fn get_migratable_workloads(&self, device_index: usize) -> Vec<&MigratableWorkload> {
474        self.pending_workloads
475            .get(device_index)
476            .map(|workloads| {
477                workloads
478                    .iter()
479                    .filter(|w| !w.migrating && w.dependencies.is_empty())
480                    .collect()
481            })
482            .unwrap_or_default()
483    }
484
485    /// Update last rebalance time for a device
486    pub fn update_rebalance_time(&mut self, device_index: usize) {
487        if let Some(time) = self.last_rebalance.get_mut(device_index) {
488            *time = Some(Instant::now());
489        }
490    }
491
492    /// Check if device is in cooldown period
493    ///
494    /// Returns false if no rebalancing has ever happened on this device.
495    pub fn is_in_cooldown(&self, device_index: usize, cooldown_secs: u64) -> bool {
496        self.last_rebalance
497            .get(device_index)
498            .and_then(|opt| opt.as_ref())
499            .map(|t| t.elapsed().as_secs() < cooldown_secs)
500            .unwrap_or(false)
501    }
502
503    /// Get pending workload count for a device
504    pub fn pending_count(&self, device_index: usize) -> usize {
505        self.pending_workloads
506            .get(device_index)
507            .map(|w| w.len())
508            .unwrap_or(0)
509    }
510}
511
512/// Device load information for balancing decisions
513#[derive(Debug, Clone)]
514pub struct DeviceLoad {
515    /// Device index
516    pub device_index: usize,
517    /// Current compute utilization (0.0 to 1.0)
518    pub compute_utilization: f32,
519    /// Current memory utilization (0.0 to 1.0)
520    pub memory_utilization: f32,
521    /// Combined load score
522    pub combined_load: f32,
523    /// Active task count
524    pub active_tasks: usize,
525    /// Pending workload count
526    pub pending_workloads: usize,
527    /// Device score (higher = better for new work)
528    pub score: f32,
529    /// Utilization trend (positive = increasing)
530    pub trend: f32,
531}
532
533impl DeviceLoad {
534    /// Check if device is overloaded
535    pub fn is_overloaded(&self, config: &MigrationConfig) -> bool {
536        self.combined_load > config.overload_threshold
537    }
538
539    /// Check if device is underutilized
540    pub fn is_underutilized(&self, config: &MigrationConfig) -> bool {
541        self.combined_load < config.underutilization_threshold
542    }
543}
544
545impl LoadBalancer {
546    /// Create a new load balancer
547    pub fn new(devices: Vec<Arc<GpuDevice>>, strategy: SelectionStrategy) -> Self {
548        let device_count = devices.len();
549        let stats = LoadStats {
550            tasks_per_device: vec![0; device_count],
551            time_per_device: vec![0; device_count],
552            active_tasks: vec![0; device_count],
553            memory_per_device: vec![0; device_count],
554            migrations_from: vec![0; device_count],
555            migrations_to: vec![0; device_count],
556        };
557
558        let config = MigrationConfig::default();
559        let tracker = WorkloadTracker::new(device_count, config.history_window_size);
560        let history = MigrationHistory::new(config.history_window_size);
561
562        Self {
563            devices,
564            strategy,
565            rr_counter: AtomicUsize::new(0),
566            stats: Arc::new(RwLock::new(stats)),
567            migration_config: Arc::new(RwLock::new(config)),
568            migration_history: Arc::new(RwLock::new(history)),
569            workload_tracker: Arc::new(RwLock::new(tracker)),
570        }
571    }
572
573    /// Get the migration configuration
574    pub fn migration_config(&self) -> MigrationConfig {
575        self.migration_config.read().clone()
576    }
577
578    /// Update migration configuration
579    pub fn set_migration_config(&self, config: MigrationConfig) {
580        *self.migration_config.write() = config;
581    }
582
583    /// Select a device using the configured strategy
584    pub fn select_device(&self) -> Result<Arc<GpuDevice>> {
585        if self.devices.is_empty() {
586            return Err(GpuAdvancedError::GpuNotFound(
587                "No devices available".to_string(),
588            ));
589        }
590
591        match self.strategy {
592            SelectionStrategy::RoundRobin => self.select_round_robin(),
593            SelectionStrategy::LeastLoaded => self.select_least_loaded(),
594            SelectionStrategy::BestScore => self.select_best_score(),
595            SelectionStrategy::Affinity => self.select_affinity(),
596        }
597    }
598
599    /// Round-robin selection
600    fn select_round_robin(&self) -> Result<Arc<GpuDevice>> {
601        let index = self.rr_counter.fetch_add(1, AtomicOrdering::Relaxed) % self.devices.len();
602        self.devices
603            .get(index)
604            .cloned()
605            .ok_or(GpuAdvancedError::InvalidGpuIndex {
606                index,
607                total: self.devices.len(),
608            })
609    }
610
611    /// Select least loaded device
612    fn select_least_loaded(&self) -> Result<Arc<GpuDevice>> {
613        let stats = self.stats.read();
614
615        let (index, _) = self
616            .devices
617            .iter()
618            .enumerate()
619            .map(|(i, device)| {
620                let active_tasks = stats.active_tasks.get(i).copied().unwrap_or(0);
621                let workload = device.get_workload();
622                let load = (active_tasks as f32) + workload;
623                (i, load)
624            })
625            .min_by(|(_, load_a), (_, load_b)| {
626                load_a.partial_cmp(load_b).unwrap_or(Ordering::Equal)
627            })
628            .ok_or_else(|| {
629                GpuAdvancedError::LoadBalancingError("No device available".to_string())
630            })?;
631
632        self.devices
633            .get(index)
634            .cloned()
635            .ok_or(GpuAdvancedError::InvalidGpuIndex {
636                index,
637                total: self.devices.len(),
638            })
639    }
640
641    /// Select device with best score
642    fn select_best_score(&self) -> Result<Arc<GpuDevice>> {
643        let (index, _) = self
644            .devices
645            .iter()
646            .enumerate()
647            .map(|(i, device)| (i, device.get_score()))
648            .max_by(|(_, score_a), (_, score_b)| {
649                score_a.partial_cmp(score_b).unwrap_or(Ordering::Equal)
650            })
651            .ok_or_else(|| {
652                GpuAdvancedError::LoadBalancingError("No device available".to_string())
653            })?;
654
655        self.devices
656            .get(index)
657            .cloned()
658            .ok_or(GpuAdvancedError::InvalidGpuIndex {
659                index,
660                total: self.devices.len(),
661            })
662    }
663
664    /// Select device using affinity (prefers previously used device)
665    fn select_affinity(&self) -> Result<Arc<GpuDevice>> {
666        // For now, use thread-local affinity based on thread ID
667        let thread_id = std::thread::current().id();
668        let hash = {
669            use std::collections::hash_map::DefaultHasher;
670            use std::hash::{Hash, Hasher};
671            let mut hasher = DefaultHasher::new();
672            thread_id.hash(&mut hasher);
673            hasher.finish()
674        };
675
676        let index = (hash as usize) % self.devices.len();
677        self.devices
678            .get(index)
679            .cloned()
680            .ok_or(GpuAdvancedError::InvalidGpuIndex {
681                index,
682                total: self.devices.len(),
683            })
684    }
685
686    /// Select device using weighted strategy based on device performance
687    pub fn select_weighted(&self) -> Result<Arc<GpuDevice>> {
688        if self.devices.is_empty() {
689            return Err(GpuAdvancedError::GpuNotFound(
690                "No devices available".to_string(),
691            ));
692        }
693
694        let config = self.migration_config.read();
695
696        // Calculate weighted scores for each device
697        let mut best_index = 0;
698        let mut best_score = f32::MIN;
699
700        for (i, device) in self.devices.iter().enumerate() {
701            let compute_util = device.get_workload();
702            let memory_usage = device.get_memory_usage();
703            let max_memory = device.info.max_buffer_size;
704            let memory_util = if max_memory > 0 {
705                memory_usage as f32 / max_memory as f32
706            } else {
707                0.0
708            };
709
710            // Weighted combination of factors
711            let availability =
712                1.0 - (compute_util * config.compute_weight + memory_util * config.memory_weight);
713            let type_bonus = device.get_score();
714            let score = availability * type_bonus;
715
716            if score > best_score {
717                best_score = score;
718                best_index = i;
719            }
720        }
721
722        self.devices
723            .get(best_index)
724            .cloned()
725            .ok_or(GpuAdvancedError::InvalidGpuIndex {
726                index: best_index,
727                total: self.devices.len(),
728            })
729    }
730
731    /// Get current load information for all devices
732    pub fn get_device_loads(&self) -> Vec<DeviceLoad> {
733        let config = self.migration_config.read();
734        let tracker = self.workload_tracker.read();
735        let stats = self.stats.read();
736
737        self.devices
738            .iter()
739            .enumerate()
740            .map(|(i, device)| {
741                let compute_utilization = device.get_workload();
742                let memory_usage = device.get_memory_usage();
743                let max_memory = device.info.max_buffer_size;
744                let memory_utilization = if max_memory > 0 {
745                    memory_usage as f32 / max_memory as f32
746                } else {
747                    0.0
748                };
749
750                let combined_load = compute_utilization * config.compute_weight
751                    + memory_utilization * config.memory_weight;
752
753                let trend = tracker.utilization_trend(i, 10).unwrap_or(0.0);
754
755                DeviceLoad {
756                    device_index: i,
757                    compute_utilization,
758                    memory_utilization,
759                    combined_load,
760                    active_tasks: stats.active_tasks.get(i).copied().unwrap_or(0),
761                    pending_workloads: tracker.pending_count(i),
762                    score: device.get_score(),
763                    trend,
764                }
765            })
766            .collect()
767    }
768
769    /// Identify overloaded devices
770    pub fn identify_overloaded_devices(&self) -> Vec<DeviceLoad> {
771        let config = self.migration_config.read();
772        self.get_device_loads()
773            .into_iter()
774            .filter(|load| load.is_overloaded(&config))
775            .collect()
776    }
777
778    /// Identify underutilized devices
779    pub fn identify_underutilized_devices(&self) -> Vec<DeviceLoad> {
780        let config = self.migration_config.read();
781        self.get_device_loads()
782            .into_iter()
783            .filter(|load| load.is_underutilized(&config))
784            .collect()
785    }
786
787    /// Check if load is imbalanced (requires rebalancing)
788    pub fn is_imbalanced(&self) -> bool {
789        let loads = self.get_device_loads();
790        if loads.len() < 2 {
791            return false;
792        }
793
794        let config = self.migration_config.read();
795
796        // Find max and min load
797        let max_load = loads
798            .iter()
799            .map(|l| l.combined_load)
800            .fold(f32::MIN, f32::max);
801        let min_load = loads
802            .iter()
803            .map(|l| l.combined_load)
804            .fold(f32::MAX, f32::min);
805
806        (max_load - min_load) > config.min_imbalance_threshold
807    }
808
809    /// Calculate data transfer cost between two devices
810    pub fn calculate_transfer_cost(
811        &self,
812        source_device: usize,
813        target_device: usize,
814        data_size: u64,
815    ) -> Result<f64> {
816        if source_device >= self.devices.len() || target_device >= self.devices.len() {
817            return Err(GpuAdvancedError::InvalidGpuIndex {
818                index: source_device.max(target_device),
819                total: self.devices.len(),
820            });
821        }
822
823        let config = self.migration_config.read();
824        let history = self.migration_history.read();
825
826        // Base cost from configuration
827        let mut cost =
828            config.transfer_cost_base + (data_size as f64 * config.transfer_cost_per_byte);
829
830        // Adjust based on historical transfer times
831        if let Some(avg_time) = history.average_transfer_time(source_device, target_device) {
832            // Scale cost based on historical performance
833            let time_factor = avg_time.as_secs_f64();
834            cost *= 1.0 + time_factor;
835        }
836
837        // Adjust for historical success rate
838        let success_rate = history.success_rate(source_device, target_device);
839        if success_rate < 1.0 {
840            // Increase cost for unreliable transfers
841            cost *= 1.0 + (1.0 - success_rate) * 0.5;
842        }
843
844        Ok(cost)
845    }
846
847    /// Create a migration plan for a workload
848    pub fn create_migration_plan(
849        &self,
850        workload: MigratableWorkload,
851        target_device: usize,
852    ) -> Result<MigrationPlan> {
853        if target_device >= self.devices.len() {
854            return Err(GpuAdvancedError::InvalidGpuIndex {
855                index: target_device,
856                total: self.devices.len(),
857            });
858        }
859
860        let config = self.migration_config.read();
861        let plan = MigrationPlan::new(workload, target_device, &config);
862
863        Ok(plan)
864    }
865
866    /// Find best migration target for an overloaded device
867    pub fn find_migration_target(&self, source_device: usize) -> Result<Option<usize>> {
868        let loads = self.get_device_loads();
869        let config = self.migration_config.read();
870        let tracker = self.workload_tracker.read();
871
872        // Find the source load
873        let source_load = loads
874            .iter()
875            .find(|l| l.device_index == source_device)
876            .ok_or(GpuAdvancedError::InvalidGpuIndex {
877                index: source_device,
878                total: self.devices.len(),
879            })?;
880
881        // Find candidate targets (underutilized devices not in cooldown)
882        let mut candidates: Vec<_> = loads
883            .iter()
884            .filter(|l| {
885                l.device_index != source_device
886                    && l.is_underutilized(&config)
887                    && !tracker.is_in_cooldown(l.device_index, config.migration_cooldown_secs)
888            })
889            .collect();
890
891        if candidates.is_empty() {
892            return Ok(None);
893        }
894
895        // Sort by combined load (ascending) and score (descending)
896        candidates.sort_by(|a, b| match a.combined_load.partial_cmp(&b.combined_load) {
897            Some(Ordering::Equal) | None => {
898                b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
899            }
900            Some(ordering) => ordering,
901        });
902
903        // Return the best candidate if migration would improve balance
904        if let Some(best) = candidates.first() {
905            let load_diff = source_load.combined_load - best.combined_load;
906            if load_diff > config.min_imbalance_threshold {
907                return Ok(Some(best.device_index));
908            }
909        }
910
911        Ok(None)
912    }
913
914    /// Select workload to migrate from an overloaded device
915    pub fn select_workload_for_migration(
916        &self,
917        source_device: usize,
918    ) -> Option<MigratableWorkload> {
919        let config = self.migration_config.read();
920        let tracker = self.workload_tracker.read();
921
922        let migratable = tracker.get_migratable_workloads(source_device);
923
924        // Filter by minimum size and sort by priority and compute intensity
925        let mut candidates: Vec<_> = migratable
926            .into_iter()
927            .filter(|w| w.memory_size >= config.min_migration_size)
928            .collect();
929
930        candidates.sort_by(|a, b| {
931            // Prefer higher priority and higher compute intensity
932            match b.priority.cmp(&a.priority) {
933                Ordering::Equal => b
934                    .compute_intensity
935                    .partial_cmp(&a.compute_intensity)
936                    .unwrap_or(Ordering::Equal),
937                other => other,
938            }
939        });
940
941        candidates.first().map(|w| (*w).clone())
942    }
943
944    /// Execute a migration (simulated - actual data transfer would use sync module)
945    pub fn execute_migration(&self, plan: &MigrationPlan) -> Result<MigrationResult> {
946        if !plan.should_migrate() {
947            return Ok(MigrationResult {
948                success: false,
949                source_device: plan.source_device,
950                target_device: plan.target_device,
951                workload_id: plan.workload.id,
952                transfer_time: Duration::ZERO,
953                bytes_transferred: 0,
954                error_message: Some("Migration not approved".to_string()),
955            });
956        }
957
958        let start = Instant::now();
959
960        // Update workload tracker
961        {
962            let mut tracker = self.workload_tracker.write();
963
964            // Remove from source
965            if tracker
966                .remove_workload(plan.source_device, plan.workload.id)
967                .is_none()
968            {
969                return Ok(MigrationResult {
970                    success: false,
971                    source_device: plan.source_device,
972                    target_device: plan.target_device,
973                    workload_id: plan.workload.id,
974                    transfer_time: Duration::ZERO,
975                    bytes_transferred: 0,
976                    error_message: Some("Workload not found on source device".to_string()),
977                });
978            }
979
980            // Add to target with updated source
981            let mut migrated = plan.workload.clone();
982            migrated.source_device = plan.target_device;
983            tracker.add_workload(plan.target_device, migrated);
984
985            // Update rebalance times
986            tracker.update_rebalance_time(plan.source_device);
987            tracker.update_rebalance_time(plan.target_device);
988        }
989
990        // Update statistics
991        {
992            let mut stats = self.stats.write();
993            if let Some(from) = stats.migrations_from.get_mut(plan.source_device) {
994                *from = from.saturating_add(1);
995            }
996            if let Some(to) = stats.migrations_to.get_mut(plan.target_device) {
997                *to = to.saturating_add(1);
998            }
999        }
1000
1001        let transfer_time = start.elapsed();
1002
1003        // Record in history
1004        {
1005            let mut history = self.migration_history.write();
1006            history.add_entry(MigrationHistoryEntry {
1007                timestamp: Instant::now(),
1008                source_device: plan.source_device,
1009                target_device: plan.target_device,
1010                success: true,
1011                transfer_time,
1012                bytes_transferred: plan.workload.memory_size,
1013            });
1014        }
1015
1016        Ok(MigrationResult {
1017            success: true,
1018            source_device: plan.source_device,
1019            target_device: plan.target_device,
1020            workload_id: plan.workload.id,
1021            transfer_time,
1022            bytes_transferred: plan.workload.memory_size,
1023            error_message: None,
1024        })
1025    }
1026
1027    /// Rebalance workloads across devices
1028    ///
1029    /// This method implements the core workload migration logic:
1030    /// 1. Monitor GPU utilization across all devices
1031    /// 2. Identify overloaded and underutilized GPUs
1032    /// 3. Calculate transfer costs for potential migrations
1033    /// 4. Execute migrations that improve overall balance
1034    pub fn rebalance(&self) -> Result<Vec<MigrationResult>> {
1035        // Check if rebalancing is needed
1036        if !self.is_imbalanced() {
1037            return Ok(Vec::new());
1038        }
1039
1040        let mut results = Vec::new();
1041
1042        // Sample current utilization for all devices
1043        self.sample_utilization();
1044
1045        // Identify overloaded devices
1046        let overloaded = self.identify_overloaded_devices();
1047        if overloaded.is_empty() {
1048            return Ok(results);
1049        }
1050
1051        // Process each overloaded device
1052        for source_load in overloaded {
1053            // Check migration limit
1054            if results.len() >= self.migration_config.read().max_pending_migrations {
1055                break;
1056            }
1057
1058            // Find a suitable target
1059            let target = match self.find_migration_target(source_load.device_index)? {
1060                Some(t) => t,
1061                None => continue,
1062            };
1063
1064            // Select workload to migrate
1065            let workload = match self.select_workload_for_migration(source_load.device_index) {
1066                Some(w) => w,
1067                None => continue,
1068            };
1069
1070            // Create and execute migration plan
1071            let plan = self.create_migration_plan(workload, target)?;
1072            if plan.should_migrate() {
1073                let result = self.execute_migration(&plan)?;
1074                results.push(result);
1075            }
1076        }
1077
1078        // Handle predictive migration if enabled
1079        let config = self.migration_config.read();
1080        if config.enable_predictive_migration {
1081            drop(config);
1082            self.handle_predictive_migrations(&mut results)?;
1083        }
1084
1085        Ok(results)
1086    }
1087
1088    /// Sample current utilization for all devices
1089    fn sample_utilization(&self) {
1090        let stats = self.stats.read();
1091        let mut tracker = self.workload_tracker.write();
1092
1093        for (i, device) in self.devices.iter().enumerate() {
1094            let compute = device.get_workload();
1095            let memory_usage = device.get_memory_usage();
1096            let max_memory = device.info.max_buffer_size;
1097            let memory = if max_memory > 0 {
1098                memory_usage as f32 / max_memory as f32
1099            } else {
1100                0.0
1101            };
1102
1103            tracker.record_sample(
1104                i,
1105                UtilizationSample {
1106                    timestamp: Instant::now(),
1107                    compute,
1108                    memory,
1109                    active_tasks: stats.active_tasks.get(i).copied().unwrap_or(0),
1110                },
1111            );
1112        }
1113    }
1114
1115    /// Handle predictive migrations based on utilization trends
1116    fn handle_predictive_migrations(&self, results: &mut Vec<MigrationResult>) -> Result<()> {
1117        // Collect device indices that need predictive migration
1118        let candidates: Vec<(usize, f32, f32)> = {
1119            let config = self.migration_config.read();
1120            let tracker = self.workload_tracker.read();
1121            let device_loads = self.get_device_loads();
1122
1123            self.devices
1124                .iter()
1125                .enumerate()
1126                .filter_map(|(i, _device)| {
1127                    // Check utilization trend
1128                    let trend = tracker.utilization_trend(i, 20)?;
1129
1130                    // Find load for this device
1131                    let load = device_loads.iter().find(|l| l.device_index == i)?;
1132
1133                    // If trend is strongly increasing and device is moderately loaded
1134                    if trend > 0.05
1135                        && load.combined_load > 0.5
1136                        && load.combined_load < config.overload_threshold
1137                    {
1138                        Some((i, trend, load.combined_load))
1139                    } else {
1140                        None
1141                    }
1142                })
1143                .collect()
1144        }; // Locks are released here
1145
1146        // Process candidates outside the lock
1147        let max_migrations = self.migration_config.read().max_pending_migrations;
1148        for (device_index, _trend, _combined_load) in candidates {
1149            // Check if we've hit the migration limit
1150            if results.len() >= max_migrations {
1151                break;
1152            }
1153
1154            // Preemptively migrate to prevent overload
1155            if let Some(target) = self.find_migration_target(device_index)? {
1156                if let Some(workload) = self.select_workload_for_migration(device_index) {
1157                    let plan = self.create_migration_plan(workload, target)?;
1158                    if plan.should_migrate() {
1159                        let result = self.execute_migration(&plan)?;
1160                        results.push(result);
1161                    }
1162                }
1163            }
1164        }
1165
1166        Ok(())
1167    }
1168
1169    /// Register a new workload on a device
1170    pub fn register_workload(
1171        &self,
1172        device_index: usize,
1173        memory_size: u64,
1174        compute_intensity: f32,
1175        priority: u32,
1176    ) -> Result<u64> {
1177        if device_index >= self.devices.len() {
1178            return Err(GpuAdvancedError::InvalidGpuIndex {
1179                index: device_index,
1180                total: self.devices.len(),
1181            });
1182        }
1183
1184        let mut tracker = self.workload_tracker.write();
1185        let workload_id = tracker.next_workload_id();
1186
1187        let workload = MigratableWorkload::new(
1188            workload_id,
1189            device_index,
1190            memory_size,
1191            compute_intensity,
1192            priority,
1193        );
1194
1195        tracker.add_workload(device_index, workload);
1196
1197        Ok(workload_id)
1198    }
1199
1200    /// Unregister a workload (completed or cancelled)
1201    pub fn unregister_workload(&self, device_index: usize, workload_id: u64) -> Result<()> {
1202        if device_index >= self.devices.len() {
1203            return Err(GpuAdvancedError::InvalidGpuIndex {
1204                index: device_index,
1205                total: self.devices.len(),
1206            });
1207        }
1208
1209        let mut tracker = self.workload_tracker.write();
1210        tracker.remove_workload(device_index, workload_id);
1211
1212        Ok(())
1213    }
1214
1215    /// Mark task started on device
1216    pub fn task_started(&self, device_index: usize) {
1217        let mut stats = self.stats.write();
1218        if let Some(count) = stats.tasks_per_device.get_mut(device_index) {
1219            *count = count.saturating_add(1);
1220        }
1221        if let Some(active) = stats.active_tasks.get_mut(device_index) {
1222            *active = active.saturating_add(1);
1223        }
1224    }
1225
1226    /// Mark task completed on device
1227    pub fn task_completed(&self, device_index: usize, duration_us: u64) {
1228        let mut stats = self.stats.write();
1229        if let Some(active) = stats.active_tasks.get_mut(device_index) {
1230            *active = active.saturating_sub(1);
1231        }
1232        if let Some(time) = stats.time_per_device.get_mut(device_index) {
1233            *time = time.saturating_add(duration_us);
1234        }
1235    }
1236
1237    /// Get load statistics
1238    pub fn get_stats(&self) -> LoadStats {
1239        self.stats.read().clone()
1240    }
1241
1242    /// Print load statistics
1243    pub fn print_stats(&self) {
1244        let stats = self.stats.read();
1245        println!("\nLoad Balancer Statistics:");
1246        println!("  Strategy: {:?}", self.strategy);
1247
1248        for (i, device) in self.devices.iter().enumerate() {
1249            let tasks = stats.tasks_per_device.get(i).copied().unwrap_or(0);
1250            let time_us = stats.time_per_device.get(i).copied().unwrap_or(0);
1251            let active = stats.active_tasks.get(i).copied().unwrap_or(0);
1252            let avg_time_us = if tasks > 0 {
1253                time_us / (tasks as u64)
1254            } else {
1255                0
1256            };
1257
1258            let migrations_from = stats.migrations_from.get(i).copied().unwrap_or(0);
1259            let migrations_to = stats.migrations_to.get(i).copied().unwrap_or(0);
1260
1261            println!("\n  GPU {}: {}", i, device.info.name);
1262            println!("    Total tasks: {}", tasks);
1263            println!("    Active tasks: {}", active);
1264            println!("    Total time: {} ms", time_us / 1000);
1265            println!("    Avg task time: {} us", avg_time_us);
1266            println!(
1267                "    Current workload: {:.1}%",
1268                device.get_workload() * 100.0
1269            );
1270            println!("    Migrations from: {}", migrations_from);
1271            println!("    Migrations to: {}", migrations_to);
1272        }
1273    }
1274
1275    /// Reset statistics
1276    pub fn reset_stats(&self) {
1277        let mut stats = self.stats.write();
1278        let device_count = self.devices.len();
1279        stats.tasks_per_device = vec![0; device_count];
1280        stats.time_per_device = vec![0; device_count];
1281        stats.active_tasks = vec![0; device_count];
1282        stats.memory_per_device = vec![0; device_count];
1283        stats.migrations_from = vec![0; device_count];
1284        stats.migrations_to = vec![0; device_count];
1285    }
1286
1287    /// Get device utilization (0.0 to 1.0)
1288    pub fn get_device_utilization(&self, device_index: usize) -> f32 {
1289        self.devices
1290            .get(device_index)
1291            .map(|device| device.get_workload())
1292            .unwrap_or(0.0)
1293    }
1294
1295    /// Get overall cluster utilization (0.0 to 1.0)
1296    pub fn get_cluster_utilization(&self) -> f32 {
1297        if self.devices.is_empty() {
1298            return 0.0;
1299        }
1300
1301        let total_utilization: f32 = self
1302            .devices
1303            .iter()
1304            .map(|device| device.get_workload())
1305            .sum();
1306
1307        total_utilization / (self.devices.len() as f32)
1308    }
1309
1310    /// Suggest optimal device for next task
1311    pub fn suggest_device(&self, estimated_memory: u64) -> Result<Arc<GpuDevice>> {
1312        // Filter devices with enough memory
1313        let candidates: Vec<_> = self
1314            .devices
1315            .iter()
1316            .filter(|device| {
1317                let memory_usage = device.get_memory_usage();
1318                let max_memory = device.info.max_buffer_size;
1319                (max_memory - memory_usage) >= estimated_memory
1320            })
1321            .collect();
1322
1323        if candidates.is_empty() {
1324            return Err(GpuAdvancedError::GpuNotFound(
1325                "No device with enough memory".to_string(),
1326            ));
1327        }
1328
1329        // Select based on strategy
1330        self.select_device()
1331    }
1332
1333    /// Get migration history statistics
1334    pub fn get_migration_stats(&self) -> (usize, usize, f64) {
1335        let history = self.migration_history.read();
1336        (
1337            history.total_successful,
1338            history.total_failed,
1339            history.overall_success_rate(),
1340        )
1341    }
1342
1343    /// Get device count
1344    pub fn device_count(&self) -> usize {
1345        self.devices.len()
1346    }
1347}
1348
1349#[cfg(test)]
1350mod tests {
1351    use super::*;
1352
1353    #[test]
1354    fn test_load_stats() {
1355        let stats = LoadStats::default();
1356        assert_eq!(stats.tasks_per_device.len(), 0);
1357    }
1358
1359    #[test]
1360    fn test_selection_strategy() {
1361        // Test that strategies are copy
1362        let strategy = SelectionStrategy::RoundRobin;
1363        let _strategy2 = strategy;
1364        // This compiles, proving Copy trait works
1365    }
1366
1367    #[test]
1368    fn test_migration_config_default() {
1369        let config = MigrationConfig::default();
1370        assert!(config.overload_threshold > 0.0);
1371        assert!(config.overload_threshold <= 1.0);
1372        assert!(config.underutilization_threshold >= 0.0);
1373        assert!(config.underutilization_threshold < config.overload_threshold);
1374    }
1375
1376    #[test]
1377    fn test_migratable_workload() {
1378        let workload = MigratableWorkload::new(1, 0, 1024 * 1024, 0.5, 10);
1379        assert_eq!(workload.id, 1);
1380        assert_eq!(workload.source_device, 0);
1381        assert_eq!(workload.memory_size, 1024 * 1024);
1382        assert!(!workload.migrating);
1383
1384        let workload_with_dep = workload.with_dependency(0);
1385        assert_eq!(workload_with_dep.dependencies.len(), 1);
1386    }
1387
1388    #[test]
1389    fn test_migration_cost_calculation() {
1390        let config = MigrationConfig::default();
1391        let workload = MigratableWorkload::new(1, 0, 1024 * 1024, 0.5, 10);
1392
1393        let cost = workload.calculate_migration_cost(&config);
1394        assert!(cost > config.transfer_cost_base);
1395    }
1396
1397    #[test]
1398    fn test_migration_plan() {
1399        let config = MigrationConfig::default();
1400        let workload = MigratableWorkload::new(1, 0, 1024 * 1024, 0.8, 10);
1401        let plan = MigrationPlan::new(workload, 1, &config);
1402
1403        assert_eq!(plan.source_device, 0);
1404        assert_eq!(plan.target_device, 1);
1405        assert!(plan.estimated_cost > 0.0);
1406    }
1407
1408    #[test]
1409    fn test_migration_history() {
1410        let mut history = MigrationHistory::new(10);
1411
1412        history.add_entry(MigrationHistoryEntry {
1413            timestamp: Instant::now(),
1414            source_device: 0,
1415            target_device: 1,
1416            success: true,
1417            transfer_time: Duration::from_millis(10),
1418            bytes_transferred: 1024,
1419        });
1420
1421        assert_eq!(history.total_successful, 1);
1422        assert_eq!(history.total_failed, 0);
1423        assert!((history.overall_success_rate() - 1.0).abs() < f64::EPSILON);
1424
1425        history.add_entry(MigrationHistoryEntry {
1426            timestamp: Instant::now(),
1427            source_device: 0,
1428            target_device: 1,
1429            success: false,
1430            transfer_time: Duration::from_millis(5),
1431            bytes_transferred: 0,
1432        });
1433
1434        assert_eq!(history.total_failed, 1);
1435        assert!((history.overall_success_rate() - 0.5).abs() < f64::EPSILON);
1436    }
1437
1438    #[test]
1439    fn test_workload_tracker() {
1440        let mut tracker = WorkloadTracker::new(2, 100);
1441
1442        let id1 = tracker.next_workload_id();
1443        let id2 = tracker.next_workload_id();
1444        assert_ne!(id1, id2);
1445
1446        let workload = MigratableWorkload::new(id1, 0, 1024, 0.5, 10);
1447        tracker.add_workload(0, workload);
1448        assert_eq!(tracker.pending_count(0), 1);
1449
1450        let removed = tracker.remove_workload(0, id1);
1451        assert!(removed.is_some());
1452        assert_eq!(tracker.pending_count(0), 0);
1453    }
1454
1455    #[test]
1456    fn test_utilization_sample() {
1457        let mut tracker = WorkloadTracker::new(2, 100);
1458
1459        for i in 0..10 {
1460            tracker.record_sample(
1461                0,
1462                UtilizationSample {
1463                    timestamp: Instant::now(),
1464                    compute: 0.1 * (i as f32),
1465                    memory: 0.05 * (i as f32),
1466                    active_tasks: i,
1467                },
1468            );
1469        }
1470
1471        let (avg_compute, avg_memory) = tracker
1472            .average_utilization(0, 5)
1473            .expect("Should have samples");
1474        assert!(avg_compute > 0.0);
1475        assert!(avg_memory > 0.0);
1476
1477        let trend = tracker.utilization_trend(0, 10).expect("Should have trend");
1478        assert!(trend > 0.0); // Increasing trend
1479    }
1480
1481    #[test]
1482    fn test_device_load() {
1483        let config = MigrationConfig::default();
1484
1485        let load = DeviceLoad {
1486            device_index: 0,
1487            compute_utilization: 0.9,
1488            memory_utilization: 0.5,
1489            combined_load: 0.85,
1490            active_tasks: 5,
1491            pending_workloads: 3,
1492            score: 0.7,
1493            trend: 0.1,
1494        };
1495
1496        assert!(load.is_overloaded(&config));
1497        assert!(!load.is_underutilized(&config));
1498
1499        let underutilized_load = DeviceLoad {
1500            device_index: 1,
1501            compute_utilization: 0.1,
1502            memory_utilization: 0.1,
1503            combined_load: 0.1,
1504            active_tasks: 0,
1505            pending_workloads: 0,
1506            score: 0.9,
1507            trend: -0.05,
1508        };
1509
1510        assert!(!underutilized_load.is_overloaded(&config));
1511        assert!(underutilized_load.is_underutilized(&config));
1512    }
1513
1514    #[test]
1515    fn test_migration_history_average_time() {
1516        let mut history = MigrationHistory::new(10);
1517
1518        history.add_entry(MigrationHistoryEntry {
1519            timestamp: Instant::now(),
1520            source_device: 0,
1521            target_device: 1,
1522            success: true,
1523            transfer_time: Duration::from_millis(10),
1524            bytes_transferred: 1024,
1525        });
1526
1527        history.add_entry(MigrationHistoryEntry {
1528            timestamp: Instant::now(),
1529            source_device: 0,
1530            target_device: 1,
1531            success: true,
1532            transfer_time: Duration::from_millis(20),
1533            bytes_transferred: 2048,
1534        });
1535
1536        let avg = history
1537            .average_transfer_time(0, 1)
1538            .expect("Should have average");
1539        assert_eq!(avg, Duration::from_millis(15));
1540
1541        assert!(history.average_transfer_time(1, 0).is_none());
1542    }
1543
1544    #[test]
1545    fn test_migration_history_success_rate() {
1546        let mut history = MigrationHistory::new(10);
1547
1548        // No entries - assume success
1549        assert!((history.success_rate(0, 1) - 1.0).abs() < f64::EPSILON);
1550
1551        // Add entries
1552        for _ in 0..3 {
1553            history.add_entry(MigrationHistoryEntry {
1554                timestamp: Instant::now(),
1555                source_device: 0,
1556                target_device: 1,
1557                success: true,
1558                transfer_time: Duration::from_millis(10),
1559                bytes_transferred: 1024,
1560            });
1561        }
1562
1563        history.add_entry(MigrationHistoryEntry {
1564            timestamp: Instant::now(),
1565            source_device: 0,
1566            target_device: 1,
1567            success: false,
1568            transfer_time: Duration::from_millis(5),
1569            bytes_transferred: 0,
1570        });
1571
1572        let rate = history.success_rate(0, 1);
1573        assert!((rate - 0.75).abs() < f64::EPSILON);
1574    }
1575
1576    #[test]
1577    fn test_workload_tracker_cooldown() {
1578        let mut tracker = WorkloadTracker::new(2, 100);
1579
1580        // Initially not in cooldown
1581        assert!(!tracker.is_in_cooldown(0, 1));
1582
1583        // Update rebalance time
1584        tracker.update_rebalance_time(0);
1585
1586        // Now in cooldown
1587        assert!(tracker.is_in_cooldown(0, 1));
1588
1589        // Wait and check (using 0 seconds should always pass)
1590        assert!(!tracker.is_in_cooldown(0, 0));
1591    }
1592
1593    #[test]
1594    fn test_workload_tracker_migratable() {
1595        let mut tracker = WorkloadTracker::new(2, 100);
1596
1597        let workload1 = MigratableWorkload::new(0, 0, 1024, 0.5, 10);
1598        let mut workload2 = MigratableWorkload::new(1, 0, 2048, 0.7, 5);
1599        workload2.migrating = true;
1600        let workload3 = MigratableWorkload::new(2, 0, 4096, 0.3, 15).with_dependency(0);
1601
1602        tracker.add_workload(0, workload1);
1603        tracker.add_workload(0, workload2);
1604        tracker.add_workload(0, workload3);
1605
1606        let migratable = tracker.get_migratable_workloads(0);
1607
1608        // Only workload1 should be migratable (workload2 is migrating, workload3 has dependency)
1609        assert_eq!(migratable.len(), 1);
1610        assert_eq!(migratable[0].id, 0);
1611    }
1612
1613    #[test]
1614    fn test_utilization_trend_calculation() {
1615        let mut tracker = WorkloadTracker::new(1, 100);
1616
1617        // Add increasing samples
1618        for i in 0..20 {
1619            tracker.record_sample(
1620                0,
1621                UtilizationSample {
1622                    timestamp: Instant::now(),
1623                    compute: 0.05 * (i as f32),
1624                    memory: 0.02 * (i as f32),
1625                    active_tasks: i,
1626                },
1627            );
1628        }
1629
1630        let trend = tracker
1631            .utilization_trend(0, 20)
1632            .expect("Should compute trend");
1633        assert!(
1634            trend > 0.0,
1635            "Trend should be positive for increasing samples"
1636        );
1637
1638        // Add decreasing samples
1639        let mut tracker2 = WorkloadTracker::new(1, 100);
1640        for i in 0..20 {
1641            tracker2.record_sample(
1642                0,
1643                UtilizationSample {
1644                    timestamp: Instant::now(),
1645                    compute: 1.0 - 0.05 * (i as f32),
1646                    memory: 0.5 - 0.02 * (i as f32),
1647                    active_tasks: 20 - i,
1648                },
1649            );
1650        }
1651
1652        let trend2 = tracker2
1653            .utilization_trend(0, 20)
1654            .expect("Should compute trend");
1655        assert!(
1656            trend2 < 0.0,
1657            "Trend should be negative for decreasing samples"
1658        );
1659    }
1660}