1use crate::{
8 multi_gpu::{
9 DeviceId, GpuDevice, IntelligentLoadBalancer, LoadBalancingStrategy, Workload,
10 WorkloadCoordinator,
11 },
12 GpuError,
13};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::Instant;
17use thiserror::Error;
18use tokio::sync::RwLock;
19use wgpu::util::DeviceExt;
20
21#[derive(Error, Debug)]
22pub enum UnifiedGpuError {
23 #[error("GPU error: {0}")]
24 Gpu(#[from] GpuError),
25
26 #[error("Shader compilation failed: {0}")]
27 ShaderCompilation(String),
28
29 #[error("Buffer size mismatch: expected {expected}, got {actual}")]
30 BufferSizeMismatch { expected: usize, actual: usize },
31
32 #[error("Invalid operation: {0}")]
33 InvalidOperation(String),
34
35 #[error("Memory allocation failed: {0}")]
36 MemoryAllocation(String),
37}
38
39pub type UnifiedGpuResult<T> = Result<T, UnifiedGpuError>;
40
41pub trait GpuAccelerated<T> {
43 fn to_gpu_buffer(&self, context: &GpuContext) -> UnifiedGpuResult<wgpu::Buffer>;
45
46 fn from_gpu_buffer(buffer: &wgpu::Buffer, context: &GpuContext) -> UnifiedGpuResult<T>;
48
49 fn gpu_operation(
51 &self,
52 operation: &str,
53 context: &GpuContext,
54 params: &GpuOperationParams,
55 ) -> UnifiedGpuResult<T>;
56}
57
58#[derive(Debug, Clone)]
60pub struct GpuOperationParams {
61 pub params: HashMap<String, GpuParam>,
63 pub batch_size: usize,
65 pub workgroup_size: (u32, u32, u32),
67}
68
69#[derive(Debug, Clone)]
71pub enum GpuParam {
72 Float(f32),
73 Double(f64),
74 Integer(i32),
75 UnsignedInteger(u32),
76 Buffer(String), Array(Vec<f32>),
78}
79
80impl Default for GpuOperationParams {
81 fn default() -> Self {
82 Self {
83 params: HashMap::new(),
84 batch_size: 1,
85 workgroup_size: (1, 1, 1),
86 }
87 }
88}
89
90pub struct GpuContext {
92 pub device: wgpu::Device,
93 pub queue: wgpu::Queue,
94 shader_cache: HashMap<String, wgpu::ComputePipeline>,
95 #[allow(dead_code)]
96 buffer_pool: GpuBufferPool,
97}
98
99impl GpuContext {
100 pub async fn new() -> UnifiedGpuResult<Self> {
102 let instance = wgpu::Instance::default();
103
104 let adapter = instance
105 .request_adapter(&wgpu::RequestAdapterOptions {
106 power_preference: wgpu::PowerPreference::HighPerformance,
107 compatible_surface: None,
108 force_fallback_adapter: false,
109 })
110 .await
111 .ok_or_else(|| {
112 UnifiedGpuError::Gpu(GpuError::InitializationError(
113 "No GPU adapter found".to_string(),
114 ))
115 })?;
116
117 let (device, queue) = adapter
118 .request_device(
119 &wgpu::DeviceDescriptor {
120 label: Some("Amari Unified GPU Device"),
121 required_features: wgpu::Features::empty(),
122 required_limits: wgpu::Limits::default(),
123 },
124 None,
125 )
126 .await
127 .map_err(|e| UnifiedGpuError::Gpu(GpuError::InitializationError(e.to_string())))?;
128
129 Ok(Self {
130 device,
131 queue,
132 shader_cache: HashMap::new(),
133 buffer_pool: GpuBufferPool::new(),
134 })
135 }
136
137 pub fn get_compute_pipeline(
139 &mut self,
140 shader_key: &str,
141 shader_source: &str,
142 bind_group_layout: &wgpu::BindGroupLayout,
143 ) -> UnifiedGpuResult<&wgpu::ComputePipeline> {
144 if !self.shader_cache.contains_key(shader_key) {
145 let shader_module = self
146 .device
147 .create_shader_module(wgpu::ShaderModuleDescriptor {
148 label: Some(&format!("{} Shader", shader_key)),
149 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
150 });
151
152 let pipeline_layout =
153 self.device
154 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
155 label: Some(&format!("{} Pipeline Layout", shader_key)),
156 bind_group_layouts: &[bind_group_layout],
157 push_constant_ranges: &[],
158 });
159
160 let compute_pipeline =
161 self.device
162 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
163 label: Some(&format!("{} Pipeline", shader_key)),
164 layout: Some(&pipeline_layout),
165 module: &shader_module,
166 entry_point: "main",
167 });
168
169 self.shader_cache
170 .insert(shader_key.to_string(), compute_pipeline);
171 }
172
173 Ok(self
174 .shader_cache
175 .get(shader_key)
176 .expect("Pipeline should exist"))
177 }
178
179 pub fn create_buffer_with_data<T: bytemuck::Pod>(
181 &self,
182 label: &str,
183 data: &[T],
184 usage: wgpu::BufferUsages,
185 ) -> wgpu::Buffer {
186 self.device
187 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
188 label: Some(label),
189 contents: bytemuck::cast_slice(data),
190 usage,
191 })
192 }
193
194 pub fn create_buffer(&self, label: &str, size: u64, usage: wgpu::BufferUsages) -> wgpu::Buffer {
196 self.device.create_buffer(&wgpu::BufferDescriptor {
197 label: Some(label),
198 size,
199 usage,
200 mapped_at_creation: false,
201 })
202 }
203
204 pub fn execute_compute(
206 &self,
207 pipeline: &wgpu::ComputePipeline,
208 bind_group: &wgpu::BindGroup,
209 workgroup_count: (u32, u32, u32),
210 ) {
211 let mut encoder = self
212 .device
213 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
214 label: Some("Compute Encoder"),
215 });
216
217 {
218 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
219 label: Some("Compute Pass"),
220 timestamp_writes: None,
221 });
222 compute_pass.set_pipeline(pipeline);
223 compute_pass.set_bind_group(0, bind_group, &[]);
224 compute_pass.dispatch_workgroups(
225 workgroup_count.0,
226 workgroup_count.1,
227 workgroup_count.2,
228 );
229 }
230
231 self.queue.submit([encoder.finish()]);
232 }
233
234 pub async fn read_buffer<T: bytemuck::Pod + Clone>(
236 &self,
237 buffer: &wgpu::Buffer,
238 size: u64,
239 ) -> UnifiedGpuResult<Vec<T>> {
240 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
241 label: Some("Staging Buffer"),
242 size,
243 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
244 mapped_at_creation: false,
245 });
246
247 let mut encoder = self
248 .device
249 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
250 label: Some("Copy Encoder"),
251 });
252
253 encoder.copy_buffer_to_buffer(buffer, 0, &staging_buffer, 0, size);
254 self.queue.submit([encoder.finish()]);
255
256 let buffer_slice = staging_buffer.slice(..);
257 let (tx, rx) = futures::channel::oneshot::channel();
258 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
259 tx.send(result).ok();
260 });
261
262 self.device.poll(wgpu::Maintain::Wait);
263
264 rx.await
265 .map_err(|_| UnifiedGpuError::InvalidOperation("Buffer read timeout".to_string()))?
266 .map_err(|e| UnifiedGpuError::InvalidOperation(format!("Buffer map failed: {}", e)))?;
267
268 let data = buffer_slice.get_mapped_range();
269 let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
270 drop(data);
271 staging_buffer.unmap();
272
273 Ok(result)
274 }
275}
276
277pub struct GpuBufferPool {
279 _pools: HashMap<String, Vec<wgpu::Buffer>>, }
281
282impl GpuBufferPool {
283 pub fn new() -> Self {
284 Self {
285 _pools: HashMap::new(),
286 }
287 }
288
289 }
293
294impl Default for GpuBufferPool {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300#[derive(Clone)]
303pub struct SharedGpuContext {
304 device: Arc<wgpu::Device>,
306 queue: Arc<wgpu::Queue>,
307 adapter_info: wgpu::AdapterInfo,
308 buffer_pool: Arc<std::sync::Mutex<EnhancedGpuBufferPool>>,
309 shader_cache: Arc<std::sync::Mutex<HashMap<String, Arc<wgpu::ComputePipeline>>>>,
310 creation_time: Instant,
311
312 multi_gpu_enabled: bool,
314 gpu_devices: Arc<RwLock<HashMap<DeviceId, Arc<GpuDevice>>>>,
315 load_balancer: Arc<IntelligentLoadBalancer>,
316 workload_coordinator: Arc<WorkloadCoordinator>,
317 primary_device_id: DeviceId,
318}
319
320impl SharedGpuContext {
321 pub async fn global() -> UnifiedGpuResult<&'static Self> {
325 let context = Self::new().await?;
326 Ok(Box::leak(Box::new(context)))
328 }
329
330 async fn new() -> UnifiedGpuResult<Self> {
332 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
333 backends: wgpu::Backends::all(),
334 flags: wgpu::InstanceFlags::default(),
335 dx12_shader_compiler: wgpu::Dx12Compiler::default(),
336 gles_minor_version: wgpu::Gles3MinorVersion::Automatic,
337 });
338
339 let adapter = instance
340 .request_adapter(&wgpu::RequestAdapterOptions {
341 power_preference: wgpu::PowerPreference::HighPerformance,
342 compatible_surface: None,
343 force_fallback_adapter: false,
344 })
345 .await
346 .ok_or_else(|| {
347 UnifiedGpuError::InvalidOperation("No suitable GPU adapter found".into())
348 })?;
349
350 let adapter_info = adapter.get_info();
351
352 let (device, queue) = adapter
353 .request_device(
354 &wgpu::DeviceDescriptor {
355 label: Some("Shared Amari GPU Device"),
356 required_features: wgpu::Features::TIMESTAMP_QUERY,
357 required_limits: wgpu::Limits::default(),
358 },
359 None,
360 )
361 .await
362 .map_err(|e| {
363 UnifiedGpuError::InvalidOperation(format!("Device request failed: {:?}", e))
364 })?;
365
366 let primary_device_id = DeviceId(0);
367
368 let gpu_device = Arc::new(
370 GpuDevice::new(primary_device_id, &adapter, device, queue)
371 .await
372 .map_err(|_| {
373 UnifiedGpuError::InvalidOperation("Failed to create GPU device".into())
374 })?,
375 );
376
377 let device_arc = Arc::clone(&gpu_device.device);
378 let queue_arc = Arc::clone(&gpu_device.queue);
379
380 let mut gpu_devices = HashMap::new();
381 gpu_devices.insert(primary_device_id, gpu_device);
382
383 Ok(Self {
384 device: device_arc,
385 queue: queue_arc,
386 adapter_info,
387 buffer_pool: Arc::new(std::sync::Mutex::new(EnhancedGpuBufferPool::new())),
388 shader_cache: Arc::new(std::sync::Mutex::new(HashMap::new())),
389 creation_time: Instant::now(),
390
391 multi_gpu_enabled: false,
393 gpu_devices: Arc::new(RwLock::new(gpu_devices)),
394 load_balancer: Arc::new(IntelligentLoadBalancer::new(
395 LoadBalancingStrategy::Balanced,
396 )),
397 workload_coordinator: Arc::new(WorkloadCoordinator::new()),
398 primary_device_id,
399 })
400 }
401
402 pub async fn with_multi_gpu() -> UnifiedGpuResult<Self> {
404 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
405 backends: wgpu::Backends::all(),
406 flags: wgpu::InstanceFlags::default(),
407 dx12_shader_compiler: wgpu::Dx12Compiler::default(),
408 gles_minor_version: wgpu::Gles3MinorVersion::Automatic,
409 });
410
411 let adapters: Vec<_> = instance.enumerate_adapters(wgpu::Backends::all());
413
414 if adapters.is_empty() {
415 return Err(UnifiedGpuError::InvalidOperation(
416 "No GPU adapters found".into(),
417 ));
418 }
419
420 let mut gpu_devices = HashMap::new();
422 let mut primary_device = None;
423 let mut primary_queue = None;
424 let mut primary_adapter_info = None;
425
426 for (i, adapter) in adapters.iter().enumerate() {
427 let device_id = DeviceId(i);
428
429 if let Ok((device, queue)) = adapter
431 .request_device(
432 &wgpu::DeviceDescriptor {
433 label: Some(&format!("Amari Multi-GPU Device {}", i)),
434 required_features: wgpu::Features::TIMESTAMP_QUERY,
435 required_limits: wgpu::Limits::default(),
436 },
437 None,
438 )
439 .await
440 {
441 if let Ok(gpu_device) = GpuDevice::new(device_id, adapter, device, queue).await {
443 if primary_device.is_none() {
445 primary_device = Some(Arc::clone(&gpu_device.device));
446 primary_queue = Some(Arc::clone(&gpu_device.queue));
447 primary_adapter_info = Some(adapter.get_info());
448 }
449
450 gpu_devices.insert(device_id, Arc::new(gpu_device));
451 }
452 }
453 }
454
455 if gpu_devices.is_empty() {
456 return Err(UnifiedGpuError::InvalidOperation(
457 "No usable GPU devices found".into(),
458 ));
459 }
460
461 let primary_device_id = DeviceId(0);
462 let load_balancer = Arc::new(IntelligentLoadBalancer::new(
463 LoadBalancingStrategy::CapabilityAware,
464 ));
465
466 for device in gpu_devices.values() {
468 load_balancer.add_device(Arc::clone(device)).await;
469 }
470
471 Ok(Self {
472 device: primary_device.unwrap(),
473 queue: primary_queue.unwrap(),
474 adapter_info: primary_adapter_info.unwrap(),
475 buffer_pool: Arc::new(std::sync::Mutex::new(EnhancedGpuBufferPool::new())),
476 shader_cache: Arc::new(std::sync::Mutex::new(HashMap::new())),
477 creation_time: Instant::now(),
478
479 multi_gpu_enabled: true,
481 gpu_devices: Arc::new(RwLock::new(gpu_devices)),
482 load_balancer,
483 workload_coordinator: Arc::new(WorkloadCoordinator::new()),
484 primary_device_id,
485 })
486 }
487
488 pub fn device(&self) -> &wgpu::Device {
490 &self.device
491 }
492
493 pub fn queue(&self) -> &wgpu::Queue {
495 &self.queue
496 }
497
498 pub fn adapter_info(&self) -> &wgpu::AdapterInfo {
500 &self.adapter_info
501 }
502
503 pub fn get_buffer(
505 &self,
506 size: u64,
507 usage: wgpu::BufferUsages,
508 label: Option<&str>,
509 ) -> wgpu::Buffer {
510 if let Ok(mut pool) = self.buffer_pool.lock() {
511 pool.get_or_create(&self.device, size, usage, label)
512 } else {
513 self.device.create_buffer(&wgpu::BufferDescriptor {
515 label,
516 size,
517 usage,
518 mapped_at_creation: false,
519 })
520 }
521 }
522
523 pub fn return_buffer(&self, buffer: wgpu::Buffer, size: u64, usage: wgpu::BufferUsages) {
525 if let Ok(mut pool) = self.buffer_pool.lock() {
526 pool.return_buffer(buffer, size, usage);
527 }
528 }
530
531 pub fn get_compute_pipeline(
533 &self,
534 shader_key: &str,
535 shader_source: &str,
536 entry_point: &str,
537 ) -> UnifiedGpuResult<Arc<wgpu::ComputePipeline>> {
538 let cache_key = format!("{}:{}", shader_key, entry_point);
539
540 if let Ok(mut cache) = self.shader_cache.lock() {
541 if let Some(pipeline) = cache.get(&cache_key) {
542 return Ok(Arc::clone(pipeline));
543 }
544
545 let shader_module = self
547 .device
548 .create_shader_module(wgpu::ShaderModuleDescriptor {
549 label: Some(&format!("{} Shader", shader_key)),
550 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
551 });
552
553 let bind_group_layout =
554 self.device
555 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
556 label: Some(&format!("{} Bind Group Layout", shader_key)),
557 entries: &[
558 wgpu::BindGroupLayoutEntry {
559 binding: 0,
560 visibility: wgpu::ShaderStages::COMPUTE,
561 ty: wgpu::BindingType::Buffer {
562 ty: wgpu::BufferBindingType::Storage { read_only: true },
563 has_dynamic_offset: false,
564 min_binding_size: None,
565 },
566 count: None,
567 },
568 wgpu::BindGroupLayoutEntry {
569 binding: 1,
570 visibility: wgpu::ShaderStages::COMPUTE,
571 ty: wgpu::BindingType::Buffer {
572 ty: wgpu::BufferBindingType::Storage { read_only: false },
573 has_dynamic_offset: false,
574 min_binding_size: None,
575 },
576 count: None,
577 },
578 ],
579 });
580
581 let pipeline_layout =
582 self.device
583 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
584 label: Some(&format!("{} Pipeline Layout", shader_key)),
585 bind_group_layouts: &[&bind_group_layout],
586 push_constant_ranges: &[],
587 });
588
589 let pipeline = self
590 .device
591 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
592 label: Some(&format!("{} Pipeline", shader_key)),
593 layout: Some(&pipeline_layout),
594 module: &shader_module,
595 entry_point,
596 });
597
598 let pipeline_arc = Arc::new(pipeline);
599 cache.insert(cache_key, Arc::clone(&pipeline_arc));
600 Ok(pipeline_arc)
601 } else {
602 Err(UnifiedGpuError::InvalidOperation(
603 "Failed to access shader cache".into(),
604 ))
605 }
606 }
607
608 pub fn buffer_pool_stats(&self) -> BufferPoolStats {
610 if let Ok(pool) = self.buffer_pool.lock() {
611 pool.get_stats()
612 } else {
613 BufferPoolStats::default()
614 }
615 }
616
617 pub fn uptime(&self) -> std::time::Duration {
619 self.creation_time.elapsed()
620 }
621
622 pub fn get_optimal_workgroup(&self, operation: &str, data_size: usize) -> (u32, u32, u32) {
624 match operation {
625 "matrix_multiply" | "matrix_operation" => {
626 (16, 16, 1)
629 }
630 "vector_operation" | "reduce" | "scan" => {
631 let workgroup_size = if data_size > 10000 {
633 256 } else if data_size > 1000 {
635 128 } else {
637 64 };
639 (workgroup_size, 1, 1)
640 }
641 "geometric_algebra" | "clifford_algebra" => {
642 (128, 1, 1)
644 }
645 "cellular_automata" | "ca_evolution" => {
646 (16, 16, 1)
648 }
649 "neural_network" | "batch_processing" => {
650 (256, 1, 1)
652 }
653 "information_geometry" | "fisher_information" | "bregman_divergence" => {
654 (256, 1, 1)
656 }
657 "tropical_algebra" | "tropical_matrix" => {
658 (128, 1, 1)
660 }
661 "dual_number" | "automatic_differentiation" => {
662 (128, 1, 1)
664 }
665 "fusion_system" | "llm_evaluation" => {
666 (256, 1, 1)
668 }
669 "enumerative_geometry" | "intersection_theory" => {
670 (64, 1, 1)
672 }
673 _ => (64, 1, 1), }
675 }
676
677 pub fn get_workgroup_declaration(&self, operation: &str, data_size: usize) -> String {
679 let (x, y, z) = self.get_optimal_workgroup(operation, data_size);
680
681 if y == 1 && z == 1 {
682 format!("@compute @workgroup_size({})", x)
683 } else if z == 1 {
684 format!("@compute @workgroup_size({}, {})", x, y)
685 } else {
686 format!("@compute @workgroup_size({}, {}, {})", x, y, z)
687 }
688 }
689
690 pub fn is_multi_gpu_enabled(&self) -> bool {
694 self.multi_gpu_enabled
695 }
696
697 pub async fn device_count(&self) -> usize {
699 self.gpu_devices.read().await.len()
700 }
701
702 pub async fn get_device_info(&self) -> Vec<(DeviceId, String, String)> {
704 let devices = self.gpu_devices.read().await;
705 devices
706 .iter()
707 .map(|(id, device)| {
708 (
709 *id,
710 device.adapter_info.name.clone(),
711 format!("{:?}", device.capabilities.architecture),
712 )
713 })
714 .collect()
715 }
716
717 pub async fn get_device(&self, device_id: DeviceId) -> Option<Arc<GpuDevice>> {
719 let devices = self.gpu_devices.read().await;
720 devices.get(&device_id).cloned()
721 }
722
723 pub async fn optimal_device_for_operation(
725 &self,
726 operation: &str,
727 _data_size: usize,
728 ) -> DeviceId {
729 if !self.multi_gpu_enabled {
730 return self.primary_device_id;
731 }
732
733 let devices = self.gpu_devices.read().await;
734 let available_devices: Vec<_> = devices
735 .values()
736 .filter(|device| device.is_available())
737 .collect();
738
739 if available_devices.is_empty() {
740 return self.primary_device_id;
741 }
742
743 available_devices
745 .iter()
746 .max_by(|a, b| {
747 a.performance_score(operation)
748 .partial_cmp(&b.performance_score(operation))
749 .unwrap_or(std::cmp::Ordering::Equal)
750 })
751 .map(|device| device.id)
752 .unwrap_or(self.primary_device_id)
753 }
754
755 pub async fn distribute_workload(
757 &self,
758 workload: Workload,
759 ) -> UnifiedGpuResult<Vec<crate::multi_gpu::DeviceWorkload>> {
760 if !self.multi_gpu_enabled {
761 return Ok(vec![crate::multi_gpu::DeviceWorkload {
763 device_id: self.primary_device_id,
764 workload_fraction: 1.0,
765 data_range: (0, workload.data_size),
766 estimated_completion_ms: 100.0,
767 memory_requirement_mb: workload.memory_requirement_mb,
768 }]);
769 }
770
771 self.load_balancer
772 .distribute_workload(&workload)
773 .await
774 .map_err(|e| {
775 UnifiedGpuError::InvalidOperation(format!("Workload distribution failed: {:?}", e))
776 })
777 }
778
779 pub async fn execute_multi_gpu_workload(
781 &self,
782 workload_id: String,
783 workload: Workload,
784 ) -> UnifiedGpuResult<Vec<Vec<u8>>> {
785 if !self.multi_gpu_enabled {
786 return Err(UnifiedGpuError::InvalidOperation(
787 "Multi-GPU mode not enabled".into(),
788 ));
789 }
790
791 let assignments = self.distribute_workload(workload).await?;
793
794 self.workload_coordinator
796 .submit_workload(workload_id.clone(), assignments)
797 .await
798 .map_err(|e| {
799 UnifiedGpuError::InvalidOperation(format!("Workload submission failed: {:?}", e))
800 })?;
801
802 let timeout = std::time::Duration::from_secs(30);
804 self.workload_coordinator
805 .wait_for_completion(&workload_id, timeout)
806 .await
807 .map_err(|e| {
808 UnifiedGpuError::InvalidOperation(format!("Workload execution failed: {:?}", e))
809 })
810 }
811
812 pub async fn get_gpu_utilization(&self) -> HashMap<DeviceId, f32> {
814 let devices: tokio::sync::RwLockReadGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
815 self.gpu_devices.read().await;
816 devices
817 .iter()
818 .map(|(id, device): (&DeviceId, &Arc<GpuDevice>)| (*id, device.current_load()))
819 .collect()
820 }
821
822 pub async fn get_multi_gpu_stats(&self) -> MultiGpuStats {
824 let devices: tokio::sync::RwLockReadGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
825 self.gpu_devices.read().await;
826 let device_count = devices.len();
827
828 let total_operations: usize = devices
829 .values()
830 .map(|device| {
831 device
832 .total_operations
833 .load(std::sync::atomic::Ordering::Relaxed)
834 })
835 .sum();
836
837 let total_errors: usize = devices
838 .values()
839 .map(|device| {
840 device
841 .error_count
842 .load(std::sync::atomic::Ordering::Relaxed)
843 })
844 .sum();
845
846 let avg_utilization = if !devices.is_empty() {
847 devices
848 .values()
849 .map(|device: &Arc<GpuDevice>| device.current_load())
850 .sum::<f32>()
851 / devices.len() as f32
852 } else {
853 0.0
854 };
855
856 MultiGpuStats {
857 device_count,
858 total_operations,
859 total_errors,
860 avg_utilization_percent: avg_utilization,
861 uptime: self.creation_time.elapsed(),
862 }
863 }
864
865 pub async fn set_load_balancing_strategy(
867 &self,
868 _strategy: LoadBalancingStrategy,
869 ) -> UnifiedGpuResult<()> {
870 if !self.multi_gpu_enabled {
871 return Err(UnifiedGpuError::InvalidOperation(
872 "Multi-GPU mode not enabled".into(),
873 ));
874 }
875
876 Ok(())
879 }
880
881 pub async fn add_gpu_device(&self, device: Arc<GpuDevice>) -> UnifiedGpuResult<()> {
883 if !self.multi_gpu_enabled {
884 return Err(UnifiedGpuError::InvalidOperation(
885 "Multi-GPU mode not enabled".into(),
886 ));
887 }
888
889 let mut devices: tokio::sync::RwLockWriteGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
890 self.gpu_devices.write().await;
891 devices.insert(device.id, Arc::clone(&device));
892
893 self.load_balancer.add_device(device).await;
895
896 Ok(())
897 }
898
899 pub async fn remove_gpu_device(&self, device_id: DeviceId) -> UnifiedGpuResult<()> {
901 if !self.multi_gpu_enabled {
902 return Err(UnifiedGpuError::InvalidOperation(
903 "Multi-GPU mode not enabled".into(),
904 ));
905 }
906
907 let mut devices: tokio::sync::RwLockWriteGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
908 self.gpu_devices.write().await;
909 devices.remove(&device_id);
910
911 self.load_balancer.remove_device(device_id).await;
913
914 Ok(())
915 }
916}
917
918pub struct EnhancedGpuBufferPool {
920 pools: HashMap<(u64, wgpu::BufferUsages), Vec<wgpu::Buffer>>,
921 stats: HashMap<(u64, wgpu::BufferUsages), PoolEntryStats>,
922 total_created: u64,
923 total_reused: u64,
924 last_cleanup: Instant,
925}
926
927#[derive(Debug, Clone, Default)]
928pub struct PoolEntryStats {
929 pub created_count: u64,
930 pub reused_count: u64,
931 pub last_used: Option<Instant>,
932 pub total_size_bytes: u64,
933}
934
935#[derive(Debug, Clone, Default)]
936pub struct BufferPoolStats {
937 pub total_buffers_created: u64,
938 pub total_buffers_reused: u64,
939 pub current_pooled_count: usize,
940 pub total_pooled_memory_mb: f32,
941 pub hit_rate_percent: f32,
942}
943
944impl EnhancedGpuBufferPool {
945 pub fn new() -> Self {
946 Self {
947 pools: HashMap::new(),
948 stats: HashMap::new(),
949 total_created: 0,
950 total_reused: 0,
951 last_cleanup: Instant::now(),
952 }
953 }
954}
955
956impl Default for EnhancedGpuBufferPool {
957 fn default() -> Self {
958 Self::new()
959 }
960}
961
962impl EnhancedGpuBufferPool {
963 pub fn get_or_create(
964 &mut self,
965 device: &wgpu::Device,
966 size: u64,
967 usage: wgpu::BufferUsages,
968 label: Option<&str>,
969 ) -> wgpu::Buffer {
970 let key = (size, usage);
971
972 if let Some(buffers) = self.pools.get_mut(&key) {
974 if let Some(buffer) = buffers.pop() {
975 self.total_reused += 1;
976 self.stats.entry(key).or_default().reused_count += 1;
977 self.stats.get_mut(&key).unwrap().last_used = Some(Instant::now());
978 return buffer;
979 }
980 }
981
982 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
984 label,
985 size,
986 usage,
987 mapped_at_creation: false,
988 });
989
990 self.total_created += 1;
991 let stats = self.stats.entry(key).or_default();
992 stats.created_count += 1;
993 stats.total_size_bytes += size;
994 stats.last_used = Some(Instant::now());
995
996 if self.last_cleanup.elapsed().as_secs() > 30 {
998 self.cleanup_old_buffers();
999 }
1000
1001 buffer
1002 }
1003
1004 pub fn return_buffer(&mut self, buffer: wgpu::Buffer, size: u64, usage: wgpu::BufferUsages) {
1005 let key = (size, usage);
1006 self.pools.entry(key).or_default().push(buffer);
1007 }
1008
1009 pub fn get_stats(&self) -> BufferPoolStats {
1010 let total_ops = self.total_created + self.total_reused;
1011 let hit_rate = if total_ops > 0 {
1012 (self.total_reused as f32 / total_ops as f32) * 100.0
1013 } else {
1014 0.0
1015 };
1016
1017 let current_pooled_count = self.pools.values().map(|v| v.len()).sum();
1018 let total_pooled_memory_mb: f32 = self
1019 .pools
1020 .iter()
1021 .map(|((size, _usage), buffers)| {
1022 (*size as f32 * buffers.len() as f32) / 1024.0 / 1024.0
1023 })
1024 .sum();
1025
1026 BufferPoolStats {
1027 total_buffers_created: self.total_created,
1028 total_buffers_reused: self.total_reused,
1029 current_pooled_count,
1030 total_pooled_memory_mb,
1031 hit_rate_percent: hit_rate,
1032 }
1033 }
1034
1035 fn cleanup_old_buffers(&mut self) {
1036 let now = Instant::now();
1037 let cleanup_threshold = std::time::Duration::from_secs(300); self.pools.retain(|&key, buffers| {
1040 if let Some(stats) = self.stats.get(&key) {
1041 if let Some(last_used) = stats.last_used {
1042 if now.duration_since(last_used) > cleanup_threshold {
1043 buffers.clear();
1045 return false;
1046 }
1047 }
1048 }
1049 true
1050 });
1051
1052 self.last_cleanup = now;
1053 }
1054}
1055
1056pub struct GpuDispatcher {
1058 gpu_context: Option<GpuContext>,
1059 cpu_threshold: usize,
1060 gpu_threshold: usize,
1061}
1062
1063impl GpuDispatcher {
1064 pub async fn new() -> UnifiedGpuResult<Self> {
1066 let gpu_context = (GpuContext::new().await).ok(); Ok(Self {
1069 gpu_context,
1070 cpu_threshold: 100, gpu_threshold: 1000, })
1073 }
1074
1075 pub fn should_use_gpu(&self, workload_size: usize) -> bool {
1077 self.gpu_context.is_some()
1078 && workload_size >= self.cpu_threshold
1079 && workload_size >= self.gpu_threshold
1080 }
1081
1082 pub async fn execute<T, F, G>(&mut self, workload_size: usize, gpu_op: G, cpu_op: F) -> T
1084 where
1085 F: FnOnce() -> T,
1086 G: FnOnce(&mut GpuContext) -> UnifiedGpuResult<T>,
1087 {
1088 if self.should_use_gpu(workload_size) {
1089 if let Some(ref mut ctx) = self.gpu_context {
1090 if let Ok(result) = gpu_op(ctx) {
1091 return result;
1092 }
1093 }
1094 }
1095
1096 cpu_op()
1098 }
1099}
1100
1101#[derive(Debug, Clone)]
1103pub struct MultiGpuStats {
1104 pub device_count: usize,
1105 pub total_operations: usize,
1106 pub total_errors: usize,
1107 pub avg_utilization_percent: f32,
1108 pub uptime: std::time::Duration,
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113 use super::*;
1114
1115 #[tokio::test]
1116 #[ignore = "GPU hardware required, may fail in CI/CD environments"]
1117 async fn test_gpu_context_creation() {
1118 let _result = GpuContext::new().await;
1120 }
1122
1123 #[tokio::test]
1124 #[ignore = "GPU hardware required, may fail in CI/CD environments"]
1125 async fn test_gpu_dispatcher() {
1126 let dispatcher = GpuDispatcher::new().await;
1127 assert!(dispatcher.is_ok());
1128 }
1129
1130 #[test]
1131 fn test_gpu_operation_params() {
1132 let mut params = GpuOperationParams::default();
1133 params
1134 .params
1135 .insert("scale".to_string(), GpuParam::Float(2.0));
1136 params.batch_size = 100;
1137
1138 assert_eq!(params.batch_size, 100);
1139 match params.params.get("scale") {
1140 Some(GpuParam::Float(val)) => assert_eq!(*val, 2.0),
1141 _ => panic!("Expected float parameter"),
1142 }
1143 }
1144}