amari_gpu/
multi_gpu.rs

1//! Multi-GPU workload distribution and coordination infrastructure
2//!
3//! This module implements sophisticated multi-GPU capabilities for the Amari library,
4//! enabling intelligent workload distribution, load balancing, and performance optimization
5//! across multiple GPU devices.
6
7use crate::{UnifiedGpuError, UnifiedGpuResult};
8use std::collections::HashMap;
9use std::sync::{
10    atomic::{AtomicU32, AtomicU64, AtomicUsize, Ordering},
11    Arc, Mutex,
12};
13use std::time::{Duration, Instant};
14use tokio::sync::{Notify, RwLock};
15
16/// Unique identifier for GPU devices in the multi-GPU system
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct DeviceId(pub usize);
19
20/// GPU device capabilities and characteristics
21#[derive(Debug, Clone)]
22pub struct DeviceCapabilities {
23    /// Number of compute units (shader cores, streaming multiprocessors, etc.)
24    pub compute_units: u32,
25    /// Memory bandwidth in GB/s
26    pub memory_bandwidth_gb_s: f32,
27    /// Peak floating-point operations per second
28    pub peak_flops: f64,
29    /// Total device memory in GB
30    pub memory_size_gb: f32,
31    /// GPU architecture family
32    pub architecture: GpuArchitecture,
33    /// Maximum workgroup size
34    pub max_workgroup_size: (u32, u32, u32),
35    /// Shared memory per workgroup in bytes
36    pub shared_memory_per_workgroup: u32,
37}
38
39#[derive(Debug, Clone, PartialEq)]
40pub enum GpuArchitecture {
41    Nvidia { compute_capability: (u32, u32) },
42    Amd { gcn_generation: u32 },
43    Intel { generation: String },
44    Apple { gpu_family: u32 },
45    Unknown,
46}
47
48/// Individual GPU device in the multi-GPU system
49#[derive(Debug)]
50pub struct GpuDevice {
51    pub id: DeviceId,
52    pub device: Arc<wgpu::Device>,
53    pub queue: Arc<wgpu::Queue>,
54    pub adapter_info: wgpu::AdapterInfo,
55    pub capabilities: DeviceCapabilities,
56
57    // Runtime metrics
58    pub current_load: Arc<AtomicU32>, // Stores f32 load as u32 bits
59    pub memory_usage: Arc<AtomicU64>,
60    pub total_operations: Arc<AtomicUsize>,
61    pub error_count: Arc<AtomicUsize>,
62    pub last_activity: Arc<Mutex<Instant>>,
63
64    // Device health monitoring
65    pub is_healthy: Arc<std::sync::atomic::AtomicBool>,
66}
67
68impl GpuDevice {
69    /// Create a new GPU device from WebGPU adapter and device
70    pub async fn new(
71        id: DeviceId,
72        adapter: &wgpu::Adapter,
73        device: wgpu::Device,
74        queue: wgpu::Queue,
75    ) -> UnifiedGpuResult<Self> {
76        let adapter_info = adapter.get_info();
77        let capabilities = Self::assess_capabilities(adapter, &adapter_info).await?;
78
79        Ok(Self {
80            id,
81            device: Arc::new(device),
82            queue: Arc::new(queue),
83            adapter_info,
84            capabilities,
85            current_load: Arc::new(AtomicU32::new(0.0_f32.to_bits())),
86            memory_usage: Arc::new(AtomicU64::new(0)),
87            total_operations: Arc::new(AtomicUsize::new(0)),
88            error_count: Arc::new(AtomicUsize::new(0)),
89            last_activity: Arc::new(Mutex::new(Instant::now())),
90            is_healthy: Arc::new(std::sync::atomic::AtomicBool::new(true)),
91        })
92    }
93
94    /// Assess device capabilities through benchmarking and feature detection
95    async fn assess_capabilities(
96        adapter: &wgpu::Adapter,
97        adapter_info: &wgpu::AdapterInfo,
98    ) -> UnifiedGpuResult<DeviceCapabilities> {
99        let limits = adapter.limits();
100
101        // Estimate capabilities based on adapter info and limits
102        let architecture = match adapter_info.vendor {
103            0x10DE => GpuArchitecture::Nvidia {
104                compute_capability: (8, 0),
105            }, // Estimate
106            0x1002 | 0x1022 => GpuArchitecture::Amd { gcn_generation: 5 },
107            0x8086 => GpuArchitecture::Intel {
108                generation: "Gen12".to_string(),
109            },
110            _ => GpuArchitecture::Unknown,
111        };
112
113        // Estimate memory bandwidth (simplified heuristics)
114        let memory_bandwidth_gb_s = match &architecture {
115            GpuArchitecture::Nvidia { .. } => 500.0, // RTX 3080-class estimate
116            GpuArchitecture::Amd { .. } => 512.0,    // RX 6800 XT-class estimate
117            GpuArchitecture::Intel { .. } => 100.0,  // Intel Arc estimate
118            GpuArchitecture::Apple { .. } => 400.0,  // M1 Ultra estimate
119            GpuArchitecture::Unknown => 200.0,       // Conservative estimate
120        };
121
122        // Estimate compute units and FLOPS
123        let (compute_units, peak_flops) = match &architecture {
124            GpuArchitecture::Nvidia { .. } => (68, 30e12), // RTX 3080 estimate
125            GpuArchitecture::Amd { .. } => (72, 20e12),    // RX 6800 XT estimate
126            GpuArchitecture::Intel { .. } => (32, 15e12),  // Intel Arc estimate
127            GpuArchitecture::Apple { .. } => (64, 25e12),  // M1 Ultra estimate
128            GpuArchitecture::Unknown => (32, 10e12),       // Conservative estimate
129        };
130
131        Ok(DeviceCapabilities {
132            compute_units,
133            memory_bandwidth_gb_s,
134            peak_flops,
135            memory_size_gb: 8.0, // Default estimate - would need actual query
136            architecture,
137            max_workgroup_size: (
138                limits.max_compute_workgroup_size_x,
139                limits.max_compute_workgroup_size_y,
140                limits.max_compute_workgroup_size_z,
141            ),
142            shared_memory_per_workgroup: limits.max_compute_workgroup_storage_size,
143        })
144    }
145
146    /// Update device load metrics
147    pub fn update_load(&self, load_percent: f32) {
148        self.current_load
149            .store(load_percent.to_bits(), Ordering::Relaxed);
150        if let Ok(mut last_activity) = self.last_activity.lock() {
151            *last_activity = Instant::now();
152        }
153    }
154
155    /// Get current device load percentage
156    pub fn current_load(&self) -> f32 {
157        f32::from_bits(self.current_load.load(Ordering::Relaxed))
158    }
159
160    /// Check if device is currently available for work
161    pub fn is_available(&self) -> bool {
162        self.is_healthy.load(Ordering::Relaxed) && self.current_load() < 90.0
163    }
164
165    /// Get device performance score for workload assignment
166    pub fn performance_score(&self, operation_type: &str) -> f32 {
167        let base_score = match operation_type {
168            "matrix_multiply" => self.capabilities.peak_flops as f32 / 1e12,
169            "memory_intensive" => self.capabilities.memory_bandwidth_gb_s,
170            _ => {
171                (self.capabilities.peak_flops as f32 / 1e12) * 0.5
172                    + self.capabilities.memory_bandwidth_gb_s * 0.5
173            }
174        };
175
176        // Adjust for current load
177        let load_factor = 1.0 - (self.current_load() / 100.0);
178        base_score * load_factor
179    }
180}
181
182/// Workload definition for distribution across multiple GPUs
183#[derive(Debug, Clone)]
184pub struct Workload {
185    pub operation_type: String,
186    pub data_size: usize,
187    pub memory_requirement_mb: f32,
188    pub compute_intensity: ComputeIntensity,
189    pub parallelizable: bool,
190    pub synchronization_required: bool,
191}
192
193#[derive(Debug, Clone)]
194pub enum ComputeIntensity {
195    Light,    // Memory-bound operations
196    Moderate, // Balanced compute/memory
197    Heavy,    // Compute-bound operations
198    Extreme,  // Very high arithmetic intensity
199}
200
201/// Device-specific workload assignment
202#[derive(Debug, Clone)]
203pub struct DeviceWorkload {
204    pub device_id: DeviceId,
205    pub workload_fraction: f32,
206    pub data_range: (usize, usize),
207    pub estimated_completion_ms: f32,
208    pub memory_requirement_mb: f32,
209}
210
211/// Intelligent load balancer for multi-GPU workload distribution
212pub struct IntelligentLoadBalancer {
213    devices: Arc<RwLock<HashMap<DeviceId, Arc<GpuDevice>>>>,
214    balancing_strategy: LoadBalancingStrategy,
215    performance_history: Arc<Mutex<HashMap<String, Vec<PerformanceRecord>>>>,
216}
217
218#[derive(Debug, Clone, Copy)]
219pub enum LoadBalancingStrategy {
220    /// Distribute work equally across all devices
221    Balanced,
222    /// Distribute based on device capabilities
223    CapabilityAware,
224    /// Optimize for memory constraints
225    MemoryAware,
226    /// Minimize total completion time
227    LatencyOptimized,
228    /// Machine learning-driven distribution
229    Adaptive,
230}
231
232#[derive(Debug, Clone)]
233pub struct PerformanceRecord {
234    pub device_id: DeviceId,
235    pub operation_type: String,
236    pub data_size: usize,
237    pub completion_time_ms: f32,
238    pub throughput_gops: f32,
239    pub memory_bandwidth_utilized_gb_s: f32,
240    pub timestamp: Instant,
241}
242
243impl IntelligentLoadBalancer {
244    /// Create a new intelligent load balancer
245    pub fn new(strategy: LoadBalancingStrategy) -> Self {
246        Self {
247            devices: Arc::new(RwLock::new(HashMap::new())),
248            balancing_strategy: strategy,
249            performance_history: Arc::new(Mutex::new(HashMap::new())),
250        }
251    }
252
253    /// Add a device to the load balancer
254    pub async fn add_device(&self, device: Arc<GpuDevice>) {
255        let mut devices: tokio::sync::RwLockWriteGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
256            self.devices.write().await;
257        devices.insert(device.id, device);
258    }
259
260    /// Remove a device from the load balancer
261    pub async fn remove_device(&self, device_id: DeviceId) {
262        let mut devices: tokio::sync::RwLockWriteGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
263            self.devices.write().await;
264        devices.remove(&device_id);
265    }
266
267    /// Distribute workload across available devices
268    pub async fn distribute_workload(
269        &self,
270        workload: &Workload,
271    ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
272        let devices: tokio::sync::RwLockReadGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
273            self.devices.read().await;
274        let available_devices: Vec<&Arc<GpuDevice>> = devices
275            .values()
276            .filter(|device| device.is_available())
277            .collect();
278
279        if available_devices.is_empty() {
280            return Err(UnifiedGpuError::InvalidOperation(
281                "No available devices for workload distribution".into(),
282            ));
283        }
284
285        match self.balancing_strategy {
286            LoadBalancingStrategy::Balanced => {
287                self.distribute_balanced(&available_devices, workload)
288            }
289            LoadBalancingStrategy::CapabilityAware => {
290                self.distribute_capability_aware(&available_devices, workload)
291            }
292            LoadBalancingStrategy::MemoryAware => {
293                self.distribute_memory_aware(&available_devices, workload)
294            }
295            LoadBalancingStrategy::LatencyOptimized => {
296                self.distribute_latency_optimized(&available_devices, workload)
297            }
298            LoadBalancingStrategy::Adaptive => {
299                self.distribute_adaptive(&available_devices, workload).await
300            }
301        }
302    }
303
304    /// Balanced distribution - equal work per device
305    fn distribute_balanced(
306        &self,
307        devices: &[&Arc<GpuDevice>],
308        workload: &Workload,
309    ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
310        let num_devices = devices.len();
311        let work_per_device = 1.0 / num_devices as f32;
312        let data_per_device = workload.data_size / num_devices;
313
314        let mut assignments = Vec::new();
315        for (i, device) in devices.iter().enumerate() {
316            let start = i * data_per_device;
317            let end = if i == num_devices - 1 {
318                workload.data_size
319            } else {
320                (i + 1) * data_per_device
321            };
322
323            assignments.push(DeviceWorkload {
324                device_id: device.id,
325                workload_fraction: work_per_device,
326                data_range: (start, end),
327                estimated_completion_ms: 100.0, // Placeholder estimation
328                memory_requirement_mb: workload.memory_requirement_mb / num_devices as f32,
329            });
330        }
331
332        Ok(assignments)
333    }
334
335    /// Capability-aware distribution - weight by device performance
336    fn distribute_capability_aware(
337        &self,
338        devices: &[&Arc<GpuDevice>],
339        workload: &Workload,
340    ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
341        // Calculate performance scores for each device
342        let scores: Vec<f32> = devices
343            .iter()
344            .map(|device| device.performance_score(&workload.operation_type))
345            .collect();
346
347        let total_score: f32 = scores.iter().sum();
348
349        let mut assignments = Vec::new();
350        let mut data_offset = 0;
351
352        for (i, (device, &score)) in devices.iter().zip(scores.iter()).enumerate() {
353            let fraction = score / total_score;
354            let data_chunk_size = (workload.data_size as f32 * fraction) as usize;
355
356            let end = if i == devices.len() - 1 {
357                workload.data_size
358            } else {
359                data_offset + data_chunk_size
360            };
361
362            assignments.push(DeviceWorkload {
363                device_id: device.id,
364                workload_fraction: fraction,
365                data_range: (data_offset, end),
366                estimated_completion_ms: 100.0 / fraction, // Inverse of capability
367                memory_requirement_mb: workload.memory_requirement_mb * fraction,
368            });
369
370            data_offset = end;
371        }
372
373        Ok(assignments)
374    }
375
376    /// Memory-aware distribution - consider memory constraints
377    fn distribute_memory_aware(
378        &self,
379        devices: &[&Arc<GpuDevice>],
380        workload: &Workload,
381    ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
382        // Filter devices that can handle the memory requirement
383        let viable_devices: Vec<&Arc<GpuDevice>> = devices
384            .iter()
385            .filter(|device| {
386                let required_memory_gb = workload.memory_requirement_mb / 1024.0;
387                device.capabilities.memory_size_gb >= required_memory_gb
388            })
389            .copied()
390            .collect();
391
392        if viable_devices.is_empty() {
393            return Err(UnifiedGpuError::InvalidOperation(
394                "No devices with sufficient memory for workload".into(),
395            ));
396        }
397
398        // Use capability-aware distribution among viable devices
399        self.distribute_capability_aware(&viable_devices, workload)
400    }
401
402    /// Latency-optimized distribution - minimize total completion time
403    fn distribute_latency_optimized(
404        &self,
405        devices: &[&Arc<GpuDevice>],
406        workload: &Workload,
407    ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
408        // For now, use capability-aware as a proxy for latency optimization
409        // In practice, this would use more sophisticated scheduling algorithms
410        self.distribute_capability_aware(devices, workload)
411    }
412
413    /// Adaptive distribution using machine learning and historical data
414    async fn distribute_adaptive(
415        &self,
416        devices: &[&Arc<GpuDevice>],
417        workload: &Workload,
418    ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
419        // Check performance history for similar workloads
420        if let Ok(history) = self.performance_history.lock() {
421            if let Some(records) = history.get(&workload.operation_type) {
422                // Use historical data to inform distribution
423                return self.distribute_based_on_history(devices, workload, records);
424            }
425        }
426
427        // Fallback to capability-aware if no historical data
428        self.distribute_capability_aware(devices, workload)
429    }
430
431    /// Distribution based on historical performance data
432    fn distribute_based_on_history(
433        &self,
434        devices: &[&Arc<GpuDevice>],
435        workload: &Workload,
436        history: &[PerformanceRecord],
437    ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
438        // Calculate performance predictions based on historical data
439        let mut device_predictions = HashMap::new();
440
441        for device in devices {
442            let device_history: Vec<_> = history
443                .iter()
444                .filter(|record| record.device_id == device.id)
445                .collect();
446
447            let predicted_throughput = if device_history.is_empty() {
448                device.performance_score(&workload.operation_type)
449            } else {
450                // Weighted average of recent performance
451                let recent_throughput: f32 = device_history
452                    .iter()
453                    .rev()
454                    .take(10) // Last 10 operations
455                    .map(|record| record.throughput_gops)
456                    .sum::<f32>()
457                    / device_history.len().min(10) as f32;
458                recent_throughput
459            };
460
461            device_predictions.insert(device.id, predicted_throughput);
462        }
463
464        // Distribute based on predicted performance
465        let total_predicted: f32 = device_predictions.values().sum();
466
467        let mut assignments = Vec::new();
468        let mut data_offset = 0;
469
470        for (i, device) in devices.iter().enumerate() {
471            let predicted = device_predictions[&device.id];
472            let fraction = predicted / total_predicted;
473            let data_chunk_size = (workload.data_size as f32 * fraction) as usize;
474
475            let end = if i == devices.len() - 1 {
476                workload.data_size
477            } else {
478                data_offset + data_chunk_size
479            };
480
481            assignments.push(DeviceWorkload {
482                device_id: device.id,
483                workload_fraction: fraction,
484                data_range: (data_offset, end),
485                estimated_completion_ms: 100.0 / fraction,
486                memory_requirement_mb: workload.memory_requirement_mb * fraction,
487            });
488
489            data_offset = end;
490        }
491
492        Ok(assignments)
493    }
494
495    /// Record performance data for adaptive learning
496    pub async fn record_performance(&self, record: PerformanceRecord) {
497        if let Ok(mut history) = self.performance_history.lock() {
498            let operation_history = history
499                .entry(record.operation_type.clone())
500                .or_insert_with(Vec::new);
501
502            operation_history.push(record);
503
504            // Keep only recent history (last 1000 records)
505            if operation_history.len() > 1000 {
506                operation_history.remove(0);
507            }
508        }
509    }
510
511    /// Get performance statistics for an operation type
512    pub async fn get_performance_stats(&self, operation_type: &str) -> Option<PerformanceStats> {
513        if let Ok(history) = self.performance_history.lock() {
514            if let Some(records) = history.get(operation_type) {
515                if records.is_empty() {
516                    return None;
517                }
518
519                let completion_times: Vec<f32> =
520                    records.iter().map(|r| r.completion_time_ms).collect();
521                let throughputs: Vec<f32> = records.iter().map(|r| r.throughput_gops).collect();
522
523                let avg_completion_time =
524                    completion_times.iter().sum::<f32>() / completion_times.len() as f32;
525                let avg_throughput = throughputs.iter().sum::<f32>() / throughputs.len() as f32;
526
527                return Some(PerformanceStats {
528                    operation_type: operation_type.to_string(),
529                    avg_completion_time_ms: avg_completion_time,
530                    avg_throughput_gops: avg_throughput,
531                    total_operations: records.len(),
532                    best_device_id: records
533                        .iter()
534                        .max_by(|a, b| a.throughput_gops.partial_cmp(&b.throughput_gops).unwrap())
535                        .map(|r| r.device_id),
536                });
537            }
538        }
539        None
540    }
541}
542
543#[derive(Debug, Clone)]
544pub struct PerformanceStats {
545    pub operation_type: String,
546    pub avg_completion_time_ms: f32,
547    pub avg_throughput_gops: f32,
548    pub total_operations: usize,
549    pub best_device_id: Option<DeviceId>,
550}
551
552/// Multi-GPU workload coordinator for synchronization and result aggregation
553pub struct WorkloadCoordinator {
554    active_workloads: Arc<Mutex<HashMap<String, ActiveWorkload>>>,
555    #[allow(dead_code)]
556    synchronization_manager: SynchronizationManager,
557}
558
559#[derive(Debug)]
560pub struct ActiveWorkload {
561    pub id: String,
562    pub device_assignments: Vec<DeviceWorkload>,
563    pub completion_status: Vec<bool>,
564    pub results: Vec<Option<Vec<u8>>>,
565    pub start_time: Instant,
566}
567
568/// Synchronization manager for multi-GPU operations
569pub struct SynchronizationManager {
570    barriers: Arc<Mutex<HashMap<String, MultiGpuBarrier>>>,
571}
572
573/// Multi-GPU barrier for synchronizing operations across devices
574pub struct MultiGpuBarrier {
575    pub barrier_id: String,
576    pub device_count: usize,
577    pub completed_devices: Arc<AtomicUsize>,
578    pub completion_notifier: Arc<Notify>,
579    pub timeout: Duration,
580}
581
582impl WorkloadCoordinator {
583    /// Create a new workload coordinator
584    pub fn new() -> Self {
585        Self {
586            active_workloads: Arc::new(Mutex::new(HashMap::new())),
587            synchronization_manager: SynchronizationManager::new(),
588        }
589    }
590
591    /// Submit a workload for execution across multiple devices
592    pub async fn submit_workload(
593        &self,
594        workload_id: String,
595        assignments: Vec<DeviceWorkload>,
596    ) -> UnifiedGpuResult<()> {
597        let device_count = assignments.len();
598
599        let active_workload = ActiveWorkload {
600            id: workload_id.clone(),
601            device_assignments: assignments,
602            completion_status: vec![false; device_count],
603            results: vec![None; device_count],
604            start_time: Instant::now(),
605        };
606
607        if let Ok(mut workloads) = self.active_workloads.lock() {
608            workloads.insert(workload_id, active_workload);
609        }
610
611        Ok(())
612    }
613
614    /// Wait for workload completion and aggregate results
615    pub async fn wait_for_completion(
616        &self,
617        workload_id: &str,
618        timeout: Duration,
619    ) -> UnifiedGpuResult<Vec<Vec<u8>>> {
620        let start = Instant::now();
621
622        loop {
623            if start.elapsed() > timeout {
624                return Err(UnifiedGpuError::InvalidOperation(
625                    "Workload completion timeout".into(),
626                ));
627            }
628
629            // Check completion status
630            if let Ok(workloads) = self.active_workloads.lock() {
631                if let Some(workload) = workloads.get(workload_id) {
632                    if workload
633                        .completion_status
634                        .iter()
635                        .all(|&completed| completed)
636                    {
637                        // All devices completed - aggregate results
638                        let results: Vec<Vec<u8>> = workload
639                            .results
640                            .iter()
641                            .filter_map(|result| result.as_ref())
642                            .cloned()
643                            .collect();
644                        return Ok(results);
645                    }
646                }
647            }
648
649            // Brief sleep before checking again
650            tokio::time::sleep(Duration::from_millis(10)).await;
651        }
652    }
653
654    /// Mark device as completed for a workload
655    pub async fn mark_device_completed(
656        &self,
657        workload_id: &str,
658        device_id: DeviceId,
659        result: Vec<u8>,
660    ) -> UnifiedGpuResult<()> {
661        if let Ok(mut workloads) = self.active_workloads.lock() {
662            if let Some(workload) = workloads.get_mut(workload_id) {
663                // Find device index and mark as completed
664                for (i, assignment) in workload.device_assignments.iter().enumerate() {
665                    if assignment.device_id == device_id {
666                        workload.completion_status[i] = true;
667                        workload.results[i] = Some(result);
668                        break;
669                    }
670                }
671            }
672        }
673        Ok(())
674    }
675}
676
677impl SynchronizationManager {
678    /// Create a new synchronization manager
679    pub fn new() -> Self {
680        Self {
681            barriers: Arc::new(Mutex::new(HashMap::new())),
682        }
683    }
684
685    /// Create a barrier for synchronizing multiple devices
686    pub async fn create_barrier(
687        &self,
688        barrier_id: String,
689        device_count: usize,
690        timeout: Duration,
691    ) -> UnifiedGpuResult<()> {
692        let barrier = MultiGpuBarrier {
693            barrier_id: barrier_id.clone(),
694            device_count,
695            completed_devices: Arc::new(AtomicUsize::new(0)),
696            completion_notifier: Arc::new(Notify::new()),
697            timeout,
698        };
699
700        if let Ok(mut barriers) = self.barriers.lock() {
701            barriers.insert(barrier_id, barrier);
702        }
703
704        Ok(())
705    }
706
707    /// Wait for all devices to reach the barrier
708    pub async fn wait_barrier(
709        &self,
710        barrier_id: &str,
711        _device_id: DeviceId,
712    ) -> UnifiedGpuResult<()> {
713        let (notifier, _device_count) = {
714            if let Ok(barriers) = self.barriers.lock() {
715                if let Some(barrier) = barriers.get(barrier_id) {
716                    let completed = barrier.completed_devices.fetch_add(1, Ordering::SeqCst) + 1;
717
718                    if completed >= barrier.device_count {
719                        // Last device to reach barrier - notify all
720                        barrier.completion_notifier.notify_waiters();
721                        return Ok(());
722                    }
723
724                    (
725                        Arc::clone(&barrier.completion_notifier),
726                        barrier.device_count,
727                    )
728                } else {
729                    return Err(UnifiedGpuError::InvalidOperation(format!(
730                        "Barrier {} not found",
731                        barrier_id
732                    )));
733                }
734            } else {
735                return Err(UnifiedGpuError::InvalidOperation(
736                    "Failed to access barriers".into(),
737                ));
738            }
739        };
740
741        // Wait for notification with timeout
742        let timeout_duration = Duration::from_secs(30); // Default timeout
743        tokio::time::timeout(timeout_duration, notifier.notified())
744            .await
745            .map_err(|_| UnifiedGpuError::InvalidOperation("Barrier wait timeout".into()))?;
746
747        Ok(())
748    }
749}
750
751impl Default for WorkloadCoordinator {
752    fn default() -> Self {
753        Self::new()
754    }
755}
756
757impl Default for SynchronizationManager {
758    fn default() -> Self {
759        Self::new()
760    }
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766
767    #[test]
768    fn test_device_id_creation() {
769        let device_id = DeviceId(0);
770        assert_eq!(device_id.0, 0);
771    }
772
773    #[test]
774    fn test_workload_creation() {
775        let workload = Workload {
776            operation_type: "test_operation".to_string(),
777            data_size: 1000,
778            memory_requirement_mb: 100.0,
779            compute_intensity: ComputeIntensity::Moderate,
780            parallelizable: true,
781            synchronization_required: false,
782        };
783
784        assert_eq!(workload.data_size, 1000);
785        assert_eq!(workload.memory_requirement_mb, 100.0);
786    }
787
788    #[test]
789    fn test_load_balancer_creation() {
790        let balancer = IntelligentLoadBalancer::new(LoadBalancingStrategy::Balanced);
791        assert!(matches!(
792            balancer.balancing_strategy,
793            LoadBalancingStrategy::Balanced
794        ));
795    }
796
797    #[tokio::test]
798    async fn test_workload_coordinator() {
799        let coordinator = WorkloadCoordinator::new();
800
801        let assignments = vec![
802            DeviceWorkload {
803                device_id: DeviceId(0),
804                workload_fraction: 0.5,
805                data_range: (0, 500),
806                estimated_completion_ms: 100.0,
807                memory_requirement_mb: 50.0,
808            },
809            DeviceWorkload {
810                device_id: DeviceId(1),
811                workload_fraction: 0.5,
812                data_range: (500, 1000),
813                estimated_completion_ms: 100.0,
814                memory_requirement_mb: 50.0,
815            },
816        ];
817
818        let result = coordinator
819            .submit_workload("test_workload".to_string(), assignments)
820            .await;
821        assert!(result.is_ok());
822    }
823}