Skip to main content

ronn_providers/gpu/
provider.rs

1//! GPU execution provider using Candle backend.
2//!
3//! This module provides GPU-accelerated execution using the Candle library
4//! with support for CUDA and Metal backends.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use anyhow::{Result, anyhow};
10use candle_core::{Device, Tensor as CandleTensor};
11use ronn_core::tensor::Tensor;
12use ronn_core::{
13    CompiledKernel, DataType, ExecutionProvider, KernelStats, MemoryType, MemoryUsage,
14    OperatorSpec, PerformanceProfile, ProviderCapability, ProviderConfig, ProviderId,
15    ResourceRequirements, SubGraph, TensorAllocator, TensorLayout,
16};
17use tracing::{debug, info, warn};
18
19use super::allocator::create_gpu_allocator;
20use super::cuda_kernels::{CudaCompileOptions, CudaKernelManager};
21use super::memory_manager::{MultiGpuMemoryConfig, MultiGpuMemoryManager};
22use super::topology::{GpuTopologyManager, TopologyConfig};
23
24/// GPU execution provider using Candle backend.
25pub struct GpuExecutionProvider {
26    /// GPU devices for multi-GPU execution.
27    devices: Vec<Device>,
28    /// Memory allocators per device.
29    allocators: Vec<Arc<dyn TensorAllocator>>,
30    /// Set of supported operations.
31    supported_ops: HashSet<String>,
32    /// Provider configuration.
33    config: GpuProviderConfig,
34    /// Multi-GPU device manager.
35    device_manager: Arc<std::sync::Mutex<MultiGpuManager>>,
36    /// CUDA kernel managers per device.
37    cuda_kernel_managers: Vec<Option<CudaKernelManager>>,
38    /// Multi-GPU memory manager.
39    memory_manager: Option<Arc<MultiGpuMemoryManager>>,
40    /// GPU topology manager.
41    topology_manager: Option<Arc<GpuTopologyManager>>,
42}
43
44/// Configuration for GPU execution provider.
45#[derive(Debug, Clone)]
46pub struct GpuProviderConfig {
47    /// GPU device IDs for multi-GPU support.
48    pub device_ids: Vec<usize>,
49    /// Primary device ID (first device in device_ids).
50    pub primary_device_id: usize,
51    /// Memory limit in bytes per device (None = no limit).
52    pub memory_limit: Option<usize>,
53    /// Enable mixed precision (F16) operations.
54    pub enable_mixed_precision: bool,
55    /// Enable tensor core optimizations (if available).
56    pub enable_tensor_cores: bool,
57    /// Stream count for async operations per device.
58    pub stream_count: usize,
59    /// Enable multi-GPU distribution.
60    pub enable_multi_gpu: bool,
61    /// P2P memory transfer optimization.
62    pub enable_p2p_transfer: bool,
63    /// Load balancing strategy for multi-GPU.
64    pub load_balancing: LoadBalancingStrategy,
65    /// Enable custom CUDA kernels for optimized operations.
66    pub enable_custom_kernels: bool,
67    /// CUDA compilation options for custom kernels.
68    pub cuda_compile_options: CudaCompileOptions,
69    /// Multi-GPU memory management configuration.
70    pub memory_config: MultiGpuMemoryConfig,
71    /// GPU topology detection configuration.
72    pub topology_config: TopologyConfig,
73}
74
75/// Load balancing strategies for multi-GPU execution.
76#[derive(Debug, Clone, Copy)]
77pub enum LoadBalancingStrategy {
78    /// Round-robin assignment of operations.
79    RoundRobin,
80    /// Balance based on current GPU memory usage.
81    MemoryBased,
82    /// Balance based on GPU utilization.
83    UtilizationBased,
84    /// Static assignment based on operation type.
85    OperationBased,
86    /// Optimal placement using cost model.
87    CostModel,
88}
89
90impl Default for LoadBalancingStrategy {
91    fn default() -> Self {
92        LoadBalancingStrategy::RoundRobin
93    }
94}
95
96impl Default for GpuProviderConfig {
97    fn default() -> Self {
98        Self {
99            device_ids: vec![0],
100            primary_device_id: 0,
101            memory_limit: None,
102            enable_mixed_precision: true,
103            enable_tensor_cores: true,
104            stream_count: 1,
105            enable_multi_gpu: false,
106            enable_p2p_transfer: true,
107            load_balancing: LoadBalancingStrategy::default(),
108            enable_custom_kernels: true,
109            cuda_compile_options: CudaCompileOptions::default(),
110            memory_config: MultiGpuMemoryConfig::default(),
111            topology_config: TopologyConfig::default(),
112        }
113    }
114}
115
116/// GPU kernel implementation using Candle.
117#[derive(Debug)]
118pub struct GpuKernel {
119    /// Original subgraph.
120    subgraph: SubGraph,
121    /// GPU device for execution.
122    device: Device,
123    /// Execution statistics.
124    stats: std::sync::Mutex<GpuKernelStats>,
125    /// Stream ID for async execution.
126    stream_id: usize,
127    /// Compiled kernel cache.
128    kernel_cache: std::sync::Mutex<KernelCache>,
129}
130
131/// Cache for compiled GPU kernels.
132#[derive(Debug, Default)]
133struct KernelCache {
134    /// Cached operations by operation signature.
135    cached_ops: HashMap<String, CachedOperation>,
136    /// Total cache size in bytes.
137    cache_size: usize,
138    /// Maximum cache size.
139    max_cache_size: usize,
140}
141
142/// A cached GPU operation.
143#[derive(Debug, Clone)]
144struct CachedOperation {
145    /// Operation signature (hash of inputs + op type).
146    signature: String,
147    /// Optimized execution path.
148    execution_path: OptimizedPath,
149    /// Hit count for LRU eviction.
150    hit_count: u64,
151    /// Last access time.
152    last_accessed: std::time::Instant,
153}
154
155/// Optimized execution path for GPU operations.
156#[derive(Debug, Clone)]
157enum OptimizedPath {
158    /// Single operation execution.
159    Single(String),
160    /// Fused operation sequence.
161    Fused(Vec<String>),
162    /// Mixed precision path.
163    MixedPrecision {
164        fp16_ops: Vec<String>,
165        fp32_ops: Vec<String>,
166    },
167}
168
169#[derive(Debug, Default)]
170struct GpuKernelStats {
171    execution_count: u64,
172    total_time_us: u64,
173    min_time_us: u64,
174    max_time_us: u64,
175    memory_peak: usize,
176}
177
178/// Multi-GPU device manager for load balancing and coordination.
179#[derive(Debug)]
180struct MultiGpuManager {
181    /// Configuration for multi-GPU setup.
182    config: GpuProviderConfig,
183    /// Device utilization stats for load balancing.
184    device_stats: HashMap<usize, DeviceStats>,
185    /// Round-robin counter for device selection.
186    round_robin_counter: usize,
187    /// Current memory usage per device.
188    memory_usage: HashMap<usize, usize>,
189}
190
191/// Statistics for individual GPU devices.
192#[derive(Debug, Default, Clone)]
193struct DeviceStats {
194    /// Number of operations executed on this device.
195    operation_count: u64,
196    /// Current memory usage in bytes.
197    current_memory: usize,
198    /// Peak memory usage in bytes.
199    peak_memory: usize,
200    /// Average execution time in microseconds.
201    avg_execution_time: f64,
202    /// Last utilization measurement (0.0 to 1.0).
203    utilization: f32,
204}
205
206impl MultiGpuManager {
207    /// Create a new multi-GPU manager.
208    fn new(config: GpuProviderConfig) -> Self {
209        let mut device_stats = HashMap::new();
210        let mut memory_usage = HashMap::new();
211
212        for &device_id in &config.device_ids {
213            device_stats.insert(device_id, DeviceStats::default());
214            memory_usage.insert(device_id, 0);
215        }
216
217        Self {
218            config,
219            device_stats,
220            round_robin_counter: 0,
221            memory_usage,
222        }
223    }
224
225    /// Select the best device for operation execution based on load balancing strategy.
226    fn select_device(&mut self, op_type: &str, memory_requirement: usize) -> usize {
227        if self.config.device_ids.len() == 1 {
228            return self.config.device_ids[0];
229        }
230
231        if !self.config.enable_multi_gpu {
232            return self.config.primary_device_id;
233        }
234
235        match self.config.load_balancing {
236            LoadBalancingStrategy::RoundRobin => {
237                let device_id =
238                    self.config.device_ids[self.round_robin_counter % self.config.device_ids.len()];
239                self.round_robin_counter += 1;
240                device_id
241            }
242            LoadBalancingStrategy::MemoryBased => self.select_device_by_memory(memory_requirement),
243            LoadBalancingStrategy::UtilizationBased => self.select_device_by_utilization(),
244            LoadBalancingStrategy::OperationBased => self.select_device_by_operation_type(op_type),
245            LoadBalancingStrategy::CostModel => {
246                self.select_device_by_cost_model(op_type, memory_requirement)
247            }
248        }
249    }
250
251    /// Select device with the most available memory.
252    fn select_device_by_memory(&self, memory_requirement: usize) -> usize {
253        self.config
254            .device_ids
255            .iter()
256            .min_by_key(|&&device_id| {
257                self.memory_usage.get(&device_id).unwrap_or(&0) + memory_requirement
258            })
259            .copied()
260            .unwrap_or(self.config.primary_device_id)
261    }
262
263    /// Select device with the lowest utilization.
264    fn select_device_by_utilization(&self) -> usize {
265        self.config
266            .device_ids
267            .iter()
268            .min_by(|&&a, &&b| {
269                let util_a = self
270                    .device_stats
271                    .get(&a)
272                    .map(|s| s.utilization)
273                    .unwrap_or(0.0);
274                let util_b = self
275                    .device_stats
276                    .get(&b)
277                    .map(|s| s.utilization)
278                    .unwrap_or(0.0);
279                util_a
280                    .partial_cmp(&util_b)
281                    .unwrap_or(std::cmp::Ordering::Equal)
282            })
283            .copied()
284            .unwrap_or(self.config.primary_device_id)
285    }
286
287    /// Select device based on operation type preferences.
288    fn select_device_by_operation_type(&self, op_type: &str) -> usize {
289        // For now, use simple heuristics
290        match op_type {
291            // Compute-intensive operations prefer higher-end GPUs (lower device IDs)
292            "MatMul" | "Conv" | "ConvTranspose" => self
293                .config
294                .device_ids
295                .iter()
296                .min()
297                .copied()
298                .unwrap_or(self.config.primary_device_id),
299            // Memory-intensive operations prefer GPUs with more available memory
300            "Concat" | "Split" | "Reshape" => self.select_device_by_memory(0),
301            // Default to round-robin
302            _ => {
303                let device_id =
304                    self.config.device_ids[self.round_robin_counter % self.config.device_ids.len()];
305                device_id
306            }
307        }
308    }
309
310    /// Select device using a cost model that considers multiple factors.
311    fn select_device_by_cost_model(&self, _op_type: &str, memory_requirement: usize) -> usize {
312        let mut best_device = self.config.primary_device_id;
313        let mut best_score = f64::INFINITY;
314
315        for &device_id in &self.config.device_ids {
316            let default_stats = DeviceStats::default();
317            let stats = self.device_stats.get(&device_id).unwrap_or(&default_stats);
318            let memory_used = self.memory_usage.get(&device_id).unwrap_or(&0);
319
320            // Calculate cost score (lower is better)
321            let memory_pressure =
322                (*memory_used + memory_requirement) as f64 / (1024.0 * 1024.0 * 1024.0); // GB
323            let utilization_penalty = stats.utilization as f64 * 2.0;
324            let execution_time_penalty = stats.avg_execution_time / 1000.0; // Convert to ms
325
326            let total_score = memory_pressure + utilization_penalty + execution_time_penalty;
327
328            if total_score < best_score {
329                best_score = total_score;
330                best_device = device_id;
331            }
332        }
333
334        best_device
335    }
336
337    /// Update device statistics after operation execution.
338    fn update_device_stats(
339        &mut self,
340        device_id: usize,
341        execution_time_us: u64,
342        memory_used: usize,
343    ) {
344        if let Some(stats) = self.device_stats.get_mut(&device_id) {
345            stats.operation_count += 1;
346            stats.current_memory = memory_used;
347            stats.peak_memory = stats.peak_memory.max(memory_used);
348
349            // Update rolling average execution time
350            let alpha = 0.1; // Smoothing factor
351            if stats.avg_execution_time == 0.0 {
352                stats.avg_execution_time = execution_time_us as f64;
353            } else {
354                stats.avg_execution_time =
355                    alpha * execution_time_us as f64 + (1.0 - alpha) * stats.avg_execution_time;
356            }
357        }
358
359        self.memory_usage.insert(device_id, memory_used);
360    }
361
362    /// Get current device statistics for monitoring.
363    fn get_device_stats(&self) -> &HashMap<usize, DeviceStats> {
364        &self.device_stats
365    }
366}
367
368impl GpuExecutionProvider {
369    /// Create a new GPU execution provider with default configuration.
370    #[cfg(feature = "gpu")]
371    pub fn new() -> Result<Self> {
372        Self::with_config(GpuProviderConfig::default())
373    }
374
375    /// Create a GPU execution provider with custom configuration.
376    #[cfg(feature = "gpu")]
377    pub fn with_config(config: GpuProviderConfig) -> Result<Self> {
378        // Create GPU devices based on configuration
379        let mut devices = Vec::new();
380        let mut allocators = Vec::new();
381        let mut cuda_kernel_managers = Vec::new();
382
383        for &device_id in &config.device_ids {
384            let device = Self::create_gpu_device(device_id)?;
385            info!("Created GPU device {}: {:?}", device_id, device);
386
387            // Create CUDA kernel manager if enabled and device is CUDA
388            let cuda_manager = if config.enable_custom_kernels && matches!(device, Device::Cuda(_))
389            {
390                match CudaKernelManager::with_options(
391                    device.clone(),
392                    config.cuda_compile_options.clone(),
393                ) {
394                    Ok(manager) => {
395                        info!("Created CUDA kernel manager for device {}", device_id);
396                        Some(manager)
397                    }
398                    Err(e) => {
399                        warn!(
400                            "Failed to create CUDA kernel manager for device {}: {}",
401                            device_id, e
402                        );
403                        None
404                    }
405                }
406            } else {
407                None
408            };
409
410            devices.push(device);
411            cuda_kernel_managers.push(cuda_manager);
412
413            // Create allocator for each device
414            let allocator = create_gpu_allocator().map_err(|e| {
415                anyhow!(
416                    "Failed to create GPU allocator for device {}: {}",
417                    device_id,
418                    e
419                )
420            })?;
421            allocators.push(allocator);
422        }
423
424        if devices.is_empty() {
425            return Err(anyhow!("No GPU devices configured"));
426        }
427
428        info!("Created GPU provider with {} devices", devices.len());
429
430        // Create multi-GPU manager
431        let device_manager = Arc::new(std::sync::Mutex::new(MultiGpuManager::new(config.clone())));
432
433        // Create multi-GPU memory manager if multi-GPU is enabled
434        let memory_manager = if config.enable_multi_gpu && config.device_ids.len() > 1 {
435            match MultiGpuMemoryManager::new(
436                config.device_ids.clone(),
437                config.memory_config.clone(),
438            ) {
439                Ok(manager) => {
440                    info!("Created multi-GPU memory manager");
441                    Some(Arc::new(manager))
442                }
443                Err(e) => {
444                    warn!("Failed to create multi-GPU memory manager: {}", e);
445                    None
446                }
447            }
448        } else {
449            None
450        };
451
452        // Create GPU topology manager if multi-GPU is enabled
453        let topology_manager = if config.enable_multi_gpu && config.device_ids.len() > 1 {
454            match GpuTopologyManager::new(config.topology_config.clone()) {
455                Ok(mut manager) => {
456                    // Discover topology
457                    if let Err(e) = manager.discover_topology() {
458                        warn!("Failed to discover GPU topology: {}", e);
459                    } else {
460                        info!("GPU topology discovered successfully");
461                    }
462                    Some(Arc::new(manager))
463                }
464                Err(e) => {
465                    warn!("Failed to create topology manager: {}", e);
466                    None
467                }
468            }
469        } else {
470            None
471        };
472
473        // Define supported operations (GPU-optimized subset)
474        let mut supported_ops = HashSet::new();
475
476        // Basic arithmetic operations (highly optimized on GPU)
477        supported_ops.insert("Add".to_string());
478        supported_ops.insert("Sub".to_string());
479        supported_ops.insert("Mul".to_string());
480        supported_ops.insert("Div".to_string());
481
482        // Matrix operations (GPU's strength)
483        supported_ops.insert("MatMul".to_string());
484        supported_ops.insert("Gemm".to_string());
485
486        // Convolution operations (GPU-accelerated)
487        supported_ops.insert("Conv".to_string());
488        supported_ops.insert("ConvTranspose".to_string());
489
490        // Pooling operations
491        supported_ops.insert("MaxPool".to_string());
492        supported_ops.insert("AveragePool".to_string());
493        supported_ops.insert("GlobalAveragePool".to_string());
494
495        // Activation functions (element-wise, GPU-friendly)
496        supported_ops.insert("ReLU".to_string());
497        supported_ops.insert("Sigmoid".to_string());
498        supported_ops.insert("Tanh".to_string());
499        supported_ops.insert("Softmax".to_string());
500        supported_ops.insert("GELU".to_string());
501
502        // Normalization operations
503        supported_ops.insert("BatchNormalization".to_string());
504        supported_ops.insert("LayerNormalization".to_string());
505
506        // Reduction operations (efficient on GPU)
507        supported_ops.insert("Sum".to_string());
508        supported_ops.insert("Mean".to_string());
509        supported_ops.insert("Max".to_string());
510        supported_ops.insert("Min".to_string());
511
512        // Shape operations (fast on GPU)
513        supported_ops.insert("Reshape".to_string());
514        supported_ops.insert("Transpose".to_string());
515        supported_ops.insert("Concat".to_string());
516        supported_ops.insert("Split".to_string());
517
518        info!(
519            "GPU provider supports {} operation types",
520            supported_ops.len()
521        );
522
523        Ok(Self {
524            devices,
525            allocators,
526            supported_ops,
527            config,
528            device_manager,
529            cuda_kernel_managers,
530            memory_manager,
531            topology_manager,
532        })
533    }
534
535    /// Fallback constructor when GPU is not available.
536    #[cfg(not(feature = "gpu"))]
537    pub fn new() -> Result<Self> {
538        Err(anyhow!("GPU support not compiled in"))
539    }
540
541    /// Create a GPU execution provider with custom configuration.
542    #[cfg(not(feature = "gpu"))]
543    pub fn with_config(_config: GpuProviderConfig) -> Result<Self> {
544        Err(anyhow!("GPU support not compiled in"))
545    }
546
547    /// Create GPU device based on configuration.
548    #[cfg(feature = "gpu")]
549    fn create_gpu_device(device_id: usize) -> Result<Device> {
550        // Try CUDA first
551        if let Ok(device) = Device::new_cuda(device_id) {
552            info!("Using CUDA device {}", device_id);
553            return Ok(device);
554        }
555
556        // Try Metal on macOS
557        #[cfg(target_os = "macos")]
558        {
559            if let Ok(device) = Device::new_metal(device_id) {
560                info!("Using Metal device {}", device_id);
561                return Ok(device);
562            }
563        }
564
565        Err(anyhow!("No GPU devices available"))
566    }
567
568    /// Get the primary GPU device.
569    pub fn device(&self) -> &Device {
570        &self.devices[0]
571    }
572
573    /// Get the current configuration.
574    pub fn get_config(&self) -> &GpuProviderConfig {
575        &self.config
576    }
577
578    /// Check if an operation type is supported.
579    pub fn supports_operation(&self, op_type: &str) -> bool {
580        self.supported_ops.contains(op_type)
581    }
582
583    /// Estimate execution cost for an operation on GPU.
584    pub fn estimate_cost(&self, op_spec: &OperatorSpec) -> f64 {
585        // GPU cost estimation - generally lower for parallel operations
586        match op_spec.op_type.as_str() {
587            "Add" | "Sub" | "Mul" | "Div" => 0.1, // Very fast on GPU
588            "ReLU" | "Sigmoid" | "Tanh" => 0.2,   // Fast element-wise
589            "MatMul" | "Gemm" => 0.5,             // GPU's strength
590            "Conv" => 0.8,                        // Complex but GPU-optimized
591            "ConvTranspose" => 1.2,               // More complex
592            "BatchNormalization" => 0.3,          // Fast on GPU
593            "Softmax" => 0.4,                     // Reduction + element-wise
594            "MaxPool" | "AveragePool" => 0.3,     // Simple operations
595            _ => 1.0,                             // Default cost
596        }
597    }
598
599    /// Check if the provider can utilize tensor cores.
600    #[cfg(feature = "gpu")]
601    pub fn has_tensor_cores(&self) -> bool {
602        // In practice, would query GPU capabilities
603        // For now, assume modern CUDA devices have tensor cores
604        matches!(self.device, Device::Cuda(_)) && self.config.enable_tensor_cores
605    }
606
607    /// Check if the GPU has tensor cores for mixed-precision operations.
608    #[cfg(not(feature = "gpu"))]
609    pub fn has_tensor_cores(&self) -> bool {
610        false
611    }
612
613    /// Get GPU memory information.
614    #[cfg(feature = "gpu")]
615    pub fn get_gpu_memory_info(&self) -> Result<(usize, usize)> {
616        // In practice, would query actual GPU memory
617        // For now, return estimated values
618        match &self.devices[0] {
619            Device::Cuda(_) => Ok((8 * 1024 * 1024 * 1024, 0)), // 8GB total, 0 used
620            Device::Metal(_) => Ok((8 * 1024 * 1024 * 1024, 0)), // 8GB total, 0 used
621            _ => Err(anyhow!("Not a GPU device")),
622        }
623    }
624
625    /// Get GPU memory information (total, available) in bytes.
626    #[cfg(not(feature = "gpu"))]
627    pub fn get_gpu_memory_info(&self) -> Result<(usize, usize)> {
628        Err(anyhow!("GPU support not available"))
629    }
630}
631
632impl Default for GpuExecutionProvider {
633    fn default() -> Self {
634        Self::new().expect("Failed to create default GPU provider")
635    }
636}
637
638impl ExecutionProvider for GpuExecutionProvider {
639    fn provider_id(&self) -> ProviderId {
640        ProviderId::GPU
641    }
642
643    fn get_capability(&self) -> ProviderCapability {
644        let mut data_types = vec![
645            DataType::F32,
646            DataType::F16, // Important for GPU mixed precision
647            DataType::F64,
648            DataType::U8,
649            DataType::U32,
650        ];
651
652        // Add additional types if tensor cores are available
653        if self.has_tensor_cores() {
654            // Tensor cores work best with F16
655            data_types.insert(0, DataType::F16); // Prioritize F16
656        }
657
658        let gpu_memory = self
659            .get_gpu_memory_info()
660            .map(|(total, _)| total)
661            .unwrap_or(0);
662
663        ProviderCapability {
664            supported_ops: self.supported_ops.clone(),
665            data_types,
666            memory_types: vec![MemoryType::DeviceMemory, MemoryType::SharedMemory],
667            performance_profile: PerformanceProfile::GPU,
668            resource_requirements: ResourceRequirements {
669                min_memory_bytes: Some(512 * 1024 * 1024), // 512MB minimum
670                cpu_features: vec![],                      // No specific CPU requirements
671                gpu_memory_bytes: Some(gpu_memory),
672            },
673        }
674    }
675
676    fn can_handle(&self, operators: &[OperatorSpec]) -> Vec<bool> {
677        operators
678            .iter()
679            .map(|op| self.supports_operation(&op.op_type))
680            .collect()
681    }
682
683    fn compile_subgraph(&self, subgraph: SubGraph) -> Result<Box<dyn CompiledKernel>> {
684        debug!(
685            "Compiling subgraph with {} nodes for GPU",
686            subgraph.nodes.len()
687        );
688
689        // Validate that all operations are supported
690        for node in &subgraph.nodes {
691            if !self.supports_operation(&node.op_type) {
692                return Err(anyhow!(
693                    "Unsupported GPU operation '{}' in subgraph",
694                    node.op_type
695                ));
696            }
697        }
698
699        // Select optimal device for this subgraph using multi-GPU manager
700        let mut device_manager = self.device_manager.lock().unwrap();
701        let primary_op = subgraph
702            .nodes
703            .first()
704            .map(|n| n.op_type.as_str())
705            .unwrap_or("Unknown");
706
707        // Estimate memory requirement (rough estimation)
708        let estimated_memory = subgraph.nodes.len() * 1024 * 1024; // 1MB per operation
709
710        let selected_device_id = device_manager.select_device(primary_op, estimated_memory);
711
712        // Find the device index in our devices vector
713        let device_index = self
714            .config
715            .device_ids
716            .iter()
717            .position(|&id| id == selected_device_id)
718            .unwrap_or(0);
719
720        let selected_device = self.devices[device_index].clone();
721
722        debug!(
723            "Selected GPU device {} for subgraph compilation",
724            selected_device_id
725        );
726
727        drop(device_manager); // Release lock before kernel creation
728
729        // Create GPU kernel with selected device and stream
730        let stream_id = selected_device_id % self.config.stream_count;
731        let kernel = GpuKernel::with_stream(subgraph, selected_device, stream_id)?;
732
733        debug!(
734            "Successfully compiled GPU kernel on device {}",
735            selected_device_id
736        );
737
738        Ok(Box::new(kernel))
739    }
740
741    fn get_allocator(&self) -> Arc<dyn TensorAllocator> {
742        // Return the allocator for the primary device
743        self.allocators[0].clone()
744    }
745
746    fn configure(&mut self, config: ProviderConfig) -> Result<()> {
747        // Update memory limit
748        if let Some(memory_limit) = config.memory_limit {
749            self.config.memory_limit = Some(memory_limit);
750            info!("Updated GPU memory limit to {} bytes", memory_limit);
751        }
752
753        // Handle custom options
754        for (key, value) in &config.custom_options {
755            match key.as_str() {
756                "enable_mixed_precision" => {
757                    if let Ok(enable) = value.parse::<bool>() {
758                        self.config.enable_mixed_precision = enable;
759                        info!("Updated mixed precision to {}", enable);
760                    }
761                }
762                "enable_tensor_cores" => {
763                    if let Ok(enable) = value.parse::<bool>() {
764                        self.config.enable_tensor_cores = enable;
765                        info!("Updated tensor cores to {}", enable);
766                    }
767                }
768                "stream_count" => {
769                    if let Ok(count) = value.parse::<usize>() {
770                        self.config.stream_count = count;
771                        info!("Updated stream count to {}", count);
772                    }
773                }
774                _ => {
775                    warn!("Unknown GPU configuration option: {}", key);
776                }
777            }
778        }
779
780        Ok(())
781    }
782
783    fn shutdown(&self) -> Result<()> {
784        info!("Shutting down GPU execution provider");
785
786        // GPU cleanup would happen here in a real implementation
787        // Candle handles most cleanup automatically
788
789        debug!("GPU provider shutdown complete");
790
791        Ok(())
792    }
793}
794
795impl GpuExecutionProvider {
796    /// Get allocator for a specific device.
797    pub fn get_device_allocator(&self, device_id: usize) -> Option<Arc<dyn TensorAllocator>> {
798        let device_index = self
799            .config
800            .device_ids
801            .iter()
802            .position(|&id| id == device_id)?;
803        self.allocators.get(device_index).cloned()
804    }
805
806    /// Get multi-GPU device statistics.
807    pub fn get_multi_gpu_stats(&self) -> HashMap<usize, DeviceStats> {
808        let device_manager = self.device_manager.lock().unwrap();
809        device_manager.get_device_stats().clone()
810    }
811
812    /// Enable or disable multi-GPU support.
813    pub fn set_multi_gpu_enabled(&mut self, enabled: bool) {
814        self.config.enable_multi_gpu = enabled;
815        info!(
816            "Multi-GPU support {}",
817            if enabled { "enabled" } else { "disabled" }
818        );
819    }
820
821    /// Set load balancing strategy for multi-GPU execution.
822    pub fn set_load_balancing_strategy(&mut self, strategy: LoadBalancingStrategy) {
823        info!("Updated load balancing strategy to {:?}", strategy);
824        self.config.load_balancing = strategy;
825    }
826
827    /// Get the number of configured GPU devices.
828    pub fn device_count(&self) -> usize {
829        self.devices.len()
830    }
831
832    /// Check if a specific device ID is available.
833    pub fn has_device(&self, device_id: usize) -> bool {
834        self.config.device_ids.contains(&device_id)
835    }
836
837    /// Check if custom CUDA kernels are available for a specific device.
838    pub fn has_custom_kernels(&self, device_id: usize) -> bool {
839        if let Some(device_index) = self
840            .config
841            .device_ids
842            .iter()
843            .position(|&id| id == device_id)
844        {
845            self.cuda_kernel_managers
846                .get(device_index)
847                .map(|manager| manager.is_some())
848                .unwrap_or(false)
849        } else {
850            false
851        }
852    }
853
854    /// Get available custom kernel operations for a specific device.
855    pub fn get_custom_kernel_ops(&self, device_id: usize) -> Vec<String> {
856        if self.has_custom_kernels(device_id) {
857            vec![
858                "FusedMatMulBias".to_string(),
859                "OptimizedSoftmax".to_string(),
860                "FusedConvBnRelu".to_string(),
861                "WarpReduceSum".to_string(),
862                "TensorCoreGemm".to_string(),
863                "FastGelu".to_string(),
864            ]
865        } else {
866            vec![]
867        }
868    }
869
870    /// Execute an operation with custom CUDA kernel if available.
871    pub fn try_execute_with_custom_kernel(
872        &self,
873        device_id: usize,
874        op_type: &str,
875        inputs: &[CandleTensor],
876    ) -> Result<Option<Vec<CandleTensor>>> {
877        let device_index = self
878            .config
879            .device_ids
880            .iter()
881            .position(|&id| id == device_id)
882            .ok_or_else(|| anyhow!("Device {} not found", device_id))?;
883
884        if let Some(Some(kernel_manager)) = self.cuda_kernel_managers.get(device_index) {
885            // Check if we have a custom kernel for this operation
886            let tensor_shape = inputs
887                .first()
888                .map(|t| t.shape().dims().to_vec())
889                .unwrap_or_else(|| vec![1]);
890
891            match kernel_manager.get_optimized_kernel(op_type, &tensor_shape) {
892                Ok(mut kernel) => {
893                    info!("Using custom CUDA kernel for operation: {}", op_type);
894
895                    // Prepare outputs (simplified - would need proper output allocation)
896                    let mut outputs: Vec<CandleTensor> = inputs.iter()
897                        .map(|input| input.clone()) // Placeholder
898                        .collect();
899
900                    // Execute the custom kernel
901                    kernel_manager.execute_kernel(&mut kernel, inputs, &mut outputs)?;
902
903                    Ok(Some(outputs))
904                }
905                Err(_) => {
906                    // No custom kernel available for this operation
907                    Ok(None)
908                }
909            }
910        } else {
911            // No CUDA kernel manager available
912            Ok(None)
913        }
914    }
915
916    /// Clear all custom kernel caches to free memory.
917    pub fn clear_kernel_caches(&self) {
918        for kernel_manager in self.cuda_kernel_managers.iter().flatten() {
919            kernel_manager.clear_cache();
920        }
921        info!("Cleared all CUDA kernel caches");
922    }
923
924    /// Get custom kernel cache statistics.
925    pub fn get_kernel_cache_stats(&self) -> Vec<super::cuda_kernels::CacheStats> {
926        self.cuda_kernel_managers
927            .iter()
928            .filter_map(|manager| manager.as_ref().map(|km| km.get_cache_stats()))
929            .collect()
930    }
931
932    /// Transfer tensor data between devices using optimal path.
933    pub fn transfer_tensor_between_devices(
934        &self,
935        tensor: &CandleTensor,
936        target_device_id: usize,
937    ) -> Result<CandleTensor> {
938        if let Some(ref _memory_manager) = self.memory_manager {
939            // Use advanced memory manager for optimal transfers
940            info!(
941                "Using multi-GPU memory manager for tensor transfer to device {}",
942                target_device_id
943            );
944
945            // In a real implementation, this would:
946            // 1. Check if tensor is already on target device
947            // 2. Use P2P transfer if available
948            // 3. Fall back to host memory if needed
949            // 4. Update transfer statistics
950
951            // For now, use Candle's built-in transfer
952            let target_device = &self.devices[self
953                .config
954                .device_ids
955                .iter()
956                .position(|&id| id == target_device_id)
957                .unwrap_or(0)];
958            Ok(tensor.to_device(target_device)?)
959        } else {
960            // Fall back to standard Candle transfer
961            let target_device = &self.devices[self
962                .config
963                .device_ids
964                .iter()
965                .position(|&id| id == target_device_id)
966                .unwrap_or(0)];
967            Ok(tensor.to_device(target_device)?)
968        }
969    }
970
971    /// Synchronize memory operations across all devices.
972    pub fn synchronize_memory(&self) -> Result<()> {
973        if let Some(ref memory_manager) = self.memory_manager {
974            memory_manager.synchronize_all()
975        } else {
976            // Single device or no memory manager - nothing to sync
977            Ok(())
978        }
979    }
980
981    /// Get memory statistics across all devices.
982    pub fn get_memory_statistics(&self) -> HashMap<usize, super::memory_manager::MemoryPoolStats> {
983        if let Some(ref memory_manager) = self.memory_manager {
984            memory_manager.get_memory_stats()
985        } else {
986            HashMap::new()
987        }
988    }
989
990    /// Get global memory statistics.
991    pub fn get_global_memory_stats(&self) -> Option<super::memory_manager::GlobalMemoryStats> {
992        self.memory_manager.as_ref().map(|mm| mm.get_global_stats())
993    }
994
995    /// Get P2P connectivity information between devices.
996    pub fn get_p2p_connectivity(
997        &self,
998    ) -> HashMap<(usize, usize), super::memory_manager::P2PCapability> {
999        if let Some(ref memory_manager) = self.memory_manager {
1000            memory_manager.get_p2p_info()
1001        } else {
1002            HashMap::new()
1003        }
1004    }
1005
1006    /// Check if P2P is available between two devices.
1007    pub fn is_p2p_available(&self, src_device: usize, dst_device: usize) -> bool {
1008        if let Some(ref memory_manager) = self.memory_manager {
1009            let p2p_info = memory_manager.get_p2p_info();
1010            p2p_info
1011                .get(&(src_device, dst_device))
1012                .map(|cap| cap.supported)
1013                .unwrap_or(false)
1014        } else {
1015            false
1016        }
1017    }
1018
1019    /// Get optimal memory layout for a given workload.
1020    pub fn optimize_memory_layout(
1021        &self,
1022        access_pattern: &super::memory_manager::AccessPattern,
1023    ) -> Result<super::memory_manager::MemoryLayout> {
1024        if let Some(ref memory_manager) = self.memory_manager {
1025            memory_manager.optimize_memory_layout(access_pattern)
1026        } else {
1027            Err(anyhow!("Multi-GPU memory manager not available"))
1028        }
1029    }
1030
1031    /// Get GPU topology information.
1032    pub fn get_topology(&self) -> Option<super::topology::GpuTopology> {
1033        self.topology_manager.as_ref().map(|tm| tm.get_topology())
1034    }
1035
1036    /// Optimize workload placement using topology analysis.
1037    pub fn optimize_workload_placement(
1038        &self,
1039        workload: &super::topology::Workload,
1040        strategy: &str,
1041    ) -> Result<super::topology::PlacementPlan> {
1042        if let Some(ref topology_manager) = self.topology_manager {
1043            topology_manager.optimize_placement(workload, strategy)
1044        } else {
1045            Err(anyhow!("GPU topology manager not available"))
1046        }
1047    }
1048
1049    /// Compare multiple placement strategies for a workload.
1050    pub fn compare_placement_strategies(
1051        &self,
1052        workload: &super::topology::Workload,
1053        strategies: &[String],
1054    ) -> Result<Vec<(String, super::topology::PlacementPlan)>> {
1055        if let Some(ref topology_manager) = self.topology_manager {
1056            topology_manager.compare_strategies(workload, strategies)
1057        } else {
1058            Err(anyhow!("GPU topology manager not available"))
1059        }
1060    }
1061
1062    /// Get available placement strategies.
1063    pub fn get_available_placement_strategies(&self) -> Vec<String> {
1064        if let Some(ref topology_manager) = self.topology_manager {
1065            topology_manager.get_available_strategies()
1066        } else {
1067            vec![]
1068        }
1069    }
1070
1071    /// Check if topology-aware placement is available.
1072    pub fn has_topology_support(&self) -> bool {
1073        self.topology_manager.is_some()
1074    }
1075
1076    /// Get detailed device information including topology.
1077    pub fn get_detailed_device_info(&self) -> HashMap<usize, super::topology::GpuDeviceInfo> {
1078        if let Some(ref topology_manager) = self.topology_manager {
1079            topology_manager.get_topology().devices
1080        } else {
1081            HashMap::new()
1082        }
1083    }
1084
1085    /// Get interconnect information between devices.
1086    pub fn get_interconnect_info(
1087        &self,
1088    ) -> HashMap<(usize, usize), super::topology::InterconnectLink> {
1089        if let Some(ref topology_manager) = self.topology_manager {
1090            topology_manager.get_topology().links
1091        } else {
1092            HashMap::new()
1093        }
1094    }
1095
1096    /// Automatically select optimal devices for a workload.
1097    pub fn auto_select_devices(&self, workload: &super::topology::Workload) -> Result<Vec<usize>> {
1098        let plan = self.optimize_workload_placement(workload, "locality_aware")?;
1099        Ok(plan.device_assignments.values().copied().collect())
1100    }
1101}
1102
1103impl GpuKernel {
1104    /// Create a new GPU kernel.
1105    pub fn new(subgraph: SubGraph, device: Device) -> Result<Self> {
1106        Ok(Self {
1107            subgraph,
1108            device,
1109            stats: std::sync::Mutex::new(GpuKernelStats::default()),
1110            stream_id: 0, // Default stream
1111            kernel_cache: std::sync::Mutex::new(KernelCache {
1112                cached_ops: HashMap::new(),
1113                cache_size: 0,
1114                max_cache_size: 64 * 1024 * 1024, // 64MB cache
1115            }),
1116        })
1117    }
1118
1119    /// Create a GPU kernel with specific stream ID.
1120    pub fn with_stream(subgraph: SubGraph, device: Device, stream_id: usize) -> Result<Self> {
1121        Ok(Self {
1122            subgraph,
1123            device,
1124            stats: std::sync::Mutex::new(GpuKernelStats::default()),
1125            stream_id,
1126            kernel_cache: std::sync::Mutex::new(KernelCache {
1127                cached_ops: HashMap::new(),
1128                cache_size: 0,
1129                max_cache_size: 64 * 1024 * 1024, // 64MB cache
1130            }),
1131        })
1132    }
1133
1134    /// Execute a single operation on GPU using Candle.
1135    fn execute_gpu_operation(
1136        &self,
1137        op_type: &str,
1138        inputs: &[CandleTensor],
1139    ) -> Result<Vec<CandleTensor>> {
1140        match op_type {
1141            "Add" => {
1142                if inputs.len() != 2 {
1143                    return Err(anyhow!("Add requires exactly 2 inputs"));
1144                }
1145                let result = (&inputs[0] + &inputs[1])?;
1146                Ok(vec![result])
1147            }
1148
1149            "Sub" => {
1150                if inputs.len() != 2 {
1151                    return Err(anyhow!("Sub requires exactly 2 inputs"));
1152                }
1153                let result = (&inputs[0] - &inputs[1])?;
1154                Ok(vec![result])
1155            }
1156
1157            "Mul" => {
1158                if inputs.len() != 2 {
1159                    return Err(anyhow!("Mul requires exactly 2 inputs"));
1160                }
1161                let result = (&inputs[0] * &inputs[1])?;
1162                Ok(vec![result])
1163            }
1164
1165            "Div" => {
1166                if inputs.len() != 2 {
1167                    return Err(anyhow!("Div requires exactly 2 inputs"));
1168                }
1169                let result = (&inputs[0] / &inputs[1])?;
1170                Ok(vec![result])
1171            }
1172
1173            "MatMul" => {
1174                if inputs.len() != 2 {
1175                    return Err(anyhow!("MatMul requires exactly 2 inputs"));
1176                }
1177                let result = inputs[0].matmul(&inputs[1])?;
1178                Ok(vec![result])
1179            }
1180
1181            "ReLU" => {
1182                if inputs.len() != 1 {
1183                    return Err(anyhow!("ReLU requires exactly 1 input"));
1184                }
1185                let zero = inputs[0].zeros_like()?;
1186                let result = inputs[0].maximum(&zero)?;
1187                Ok(vec![result])
1188            }
1189
1190            "Softmax" => {
1191                if inputs.len() != 1 {
1192                    return Err(anyhow!("Softmax requires exactly 1 input"));
1193                }
1194                let result = candle_nn::ops::softmax_last_dim(&inputs[0])?;
1195                Ok(vec![result])
1196            }
1197
1198            "Sigmoid" => {
1199                if inputs.len() != 1 {
1200                    return Err(anyhow!("Sigmoid requires exactly 1 input"));
1201                }
1202                // Sigmoid(x) = 1 / (1 + exp(-x))
1203                let neg_input = inputs[0].neg()?;
1204                let exp_neg = neg_input.exp()?;
1205                let one = inputs[0].ones_like()?;
1206                let denominator = (&one + &exp_neg)?;
1207                let result = one.div(&denominator)?;
1208                Ok(vec![result])
1209            }
1210
1211            "Tanh" => {
1212                if inputs.len() != 1 {
1213                    return Err(anyhow!("Tanh requires exactly 1 input"));
1214                }
1215                let result = inputs[0].tanh()?;
1216                Ok(vec![result])
1217            }
1218
1219            "GELU" => {
1220                if inputs.len() != 1 {
1221                    return Err(anyhow!("GELU requires exactly 1 input"));
1222                }
1223                // GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
1224                let x = &inputs[0];
1225
1226                // Calculate x^3
1227                let x_cubed = x.powf(3.0)?;
1228
1229                // Calculate 0.044715 * x^3
1230                let coeff_tensor = x_cubed.affine(0.044715, 0.0)?;
1231
1232                // Calculate x + 0.044715 * x^3
1233                let x_plus_coeff = (x + &coeff_tensor)?;
1234
1235                // Calculate sqrt(2/π) * (x + 0.044715 * x^3)
1236                let sqrt_2_over_pi = (2.0 / std::f64::consts::PI).sqrt() as f64;
1237                let inner = x_plus_coeff.affine(sqrt_2_over_pi, 0.0)?;
1238
1239                // Calculate tanh(inner)
1240                let tanh_inner = inner.tanh()?;
1241
1242                // Calculate 1 + tanh(inner)
1243                let one = x.ones_like()?;
1244                let one_plus_tanh = (&one + &tanh_inner)?;
1245
1246                // Calculate 0.5 * x
1247                let half_x = x.affine(0.5, 0.0)?;
1248
1249                // Final result: 0.5 * x * (1 + tanh(...))
1250                let result = (&half_x * &one_plus_tanh)?;
1251                Ok(vec![result])
1252            }
1253
1254            "MaxPool" => {
1255                if inputs.len() != 1 {
1256                    return Err(anyhow!("MaxPool requires exactly 1 input"));
1257                }
1258                // For now, implement a simplified 2x2 max pooling
1259                // In practice, this would use proper pooling parameters from attributes
1260                let input = &inputs[0];
1261                let dims = input.dims();
1262
1263                if dims.len() < 3 {
1264                    return Err(anyhow!("MaxPool requires at least 3D input (CHW)"));
1265                }
1266
1267                // Simple 2x2 max pooling implementation
1268                // This is a placeholder - real implementation would use proper stride and kernel size
1269                let result = input.clone(); // Placeholder - return input for now
1270                Ok(vec![result])
1271            }
1272
1273            "AveragePool" => {
1274                if inputs.len() != 1 {
1275                    return Err(anyhow!("AveragePool requires exactly 1 input"));
1276                }
1277                // For now, implement a simplified 2x2 average pooling
1278                // In practice, this would use proper pooling parameters from attributes
1279                let input = &inputs[0];
1280                let dims = input.dims();
1281
1282                if dims.len() < 3 {
1283                    return Err(anyhow!("AveragePool requires at least 3D input (CHW)"));
1284                }
1285
1286                // Simple 2x2 average pooling implementation
1287                // This is a placeholder - real implementation would use proper stride and kernel size
1288                let result = input.clone(); // Placeholder - return input for now
1289                Ok(vec![result])
1290            }
1291
1292            "Conv" => {
1293                if inputs.len() < 2 {
1294                    return Err(anyhow!(
1295                        "Conv requires at least 2 inputs (input and weights)"
1296                    ));
1297                }
1298                let input = &inputs[0];
1299                let weights = &inputs[1];
1300
1301                // Basic 2D convolution using Candle's conv2d
1302                // In practice, this would parse stride, padding, etc. from attributes
1303                let result = input.conv2d(weights, 1, 1, 1, 1)?; // stride=1, padding=1
1304                Ok(vec![result])
1305            }
1306
1307            "ConvTranspose" => {
1308                if inputs.len() < 2 {
1309                    return Err(anyhow!("ConvTranspose requires at least 2 inputs"));
1310                }
1311                let input = &inputs[0];
1312                let weights = &inputs[1];
1313
1314                // Transpose convolution (deconvolution)
1315                // For now, use conv2d as placeholder - real implementation would use conv_transpose2d
1316                let result = input.conv2d(weights, 1, 1, 1, 1)?;
1317                Ok(vec![result])
1318            }
1319
1320            "BatchNormalization" => {
1321                if inputs.len() < 5 {
1322                    return Err(anyhow!(
1323                        "BatchNormalization requires 5 inputs: input, scale, bias, mean, var"
1324                    ));
1325                }
1326                let input = &inputs[0];
1327                let scale = &inputs[1]; // gamma
1328                let bias = &inputs[2]; // beta
1329                let mean = &inputs[3]; // running mean
1330                let var = &inputs[4]; // running variance
1331
1332                // BatchNorm formula: (x - mean) / sqrt(var + eps) * scale + bias
1333                let eps = 1e-5; // epsilon for numerical stability
1334
1335                // Expand dimensions if needed for broadcasting
1336                let input_dims = input.dims();
1337                let _batch_size = input_dims[0];
1338                let channels = input_dims[1];
1339
1340                // Reshape scale, bias, mean, var for broadcasting
1341                let scale_reshaped = if scale.dims().len() == 1 {
1342                    scale.reshape(&[1, channels, 1, 1])?
1343                } else {
1344                    scale.clone()
1345                };
1346
1347                let bias_reshaped = if bias.dims().len() == 1 {
1348                    bias.reshape(&[1, channels, 1, 1])?
1349                } else {
1350                    bias.clone()
1351                };
1352
1353                let mean_reshaped = if mean.dims().len() == 1 {
1354                    mean.reshape(&[1, channels, 1, 1])?
1355                } else {
1356                    mean.clone()
1357                };
1358
1359                let var_reshaped = if var.dims().len() == 1 {
1360                    var.reshape(&[1, channels, 1, 1])?
1361                } else {
1362                    var.clone()
1363                };
1364
1365                // Calculate: (x - mean) / sqrt(var + eps) * scale + bias
1366                let normalized = (input - &mean_reshaped)?;
1367                let var_plus_eps = (&var_reshaped + eps)?;
1368                let std_dev = var_plus_eps.sqrt()?;
1369                let normalized_scaled = (&normalized / &std_dev)?;
1370                let scaled = (&normalized_scaled * &scale_reshaped)?;
1371                let result = (&scaled + &bias_reshaped)?;
1372
1373                Ok(vec![result])
1374            }
1375
1376            "LayerNormalization" => {
1377                if inputs.len() < 3 {
1378                    return Err(anyhow!(
1379                        "LayerNormalization requires 3 inputs: input, scale, bias"
1380                    ));
1381                }
1382                let input = &inputs[0];
1383                let scale = &inputs[1]; // gamma
1384                let bias = &inputs[2]; // beta
1385
1386                // LayerNorm: normalize over the last dimension(s)
1387                let eps = 1e-5;
1388                let dims = input.dims();
1389                let last_dim = dims.len() - 1;
1390
1391                // Calculate mean and variance over the last dimension
1392                let mean = input.mean_keepdim(last_dim)?;
1393                let variance = {
1394                    let diff = (input - &mean)?;
1395                    let squared = (&diff * &diff)?;
1396                    squared.mean_keepdim(last_dim)?
1397                };
1398
1399                // Normalize: (x - mean) / sqrt(var + eps) * scale + bias
1400                let normalized = (input - &mean)?;
1401                let var_plus_eps = (&variance + eps)?;
1402                let std_dev = var_plus_eps.sqrt()?;
1403                let normalized_scaled = (&normalized / &std_dev)?;
1404                let scaled = (&normalized_scaled * scale)?;
1405                let result = (&scaled + bias)?;
1406
1407                Ok(vec![result])
1408            }
1409
1410            "GlobalAveragePool" => {
1411                if inputs.len() != 1 {
1412                    return Err(anyhow!("GlobalAveragePool requires exactly 1 input"));
1413                }
1414                let input = &inputs[0];
1415                let dims = input.dims();
1416
1417                if dims.len() != 4 {
1418                    return Err(anyhow!("GlobalAveragePool expects 4D input (NCHW)"));
1419                }
1420
1421                // Global average pooling: average over spatial dimensions (H, W)
1422                let result = input.mean_keepdim(2)?.mean_keepdim(3)?;
1423                Ok(vec![result])
1424            }
1425
1426            "Reshape" => {
1427                if inputs.len() != 1 {
1428                    return Err(anyhow!("Reshape requires exactly 1 input"));
1429                }
1430                // For simplicity, just return the input (reshape params would come from attributes)
1431                Ok(vec![inputs[0].clone()])
1432            }
1433
1434            _ => Err(anyhow!("Unsupported GPU operation: {}", op_type)),
1435        }
1436    }
1437
1438    /// Convert RONN Tensor to Candle Tensor.
1439    fn ronn_to_candle(&self, tensor: &ronn_core::tensor::Tensor) -> Result<CandleTensor> {
1440        let data = tensor.to_vec()?;
1441        let shape = tensor.shape();
1442        let dtype = match tensor.dtype() {
1443            DataType::F32 => candle_core::DType::F32,
1444            DataType::F16 => candle_core::DType::F16,
1445            DataType::F64 => candle_core::DType::F64,
1446            DataType::U8 => candle_core::DType::U8,
1447            DataType::U32 => candle_core::DType::U32,
1448            _ => candle_core::DType::F32, // Fallback
1449        };
1450
1451        let candle_tensor =
1452            CandleTensor::from_vec(data, shape.as_slice(), &self.device)?.to_dtype(dtype)?;
1453
1454        Ok(candle_tensor)
1455    }
1456
1457    /// Generate operation signature for caching.
1458    fn generate_operation_signature(&self, op_type: &str, inputs: &[CandleTensor]) -> String {
1459        use std::collections::hash_map::DefaultHasher;
1460        use std::hash::{Hash, Hasher};
1461
1462        let mut hasher = DefaultHasher::new();
1463        op_type.hash(&mut hasher);
1464
1465        // Hash input shapes and dtypes
1466        for input in inputs {
1467            input.dims().hash(&mut hasher);
1468            format!("{:?}", input.dtype()).hash(&mut hasher);
1469        }
1470
1471        format!("{}_{:x}", op_type, hasher.finish())
1472    }
1473
1474    /// Check if operation can use mixed precision.
1475    fn can_use_mixed_precision(&self, op_type: &str) -> bool {
1476        matches!(
1477            op_type,
1478            "Add"
1479                | "Sub"
1480                | "Mul"
1481                | "MatMul"
1482                | "Conv"
1483                | "ReLU"
1484                | "Sigmoid"
1485                | "Tanh"
1486                | "GELU"
1487                | "BatchNormalization"
1488                | "LayerNormalization"
1489        )
1490    }
1491
1492    /// Convert tensors to mixed precision if beneficial.
1493    fn apply_mixed_precision(
1494        &self,
1495        inputs: &[CandleTensor],
1496        op_type: &str,
1497    ) -> Result<Vec<CandleTensor>> {
1498        if !self.can_use_mixed_precision(op_type) {
1499            return Ok(inputs.to_vec());
1500        }
1501
1502        let mut converted = Vec::new();
1503        for input in inputs {
1504            // Convert large tensors to FP16 for memory efficiency
1505            let element_count = input.dims().iter().product::<usize>();
1506            if element_count > 1024 && input.dtype() == candle_core::DType::F32 {
1507                let fp16_tensor = input.to_dtype(candle_core::DType::F16)?;
1508                converted.push(fp16_tensor);
1509                debug!("Converted tensor to FP16 for operation: {}", op_type);
1510            } else {
1511                converted.push(input.clone());
1512            }
1513        }
1514        Ok(converted)
1515    }
1516
1517    /// Execute operation with caching and optimization.
1518    fn execute_optimized_operation(
1519        &self,
1520        op_type: &str,
1521        inputs: &[CandleTensor],
1522    ) -> Result<Vec<CandleTensor>> {
1523        let signature = self.generate_operation_signature(op_type, inputs);
1524
1525        // Check cache first
1526        {
1527            let mut cache = self.kernel_cache.lock().unwrap();
1528            if let Some(cached_op) = cache.cached_ops.get_mut(&signature) {
1529                cached_op.hit_count += 1;
1530                cached_op.last_accessed = std::time::Instant::now();
1531                debug!(
1532                    "Cache hit for operation: {} (signature: {})",
1533                    op_type, signature
1534                );
1535            }
1536        }
1537
1538        // Apply mixed precision if beneficial
1539        let optimized_inputs = self.apply_mixed_precision(inputs, op_type)?;
1540
1541        // Execute the operation
1542        let result = self.execute_gpu_operation(op_type, &optimized_inputs)?;
1543
1544        // Cache the operation
1545        {
1546            let mut cache = self.kernel_cache.lock().unwrap();
1547            let cached_op = CachedOperation {
1548                signature: signature.clone(),
1549                execution_path: OptimizedPath::Single(op_type.to_string()),
1550                hit_count: 1,
1551                last_accessed: std::time::Instant::now(),
1552            };
1553            cache.cached_ops.insert(signature, cached_op);
1554
1555            // Simple cache eviction if needed
1556            if cache.cached_ops.len() > 1000 {
1557                self.evict_cache_entries(&mut cache);
1558            }
1559        }
1560
1561        Ok(result)
1562    }
1563
1564    /// Evict old cache entries using LRU policy.
1565    fn evict_cache_entries(&self, cache: &mut KernelCache) {
1566        let current_time = std::time::Instant::now();
1567        let mut to_remove = Vec::new();
1568
1569        for (signature, cached_op) in &cache.cached_ops {
1570            // Remove entries not accessed in the last 5 minutes
1571            if current_time
1572                .duration_since(cached_op.last_accessed)
1573                .as_secs()
1574                > 300
1575            {
1576                to_remove.push(signature.clone());
1577            }
1578        }
1579
1580        for signature in to_remove {
1581            cache.cached_ops.remove(&signature);
1582        }
1583
1584        debug!("Evicted {} cache entries", cache.cached_ops.len());
1585    }
1586
1587    /// Get cache statistics.
1588    pub fn get_cache_stats(&self) -> (usize, usize, f64) {
1589        let cache = self.kernel_cache.lock().unwrap();
1590        let total_hits: u64 = cache.cached_ops.values().map(|op| op.hit_count).sum();
1591        let cache_count = cache.cached_ops.len();
1592        let hit_rate = if cache_count > 0 {
1593            total_hits as f64 / cache_count as f64
1594        } else {
1595            0.0
1596        };
1597        (cache_count, cache.cache_size, hit_rate)
1598    }
1599
1600    /// Convert Candle Tensor to RONN Tensor.
1601    fn candle_to_ronn(&self, tensor: &CandleTensor) -> Result<ronn_core::tensor::Tensor> {
1602        let shape = tensor.dims().to_vec();
1603        let data: Vec<f32> = tensor.to_vec1()?; // Convert to F32 for now
1604
1605        let ronn_tensor = Tensor::from_data(
1606            data,
1607            shape,
1608            DataType::F32, // Simplified for now
1609            TensorLayout::RowMajor,
1610        )?;
1611
1612        Ok(ronn_tensor)
1613    }
1614}
1615
1616impl CompiledKernel for GpuKernel {
1617    fn execute(
1618        &self,
1619        inputs: &[ronn_core::tensor::Tensor],
1620    ) -> Result<Vec<ronn_core::tensor::Tensor>> {
1621        let start_time = std::time::Instant::now();
1622
1623        // Convert RONN tensors to Candle tensors
1624        let mut candle_inputs = Vec::new();
1625        for input in inputs {
1626            let candle_tensor = self.ronn_to_candle(input)?;
1627            candle_inputs.push(candle_tensor);
1628        }
1629
1630        // Execute operations with caching and optimization
1631        let mut current_tensors = candle_inputs;
1632
1633        for node in &self.subgraph.nodes {
1634            debug!(
1635                "Executing GPU operation: {} on stream {}",
1636                node.op_type, self.stream_id
1637            );
1638            let outputs = self.execute_optimized_operation(&node.op_type, &current_tensors)?;
1639            current_tensors = outputs;
1640        }
1641
1642        // Convert back to RONN tensors
1643        let mut results = Vec::new();
1644        for candle_tensor in &current_tensors {
1645            let ronn_tensor = self.candle_to_ronn(candle_tensor)?;
1646            results.push(ronn_tensor);
1647        }
1648
1649        // Update statistics
1650        let execution_time = start_time.elapsed();
1651        {
1652            let mut stats = self.stats.lock().unwrap();
1653            stats.execution_count += 1;
1654            stats.total_time_us += execution_time.as_micros() as u64;
1655
1656            if stats.execution_count == 1 {
1657                stats.min_time_us = execution_time.as_micros() as u64;
1658                stats.max_time_us = execution_time.as_micros() as u64;
1659            } else {
1660                stats.min_time_us = stats.min_time_us.min(execution_time.as_micros() as u64);
1661                stats.max_time_us = stats.max_time_us.max(execution_time.as_micros() as u64);
1662            }
1663        }
1664
1665        debug!("GPU kernel executed in {:?}", execution_time);
1666
1667        Ok(results)
1668    }
1669
1670    fn get_memory_usage(&self) -> MemoryUsage {
1671        let stats = self.stats.lock().unwrap();
1672        MemoryUsage {
1673            peak_bytes: stats.memory_peak,
1674            current_bytes: 0, // Would track current usage in practice
1675            allocation_count: stats.execution_count as usize,
1676        }
1677    }
1678
1679    fn get_performance_stats(&self) -> KernelStats {
1680        let stats = self.stats.lock().unwrap();
1681
1682        let average_time_us = if stats.execution_count > 0 {
1683            stats.total_time_us as f64 / stats.execution_count as f64
1684        } else {
1685            0.0
1686        };
1687
1688        KernelStats {
1689            execution_count: stats.execution_count,
1690            average_time_us,
1691            min_time_us: stats.min_time_us as f64,
1692            max_time_us: stats.max_time_us as f64,
1693        }
1694    }
1695}
1696
1697/// Create a default GPU execution provider.
1698pub fn create_gpu_provider() -> Result<Arc<dyn ExecutionProvider>> {
1699    Ok(Arc::new(GpuExecutionProvider::new()?))
1700}
1701
1702/// Create a GPU execution provider with custom configuration.
1703pub fn create_gpu_provider_with_config(
1704    config: GpuProviderConfig,
1705) -> Result<Arc<dyn ExecutionProvider>> {
1706    Ok(Arc::new(GpuExecutionProvider::with_config(config)?))
1707}
1708
1709#[cfg(test)]
1710mod tests {
1711    use super::*;
1712    use ronn_core::{AttributeValue, GraphNode};
1713
1714    fn create_test_subgraph() -> SubGraph {
1715        let node = GraphNode {
1716            id: 0,
1717            op_type: "Add".to_string(),
1718            attributes: HashMap::new(),
1719            inputs: vec!["input1".to_string(), "input2".to_string()],
1720            outputs: vec!["output1".to_string()],
1721            name: Some("gpu_add".to_string()),
1722        };
1723
1724        SubGraph {
1725            nodes: vec![node],
1726            edges: vec![],
1727            inputs: vec!["input1".to_string(), "input2".to_string()],
1728            outputs: vec!["output1".to_string()],
1729        }
1730    }
1731
1732    #[test]
1733    fn test_gpu_provider_creation() {
1734        // This test may fail if no GPU is available
1735        match GpuExecutionProvider::new() {
1736            Ok(provider) => {
1737                assert_eq!(provider.provider_id(), ProviderId::GPU);
1738
1739                let capability = provider.get_capability();
1740                assert_eq!(capability.performance_profile, PerformanceProfile::GPU);
1741                assert!(!capability.supported_ops.is_empty());
1742                assert!(capability.data_types.contains(&DataType::F32));
1743            }
1744            Err(e) => {
1745                println!("GPU not available: {}", e);
1746                // Test passes if GPU is not available
1747            }
1748        }
1749    }
1750
1751    #[test]
1752    fn test_gpu_provider_config() {
1753        let config = GpuProviderConfig {
1754            device_ids: vec![0],
1755            enable_mixed_precision: false,
1756            enable_tensor_cores: false,
1757            ..Default::default()
1758        };
1759
1760        match GpuExecutionProvider::with_config(config) {
1761            Ok(provider) => {
1762                assert!(!provider.get_config().enable_mixed_precision);
1763                assert!(!provider.get_config().enable_tensor_cores);
1764            }
1765            Err(_) => {
1766                // GPU not available, test passes
1767            }
1768        }
1769    }
1770
1771    #[test]
1772    fn test_operation_support() {
1773        match GpuExecutionProvider::new() {
1774            Ok(provider) => {
1775                // Test GPU-optimized operations
1776                assert!(provider.supports_operation("Add"));
1777                assert!(provider.supports_operation("MatMul"));
1778                assert!(provider.supports_operation("Conv"));
1779                assert!(provider.supports_operation("ReLU"));
1780                assert!(!provider.supports_operation("NonexistentOp"));
1781
1782                // Test cost estimation
1783                let add_op = OperatorSpec {
1784                    op_type: "Add".to_string(),
1785                    input_types: vec![DataType::F32],
1786                    output_types: vec![DataType::F32],
1787                    attributes: HashMap::new(),
1788                };
1789
1790                let conv_op = OperatorSpec {
1791                    op_type: "Conv".to_string(),
1792                    input_types: vec![DataType::F32],
1793                    output_types: vec![DataType::F32],
1794                    attributes: HashMap::new(),
1795                };
1796
1797                let add_cost = provider.estimate_cost(&add_op);
1798                let conv_cost = provider.estimate_cost(&conv_op);
1799
1800                // GPU should be very efficient for both, but Conv more complex
1801                assert!(conv_cost > add_cost);
1802                assert!(add_cost < 1.0); // Should be less than 1.0 for GPU
1803            }
1804            Err(_) => {
1805                // GPU not available
1806            }
1807        }
1808    }
1809
1810    #[test]
1811    fn test_subgraph_compilation() {
1812        match GpuExecutionProvider::new() {
1813            Ok(provider) => {
1814                let subgraph = create_test_subgraph();
1815
1816                match provider.compile_subgraph(subgraph) {
1817                    Ok(kernel) => {
1818                        let stats = kernel.get_performance_stats();
1819                        assert_eq!(stats.execution_count, 0); // Not executed yet
1820                    }
1821                    Err(e) => {
1822                        println!("Compilation failed: {}", e);
1823                    }
1824                }
1825            }
1826            Err(_) => {
1827                // GPU not available
1828            }
1829        }
1830    }
1831
1832    #[test]
1833    fn test_factory_functions() {
1834        // Test factory functions (may fail if no GPU)
1835        match create_gpu_provider() {
1836            Ok(provider) => {
1837                assert_eq!(provider.provider_id(), ProviderId::GPU);
1838            }
1839            Err(_) => {
1840                // GPU not available
1841            }
1842        }
1843
1844        let config = GpuProviderConfig::default();
1845        match create_gpu_provider_with_config(config) {
1846            Ok(provider) => {
1847                assert_eq!(provider.provider_id(), ProviderId::GPU);
1848            }
1849            Err(_) => {
1850                // GPU not available
1851            }
1852        }
1853    }
1854
1855    #[test]
1856    fn test_complex_gpu_operations() {
1857        // Test that complex operations are supported
1858        match GpuExecutionProvider::new() {
1859            Ok(provider) => {
1860                let capability = provider.get_capability();
1861
1862                // Check that complex operations are supported
1863                assert!(capability.supported_ops.contains("Conv"));
1864                assert!(capability.supported_ops.contains("BatchNormalization"));
1865                assert!(capability.supported_ops.contains("LayerNormalization"));
1866                assert!(capability.supported_ops.contains("GlobalAveragePool"));
1867
1868                // Test activation functions
1869                assert!(capability.supported_ops.contains("Sigmoid"));
1870                assert!(capability.supported_ops.contains("Tanh"));
1871                assert!(capability.supported_ops.contains("GELU"));
1872
1873                println!(
1874                    "✅ GPU provider supports {} complex operations",
1875                    capability.supported_ops.len()
1876                );
1877            }
1878            Err(e) => {
1879                println!("GPU not available: {}", e);
1880                // Test passes if GPU is not available
1881            }
1882        }
1883    }
1884
1885    #[test]
1886    fn test_gpu_benchmarks() {
1887        // Comprehensive GPU benchmarks
1888        match GpuExecutionProvider::new() {
1889            Ok(provider) => {
1890                println!("🚀 Running GPU performance benchmarks...");
1891
1892                // Test basic operation performance
1893                benchmark_basic_operations(&provider);
1894
1895                // Test complex operations
1896                benchmark_complex_operations(&provider);
1897
1898                // Test mixed precision performance
1899                benchmark_mixed_precision(&provider);
1900
1901                // Test cache performance
1902                benchmark_cache_performance(&provider);
1903
1904                // Test memory throughput
1905                benchmark_memory_throughput(&provider);
1906
1907                println!("✅ GPU benchmarks completed!");
1908            }
1909            Err(e) => {
1910                println!("GPU not available for benchmarks: {}", e);
1911            }
1912        }
1913    }
1914
1915    fn benchmark_basic_operations(provider: &GpuExecutionProvider) {
1916        use std::time::Instant;
1917
1918        println!("  📊 Basic Operations Benchmark:");
1919
1920        let ops = ["Add", "Mul", "MatMul", "ReLU", "Sigmoid", "Tanh"];
1921
1922        for op in ops {
1923            let subgraph = create_single_op_subgraph(op);
1924            if let Ok(kernel) = provider.compile_subgraph(subgraph) {
1925                // Create test tensors
1926                let test_input = ronn_core::tensor::Tensor::ones(
1927                    vec![1024, 1024],
1928                    DataType::F32,
1929                    TensorLayout::RowMajor,
1930                )
1931                .unwrap();
1932
1933                let start = Instant::now();
1934                for _ in 0..10 {
1935                    let _ = kernel.execute(&[test_input.clone()]);
1936                }
1937                let avg_time = start.elapsed() / 10;
1938
1939                println!("    {} avg: {:?}", op, avg_time);
1940            }
1941        }
1942    }
1943
1944    fn benchmark_complex_operations(provider: &GpuExecutionProvider) {
1945        use std::time::Instant;
1946
1947        println!("  🧠 Complex Operations Benchmark:");
1948
1949        let complex_ops = [
1950            "Conv",
1951            "BatchNormalization",
1952            "LayerNormalization",
1953            "GlobalAveragePool",
1954        ];
1955
1956        for op in complex_ops {
1957            let subgraph = create_single_op_subgraph(op);
1958            if let Ok(kernel) = provider.compile_subgraph(subgraph) {
1959                // Create appropriate test tensors for complex ops
1960                let test_input = match op {
1961                    "Conv" => ronn_core::tensor::Tensor::ones(
1962                        vec![1, 64, 224, 224], // NCHW format
1963                        DataType::F32,
1964                        TensorLayout::RowMajor,
1965                    )
1966                    .unwrap(),
1967                    _ => ronn_core::tensor::Tensor::ones(
1968                        vec![32, 512],
1969                        DataType::F32,
1970                        TensorLayout::RowMajor,
1971                    )
1972                    .unwrap(),
1973                };
1974
1975                let start = Instant::now();
1976                for _ in 0..5 {
1977                    let _ = kernel.execute(&[test_input.clone()]);
1978                }
1979                let avg_time = start.elapsed() / 5;
1980
1981                println!("    {} avg: {:?}", op, avg_time);
1982            }
1983        }
1984    }
1985
1986    fn benchmark_mixed_precision(provider: &GpuExecutionProvider) {
1987        println!("  🎯 Mixed Precision Benchmark:");
1988
1989        if provider.has_tensor_cores() {
1990            println!("    Tensor cores available - mixed precision enabled");
1991        } else {
1992            println!("    Tensor cores not available - mixed precision simulation");
1993        }
1994
1995        // Test F32 vs F16 performance difference
1996        let sizes = [512, 1024, 2048];
1997
1998        for size in sizes {
1999            println!("    Matrix size: {}x{}", size, size);
2000            // Would test actual F16 vs F32 performance here
2001        }
2002    }
2003
2004    fn benchmark_cache_performance(provider: &GpuExecutionProvider) {
2005        use std::time::Instant;
2006
2007        println!("  💾 Cache Performance Benchmark:");
2008
2009        let subgraph = create_single_op_subgraph("Add");
2010        if let Ok(kernel) = provider.compile_subgraph(subgraph) {
2011            let test_input = ronn_core::tensor::Tensor::ones(
2012                vec![512, 512],
2013                DataType::F32,
2014                TensorLayout::RowMajor,
2015            )
2016            .unwrap();
2017
2018            // Warm up cache
2019            for _ in 0..5 {
2020                let _ = kernel.execute(&[test_input.clone()]);
2021            }
2022
2023            // Measure cached performance
2024            let start = Instant::now();
2025            for _ in 0..20 {
2026                let _ = kernel.execute(&[test_input.clone()]);
2027            }
2028            let cached_time = start.elapsed() / 20;
2029
2030            // Note: In a full implementation, we would downcast to access cache stats
2031            // For now, just show the cached execution time
2032            // let (cache_entries, _cache_size, hit_rate) = kernel.get_cache_stats();
2033            // println!("    Cache entries: {}, Hit rate: {:.2}%", cache_entries, hit_rate * 100.0);
2034
2035            println!("    Cached execution avg: {:?}", cached_time);
2036        }
2037    }
2038
2039    fn benchmark_memory_throughput(provider: &GpuExecutionProvider) {
2040        println!("  🚀 Memory Throughput Benchmark:");
2041
2042        if let Ok((total_memory, _used_memory)) = provider.get_gpu_memory_info() {
2043            println!(
2044                "    GPU Memory: {:.2} GB total",
2045                total_memory as f64 / (1024.0 * 1024.0 * 1024.0)
2046            );
2047        }
2048
2049        let allocator = provider.get_allocator();
2050
2051        // Test allocation/deallocation speed
2052        let start = std::time::Instant::now();
2053        let mut buffers = Vec::new();
2054
2055        for _ in 0..100 {
2056            if let Ok(buffer) = allocator.allocate(&[1024], DataType::F32) {
2057                buffers.push(buffer);
2058            }
2059        }
2060
2061        let alloc_time = start.elapsed();
2062
2063        let start = std::time::Instant::now();
2064        for buffer in buffers {
2065            let _ = allocator.deallocate(buffer);
2066        }
2067        let dealloc_time = start.elapsed();
2068
2069        println!("    100 allocations: {:?}", alloc_time);
2070        println!("    100 deallocations: {:?}", dealloc_time);
2071    }
2072
2073    fn create_single_op_subgraph(op_type: &str) -> SubGraph {
2074        let node = GraphNode {
2075            id: 0,
2076            op_type: op_type.to_string(),
2077            attributes: HashMap::new(),
2078            inputs: vec!["input1".to_string()],
2079            outputs: vec!["output1".to_string()],
2080            name: Some(format!("test_{}", op_type)),
2081        };
2082
2083        SubGraph {
2084            nodes: vec![node],
2085            edges: vec![],
2086            inputs: vec!["input1".to_string()],
2087            outputs: vec!["output1".to_string()],
2088        }
2089    }
2090
2091    #[test]
2092    fn test_stream_execution() {
2093        // Test stream-based async execution
2094        match GpuExecutionProvider::new() {
2095            Ok(provider) => {
2096                if provider.get_config().stream_count > 1 {
2097                    println!(
2098                        "🌊 Testing stream-based execution with {} streams",
2099                        provider.get_config().stream_count
2100                    );
2101
2102                    // Test creating kernels with different streams
2103                    let subgraph1 = create_single_op_subgraph("Add");
2104                    let subgraph2 = create_single_op_subgraph("Mul");
2105
2106                    if let (Ok(kernel1), Ok(kernel2)) = (
2107                        GpuKernel::with_stream(subgraph1, provider.device().clone(), 0),
2108                        GpuKernel::with_stream(subgraph2, provider.device().clone(), 1),
2109                    ) {
2110                        println!("    ✅ Successfully created kernels on different streams");
2111
2112                        // Test concurrent execution (simplified)
2113                        let test_input = ronn_core::tensor::Tensor::ones(
2114                            vec![256, 256],
2115                            DataType::F32,
2116                            TensorLayout::RowMajor,
2117                        )
2118                        .unwrap();
2119
2120                        let start = std::time::Instant::now();
2121                        let _result1 = kernel1.execute(&[test_input.clone()]);
2122                        let _result2 = kernel2.execute(&[test_input.clone()]);
2123                        let concurrent_time = start.elapsed();
2124
2125                        println!("    Concurrent execution time: {:?}", concurrent_time);
2126                    }
2127                } else {
2128                    println!("🌊 Single stream execution (stream_count = 1)");
2129                }
2130            }
2131            Err(_) => {
2132                println!("GPU not available for stream testing");
2133            }
2134        }
2135    }
2136
2137    #[test]
2138    fn test_kernel_cache_operations() {
2139        // Test kernel caching system
2140        match GpuExecutionProvider::new() {
2141            Ok(provider) => {
2142                println!("💾 Testing kernel cache operations...");
2143
2144                let subgraph = create_single_op_subgraph("MatMul");
2145                if let Ok(kernel) = provider.compile_subgraph(subgraph) {
2146                    let test_input = ronn_core::tensor::Tensor::ones(
2147                        vec![128, 128],
2148                        DataType::F32,
2149                        TensorLayout::RowMajor,
2150                    )
2151                    .unwrap();
2152
2153                    // Execute multiple times to populate cache
2154                    for i in 0..10 {
2155                        let _ = kernel.execute(&[test_input.clone()]);
2156
2157                        if i == 0 {
2158                            println!("    First execution (cold cache)");
2159                        } else if i == 9 {
2160                            println!("    Tenth execution (warm cache)");
2161                        }
2162                    }
2163
2164                    println!("    ✅ Cache operations test completed");
2165                }
2166            }
2167            Err(_) => {
2168                println!("GPU not available for cache testing");
2169            }
2170        }
2171    }
2172}