amari_gpu/
unified.rs

1//! Unified GPU acceleration infrastructure for all mathematical domains
2//!
3//! This module provides a common interface and infrastructure for GPU acceleration
4//! across tropical algebra, automatic differentiation, fusion systems, and other
5//! mathematical domains in the Amari library.
6
7use crate::GpuError;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11use thiserror::Error;
12use wgpu::util::DeviceExt;
13
14#[derive(Error, Debug)]
15pub enum UnifiedGpuError {
16    #[error("GPU error: {0}")]
17    Gpu(#[from] GpuError),
18
19    #[error("Shader compilation failed: {0}")]
20    ShaderCompilation(String),
21
22    #[error("Buffer size mismatch: expected {expected}, got {actual}")]
23    BufferSizeMismatch { expected: usize, actual: usize },
24
25    #[error("Invalid operation: {0}")]
26    InvalidOperation(String),
27
28    #[error("Memory allocation failed: {0}")]
29    MemoryAllocation(String),
30}
31
32pub type UnifiedGpuResult<T> = Result<T, UnifiedGpuError>;
33
34/// Universal trait for GPU-accelerated mathematical operations
35pub trait GpuAccelerated<T> {
36    /// Convert data to GPU buffer format
37    fn to_gpu_buffer(&self, context: &GpuContext) -> UnifiedGpuResult<wgpu::Buffer>;
38
39    /// Reconstruct data from GPU buffer
40    fn from_gpu_buffer(buffer: &wgpu::Buffer, context: &GpuContext) -> UnifiedGpuResult<T>;
41
42    /// Execute GPU operation
43    fn gpu_operation(
44        &self,
45        operation: &str,
46        context: &GpuContext,
47        params: &GpuOperationParams,
48    ) -> UnifiedGpuResult<T>;
49}
50
51/// GPU operation parameters for flexible operation dispatch
52#[derive(Debug, Clone)]
53pub struct GpuOperationParams {
54    /// Operation-specific parameters
55    pub params: HashMap<String, GpuParam>,
56    /// Batch size for operations
57    pub batch_size: usize,
58    /// Workgroup size for compute shaders
59    pub workgroup_size: (u32, u32, u32),
60}
61
62/// Parameter types for GPU operations
63#[derive(Debug, Clone)]
64pub enum GpuParam {
65    Float(f32),
66    Double(f64),
67    Integer(i32),
68    UnsignedInteger(u32),
69    Buffer(String), // Buffer identifier
70    Array(Vec<f32>),
71}
72
73impl Default for GpuOperationParams {
74    fn default() -> Self {
75        Self {
76            params: HashMap::new(),
77            batch_size: 1,
78            workgroup_size: (1, 1, 1),
79        }
80    }
81}
82
83/// Unified GPU context managing device, queue, and shader cache
84pub struct GpuContext {
85    pub device: wgpu::Device,
86    pub queue: wgpu::Queue,
87    shader_cache: HashMap<String, wgpu::ComputePipeline>,
88    #[allow(dead_code)]
89    buffer_pool: GpuBufferPool,
90}
91
92impl GpuContext {
93    /// Initialize GPU context with WebGPU
94    pub async fn new() -> UnifiedGpuResult<Self> {
95        let instance = wgpu::Instance::default();
96
97        let adapter = instance
98            .request_adapter(&wgpu::RequestAdapterOptions {
99                power_preference: wgpu::PowerPreference::HighPerformance,
100                compatible_surface: None,
101                force_fallback_adapter: false,
102            })
103            .await
104            .ok_or_else(|| {
105                UnifiedGpuError::Gpu(GpuError::InitializationError(
106                    "No GPU adapter found".to_string(),
107                ))
108            })?;
109
110        let (device, queue) = adapter
111            .request_device(
112                &wgpu::DeviceDescriptor {
113                    label: Some("Amari Unified GPU Device"),
114                    required_features: wgpu::Features::empty(),
115                    required_limits: wgpu::Limits::default(),
116                },
117                None,
118            )
119            .await
120            .map_err(|e| UnifiedGpuError::Gpu(GpuError::InitializationError(e.to_string())))?;
121
122        Ok(Self {
123            device,
124            queue,
125            shader_cache: HashMap::new(),
126            buffer_pool: GpuBufferPool::new(),
127        })
128    }
129
130    /// Get or compile compute shader
131    pub fn get_compute_pipeline(
132        &mut self,
133        shader_key: &str,
134        shader_source: &str,
135        bind_group_layout: &wgpu::BindGroupLayout,
136    ) -> UnifiedGpuResult<&wgpu::ComputePipeline> {
137        if !self.shader_cache.contains_key(shader_key) {
138            let shader_module = self
139                .device
140                .create_shader_module(wgpu::ShaderModuleDescriptor {
141                    label: Some(&format!("{} Shader", shader_key)),
142                    source: wgpu::ShaderSource::Wgsl(shader_source.into()),
143                });
144
145            let pipeline_layout =
146                self.device
147                    .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
148                        label: Some(&format!("{} Pipeline Layout", shader_key)),
149                        bind_group_layouts: &[bind_group_layout],
150                        push_constant_ranges: &[],
151                    });
152
153            let compute_pipeline =
154                self.device
155                    .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
156                        label: Some(&format!("{} Pipeline", shader_key)),
157                        layout: Some(&pipeline_layout),
158                        module: &shader_module,
159                        entry_point: "main",
160                    });
161
162            self.shader_cache
163                .insert(shader_key.to_string(), compute_pipeline);
164        }
165
166        Ok(self
167            .shader_cache
168            .get(shader_key)
169            .expect("Pipeline should exist"))
170    }
171
172    /// Create buffer with data
173    pub fn create_buffer_with_data<T: bytemuck::Pod>(
174        &self,
175        label: &str,
176        data: &[T],
177        usage: wgpu::BufferUsages,
178    ) -> wgpu::Buffer {
179        self.device
180            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
181                label: Some(label),
182                contents: bytemuck::cast_slice(data),
183                usage,
184            })
185    }
186
187    /// Create empty buffer
188    pub fn create_buffer(&self, label: &str, size: u64, usage: wgpu::BufferUsages) -> wgpu::Buffer {
189        self.device.create_buffer(&wgpu::BufferDescriptor {
190            label: Some(label),
191            size,
192            usage,
193            mapped_at_creation: false,
194        })
195    }
196
197    /// Execute compute shader
198    pub fn execute_compute(
199        &self,
200        pipeline: &wgpu::ComputePipeline,
201        bind_group: &wgpu::BindGroup,
202        workgroup_count: (u32, u32, u32),
203    ) {
204        let mut encoder = self
205            .device
206            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
207                label: Some("Compute Encoder"),
208            });
209
210        {
211            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
212                label: Some("Compute Pass"),
213                timestamp_writes: None,
214            });
215            compute_pass.set_pipeline(pipeline);
216            compute_pass.set_bind_group(0, bind_group, &[]);
217            compute_pass.dispatch_workgroups(
218                workgroup_count.0,
219                workgroup_count.1,
220                workgroup_count.2,
221            );
222        }
223
224        self.queue.submit([encoder.finish()]);
225    }
226
227    /// Read buffer data back to CPU
228    pub async fn read_buffer<T: bytemuck::Pod + Clone>(
229        &self,
230        buffer: &wgpu::Buffer,
231        size: u64,
232    ) -> UnifiedGpuResult<Vec<T>> {
233        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
234            label: Some("Staging Buffer"),
235            size,
236            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
237            mapped_at_creation: false,
238        });
239
240        let mut encoder = self
241            .device
242            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
243                label: Some("Copy Encoder"),
244            });
245
246        encoder.copy_buffer_to_buffer(buffer, 0, &staging_buffer, 0, size);
247        self.queue.submit([encoder.finish()]);
248
249        let buffer_slice = staging_buffer.slice(..);
250        let (tx, rx) = futures::channel::oneshot::channel();
251        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
252            tx.send(result).ok();
253        });
254
255        self.device.poll(wgpu::Maintain::Wait);
256
257        rx.await
258            .map_err(|_| UnifiedGpuError::InvalidOperation("Buffer read timeout".to_string()))?
259            .map_err(|e| UnifiedGpuError::InvalidOperation(format!("Buffer map failed: {}", e)))?;
260
261        let data = buffer_slice.get_mapped_range();
262        let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
263        drop(data);
264        staging_buffer.unmap();
265
266        Ok(result)
267    }
268}
269
270/// GPU buffer pool for efficient memory management
271pub struct GpuBufferPool {
272    _pools: HashMap<String, Vec<wgpu::Buffer>>, // Future: implement buffer pooling
273}
274
275impl GpuBufferPool {
276    pub fn new() -> Self {
277        Self {
278            _pools: HashMap::new(),
279        }
280    }
281
282    // Future: Add buffer pooling methods
283    // pub fn get_buffer(&mut self, size: u64, usage: wgpu::BufferUsages) -> wgpu::Buffer
284    // pub fn return_buffer(&mut self, buffer: wgpu::Buffer)
285}
286
287impl Default for GpuBufferPool {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293/// Shared GPU context for efficient resource management across all crates
294#[derive(Clone)]
295pub struct SharedGpuContext {
296    device: Arc<wgpu::Device>,
297    queue: Arc<wgpu::Queue>,
298    adapter_info: wgpu::AdapterInfo,
299    buffer_pool: Arc<std::sync::Mutex<EnhancedGpuBufferPool>>,
300    shader_cache: Arc<std::sync::Mutex<HashMap<String, Arc<wgpu::ComputePipeline>>>>,
301    creation_time: Instant,
302}
303
304impl SharedGpuContext {
305    /// Get the global shared GPU context (singleton pattern)
306    /// Note: This creates a new context each time for now. In production,
307    /// this would be a proper singleton with atomic initialization.
308    pub async fn global() -> UnifiedGpuResult<&'static Self> {
309        let context = Self::new().await?;
310        // Leak the context to make it 'static - in production, this would be managed properly
311        Ok(Box::leak(Box::new(context)))
312    }
313
314    /// Create a new shared GPU context
315    async fn new() -> UnifiedGpuResult<Self> {
316        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
317            backends: wgpu::Backends::all(),
318            flags: wgpu::InstanceFlags::default(),
319            dx12_shader_compiler: wgpu::Dx12Compiler::default(),
320            gles_minor_version: wgpu::Gles3MinorVersion::Automatic,
321        });
322
323        let adapter = instance
324            .request_adapter(&wgpu::RequestAdapterOptions {
325                power_preference: wgpu::PowerPreference::HighPerformance,
326                compatible_surface: None,
327                force_fallback_adapter: false,
328            })
329            .await
330            .ok_or_else(|| {
331                UnifiedGpuError::InvalidOperation("No suitable GPU adapter found".into())
332            })?;
333
334        let adapter_info = adapter.get_info();
335
336        let (device, queue) = adapter
337            .request_device(
338                &wgpu::DeviceDescriptor {
339                    label: Some("Shared Amari GPU Device"),
340                    required_features: wgpu::Features::TIMESTAMP_QUERY,
341                    required_limits: wgpu::Limits::default(),
342                },
343                None,
344            )
345            .await
346            .map_err(|e| {
347                UnifiedGpuError::InvalidOperation(format!("Device request failed: {:?}", e))
348            })?;
349
350        Ok(Self {
351            device: Arc::new(device),
352            queue: Arc::new(queue),
353            adapter_info,
354            buffer_pool: Arc::new(std::sync::Mutex::new(EnhancedGpuBufferPool::new())),
355            shader_cache: Arc::new(std::sync::Mutex::new(HashMap::new())),
356            creation_time: Instant::now(),
357        })
358    }
359
360    /// Get the device
361    pub fn device(&self) -> &wgpu::Device {
362        &self.device
363    }
364
365    /// Get the queue
366    pub fn queue(&self) -> &wgpu::Queue {
367        &self.queue
368    }
369
370    /// Get adapter info
371    pub fn adapter_info(&self) -> &wgpu::AdapterInfo {
372        &self.adapter_info
373    }
374
375    /// Get or create a buffer from the pool
376    pub fn get_buffer(
377        &self,
378        size: u64,
379        usage: wgpu::BufferUsages,
380        label: Option<&str>,
381    ) -> wgpu::Buffer {
382        if let Ok(mut pool) = self.buffer_pool.lock() {
383            pool.get_or_create(&self.device, size, usage, label)
384        } else {
385            // Fallback if mutex is poisoned
386            self.device.create_buffer(&wgpu::BufferDescriptor {
387                label,
388                size,
389                usage,
390                mapped_at_creation: false,
391            })
392        }
393    }
394
395    /// Return a buffer to the pool for reuse
396    pub fn return_buffer(&self, buffer: wgpu::Buffer, size: u64, usage: wgpu::BufferUsages) {
397        if let Ok(mut pool) = self.buffer_pool.lock() {
398            pool.return_buffer(buffer, size, usage);
399        }
400        // If mutex is poisoned, just drop the buffer
401    }
402
403    /// Get or create a compute pipeline from cache
404    pub fn get_compute_pipeline(
405        &self,
406        shader_key: &str,
407        shader_source: &str,
408        entry_point: &str,
409    ) -> UnifiedGpuResult<Arc<wgpu::ComputePipeline>> {
410        let cache_key = format!("{}:{}", shader_key, entry_point);
411
412        if let Ok(mut cache) = self.shader_cache.lock() {
413            if let Some(pipeline) = cache.get(&cache_key) {
414                return Ok(Arc::clone(pipeline));
415            }
416
417            // Create new pipeline
418            let shader_module = self
419                .device
420                .create_shader_module(wgpu::ShaderModuleDescriptor {
421                    label: Some(&format!("{} Shader", shader_key)),
422                    source: wgpu::ShaderSource::Wgsl(shader_source.into()),
423                });
424
425            let bind_group_layout =
426                self.device
427                    .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
428                        label: Some(&format!("{} Bind Group Layout", shader_key)),
429                        entries: &[
430                            wgpu::BindGroupLayoutEntry {
431                                binding: 0,
432                                visibility: wgpu::ShaderStages::COMPUTE,
433                                ty: wgpu::BindingType::Buffer {
434                                    ty: wgpu::BufferBindingType::Storage { read_only: true },
435                                    has_dynamic_offset: false,
436                                    min_binding_size: None,
437                                },
438                                count: None,
439                            },
440                            wgpu::BindGroupLayoutEntry {
441                                binding: 1,
442                                visibility: wgpu::ShaderStages::COMPUTE,
443                                ty: wgpu::BindingType::Buffer {
444                                    ty: wgpu::BufferBindingType::Storage { read_only: false },
445                                    has_dynamic_offset: false,
446                                    min_binding_size: None,
447                                },
448                                count: None,
449                            },
450                        ],
451                    });
452
453            let pipeline_layout =
454                self.device
455                    .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
456                        label: Some(&format!("{} Pipeline Layout", shader_key)),
457                        bind_group_layouts: &[&bind_group_layout],
458                        push_constant_ranges: &[],
459                    });
460
461            let pipeline = self
462                .device
463                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
464                    label: Some(&format!("{} Pipeline", shader_key)),
465                    layout: Some(&pipeline_layout),
466                    module: &shader_module,
467                    entry_point,
468                });
469
470            let pipeline_arc = Arc::new(pipeline);
471            cache.insert(cache_key, Arc::clone(&pipeline_arc));
472            Ok(pipeline_arc)
473        } else {
474            Err(UnifiedGpuError::InvalidOperation(
475                "Failed to access shader cache".into(),
476            ))
477        }
478    }
479
480    /// Get buffer pool statistics
481    pub fn buffer_pool_stats(&self) -> BufferPoolStats {
482        if let Ok(pool) = self.buffer_pool.lock() {
483            pool.get_stats()
484        } else {
485            BufferPoolStats::default()
486        }
487    }
488
489    /// Get uptime of this context
490    pub fn uptime(&self) -> std::time::Duration {
491        self.creation_time.elapsed()
492    }
493
494    /// Get optimal workgroup configuration for given operation type and data size
495    pub fn get_optimal_workgroup(&self, operation: &str, data_size: usize) -> (u32, u32, u32) {
496        match operation {
497            "matrix_multiply" | "matrix_operation" => {
498                // 2D workgroups optimized for matrix operations
499                // Use larger workgroups for better occupancy
500                (16, 16, 1)
501            }
502            "vector_operation" | "reduce" | "scan" => {
503                // 1D operations - prefer large workgroups for coalesced memory access
504                let workgroup_size = if data_size > 10000 {
505                    256 // Large batches benefit from maximum occupancy
506                } else if data_size > 1000 {
507                    128 // Medium batches
508                } else {
509                    64 // Small batches
510                };
511                (workgroup_size, 1, 1)
512            }
513            "geometric_algebra" | "clifford_algebra" => {
514                // GA operations with moderate computational complexity
515                (128, 1, 1)
516            }
517            "cellular_automata" | "ca_evolution" => {
518                // 2D grid operations, optimized for spatial locality
519                (16, 16, 1)
520            }
521            "neural_network" | "batch_processing" => {
522                // Large 1D workgroups for high-throughput batch processing
523                (256, 1, 1)
524            }
525            "information_geometry" | "fisher_information" | "bregman_divergence" => {
526                // Statistical manifold computations - large workgroups
527                (256, 1, 1)
528            }
529            "tropical_algebra" | "tropical_matrix" => {
530                // Tropical operations, moderate workgroup size
531                (128, 1, 1)
532            }
533            "dual_number" | "automatic_differentiation" => {
534                // AD operations, balanced workgroup size
535                (128, 1, 1)
536            }
537            "fusion_system" | "llm_evaluation" => {
538                // Complex fusion operations, large workgroups
539                (256, 1, 1)
540            }
541            "enumerative_geometry" | "intersection_theory" => {
542                // Geometric computations, moderate workgroups
543                (64, 1, 1)
544            }
545            _ => (64, 1, 1), // Conservative default for unknown operations
546        }
547    }
548
549    /// Generate optimized WGSL workgroup declaration for operation
550    pub fn get_workgroup_declaration(&self, operation: &str, data_size: usize) -> String {
551        let (x, y, z) = self.get_optimal_workgroup(operation, data_size);
552
553        if y == 1 && z == 1 {
554            format!("@compute @workgroup_size({})", x)
555        } else if z == 1 {
556            format!("@compute @workgroup_size({}, {})", x, y)
557        } else {
558            format!("@compute @workgroup_size({}, {}, {})", x, y, z)
559        }
560    }
561}
562
563/// Enhanced buffer pool with statistics and eviction policies
564pub struct EnhancedGpuBufferPool {
565    pools: HashMap<(u64, wgpu::BufferUsages), Vec<wgpu::Buffer>>,
566    stats: HashMap<(u64, wgpu::BufferUsages), PoolEntryStats>,
567    total_created: u64,
568    total_reused: u64,
569    last_cleanup: Instant,
570}
571
572#[derive(Debug, Clone, Default)]
573pub struct PoolEntryStats {
574    pub created_count: u64,
575    pub reused_count: u64,
576    pub last_used: Option<Instant>,
577    pub total_size_bytes: u64,
578}
579
580#[derive(Debug, Clone, Default)]
581pub struct BufferPoolStats {
582    pub total_buffers_created: u64,
583    pub total_buffers_reused: u64,
584    pub current_pooled_count: usize,
585    pub total_pooled_memory_mb: f32,
586    pub hit_rate_percent: f32,
587}
588
589impl EnhancedGpuBufferPool {
590    pub fn new() -> Self {
591        Self {
592            pools: HashMap::new(),
593            stats: HashMap::new(),
594            total_created: 0,
595            total_reused: 0,
596            last_cleanup: Instant::now(),
597        }
598    }
599}
600
601impl Default for EnhancedGpuBufferPool {
602    fn default() -> Self {
603        Self::new()
604    }
605}
606
607impl EnhancedGpuBufferPool {
608    pub fn get_or_create(
609        &mut self,
610        device: &wgpu::Device,
611        size: u64,
612        usage: wgpu::BufferUsages,
613        label: Option<&str>,
614    ) -> wgpu::Buffer {
615        let key = (size, usage);
616
617        // Try to reuse from pool
618        if let Some(buffers) = self.pools.get_mut(&key) {
619            if let Some(buffer) = buffers.pop() {
620                self.total_reused += 1;
621                self.stats.entry(key).or_default().reused_count += 1;
622                self.stats.get_mut(&key).unwrap().last_used = Some(Instant::now());
623                return buffer;
624            }
625        }
626
627        // Create new buffer
628        let buffer = device.create_buffer(&wgpu::BufferDescriptor {
629            label,
630            size,
631            usage,
632            mapped_at_creation: false,
633        });
634
635        self.total_created += 1;
636        let stats = self.stats.entry(key).or_default();
637        stats.created_count += 1;
638        stats.total_size_bytes += size;
639        stats.last_used = Some(Instant::now());
640
641        // Periodic cleanup
642        if self.last_cleanup.elapsed().as_secs() > 30 {
643            self.cleanup_old_buffers();
644        }
645
646        buffer
647    }
648
649    pub fn return_buffer(&mut self, buffer: wgpu::Buffer, size: u64, usage: wgpu::BufferUsages) {
650        let key = (size, usage);
651        self.pools.entry(key).or_default().push(buffer);
652    }
653
654    pub fn get_stats(&self) -> BufferPoolStats {
655        let total_ops = self.total_created + self.total_reused;
656        let hit_rate = if total_ops > 0 {
657            (self.total_reused as f32 / total_ops as f32) * 100.0
658        } else {
659            0.0
660        };
661
662        let current_pooled_count = self.pools.values().map(|v| v.len()).sum();
663        let total_pooled_memory_mb: f32 = self
664            .pools
665            .iter()
666            .map(|((size, _usage), buffers)| {
667                (*size as f32 * buffers.len() as f32) / 1024.0 / 1024.0
668            })
669            .sum();
670
671        BufferPoolStats {
672            total_buffers_created: self.total_created,
673            total_buffers_reused: self.total_reused,
674            current_pooled_count,
675            total_pooled_memory_mb,
676            hit_rate_percent: hit_rate,
677        }
678    }
679
680    fn cleanup_old_buffers(&mut self) {
681        let now = Instant::now();
682        let cleanup_threshold = std::time::Duration::from_secs(300); // 5 minutes
683
684        self.pools.retain(|&key, buffers| {
685            if let Some(stats) = self.stats.get(&key) {
686                if let Some(last_used) = stats.last_used {
687                    if now.duration_since(last_used) > cleanup_threshold {
688                        // Remove old unused buffers
689                        buffers.clear();
690                        return false;
691                    }
692                }
693            }
694            true
695        });
696
697        self.last_cleanup = now;
698    }
699}
700
701/// Smart GPU/CPU dispatch based on workload characteristics
702pub struct GpuDispatcher {
703    gpu_context: Option<GpuContext>,
704    cpu_threshold: usize,
705    gpu_threshold: usize,
706}
707
708impl GpuDispatcher {
709    /// Create new dispatcher with GPU context
710    pub async fn new() -> UnifiedGpuResult<Self> {
711        let gpu_context = (GpuContext::new().await).ok(); // Graceful fallback to CPU-only
712
713        Ok(Self {
714            gpu_context,
715            cpu_threshold: 100,  // Use CPU for small workloads
716            gpu_threshold: 1000, // Use GPU for large workloads
717        })
718    }
719
720    /// Determine optimal compute strategy
721    pub fn should_use_gpu(&self, workload_size: usize) -> bool {
722        self.gpu_context.is_some()
723            && workload_size >= self.cpu_threshold
724            && workload_size >= self.gpu_threshold
725    }
726
727    /// Execute operation with optimal strategy
728    pub async fn execute<T, F, G>(&mut self, workload_size: usize, gpu_op: G, cpu_op: F) -> T
729    where
730        F: FnOnce() -> T,
731        G: FnOnce(&mut GpuContext) -> UnifiedGpuResult<T>,
732    {
733        if self.should_use_gpu(workload_size) {
734            if let Some(ref mut ctx) = self.gpu_context {
735                if let Ok(result) = gpu_op(ctx) {
736                    return result;
737                }
738            }
739        }
740
741        // Fallback to CPU
742        cpu_op()
743    }
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749
750    #[tokio::test]
751    #[ignore = "GPU hardware required, may fail in CI/CD environments"]
752    async fn test_gpu_context_creation() {
753        // Test should pass even without GPU (graceful fallback)
754        let _result = GpuContext::new().await;
755        // Don't assert success since GPU might not be available in CI
756    }
757
758    #[tokio::test]
759    #[ignore = "GPU hardware required, may fail in CI/CD environments"]
760    async fn test_gpu_dispatcher() {
761        let dispatcher = GpuDispatcher::new().await;
762        assert!(dispatcher.is_ok());
763    }
764
765    #[test]
766    fn test_gpu_operation_params() {
767        let mut params = GpuOperationParams::default();
768        params
769            .params
770            .insert("scale".to_string(), GpuParam::Float(2.0));
771        params.batch_size = 100;
772
773        assert_eq!(params.batch_size, 100);
774        match params.params.get("scale") {
775            Some(GpuParam::Float(val)) => assert_eq!(*val, 2.0),
776            _ => panic!("Expected float parameter"),
777        }
778    }
779}