amari_gpu/
unified.rs

1//! Unified GPU acceleration infrastructure for all mathematical domains
2//!
3//! This module provides a common interface and infrastructure for GPU acceleration
4//! across tropical algebra, automatic differentiation, fusion systems, and other
5//! mathematical domains in the Amari library.
6
7use crate::{
8    multi_gpu::{
9        DeviceId, GpuDevice, IntelligentLoadBalancer, LoadBalancingStrategy, Workload,
10        WorkloadCoordinator,
11    },
12    GpuError,
13};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::Instant;
17use thiserror::Error;
18use tokio::sync::RwLock;
19use wgpu::util::DeviceExt;
20
21#[derive(Error, Debug)]
22pub enum UnifiedGpuError {
23    #[error("GPU error: {0}")]
24    Gpu(#[from] GpuError),
25
26    #[error("Shader compilation failed: {0}")]
27    ShaderCompilation(String),
28
29    #[error("Buffer size mismatch: expected {expected}, got {actual}")]
30    BufferSizeMismatch { expected: usize, actual: usize },
31
32    #[error("Invalid operation: {0}")]
33    InvalidOperation(String),
34
35    #[error("Memory allocation failed: {0}")]
36    MemoryAllocation(String),
37}
38
39pub type UnifiedGpuResult<T> = Result<T, UnifiedGpuError>;
40
41/// Universal trait for GPU-accelerated mathematical operations
42pub trait GpuAccelerated<T> {
43    /// Convert data to GPU buffer format
44    fn to_gpu_buffer(&self, context: &GpuContext) -> UnifiedGpuResult<wgpu::Buffer>;
45
46    /// Reconstruct data from GPU buffer
47    fn from_gpu_buffer(buffer: &wgpu::Buffer, context: &GpuContext) -> UnifiedGpuResult<T>;
48
49    /// Execute GPU operation
50    fn gpu_operation(
51        &self,
52        operation: &str,
53        context: &GpuContext,
54        params: &GpuOperationParams,
55    ) -> UnifiedGpuResult<T>;
56}
57
58/// GPU operation parameters for flexible operation dispatch
59#[derive(Debug, Clone)]
60pub struct GpuOperationParams {
61    /// Operation-specific parameters
62    pub params: HashMap<String, GpuParam>,
63    /// Batch size for operations
64    pub batch_size: usize,
65    /// Workgroup size for compute shaders
66    pub workgroup_size: (u32, u32, u32),
67}
68
69/// Parameter types for GPU operations
70#[derive(Debug, Clone)]
71pub enum GpuParam {
72    Float(f32),
73    Double(f64),
74    Integer(i32),
75    UnsignedInteger(u32),
76    Buffer(String), // Buffer identifier
77    Array(Vec<f32>),
78}
79
80impl Default for GpuOperationParams {
81    fn default() -> Self {
82        Self {
83            params: HashMap::new(),
84            batch_size: 1,
85            workgroup_size: (1, 1, 1),
86        }
87    }
88}
89
90/// Unified GPU context managing device, queue, and shader cache
91pub struct GpuContext {
92    pub device: wgpu::Device,
93    pub queue: wgpu::Queue,
94    shader_cache: HashMap<String, wgpu::ComputePipeline>,
95    #[allow(dead_code)]
96    buffer_pool: GpuBufferPool,
97}
98
99impl GpuContext {
100    /// Initialize GPU context with WebGPU
101    pub async fn new() -> UnifiedGpuResult<Self> {
102        let instance = wgpu::Instance::default();
103
104        let adapter = instance
105            .request_adapter(&wgpu::RequestAdapterOptions {
106                power_preference: wgpu::PowerPreference::HighPerformance,
107                compatible_surface: None,
108                force_fallback_adapter: false,
109            })
110            .await
111            .ok_or_else(|| {
112                UnifiedGpuError::Gpu(GpuError::InitializationError(
113                    "No GPU adapter found".to_string(),
114                ))
115            })?;
116
117        let (device, queue) = adapter
118            .request_device(
119                &wgpu::DeviceDescriptor {
120                    label: Some("Amari Unified GPU Device"),
121                    required_features: wgpu::Features::empty(),
122                    required_limits: wgpu::Limits::default(),
123                },
124                None,
125            )
126            .await
127            .map_err(|e| UnifiedGpuError::Gpu(GpuError::InitializationError(e.to_string())))?;
128
129        Ok(Self {
130            device,
131            queue,
132            shader_cache: HashMap::new(),
133            buffer_pool: GpuBufferPool::new(),
134        })
135    }
136
137    /// Get or compile compute shader
138    pub fn get_compute_pipeline(
139        &mut self,
140        shader_key: &str,
141        shader_source: &str,
142        bind_group_layout: &wgpu::BindGroupLayout,
143    ) -> UnifiedGpuResult<&wgpu::ComputePipeline> {
144        if !self.shader_cache.contains_key(shader_key) {
145            let shader_module = self
146                .device
147                .create_shader_module(wgpu::ShaderModuleDescriptor {
148                    label: Some(&format!("{} Shader", shader_key)),
149                    source: wgpu::ShaderSource::Wgsl(shader_source.into()),
150                });
151
152            let pipeline_layout =
153                self.device
154                    .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
155                        label: Some(&format!("{} Pipeline Layout", shader_key)),
156                        bind_group_layouts: &[bind_group_layout],
157                        push_constant_ranges: &[],
158                    });
159
160            let compute_pipeline =
161                self.device
162                    .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
163                        label: Some(&format!("{} Pipeline", shader_key)),
164                        layout: Some(&pipeline_layout),
165                        module: &shader_module,
166                        entry_point: "main",
167                    });
168
169            self.shader_cache
170                .insert(shader_key.to_string(), compute_pipeline);
171        }
172
173        Ok(self
174            .shader_cache
175            .get(shader_key)
176            .expect("Pipeline should exist"))
177    }
178
179    /// Create buffer with data
180    pub fn create_buffer_with_data<T: bytemuck::Pod>(
181        &self,
182        label: &str,
183        data: &[T],
184        usage: wgpu::BufferUsages,
185    ) -> wgpu::Buffer {
186        self.device
187            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
188                label: Some(label),
189                contents: bytemuck::cast_slice(data),
190                usage,
191            })
192    }
193
194    /// Create empty buffer
195    pub fn create_buffer(&self, label: &str, size: u64, usage: wgpu::BufferUsages) -> wgpu::Buffer {
196        self.device.create_buffer(&wgpu::BufferDescriptor {
197            label: Some(label),
198            size,
199            usage,
200            mapped_at_creation: false,
201        })
202    }
203
204    /// Execute compute shader
205    pub fn execute_compute(
206        &self,
207        pipeline: &wgpu::ComputePipeline,
208        bind_group: &wgpu::BindGroup,
209        workgroup_count: (u32, u32, u32),
210    ) {
211        let mut encoder = self
212            .device
213            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
214                label: Some("Compute Encoder"),
215            });
216
217        {
218            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
219                label: Some("Compute Pass"),
220                timestamp_writes: None,
221            });
222            compute_pass.set_pipeline(pipeline);
223            compute_pass.set_bind_group(0, bind_group, &[]);
224            compute_pass.dispatch_workgroups(
225                workgroup_count.0,
226                workgroup_count.1,
227                workgroup_count.2,
228            );
229        }
230
231        self.queue.submit([encoder.finish()]);
232    }
233
234    /// Read buffer data back to CPU
235    pub async fn read_buffer<T: bytemuck::Pod + Clone>(
236        &self,
237        buffer: &wgpu::Buffer,
238        size: u64,
239    ) -> UnifiedGpuResult<Vec<T>> {
240        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
241            label: Some("Staging Buffer"),
242            size,
243            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
244            mapped_at_creation: false,
245        });
246
247        let mut encoder = self
248            .device
249            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
250                label: Some("Copy Encoder"),
251            });
252
253        encoder.copy_buffer_to_buffer(buffer, 0, &staging_buffer, 0, size);
254        self.queue.submit([encoder.finish()]);
255
256        let buffer_slice = staging_buffer.slice(..);
257        let (tx, rx) = futures::channel::oneshot::channel();
258        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
259            tx.send(result).ok();
260        });
261
262        self.device.poll(wgpu::Maintain::Wait);
263
264        rx.await
265            .map_err(|_| UnifiedGpuError::InvalidOperation("Buffer read timeout".to_string()))?
266            .map_err(|e| UnifiedGpuError::InvalidOperation(format!("Buffer map failed: {}", e)))?;
267
268        let data = buffer_slice.get_mapped_range();
269        let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
270        drop(data);
271        staging_buffer.unmap();
272
273        Ok(result)
274    }
275}
276
277/// GPU buffer pool for efficient memory management
278pub struct GpuBufferPool {
279    _pools: HashMap<String, Vec<wgpu::Buffer>>, // Future: implement buffer pooling
280}
281
282impl GpuBufferPool {
283    pub fn new() -> Self {
284        Self {
285            _pools: HashMap::new(),
286        }
287    }
288
289    // Future: Add buffer pooling methods
290    // pub fn get_buffer(&mut self, size: u64, usage: wgpu::BufferUsages) -> wgpu::Buffer
291    // pub fn return_buffer(&mut self, buffer: wgpu::Buffer)
292}
293
294impl Default for GpuBufferPool {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300/// Shared GPU context for efficient resource management across all crates
301/// Now supports both single-GPU (legacy) and multi-GPU operations
302#[derive(Clone)]
303pub struct SharedGpuContext {
304    // Legacy single-GPU support (primary device)
305    device: Arc<wgpu::Device>,
306    queue: Arc<wgpu::Queue>,
307    adapter_info: wgpu::AdapterInfo,
308    buffer_pool: Arc<std::sync::Mutex<EnhancedGpuBufferPool>>,
309    shader_cache: Arc<std::sync::Mutex<HashMap<String, Arc<wgpu::ComputePipeline>>>>,
310    creation_time: Instant,
311
312    // Multi-GPU support (v0.9.6+)
313    multi_gpu_enabled: bool,
314    gpu_devices: Arc<RwLock<HashMap<DeviceId, Arc<GpuDevice>>>>,
315    load_balancer: Arc<IntelligentLoadBalancer>,
316    workload_coordinator: Arc<WorkloadCoordinator>,
317    primary_device_id: DeviceId,
318}
319
320impl SharedGpuContext {
321    /// Get the global shared GPU context (singleton pattern)
322    /// Note: This creates a new context each time for now. In production,
323    /// this would be a proper singleton with atomic initialization.
324    pub async fn global() -> UnifiedGpuResult<&'static Self> {
325        let context = Self::new().await?;
326        // Leak the context to make it 'static - in production, this would be managed properly
327        Ok(Box::leak(Box::new(context)))
328    }
329
330    /// Create a new shared GPU context (single GPU mode for backward compatibility)
331    async fn new() -> UnifiedGpuResult<Self> {
332        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
333            backends: wgpu::Backends::all(),
334            flags: wgpu::InstanceFlags::default(),
335            dx12_shader_compiler: wgpu::Dx12Compiler::default(),
336            gles_minor_version: wgpu::Gles3MinorVersion::Automatic,
337        });
338
339        let adapter = instance
340            .request_adapter(&wgpu::RequestAdapterOptions {
341                power_preference: wgpu::PowerPreference::HighPerformance,
342                compatible_surface: None,
343                force_fallback_adapter: false,
344            })
345            .await
346            .ok_or_else(|| {
347                UnifiedGpuError::InvalidOperation("No suitable GPU adapter found".into())
348            })?;
349
350        let adapter_info = adapter.get_info();
351
352        let (device, queue) = adapter
353            .request_device(
354                &wgpu::DeviceDescriptor {
355                    label: Some("Shared Amari GPU Device"),
356                    required_features: wgpu::Features::TIMESTAMP_QUERY,
357                    required_limits: wgpu::Limits::default(),
358                },
359                None,
360            )
361            .await
362            .map_err(|e| {
363                UnifiedGpuError::InvalidOperation(format!("Device request failed: {:?}", e))
364            })?;
365
366        let primary_device_id = DeviceId(0);
367
368        // Create single-GPU device for multi-GPU compatibility
369        let gpu_device = Arc::new(
370            GpuDevice::new(primary_device_id, &adapter, device, queue)
371                .await
372                .map_err(|_| {
373                    UnifiedGpuError::InvalidOperation("Failed to create GPU device".into())
374                })?,
375        );
376
377        let device_arc = Arc::clone(&gpu_device.device);
378        let queue_arc = Arc::clone(&gpu_device.queue);
379
380        let mut gpu_devices = HashMap::new();
381        gpu_devices.insert(primary_device_id, gpu_device);
382
383        Ok(Self {
384            device: device_arc,
385            queue: queue_arc,
386            adapter_info,
387            buffer_pool: Arc::new(std::sync::Mutex::new(EnhancedGpuBufferPool::new())),
388            shader_cache: Arc::new(std::sync::Mutex::new(HashMap::new())),
389            creation_time: Instant::now(),
390
391            // Multi-GPU fields (initially single-GPU mode)
392            multi_gpu_enabled: false,
393            gpu_devices: Arc::new(RwLock::new(gpu_devices)),
394            load_balancer: Arc::new(IntelligentLoadBalancer::new(
395                LoadBalancingStrategy::Balanced,
396            )),
397            workload_coordinator: Arc::new(WorkloadCoordinator::new()),
398            primary_device_id,
399        })
400    }
401
402    /// Create a new shared GPU context with multi-GPU support
403    pub async fn with_multi_gpu() -> UnifiedGpuResult<Self> {
404        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
405            backends: wgpu::Backends::all(),
406            flags: wgpu::InstanceFlags::default(),
407            dx12_shader_compiler: wgpu::Dx12Compiler::default(),
408            gles_minor_version: wgpu::Gles3MinorVersion::Automatic,
409        });
410
411        // Enumerate all available adapters
412        let adapters: Vec<_> = instance.enumerate_adapters(wgpu::Backends::all());
413
414        if adapters.is_empty() {
415            return Err(UnifiedGpuError::InvalidOperation(
416                "No GPU adapters found".into(),
417            ));
418        }
419
420        // Initialize devices from adapters
421        let mut gpu_devices = HashMap::new();
422        let mut primary_device = None;
423        let mut primary_queue = None;
424        let mut primary_adapter_info = None;
425
426        for (i, adapter) in adapters.iter().enumerate() {
427            let device_id = DeviceId(i);
428
429            // Try to create device
430            if let Ok((device, queue)) = adapter
431                .request_device(
432                    &wgpu::DeviceDescriptor {
433                        label: Some(&format!("Amari Multi-GPU Device {}", i)),
434                        required_features: wgpu::Features::TIMESTAMP_QUERY,
435                        required_limits: wgpu::Limits::default(),
436                    },
437                    None,
438                )
439                .await
440            {
441                // Create GPU device wrapper
442                if let Ok(gpu_device) = GpuDevice::new(device_id, adapter, device, queue).await {
443                    // Set primary device (first successful device)
444                    if primary_device.is_none() {
445                        primary_device = Some(Arc::clone(&gpu_device.device));
446                        primary_queue = Some(Arc::clone(&gpu_device.queue));
447                        primary_adapter_info = Some(adapter.get_info());
448                    }
449
450                    gpu_devices.insert(device_id, Arc::new(gpu_device));
451                }
452            }
453        }
454
455        if gpu_devices.is_empty() {
456            return Err(UnifiedGpuError::InvalidOperation(
457                "No usable GPU devices found".into(),
458            ));
459        }
460
461        let primary_device_id = DeviceId(0);
462        let load_balancer = Arc::new(IntelligentLoadBalancer::new(
463            LoadBalancingStrategy::CapabilityAware,
464        ));
465
466        // Add all devices to load balancer
467        for device in gpu_devices.values() {
468            load_balancer.add_device(Arc::clone(device)).await;
469        }
470
471        Ok(Self {
472            device: primary_device.unwrap(),
473            queue: primary_queue.unwrap(),
474            adapter_info: primary_adapter_info.unwrap(),
475            buffer_pool: Arc::new(std::sync::Mutex::new(EnhancedGpuBufferPool::new())),
476            shader_cache: Arc::new(std::sync::Mutex::new(HashMap::new())),
477            creation_time: Instant::now(),
478
479            // Multi-GPU configuration
480            multi_gpu_enabled: true,
481            gpu_devices: Arc::new(RwLock::new(gpu_devices)),
482            load_balancer,
483            workload_coordinator: Arc::new(WorkloadCoordinator::new()),
484            primary_device_id,
485        })
486    }
487
488    /// Get the device
489    pub fn device(&self) -> &wgpu::Device {
490        &self.device
491    }
492
493    /// Get the queue
494    pub fn queue(&self) -> &wgpu::Queue {
495        &self.queue
496    }
497
498    /// Get adapter info
499    pub fn adapter_info(&self) -> &wgpu::AdapterInfo {
500        &self.adapter_info
501    }
502
503    /// Get or create a buffer from the pool
504    pub fn get_buffer(
505        &self,
506        size: u64,
507        usage: wgpu::BufferUsages,
508        label: Option<&str>,
509    ) -> wgpu::Buffer {
510        if let Ok(mut pool) = self.buffer_pool.lock() {
511            pool.get_or_create(&self.device, size, usage, label)
512        } else {
513            // Fallback if mutex is poisoned
514            self.device.create_buffer(&wgpu::BufferDescriptor {
515                label,
516                size,
517                usage,
518                mapped_at_creation: false,
519            })
520        }
521    }
522
523    /// Return a buffer to the pool for reuse
524    pub fn return_buffer(&self, buffer: wgpu::Buffer, size: u64, usage: wgpu::BufferUsages) {
525        if let Ok(mut pool) = self.buffer_pool.lock() {
526            pool.return_buffer(buffer, size, usage);
527        }
528        // If mutex is poisoned, just drop the buffer
529    }
530
531    /// Get or create a compute pipeline from cache
532    pub fn get_compute_pipeline(
533        &self,
534        shader_key: &str,
535        shader_source: &str,
536        entry_point: &str,
537    ) -> UnifiedGpuResult<Arc<wgpu::ComputePipeline>> {
538        let cache_key = format!("{}:{}", shader_key, entry_point);
539
540        if let Ok(mut cache) = self.shader_cache.lock() {
541            if let Some(pipeline) = cache.get(&cache_key) {
542                return Ok(Arc::clone(pipeline));
543            }
544
545            // Create new pipeline
546            let shader_module = self
547                .device
548                .create_shader_module(wgpu::ShaderModuleDescriptor {
549                    label: Some(&format!("{} Shader", shader_key)),
550                    source: wgpu::ShaderSource::Wgsl(shader_source.into()),
551                });
552
553            let bind_group_layout =
554                self.device
555                    .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
556                        label: Some(&format!("{} Bind Group Layout", shader_key)),
557                        entries: &[
558                            wgpu::BindGroupLayoutEntry {
559                                binding: 0,
560                                visibility: wgpu::ShaderStages::COMPUTE,
561                                ty: wgpu::BindingType::Buffer {
562                                    ty: wgpu::BufferBindingType::Storage { read_only: true },
563                                    has_dynamic_offset: false,
564                                    min_binding_size: None,
565                                },
566                                count: None,
567                            },
568                            wgpu::BindGroupLayoutEntry {
569                                binding: 1,
570                                visibility: wgpu::ShaderStages::COMPUTE,
571                                ty: wgpu::BindingType::Buffer {
572                                    ty: wgpu::BufferBindingType::Storage { read_only: false },
573                                    has_dynamic_offset: false,
574                                    min_binding_size: None,
575                                },
576                                count: None,
577                            },
578                        ],
579                    });
580
581            let pipeline_layout =
582                self.device
583                    .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
584                        label: Some(&format!("{} Pipeline Layout", shader_key)),
585                        bind_group_layouts: &[&bind_group_layout],
586                        push_constant_ranges: &[],
587                    });
588
589            let pipeline = self
590                .device
591                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
592                    label: Some(&format!("{} Pipeline", shader_key)),
593                    layout: Some(&pipeline_layout),
594                    module: &shader_module,
595                    entry_point,
596                });
597
598            let pipeline_arc = Arc::new(pipeline);
599            cache.insert(cache_key, Arc::clone(&pipeline_arc));
600            Ok(pipeline_arc)
601        } else {
602            Err(UnifiedGpuError::InvalidOperation(
603                "Failed to access shader cache".into(),
604            ))
605        }
606    }
607
608    /// Get buffer pool statistics
609    pub fn buffer_pool_stats(&self) -> BufferPoolStats {
610        if let Ok(pool) = self.buffer_pool.lock() {
611            pool.get_stats()
612        } else {
613            BufferPoolStats::default()
614        }
615    }
616
617    /// Get uptime of this context
618    pub fn uptime(&self) -> std::time::Duration {
619        self.creation_time.elapsed()
620    }
621
622    /// Get optimal workgroup configuration for given operation type and data size
623    pub fn get_optimal_workgroup(&self, operation: &str, data_size: usize) -> (u32, u32, u32) {
624        match operation {
625            "matrix_multiply" | "matrix_operation" => {
626                // 2D workgroups optimized for matrix operations
627                // Use larger workgroups for better occupancy
628                (16, 16, 1)
629            }
630            "vector_operation" | "reduce" | "scan" => {
631                // 1D operations - prefer large workgroups for coalesced memory access
632                let workgroup_size = if data_size > 10000 {
633                    256 // Large batches benefit from maximum occupancy
634                } else if data_size > 1000 {
635                    128 // Medium batches
636                } else {
637                    64 // Small batches
638                };
639                (workgroup_size, 1, 1)
640            }
641            "geometric_algebra" | "clifford_algebra" => {
642                // GA operations with moderate computational complexity
643                (128, 1, 1)
644            }
645            "cellular_automata" | "ca_evolution" => {
646                // 2D grid operations, optimized for spatial locality
647                (16, 16, 1)
648            }
649            "neural_network" | "batch_processing" => {
650                // Large 1D workgroups for high-throughput batch processing
651                (256, 1, 1)
652            }
653            "information_geometry" | "fisher_information" | "bregman_divergence" => {
654                // Statistical manifold computations - large workgroups
655                (256, 1, 1)
656            }
657            "tropical_algebra" | "tropical_matrix" => {
658                // Tropical operations, moderate workgroup size
659                (128, 1, 1)
660            }
661            "dual_number" | "automatic_differentiation" => {
662                // AD operations, balanced workgroup size
663                (128, 1, 1)
664            }
665            "fusion_system" | "llm_evaluation" => {
666                // Complex fusion operations, large workgroups
667                (256, 1, 1)
668            }
669            "enumerative_geometry" | "intersection_theory" => {
670                // Geometric computations, moderate workgroups
671                (64, 1, 1)
672            }
673            _ => (64, 1, 1), // Conservative default for unknown operations
674        }
675    }
676
677    /// Generate optimized WGSL workgroup declaration for operation
678    pub fn get_workgroup_declaration(&self, operation: &str, data_size: usize) -> String {
679        let (x, y, z) = self.get_optimal_workgroup(operation, data_size);
680
681        if y == 1 && z == 1 {
682            format!("@compute @workgroup_size({})", x)
683        } else if z == 1 {
684            format!("@compute @workgroup_size({}, {})", x, y)
685        } else {
686            format!("@compute @workgroup_size({}, {}, {})", x, y, z)
687        }
688    }
689
690    // === Multi-GPU Methods (v0.9.6+) ===
691
692    /// Check if multi-GPU mode is enabled
693    pub fn is_multi_gpu_enabled(&self) -> bool {
694        self.multi_gpu_enabled
695    }
696
697    /// Get the number of available GPU devices
698    pub async fn device_count(&self) -> usize {
699        self.gpu_devices.read().await.len()
700    }
701
702    /// Get information about all GPU devices
703    pub async fn get_device_info(&self) -> Vec<(DeviceId, String, String)> {
704        let devices = self.gpu_devices.read().await;
705        devices
706            .iter()
707            .map(|(id, device)| {
708                (
709                    *id,
710                    device.adapter_info.name.clone(),
711                    format!("{:?}", device.capabilities.architecture),
712                )
713            })
714            .collect()
715    }
716
717    /// Get a specific GPU device by ID
718    pub async fn get_device(&self, device_id: DeviceId) -> Option<Arc<GpuDevice>> {
719        let devices = self.gpu_devices.read().await;
720        devices.get(&device_id).cloned()
721    }
722
723    /// Get the optimal device for a specific operation
724    pub async fn optimal_device_for_operation(
725        &self,
726        operation: &str,
727        _data_size: usize,
728    ) -> DeviceId {
729        if !self.multi_gpu_enabled {
730            return self.primary_device_id;
731        }
732
733        let devices = self.gpu_devices.read().await;
734        let available_devices: Vec<_> = devices
735            .values()
736            .filter(|device| device.is_available())
737            .collect();
738
739        if available_devices.is_empty() {
740            return self.primary_device_id;
741        }
742
743        // Find device with best performance score for this operation
744        available_devices
745            .iter()
746            .max_by(|a, b| {
747                a.performance_score(operation)
748                    .partial_cmp(&b.performance_score(operation))
749                    .unwrap_or(std::cmp::Ordering::Equal)
750            })
751            .map(|device| device.id)
752            .unwrap_or(self.primary_device_id)
753    }
754
755    /// Distribute a workload across multiple GPUs
756    pub async fn distribute_workload(
757        &self,
758        workload: Workload,
759    ) -> UnifiedGpuResult<Vec<crate::multi_gpu::DeviceWorkload>> {
760        if !self.multi_gpu_enabled {
761            // Single GPU fallback
762            return Ok(vec![crate::multi_gpu::DeviceWorkload {
763                device_id: self.primary_device_id,
764                workload_fraction: 1.0,
765                data_range: (0, workload.data_size),
766                estimated_completion_ms: 100.0,
767                memory_requirement_mb: workload.memory_requirement_mb,
768            }]);
769        }
770
771        self.load_balancer
772            .distribute_workload(&workload)
773            .await
774            .map_err(|e| {
775                UnifiedGpuError::InvalidOperation(format!("Workload distribution failed: {:?}", e))
776            })
777    }
778
779    /// Execute a workload on multiple GPUs and aggregate results
780    pub async fn execute_multi_gpu_workload(
781        &self,
782        workload_id: String,
783        workload: Workload,
784    ) -> UnifiedGpuResult<Vec<Vec<u8>>> {
785        if !self.multi_gpu_enabled {
786            return Err(UnifiedGpuError::InvalidOperation(
787                "Multi-GPU mode not enabled".into(),
788            ));
789        }
790
791        // Distribute workload
792        let assignments = self.distribute_workload(workload).await?;
793
794        // Submit to coordinator
795        self.workload_coordinator
796            .submit_workload(workload_id.clone(), assignments)
797            .await
798            .map_err(|e| {
799                UnifiedGpuError::InvalidOperation(format!("Workload submission failed: {:?}", e))
800            })?;
801
802        // Wait for completion (with timeout)
803        let timeout = std::time::Duration::from_secs(30);
804        self.workload_coordinator
805            .wait_for_completion(&workload_id, timeout)
806            .await
807            .map_err(|e| {
808                UnifiedGpuError::InvalidOperation(format!("Workload execution failed: {:?}", e))
809            })
810    }
811
812    /// Get real-time GPU utilization across all devices
813    pub async fn get_gpu_utilization(&self) -> HashMap<DeviceId, f32> {
814        let devices: tokio::sync::RwLockReadGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
815            self.gpu_devices.read().await;
816        devices
817            .iter()
818            .map(|(id, device): (&DeviceId, &Arc<GpuDevice>)| (*id, device.current_load()))
819            .collect()
820    }
821
822    /// Get performance statistics for multi-GPU operations
823    pub async fn get_multi_gpu_stats(&self) -> MultiGpuStats {
824        let devices: tokio::sync::RwLockReadGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
825            self.gpu_devices.read().await;
826        let device_count = devices.len();
827
828        let total_operations: usize = devices
829            .values()
830            .map(|device| {
831                device
832                    .total_operations
833                    .load(std::sync::atomic::Ordering::Relaxed)
834            })
835            .sum();
836
837        let total_errors: usize = devices
838            .values()
839            .map(|device| {
840                device
841                    .error_count
842                    .load(std::sync::atomic::Ordering::Relaxed)
843            })
844            .sum();
845
846        let avg_utilization = if !devices.is_empty() {
847            devices
848                .values()
849                .map(|device: &Arc<GpuDevice>| device.current_load())
850                .sum::<f32>()
851                / devices.len() as f32
852        } else {
853            0.0
854        };
855
856        MultiGpuStats {
857            device_count,
858            total_operations,
859            total_errors,
860            avg_utilization_percent: avg_utilization,
861            uptime: self.creation_time.elapsed(),
862        }
863    }
864
865    /// Set load balancing strategy for multi-GPU operations
866    pub async fn set_load_balancing_strategy(
867        &self,
868        _strategy: LoadBalancingStrategy,
869    ) -> UnifiedGpuResult<()> {
870        if !self.multi_gpu_enabled {
871            return Err(UnifiedGpuError::InvalidOperation(
872                "Multi-GPU mode not enabled".into(),
873            ));
874        }
875
876        // Note: In a full implementation, this would update the load balancer's strategy
877        // For now, this is a placeholder that could be extended
878        Ok(())
879    }
880
881    /// Add a new GPU device to the multi-GPU context (hot-plugging support)
882    pub async fn add_gpu_device(&self, device: Arc<GpuDevice>) -> UnifiedGpuResult<()> {
883        if !self.multi_gpu_enabled {
884            return Err(UnifiedGpuError::InvalidOperation(
885                "Multi-GPU mode not enabled".into(),
886            ));
887        }
888
889        let mut devices: tokio::sync::RwLockWriteGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
890            self.gpu_devices.write().await;
891        devices.insert(device.id, Arc::clone(&device));
892
893        // Add to load balancer
894        self.load_balancer.add_device(device).await;
895
896        Ok(())
897    }
898
899    /// Remove a GPU device from the multi-GPU context
900    pub async fn remove_gpu_device(&self, device_id: DeviceId) -> UnifiedGpuResult<()> {
901        if !self.multi_gpu_enabled {
902            return Err(UnifiedGpuError::InvalidOperation(
903                "Multi-GPU mode not enabled".into(),
904            ));
905        }
906
907        let mut devices: tokio::sync::RwLockWriteGuard<HashMap<DeviceId, Arc<GpuDevice>>> =
908            self.gpu_devices.write().await;
909        devices.remove(&device_id);
910
911        // Remove from load balancer
912        self.load_balancer.remove_device(device_id).await;
913
914        Ok(())
915    }
916}
917
918/// Enhanced buffer pool with statistics and eviction policies
919pub struct EnhancedGpuBufferPool {
920    pools: HashMap<(u64, wgpu::BufferUsages), Vec<wgpu::Buffer>>,
921    stats: HashMap<(u64, wgpu::BufferUsages), PoolEntryStats>,
922    total_created: u64,
923    total_reused: u64,
924    last_cleanup: Instant,
925}
926
927#[derive(Debug, Clone, Default)]
928pub struct PoolEntryStats {
929    pub created_count: u64,
930    pub reused_count: u64,
931    pub last_used: Option<Instant>,
932    pub total_size_bytes: u64,
933}
934
935#[derive(Debug, Clone, Default)]
936pub struct BufferPoolStats {
937    pub total_buffers_created: u64,
938    pub total_buffers_reused: u64,
939    pub current_pooled_count: usize,
940    pub total_pooled_memory_mb: f32,
941    pub hit_rate_percent: f32,
942}
943
944impl EnhancedGpuBufferPool {
945    pub fn new() -> Self {
946        Self {
947            pools: HashMap::new(),
948            stats: HashMap::new(),
949            total_created: 0,
950            total_reused: 0,
951            last_cleanup: Instant::now(),
952        }
953    }
954}
955
956impl Default for EnhancedGpuBufferPool {
957    fn default() -> Self {
958        Self::new()
959    }
960}
961
962impl EnhancedGpuBufferPool {
963    pub fn get_or_create(
964        &mut self,
965        device: &wgpu::Device,
966        size: u64,
967        usage: wgpu::BufferUsages,
968        label: Option<&str>,
969    ) -> wgpu::Buffer {
970        let key = (size, usage);
971
972        // Try to reuse from pool
973        if let Some(buffers) = self.pools.get_mut(&key) {
974            if let Some(buffer) = buffers.pop() {
975                self.total_reused += 1;
976                self.stats.entry(key).or_default().reused_count += 1;
977                self.stats.get_mut(&key).unwrap().last_used = Some(Instant::now());
978                return buffer;
979            }
980        }
981
982        // Create new buffer
983        let buffer = device.create_buffer(&wgpu::BufferDescriptor {
984            label,
985            size,
986            usage,
987            mapped_at_creation: false,
988        });
989
990        self.total_created += 1;
991        let stats = self.stats.entry(key).or_default();
992        stats.created_count += 1;
993        stats.total_size_bytes += size;
994        stats.last_used = Some(Instant::now());
995
996        // Periodic cleanup
997        if self.last_cleanup.elapsed().as_secs() > 30 {
998            self.cleanup_old_buffers();
999        }
1000
1001        buffer
1002    }
1003
1004    pub fn return_buffer(&mut self, buffer: wgpu::Buffer, size: u64, usage: wgpu::BufferUsages) {
1005        let key = (size, usage);
1006        self.pools.entry(key).or_default().push(buffer);
1007    }
1008
1009    pub fn get_stats(&self) -> BufferPoolStats {
1010        let total_ops = self.total_created + self.total_reused;
1011        let hit_rate = if total_ops > 0 {
1012            (self.total_reused as f32 / total_ops as f32) * 100.0
1013        } else {
1014            0.0
1015        };
1016
1017        let current_pooled_count = self.pools.values().map(|v| v.len()).sum();
1018        let total_pooled_memory_mb: f32 = self
1019            .pools
1020            .iter()
1021            .map(|((size, _usage), buffers)| {
1022                (*size as f32 * buffers.len() as f32) / 1024.0 / 1024.0
1023            })
1024            .sum();
1025
1026        BufferPoolStats {
1027            total_buffers_created: self.total_created,
1028            total_buffers_reused: self.total_reused,
1029            current_pooled_count,
1030            total_pooled_memory_mb,
1031            hit_rate_percent: hit_rate,
1032        }
1033    }
1034
1035    fn cleanup_old_buffers(&mut self) {
1036        let now = Instant::now();
1037        let cleanup_threshold = std::time::Duration::from_secs(300); // 5 minutes
1038
1039        self.pools.retain(|&key, buffers| {
1040            if let Some(stats) = self.stats.get(&key) {
1041                if let Some(last_used) = stats.last_used {
1042                    if now.duration_since(last_used) > cleanup_threshold {
1043                        // Remove old unused buffers
1044                        buffers.clear();
1045                        return false;
1046                    }
1047                }
1048            }
1049            true
1050        });
1051
1052        self.last_cleanup = now;
1053    }
1054}
1055
1056/// Smart GPU/CPU dispatch based on workload characteristics
1057pub struct GpuDispatcher {
1058    gpu_context: Option<GpuContext>,
1059    cpu_threshold: usize,
1060    gpu_threshold: usize,
1061}
1062
1063impl GpuDispatcher {
1064    /// Create new dispatcher with GPU context
1065    pub async fn new() -> UnifiedGpuResult<Self> {
1066        let gpu_context = (GpuContext::new().await).ok(); // Graceful fallback to CPU-only
1067
1068        Ok(Self {
1069            gpu_context,
1070            cpu_threshold: 100,  // Use CPU for small workloads
1071            gpu_threshold: 1000, // Use GPU for large workloads
1072        })
1073    }
1074
1075    /// Determine optimal compute strategy
1076    pub fn should_use_gpu(&self, workload_size: usize) -> bool {
1077        self.gpu_context.is_some()
1078            && workload_size >= self.cpu_threshold
1079            && workload_size >= self.gpu_threshold
1080    }
1081
1082    /// Execute operation with optimal strategy
1083    pub async fn execute<T, F, G>(&mut self, workload_size: usize, gpu_op: G, cpu_op: F) -> T
1084    where
1085        F: FnOnce() -> T,
1086        G: FnOnce(&mut GpuContext) -> UnifiedGpuResult<T>,
1087    {
1088        if self.should_use_gpu(workload_size) {
1089            if let Some(ref mut ctx) = self.gpu_context {
1090                if let Ok(result) = gpu_op(ctx) {
1091                    return result;
1092                }
1093            }
1094        }
1095
1096        // Fallback to CPU
1097        cpu_op()
1098    }
1099}
1100
1101/// Multi-GPU statistics for monitoring and optimization
1102#[derive(Debug, Clone)]
1103pub struct MultiGpuStats {
1104    pub device_count: usize,
1105    pub total_operations: usize,
1106    pub total_errors: usize,
1107    pub avg_utilization_percent: f32,
1108    pub uptime: std::time::Duration,
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113    use super::*;
1114
1115    #[tokio::test]
1116    #[ignore = "GPU hardware required, may fail in CI/CD environments"]
1117    async fn test_gpu_context_creation() {
1118        // Test should pass even without GPU (graceful fallback)
1119        let _result = GpuContext::new().await;
1120        // Don't assert success since GPU might not be available in CI
1121    }
1122
1123    #[tokio::test]
1124    #[ignore = "GPU hardware required, may fail in CI/CD environments"]
1125    async fn test_gpu_dispatcher() {
1126        let dispatcher = GpuDispatcher::new().await;
1127        assert!(dispatcher.is_ok());
1128    }
1129
1130    #[test]
1131    fn test_gpu_operation_params() {
1132        let mut params = GpuOperationParams::default();
1133        params
1134            .params
1135            .insert("scale".to_string(), GpuParam::Float(2.0));
1136        params.batch_size = 100;
1137
1138        assert_eq!(params.batch_size, 100);
1139        match params.params.get("scale") {
1140            Some(GpuParam::Float(val)) => assert_eq!(*val, 2.0),
1141            _ => panic!("Expected float parameter"),
1142        }
1143    }
1144}