1use crate::{UnifiedGpuError, UnifiedGpuResult};
8use std::collections::HashMap;
9use std::sync::{
10 atomic::{AtomicU32, AtomicU64, AtomicUsize, Ordering},
11 Arc, Mutex,
12};
13use std::time::{Duration, Instant};
14use tokio::sync::{Notify, RwLock};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct DeviceId(pub usize);
19
20#[derive(Debug, Clone)]
22pub struct DeviceCapabilities {
23 pub compute_units: u32,
25 pub memory_bandwidth_gb_s: f32,
27 pub peak_flops: f64,
29 pub memory_size_gb: f32,
31 pub architecture: GpuArchitecture,
33 pub max_workgroup_size: (u32, u32, u32),
35 pub shared_memory_per_workgroup: u32,
37}
38
39#[derive(Debug, Clone, PartialEq)]
40pub enum GpuArchitecture {
41 Nvidia { compute_capability: (u32, u32) },
42 Amd { gcn_generation: u32 },
43 Intel { generation: String },
44 Apple { gpu_family: u32 },
45 Unknown,
46}
47
48#[derive(Debug)]
50pub struct GpuDevice {
51 pub id: DeviceId,
52 pub device: Arc<wgpu::Device>,
53 pub queue: Arc<wgpu::Queue>,
54 pub adapter_info: wgpu::AdapterInfo,
55 pub capabilities: DeviceCapabilities,
56
57 pub current_load: Arc<AtomicU32>, pub memory_usage: Arc<AtomicU64>,
60 pub total_operations: Arc<AtomicUsize>,
61 pub error_count: Arc<AtomicUsize>,
62 pub last_activity: Arc<Mutex<Instant>>,
63
64 pub is_healthy: Arc<std::sync::atomic::AtomicBool>,
66}
67
68impl GpuDevice {
69 pub async fn new(
71 id: DeviceId,
72 adapter: &wgpu::Adapter,
73 device: wgpu::Device,
74 queue: wgpu::Queue,
75 ) -> UnifiedGpuResult<Self> {
76 let adapter_info = adapter.get_info();
77 let capabilities = Self::assess_capabilities(adapter, &adapter_info).await?;
78
79 Ok(Self {
80 id,
81 device: Arc::new(device),
82 queue: Arc::new(queue),
83 adapter_info,
84 capabilities,
85 current_load: Arc::new(AtomicU32::new(0.0_f32.to_bits())),
86 memory_usage: Arc::new(AtomicU64::new(0)),
87 total_operations: Arc::new(AtomicUsize::new(0)),
88 error_count: Arc::new(AtomicUsize::new(0)),
89 last_activity: Arc::new(Mutex::new(Instant::now())),
90 is_healthy: Arc::new(std::sync::atomic::AtomicBool::new(true)),
91 })
92 }
93
94 async fn assess_capabilities(
96 adapter: &wgpu::Adapter,
97 adapter_info: &wgpu::AdapterInfo,
98 ) -> UnifiedGpuResult<DeviceCapabilities> {
99 let limits = adapter.limits();
100
101 let architecture = match adapter_info.vendor {
103 0x10DE => GpuArchitecture::Nvidia {
104 compute_capability: (8, 0),
105 }, 0x1002 | 0x1022 => GpuArchitecture::Amd { gcn_generation: 5 },
107 0x8086 => GpuArchitecture::Intel {
108 generation: "Gen12".to_string(),
109 },
110 _ => GpuArchitecture::Unknown,
111 };
112
113 let memory_bandwidth_gb_s = match &architecture {
115 GpuArchitecture::Nvidia { .. } => 500.0, GpuArchitecture::Amd { .. } => 512.0, GpuArchitecture::Intel { .. } => 100.0, GpuArchitecture::Apple { .. } => 400.0, GpuArchitecture::Unknown => 200.0, };
121
122 let (compute_units, peak_flops) = match &architecture {
124 GpuArchitecture::Nvidia { .. } => (68, 30e12), GpuArchitecture::Amd { .. } => (72, 20e12), GpuArchitecture::Intel { .. } => (32, 15e12), GpuArchitecture::Apple { .. } => (64, 25e12), GpuArchitecture::Unknown => (32, 10e12), };
130
131 Ok(DeviceCapabilities {
132 compute_units,
133 memory_bandwidth_gb_s,
134 peak_flops,
135 memory_size_gb: 8.0, architecture,
137 max_workgroup_size: (
138 limits.max_compute_workgroup_size_x,
139 limits.max_compute_workgroup_size_y,
140 limits.max_compute_workgroup_size_z,
141 ),
142 shared_memory_per_workgroup: limits.max_compute_workgroup_storage_size,
143 })
144 }
145
146 pub fn update_load(&self, load_percent: f32) {
148 self.current_load
149 .store(load_percent.to_bits(), Ordering::Relaxed);
150 if let Ok(mut last_activity) = self.last_activity.lock() {
151 *last_activity = Instant::now();
152 }
153 }
154
155 pub fn current_load(&self) -> f32 {
157 f32::from_bits(self.current_load.load(Ordering::Relaxed))
158 }
159
160 pub fn is_available(&self) -> bool {
162 self.is_healthy.load(Ordering::Relaxed) && self.current_load() < 90.0
163 }
164
165 pub fn performance_score(&self, operation_type: &str) -> f32 {
167 let base_score = match operation_type {
168 "matrix_multiply" => self.capabilities.peak_flops as f32 / 1e12,
169 "memory_intensive" => self.capabilities.memory_bandwidth_gb_s,
170 _ => {
171 (self.capabilities.peak_flops as f32 / 1e12) * 0.5
172 + self.capabilities.memory_bandwidth_gb_s * 0.5
173 }
174 };
175
176 let load_factor = 1.0 - (self.current_load() / 100.0);
178 base_score * load_factor
179 }
180}
181
182#[derive(Debug, Clone)]
184pub struct Workload {
185 pub operation_type: String,
186 pub data_size: usize,
187 pub memory_requirement_mb: f32,
188 pub compute_intensity: ComputeIntensity,
189 pub parallelizable: bool,
190 pub synchronization_required: bool,
191}
192
193#[derive(Debug, Clone)]
194pub enum ComputeIntensity {
195 Light, Moderate, Heavy, Extreme, }
200
201#[derive(Debug, Clone)]
203pub struct DeviceWorkload {
204 pub device_id: DeviceId,
205 pub workload_fraction: f32,
206 pub data_range: (usize, usize),
207 pub estimated_completion_ms: f32,
208 pub memory_requirement_mb: f32,
209}
210
211pub struct IntelligentLoadBalancer {
213 devices: Arc<RwLock<HashMap<DeviceId, Arc<GpuDevice>>>>,
214 balancing_strategy: LoadBalancingStrategy,
215 performance_history: Arc<Mutex<HashMap<String, Vec<PerformanceRecord>>>>,
216}
217
218#[derive(Debug, Clone, Copy)]
219pub enum LoadBalancingStrategy {
220 Balanced,
222 CapabilityAware,
224 MemoryAware,
226 LatencyOptimized,
228 Adaptive,
230}
231
232#[derive(Debug, Clone)]
233pub struct PerformanceRecord {
234 pub device_id: DeviceId,
235 pub operation_type: String,
236 pub data_size: usize,
237 pub completion_time_ms: f32,
238 pub throughput_gops: f32,
239 pub memory_bandwidth_utilized_gb_s: f32,
240 pub timestamp: Instant,
241}
242
243impl IntelligentLoadBalancer {
244 pub fn new(strategy: LoadBalancingStrategy) -> Self {
246 Self {
247 devices: Arc::new(RwLock::new(HashMap::new())),
248 balancing_strategy: strategy,
249 performance_history: Arc::new(Mutex::new(HashMap::new())),
250 }
251 }
252
253 pub async fn add_device(&self, device: Arc<GpuDevice>) {
255 let mut devices: tokio::sync::RwLockWriteGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
256 self.devices.write().await;
257 devices.insert(device.id, device);
258 }
259
260 pub async fn remove_device(&self, device_id: DeviceId) {
262 let mut devices: tokio::sync::RwLockWriteGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
263 self.devices.write().await;
264 devices.remove(&device_id);
265 }
266
267 pub async fn distribute_workload(
269 &self,
270 workload: &Workload,
271 ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
272 let devices: tokio::sync::RwLockReadGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
273 self.devices.read().await;
274 let available_devices: Vec<&Arc<GpuDevice>> = devices
275 .values()
276 .filter(|device| device.is_available())
277 .collect();
278
279 if available_devices.is_empty() {
280 return Err(UnifiedGpuError::InvalidOperation(
281 "No available devices for workload distribution".into(),
282 ));
283 }
284
285 match self.balancing_strategy {
286 LoadBalancingStrategy::Balanced => {
287 self.distribute_balanced(&available_devices, workload)
288 }
289 LoadBalancingStrategy::CapabilityAware => {
290 self.distribute_capability_aware(&available_devices, workload)
291 }
292 LoadBalancingStrategy::MemoryAware => {
293 self.distribute_memory_aware(&available_devices, workload)
294 }
295 LoadBalancingStrategy::LatencyOptimized => {
296 self.distribute_latency_optimized(&available_devices, workload)
297 }
298 LoadBalancingStrategy::Adaptive => {
299 self.distribute_adaptive(&available_devices, workload).await
300 }
301 }
302 }
303
304 fn distribute_balanced(
306 &self,
307 devices: &[&Arc<GpuDevice>],
308 workload: &Workload,
309 ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
310 let num_devices = devices.len();
311 let work_per_device = 1.0 / num_devices as f32;
312 let data_per_device = workload.data_size / num_devices;
313
314 let mut assignments = Vec::new();
315 for (i, device) in devices.iter().enumerate() {
316 let start = i * data_per_device;
317 let end = if i == num_devices - 1 {
318 workload.data_size
319 } else {
320 (i + 1) * data_per_device
321 };
322
323 assignments.push(DeviceWorkload {
324 device_id: device.id,
325 workload_fraction: work_per_device,
326 data_range: (start, end),
327 estimated_completion_ms: 100.0, memory_requirement_mb: workload.memory_requirement_mb / num_devices as f32,
329 });
330 }
331
332 Ok(assignments)
333 }
334
335 fn distribute_capability_aware(
337 &self,
338 devices: &[&Arc<GpuDevice>],
339 workload: &Workload,
340 ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
341 let scores: Vec<f32> = devices
343 .iter()
344 .map(|device| device.performance_score(&workload.operation_type))
345 .collect();
346
347 let total_score: f32 = scores.iter().sum();
348
349 let mut assignments = Vec::new();
350 let mut data_offset = 0;
351
352 for (i, (device, &score)) in devices.iter().zip(scores.iter()).enumerate() {
353 let fraction = score / total_score;
354 let data_chunk_size = (workload.data_size as f32 * fraction) as usize;
355
356 let end = if i == devices.len() - 1 {
357 workload.data_size
358 } else {
359 data_offset + data_chunk_size
360 };
361
362 assignments.push(DeviceWorkload {
363 device_id: device.id,
364 workload_fraction: fraction,
365 data_range: (data_offset, end),
366 estimated_completion_ms: 100.0 / fraction, memory_requirement_mb: workload.memory_requirement_mb * fraction,
368 });
369
370 data_offset = end;
371 }
372
373 Ok(assignments)
374 }
375
376 fn distribute_memory_aware(
378 &self,
379 devices: &[&Arc<GpuDevice>],
380 workload: &Workload,
381 ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
382 let viable_devices: Vec<&Arc<GpuDevice>> = devices
384 .iter()
385 .filter(|device| {
386 let required_memory_gb = workload.memory_requirement_mb / 1024.0;
387 device.capabilities.memory_size_gb >= required_memory_gb
388 })
389 .copied()
390 .collect();
391
392 if viable_devices.is_empty() {
393 return Err(UnifiedGpuError::InvalidOperation(
394 "No devices with sufficient memory for workload".into(),
395 ));
396 }
397
398 self.distribute_capability_aware(&viable_devices, workload)
400 }
401
402 fn distribute_latency_optimized(
404 &self,
405 devices: &[&Arc<GpuDevice>],
406 workload: &Workload,
407 ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
408 self.distribute_capability_aware(devices, workload)
411 }
412
413 async fn distribute_adaptive(
415 &self,
416 devices: &[&Arc<GpuDevice>],
417 workload: &Workload,
418 ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
419 if let Ok(history) = self.performance_history.lock() {
421 if let Some(records) = history.get(&workload.operation_type) {
422 return self.distribute_based_on_history(devices, workload, records);
424 }
425 }
426
427 self.distribute_capability_aware(devices, workload)
429 }
430
431 fn distribute_based_on_history(
433 &self,
434 devices: &[&Arc<GpuDevice>],
435 workload: &Workload,
436 history: &[PerformanceRecord],
437 ) -> UnifiedGpuResult<Vec<DeviceWorkload>> {
438 let mut device_predictions = HashMap::new();
440
441 for device in devices {
442 let device_history: Vec<_> = history
443 .iter()
444 .filter(|record| record.device_id == device.id)
445 .collect();
446
447 let predicted_throughput = if device_history.is_empty() {
448 device.performance_score(&workload.operation_type)
449 } else {
450 let recent_throughput: f32 = device_history
452 .iter()
453 .rev()
454 .take(10) .map(|record| record.throughput_gops)
456 .sum::<f32>()
457 / device_history.len().min(10) as f32;
458 recent_throughput
459 };
460
461 device_predictions.insert(device.id, predicted_throughput);
462 }
463
464 let total_predicted: f32 = device_predictions.values().sum();
466
467 let mut assignments = Vec::new();
468 let mut data_offset = 0;
469
470 for (i, device) in devices.iter().enumerate() {
471 let predicted = device_predictions[&device.id];
472 let fraction = predicted / total_predicted;
473 let data_chunk_size = (workload.data_size as f32 * fraction) as usize;
474
475 let end = if i == devices.len() - 1 {
476 workload.data_size
477 } else {
478 data_offset + data_chunk_size
479 };
480
481 assignments.push(DeviceWorkload {
482 device_id: device.id,
483 workload_fraction: fraction,
484 data_range: (data_offset, end),
485 estimated_completion_ms: 100.0 / fraction,
486 memory_requirement_mb: workload.memory_requirement_mb * fraction,
487 });
488
489 data_offset = end;
490 }
491
492 Ok(assignments)
493 }
494
495 pub async fn record_performance(&self, record: PerformanceRecord) {
497 if let Ok(mut history) = self.performance_history.lock() {
498 let operation_history = history
499 .entry(record.operation_type.clone())
500 .or_insert_with(Vec::new);
501
502 operation_history.push(record);
503
504 if operation_history.len() > 1000 {
506 operation_history.remove(0);
507 }
508 }
509 }
510
511 pub async fn get_performance_stats(&self, operation_type: &str) -> Option<PerformanceStats> {
513 if let Ok(history) = self.performance_history.lock() {
514 if let Some(records) = history.get(operation_type) {
515 if records.is_empty() {
516 return None;
517 }
518
519 let completion_times: Vec<f32> =
520 records.iter().map(|r| r.completion_time_ms).collect();
521 let throughputs: Vec<f32> = records.iter().map(|r| r.throughput_gops).collect();
522
523 let avg_completion_time =
524 completion_times.iter().sum::<f32>() / completion_times.len() as f32;
525 let avg_throughput = throughputs.iter().sum::<f32>() / throughputs.len() as f32;
526
527 return Some(PerformanceStats {
528 operation_type: operation_type.to_string(),
529 avg_completion_time_ms: avg_completion_time,
530 avg_throughput_gops: avg_throughput,
531 total_operations: records.len(),
532 best_device_id: records
533 .iter()
534 .max_by(|a, b| a.throughput_gops.partial_cmp(&b.throughput_gops).unwrap())
535 .map(|r| r.device_id),
536 });
537 }
538 }
539 None
540 }
541}
542
543#[derive(Debug, Clone)]
544pub struct PerformanceStats {
545 pub operation_type: String,
546 pub avg_completion_time_ms: f32,
547 pub avg_throughput_gops: f32,
548 pub total_operations: usize,
549 pub best_device_id: Option<DeviceId>,
550}
551
552pub struct WorkloadCoordinator {
554 active_workloads: Arc<Mutex<HashMap<String, ActiveWorkload>>>,
555 #[allow(dead_code)]
556 synchronization_manager: SynchronizationManager,
557}
558
559#[derive(Debug)]
560pub struct ActiveWorkload {
561 pub id: String,
562 pub device_assignments: Vec<DeviceWorkload>,
563 pub completion_status: Vec<bool>,
564 pub results: Vec<Option<Vec<u8>>>,
565 pub start_time: Instant,
566}
567
568pub struct SynchronizationManager {
570 barriers: Arc<Mutex<HashMap<String, MultiGpuBarrier>>>,
571}
572
573pub struct MultiGpuBarrier {
575 pub barrier_id: String,
576 pub device_count: usize,
577 pub completed_devices: Arc<AtomicUsize>,
578 pub completion_notifier: Arc<Notify>,
579 pub timeout: Duration,
580}
581
582impl WorkloadCoordinator {
583 pub fn new() -> Self {
585 Self {
586 active_workloads: Arc::new(Mutex::new(HashMap::new())),
587 synchronization_manager: SynchronizationManager::new(),
588 }
589 }
590
591 pub async fn submit_workload(
593 &self,
594 workload_id: String,
595 assignments: Vec<DeviceWorkload>,
596 ) -> UnifiedGpuResult<()> {
597 let device_count = assignments.len();
598
599 let active_workload = ActiveWorkload {
600 id: workload_id.clone(),
601 device_assignments: assignments,
602 completion_status: vec![false; device_count],
603 results: vec![None; device_count],
604 start_time: Instant::now(),
605 };
606
607 if let Ok(mut workloads) = self.active_workloads.lock() {
608 workloads.insert(workload_id, active_workload);
609 }
610
611 Ok(())
612 }
613
614 pub async fn wait_for_completion(
616 &self,
617 workload_id: &str,
618 timeout: Duration,
619 ) -> UnifiedGpuResult<Vec<Vec<u8>>> {
620 let start = Instant::now();
621
622 loop {
623 if start.elapsed() > timeout {
624 return Err(UnifiedGpuError::InvalidOperation(
625 "Workload completion timeout".into(),
626 ));
627 }
628
629 if let Ok(workloads) = self.active_workloads.lock() {
631 if let Some(workload) = workloads.get(workload_id) {
632 if workload
633 .completion_status
634 .iter()
635 .all(|&completed| completed)
636 {
637 let results: Vec<Vec<u8>> = workload
639 .results
640 .iter()
641 .filter_map(|result| result.as_ref())
642 .cloned()
643 .collect();
644 return Ok(results);
645 }
646 }
647 }
648
649 tokio::time::sleep(Duration::from_millis(10)).await;
651 }
652 }
653
654 pub async fn mark_device_completed(
656 &self,
657 workload_id: &str,
658 device_id: DeviceId,
659 result: Vec<u8>,
660 ) -> UnifiedGpuResult<()> {
661 if let Ok(mut workloads) = self.active_workloads.lock() {
662 if let Some(workload) = workloads.get_mut(workload_id) {
663 for (i, assignment) in workload.device_assignments.iter().enumerate() {
665 if assignment.device_id == device_id {
666 workload.completion_status[i] = true;
667 workload.results[i] = Some(result);
668 break;
669 }
670 }
671 }
672 }
673 Ok(())
674 }
675}
676
677impl SynchronizationManager {
678 pub fn new() -> Self {
680 Self {
681 barriers: Arc::new(Mutex::new(HashMap::new())),
682 }
683 }
684
685 pub async fn create_barrier(
687 &self,
688 barrier_id: String,
689 device_count: usize,
690 timeout: Duration,
691 ) -> UnifiedGpuResult<()> {
692 let barrier = MultiGpuBarrier {
693 barrier_id: barrier_id.clone(),
694 device_count,
695 completed_devices: Arc::new(AtomicUsize::new(0)),
696 completion_notifier: Arc::new(Notify::new()),
697 timeout,
698 };
699
700 if let Ok(mut barriers) = self.barriers.lock() {
701 barriers.insert(barrier_id, barrier);
702 }
703
704 Ok(())
705 }
706
707 pub async fn wait_barrier(
709 &self,
710 barrier_id: &str,
711 _device_id: DeviceId,
712 ) -> UnifiedGpuResult<()> {
713 let (notifier, _device_count) = {
714 if let Ok(barriers) = self.barriers.lock() {
715 if let Some(barrier) = barriers.get(barrier_id) {
716 let completed = barrier.completed_devices.fetch_add(1, Ordering::SeqCst) + 1;
717
718 if completed >= barrier.device_count {
719 barrier.completion_notifier.notify_waiters();
721 return Ok(());
722 }
723
724 (
725 Arc::clone(&barrier.completion_notifier),
726 barrier.device_count,
727 )
728 } else {
729 return Err(UnifiedGpuError::InvalidOperation(format!(
730 "Barrier {} not found",
731 barrier_id
732 )));
733 }
734 } else {
735 return Err(UnifiedGpuError::InvalidOperation(
736 "Failed to access barriers".into(),
737 ));
738 }
739 };
740
741 let timeout_duration = Duration::from_secs(30); tokio::time::timeout(timeout_duration, notifier.notified())
744 .await
745 .map_err(|_| UnifiedGpuError::InvalidOperation("Barrier wait timeout".into()))?;
746
747 Ok(())
748 }
749}
750
751impl Default for WorkloadCoordinator {
752 fn default() -> Self {
753 Self::new()
754 }
755}
756
757impl Default for SynchronizationManager {
758 fn default() -> Self {
759 Self::new()
760 }
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766
767 #[test]
768 fn test_device_id_creation() {
769 let device_id = DeviceId(0);
770 assert_eq!(device_id.0, 0);
771 }
772
773 #[test]
774 fn test_workload_creation() {
775 let workload = Workload {
776 operation_type: "test_operation".to_string(),
777 data_size: 1000,
778 memory_requirement_mb: 100.0,
779 compute_intensity: ComputeIntensity::Moderate,
780 parallelizable: true,
781 synchronization_required: false,
782 };
783
784 assert_eq!(workload.data_size, 1000);
785 assert_eq!(workload.memory_requirement_mb, 100.0);
786 }
787
788 #[test]
789 fn test_load_balancer_creation() {
790 let balancer = IntelligentLoadBalancer::new(LoadBalancingStrategy::Balanced);
791 assert!(matches!(
792 balancer.balancing_strategy,
793 LoadBalancingStrategy::Balanced
794 ));
795 }
796
797 #[tokio::test]
798 async fn test_workload_coordinator() {
799 let coordinator = WorkloadCoordinator::new();
800
801 let assignments = vec![
802 DeviceWorkload {
803 device_id: DeviceId(0),
804 workload_fraction: 0.5,
805 data_range: (0, 500),
806 estimated_completion_ms: 100.0,
807 memory_requirement_mb: 50.0,
808 },
809 DeviceWorkload {
810 device_id: DeviceId(1),
811 workload_fraction: 0.5,
812 data_range: (500, 1000),
813 estimated_completion_ms: 100.0,
814 memory_requirement_mb: 50.0,
815 },
816 ];
817
818 let result = coordinator
819 .submit_workload("test_workload".to_string(), assignments)
820 .await;
821 assert!(result.is_ok());
822 }
823}