1use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use anyhow::{Result, anyhow};
10use candle_core::{Device, Tensor as CandleTensor};
11use ronn_core::tensor::Tensor;
12use ronn_core::{
13 CompiledKernel, DataType, ExecutionProvider, KernelStats, MemoryType, MemoryUsage,
14 OperatorSpec, PerformanceProfile, ProviderCapability, ProviderConfig, ProviderId,
15 ResourceRequirements, SubGraph, TensorAllocator, TensorLayout,
16};
17use tracing::{debug, info, warn};
18
19use super::allocator::create_gpu_allocator;
20use super::cuda_kernels::{CudaCompileOptions, CudaKernelManager};
21use super::memory_manager::{MultiGpuMemoryConfig, MultiGpuMemoryManager};
22use super::topology::{GpuTopologyManager, TopologyConfig};
23
24pub struct GpuExecutionProvider {
26 devices: Vec<Device>,
28 allocators: Vec<Arc<dyn TensorAllocator>>,
30 supported_ops: HashSet<String>,
32 config: GpuProviderConfig,
34 device_manager: Arc<std::sync::Mutex<MultiGpuManager>>,
36 cuda_kernel_managers: Vec<Option<CudaKernelManager>>,
38 memory_manager: Option<Arc<MultiGpuMemoryManager>>,
40 topology_manager: Option<Arc<GpuTopologyManager>>,
42}
43
44#[derive(Debug, Clone)]
46pub struct GpuProviderConfig {
47 pub device_ids: Vec<usize>,
49 pub primary_device_id: usize,
51 pub memory_limit: Option<usize>,
53 pub enable_mixed_precision: bool,
55 pub enable_tensor_cores: bool,
57 pub stream_count: usize,
59 pub enable_multi_gpu: bool,
61 pub enable_p2p_transfer: bool,
63 pub load_balancing: LoadBalancingStrategy,
65 pub enable_custom_kernels: bool,
67 pub cuda_compile_options: CudaCompileOptions,
69 pub memory_config: MultiGpuMemoryConfig,
71 pub topology_config: TopologyConfig,
73}
74
75#[derive(Debug, Clone, Copy)]
77pub enum LoadBalancingStrategy {
78 RoundRobin,
80 MemoryBased,
82 UtilizationBased,
84 OperationBased,
86 CostModel,
88}
89
90impl Default for LoadBalancingStrategy {
91 fn default() -> Self {
92 LoadBalancingStrategy::RoundRobin
93 }
94}
95
96impl Default for GpuProviderConfig {
97 fn default() -> Self {
98 Self {
99 device_ids: vec![0],
100 primary_device_id: 0,
101 memory_limit: None,
102 enable_mixed_precision: true,
103 enable_tensor_cores: true,
104 stream_count: 1,
105 enable_multi_gpu: false,
106 enable_p2p_transfer: true,
107 load_balancing: LoadBalancingStrategy::default(),
108 enable_custom_kernels: true,
109 cuda_compile_options: CudaCompileOptions::default(),
110 memory_config: MultiGpuMemoryConfig::default(),
111 topology_config: TopologyConfig::default(),
112 }
113 }
114}
115
116#[derive(Debug)]
118pub struct GpuKernel {
119 subgraph: SubGraph,
121 device: Device,
123 stats: std::sync::Mutex<GpuKernelStats>,
125 stream_id: usize,
127 kernel_cache: std::sync::Mutex<KernelCache>,
129}
130
131#[derive(Debug, Default)]
133struct KernelCache {
134 cached_ops: HashMap<String, CachedOperation>,
136 cache_size: usize,
138 max_cache_size: usize,
140}
141
142#[derive(Debug, Clone)]
144struct CachedOperation {
145 signature: String,
147 execution_path: OptimizedPath,
149 hit_count: u64,
151 last_accessed: std::time::Instant,
153}
154
155#[derive(Debug, Clone)]
157enum OptimizedPath {
158 Single(String),
160 Fused(Vec<String>),
162 MixedPrecision {
164 fp16_ops: Vec<String>,
165 fp32_ops: Vec<String>,
166 },
167}
168
169#[derive(Debug, Default)]
170struct GpuKernelStats {
171 execution_count: u64,
172 total_time_us: u64,
173 min_time_us: u64,
174 max_time_us: u64,
175 memory_peak: usize,
176}
177
178#[derive(Debug)]
180struct MultiGpuManager {
181 config: GpuProviderConfig,
183 device_stats: HashMap<usize, DeviceStats>,
185 round_robin_counter: usize,
187 memory_usage: HashMap<usize, usize>,
189}
190
191#[derive(Debug, Default, Clone)]
193struct DeviceStats {
194 operation_count: u64,
196 current_memory: usize,
198 peak_memory: usize,
200 avg_execution_time: f64,
202 utilization: f32,
204}
205
206impl MultiGpuManager {
207 fn new(config: GpuProviderConfig) -> Self {
209 let mut device_stats = HashMap::new();
210 let mut memory_usage = HashMap::new();
211
212 for &device_id in &config.device_ids {
213 device_stats.insert(device_id, DeviceStats::default());
214 memory_usage.insert(device_id, 0);
215 }
216
217 Self {
218 config,
219 device_stats,
220 round_robin_counter: 0,
221 memory_usage,
222 }
223 }
224
225 fn select_device(&mut self, op_type: &str, memory_requirement: usize) -> usize {
227 if self.config.device_ids.len() == 1 {
228 return self.config.device_ids[0];
229 }
230
231 if !self.config.enable_multi_gpu {
232 return self.config.primary_device_id;
233 }
234
235 match self.config.load_balancing {
236 LoadBalancingStrategy::RoundRobin => {
237 let device_id =
238 self.config.device_ids[self.round_robin_counter % self.config.device_ids.len()];
239 self.round_robin_counter += 1;
240 device_id
241 }
242 LoadBalancingStrategy::MemoryBased => self.select_device_by_memory(memory_requirement),
243 LoadBalancingStrategy::UtilizationBased => self.select_device_by_utilization(),
244 LoadBalancingStrategy::OperationBased => self.select_device_by_operation_type(op_type),
245 LoadBalancingStrategy::CostModel => {
246 self.select_device_by_cost_model(op_type, memory_requirement)
247 }
248 }
249 }
250
251 fn select_device_by_memory(&self, memory_requirement: usize) -> usize {
253 self.config
254 .device_ids
255 .iter()
256 .min_by_key(|&&device_id| {
257 self.memory_usage.get(&device_id).unwrap_or(&0) + memory_requirement
258 })
259 .copied()
260 .unwrap_or(self.config.primary_device_id)
261 }
262
263 fn select_device_by_utilization(&self) -> usize {
265 self.config
266 .device_ids
267 .iter()
268 .min_by(|&&a, &&b| {
269 let util_a = self
270 .device_stats
271 .get(&a)
272 .map(|s| s.utilization)
273 .unwrap_or(0.0);
274 let util_b = self
275 .device_stats
276 .get(&b)
277 .map(|s| s.utilization)
278 .unwrap_or(0.0);
279 util_a
280 .partial_cmp(&util_b)
281 .unwrap_or(std::cmp::Ordering::Equal)
282 })
283 .copied()
284 .unwrap_or(self.config.primary_device_id)
285 }
286
287 fn select_device_by_operation_type(&self, op_type: &str) -> usize {
289 match op_type {
291 "MatMul" | "Conv" | "ConvTranspose" => self
293 .config
294 .device_ids
295 .iter()
296 .min()
297 .copied()
298 .unwrap_or(self.config.primary_device_id),
299 "Concat" | "Split" | "Reshape" => self.select_device_by_memory(0),
301 _ => {
303 let device_id =
304 self.config.device_ids[self.round_robin_counter % self.config.device_ids.len()];
305 device_id
306 }
307 }
308 }
309
310 fn select_device_by_cost_model(&self, _op_type: &str, memory_requirement: usize) -> usize {
312 let mut best_device = self.config.primary_device_id;
313 let mut best_score = f64::INFINITY;
314
315 for &device_id in &self.config.device_ids {
316 let default_stats = DeviceStats::default();
317 let stats = self.device_stats.get(&device_id).unwrap_or(&default_stats);
318 let memory_used = self.memory_usage.get(&device_id).unwrap_or(&0);
319
320 let memory_pressure =
322 (*memory_used + memory_requirement) as f64 / (1024.0 * 1024.0 * 1024.0); let utilization_penalty = stats.utilization as f64 * 2.0;
324 let execution_time_penalty = stats.avg_execution_time / 1000.0; let total_score = memory_pressure + utilization_penalty + execution_time_penalty;
327
328 if total_score < best_score {
329 best_score = total_score;
330 best_device = device_id;
331 }
332 }
333
334 best_device
335 }
336
337 fn update_device_stats(
339 &mut self,
340 device_id: usize,
341 execution_time_us: u64,
342 memory_used: usize,
343 ) {
344 if let Some(stats) = self.device_stats.get_mut(&device_id) {
345 stats.operation_count += 1;
346 stats.current_memory = memory_used;
347 stats.peak_memory = stats.peak_memory.max(memory_used);
348
349 let alpha = 0.1; if stats.avg_execution_time == 0.0 {
352 stats.avg_execution_time = execution_time_us as f64;
353 } else {
354 stats.avg_execution_time =
355 alpha * execution_time_us as f64 + (1.0 - alpha) * stats.avg_execution_time;
356 }
357 }
358
359 self.memory_usage.insert(device_id, memory_used);
360 }
361
362 fn get_device_stats(&self) -> &HashMap<usize, DeviceStats> {
364 &self.device_stats
365 }
366}
367
368impl GpuExecutionProvider {
369 #[cfg(feature = "gpu")]
371 pub fn new() -> Result<Self> {
372 Self::with_config(GpuProviderConfig::default())
373 }
374
375 #[cfg(feature = "gpu")]
377 pub fn with_config(config: GpuProviderConfig) -> Result<Self> {
378 let mut devices = Vec::new();
380 let mut allocators = Vec::new();
381 let mut cuda_kernel_managers = Vec::new();
382
383 for &device_id in &config.device_ids {
384 let device = Self::create_gpu_device(device_id)?;
385 info!("Created GPU device {}: {:?}", device_id, device);
386
387 let cuda_manager = if config.enable_custom_kernels && matches!(device, Device::Cuda(_))
389 {
390 match CudaKernelManager::with_options(
391 device.clone(),
392 config.cuda_compile_options.clone(),
393 ) {
394 Ok(manager) => {
395 info!("Created CUDA kernel manager for device {}", device_id);
396 Some(manager)
397 }
398 Err(e) => {
399 warn!(
400 "Failed to create CUDA kernel manager for device {}: {}",
401 device_id, e
402 );
403 None
404 }
405 }
406 } else {
407 None
408 };
409
410 devices.push(device);
411 cuda_kernel_managers.push(cuda_manager);
412
413 let allocator = create_gpu_allocator().map_err(|e| {
415 anyhow!(
416 "Failed to create GPU allocator for device {}: {}",
417 device_id,
418 e
419 )
420 })?;
421 allocators.push(allocator);
422 }
423
424 if devices.is_empty() {
425 return Err(anyhow!("No GPU devices configured"));
426 }
427
428 info!("Created GPU provider with {} devices", devices.len());
429
430 let device_manager = Arc::new(std::sync::Mutex::new(MultiGpuManager::new(config.clone())));
432
433 let memory_manager = if config.enable_multi_gpu && config.device_ids.len() > 1 {
435 match MultiGpuMemoryManager::new(
436 config.device_ids.clone(),
437 config.memory_config.clone(),
438 ) {
439 Ok(manager) => {
440 info!("Created multi-GPU memory manager");
441 Some(Arc::new(manager))
442 }
443 Err(e) => {
444 warn!("Failed to create multi-GPU memory manager: {}", e);
445 None
446 }
447 }
448 } else {
449 None
450 };
451
452 let topology_manager = if config.enable_multi_gpu && config.device_ids.len() > 1 {
454 match GpuTopologyManager::new(config.topology_config.clone()) {
455 Ok(mut manager) => {
456 if let Err(e) = manager.discover_topology() {
458 warn!("Failed to discover GPU topology: {}", e);
459 } else {
460 info!("GPU topology discovered successfully");
461 }
462 Some(Arc::new(manager))
463 }
464 Err(e) => {
465 warn!("Failed to create topology manager: {}", e);
466 None
467 }
468 }
469 } else {
470 None
471 };
472
473 let mut supported_ops = HashSet::new();
475
476 supported_ops.insert("Add".to_string());
478 supported_ops.insert("Sub".to_string());
479 supported_ops.insert("Mul".to_string());
480 supported_ops.insert("Div".to_string());
481
482 supported_ops.insert("MatMul".to_string());
484 supported_ops.insert("Gemm".to_string());
485
486 supported_ops.insert("Conv".to_string());
488 supported_ops.insert("ConvTranspose".to_string());
489
490 supported_ops.insert("MaxPool".to_string());
492 supported_ops.insert("AveragePool".to_string());
493 supported_ops.insert("GlobalAveragePool".to_string());
494
495 supported_ops.insert("ReLU".to_string());
497 supported_ops.insert("Sigmoid".to_string());
498 supported_ops.insert("Tanh".to_string());
499 supported_ops.insert("Softmax".to_string());
500 supported_ops.insert("GELU".to_string());
501
502 supported_ops.insert("BatchNormalization".to_string());
504 supported_ops.insert("LayerNormalization".to_string());
505
506 supported_ops.insert("Sum".to_string());
508 supported_ops.insert("Mean".to_string());
509 supported_ops.insert("Max".to_string());
510 supported_ops.insert("Min".to_string());
511
512 supported_ops.insert("Reshape".to_string());
514 supported_ops.insert("Transpose".to_string());
515 supported_ops.insert("Concat".to_string());
516 supported_ops.insert("Split".to_string());
517
518 info!(
519 "GPU provider supports {} operation types",
520 supported_ops.len()
521 );
522
523 Ok(Self {
524 devices,
525 allocators,
526 supported_ops,
527 config,
528 device_manager,
529 cuda_kernel_managers,
530 memory_manager,
531 topology_manager,
532 })
533 }
534
535 #[cfg(not(feature = "gpu"))]
537 pub fn new() -> Result<Self> {
538 Err(anyhow!("GPU support not compiled in"))
539 }
540
541 #[cfg(not(feature = "gpu"))]
543 pub fn with_config(_config: GpuProviderConfig) -> Result<Self> {
544 Err(anyhow!("GPU support not compiled in"))
545 }
546
547 #[cfg(feature = "gpu")]
549 fn create_gpu_device(device_id: usize) -> Result<Device> {
550 if let Ok(device) = Device::new_cuda(device_id) {
552 info!("Using CUDA device {}", device_id);
553 return Ok(device);
554 }
555
556 #[cfg(target_os = "macos")]
558 {
559 if let Ok(device) = Device::new_metal(device_id) {
560 info!("Using Metal device {}", device_id);
561 return Ok(device);
562 }
563 }
564
565 Err(anyhow!("No GPU devices available"))
566 }
567
568 pub fn device(&self) -> &Device {
570 &self.devices[0]
571 }
572
573 pub fn get_config(&self) -> &GpuProviderConfig {
575 &self.config
576 }
577
578 pub fn supports_operation(&self, op_type: &str) -> bool {
580 self.supported_ops.contains(op_type)
581 }
582
583 pub fn estimate_cost(&self, op_spec: &OperatorSpec) -> f64 {
585 match op_spec.op_type.as_str() {
587 "Add" | "Sub" | "Mul" | "Div" => 0.1, "ReLU" | "Sigmoid" | "Tanh" => 0.2, "MatMul" | "Gemm" => 0.5, "Conv" => 0.8, "ConvTranspose" => 1.2, "BatchNormalization" => 0.3, "Softmax" => 0.4, "MaxPool" | "AveragePool" => 0.3, _ => 1.0, }
597 }
598
599 #[cfg(feature = "gpu")]
601 pub fn has_tensor_cores(&self) -> bool {
602 matches!(self.device, Device::Cuda(_)) && self.config.enable_tensor_cores
605 }
606
607 #[cfg(not(feature = "gpu"))]
609 pub fn has_tensor_cores(&self) -> bool {
610 false
611 }
612
613 #[cfg(feature = "gpu")]
615 pub fn get_gpu_memory_info(&self) -> Result<(usize, usize)> {
616 match &self.devices[0] {
619 Device::Cuda(_) => Ok((8 * 1024 * 1024 * 1024, 0)), Device::Metal(_) => Ok((8 * 1024 * 1024 * 1024, 0)), _ => Err(anyhow!("Not a GPU device")),
622 }
623 }
624
625 #[cfg(not(feature = "gpu"))]
627 pub fn get_gpu_memory_info(&self) -> Result<(usize, usize)> {
628 Err(anyhow!("GPU support not available"))
629 }
630}
631
632impl Default for GpuExecutionProvider {
633 fn default() -> Self {
634 Self::new().expect("Failed to create default GPU provider")
635 }
636}
637
638impl ExecutionProvider for GpuExecutionProvider {
639 fn provider_id(&self) -> ProviderId {
640 ProviderId::GPU
641 }
642
643 fn get_capability(&self) -> ProviderCapability {
644 let mut data_types = vec![
645 DataType::F32,
646 DataType::F16, DataType::F64,
648 DataType::U8,
649 DataType::U32,
650 ];
651
652 if self.has_tensor_cores() {
654 data_types.insert(0, DataType::F16); }
657
658 let gpu_memory = self
659 .get_gpu_memory_info()
660 .map(|(total, _)| total)
661 .unwrap_or(0);
662
663 ProviderCapability {
664 supported_ops: self.supported_ops.clone(),
665 data_types,
666 memory_types: vec![MemoryType::DeviceMemory, MemoryType::SharedMemory],
667 performance_profile: PerformanceProfile::GPU,
668 resource_requirements: ResourceRequirements {
669 min_memory_bytes: Some(512 * 1024 * 1024), cpu_features: vec![], gpu_memory_bytes: Some(gpu_memory),
672 },
673 }
674 }
675
676 fn can_handle(&self, operators: &[OperatorSpec]) -> Vec<bool> {
677 operators
678 .iter()
679 .map(|op| self.supports_operation(&op.op_type))
680 .collect()
681 }
682
683 fn compile_subgraph(&self, subgraph: SubGraph) -> Result<Box<dyn CompiledKernel>> {
684 debug!(
685 "Compiling subgraph with {} nodes for GPU",
686 subgraph.nodes.len()
687 );
688
689 for node in &subgraph.nodes {
691 if !self.supports_operation(&node.op_type) {
692 return Err(anyhow!(
693 "Unsupported GPU operation '{}' in subgraph",
694 node.op_type
695 ));
696 }
697 }
698
699 let mut device_manager = self.device_manager.lock().unwrap();
701 let primary_op = subgraph
702 .nodes
703 .first()
704 .map(|n| n.op_type.as_str())
705 .unwrap_or("Unknown");
706
707 let estimated_memory = subgraph.nodes.len() * 1024 * 1024; let selected_device_id = device_manager.select_device(primary_op, estimated_memory);
711
712 let device_index = self
714 .config
715 .device_ids
716 .iter()
717 .position(|&id| id == selected_device_id)
718 .unwrap_or(0);
719
720 let selected_device = self.devices[device_index].clone();
721
722 debug!(
723 "Selected GPU device {} for subgraph compilation",
724 selected_device_id
725 );
726
727 drop(device_manager); let stream_id = selected_device_id % self.config.stream_count;
731 let kernel = GpuKernel::with_stream(subgraph, selected_device, stream_id)?;
732
733 debug!(
734 "Successfully compiled GPU kernel on device {}",
735 selected_device_id
736 );
737
738 Ok(Box::new(kernel))
739 }
740
741 fn get_allocator(&self) -> Arc<dyn TensorAllocator> {
742 self.allocators[0].clone()
744 }
745
746 fn configure(&mut self, config: ProviderConfig) -> Result<()> {
747 if let Some(memory_limit) = config.memory_limit {
749 self.config.memory_limit = Some(memory_limit);
750 info!("Updated GPU memory limit to {} bytes", memory_limit);
751 }
752
753 for (key, value) in &config.custom_options {
755 match key.as_str() {
756 "enable_mixed_precision" => {
757 if let Ok(enable) = value.parse::<bool>() {
758 self.config.enable_mixed_precision = enable;
759 info!("Updated mixed precision to {}", enable);
760 }
761 }
762 "enable_tensor_cores" => {
763 if let Ok(enable) = value.parse::<bool>() {
764 self.config.enable_tensor_cores = enable;
765 info!("Updated tensor cores to {}", enable);
766 }
767 }
768 "stream_count" => {
769 if let Ok(count) = value.parse::<usize>() {
770 self.config.stream_count = count;
771 info!("Updated stream count to {}", count);
772 }
773 }
774 _ => {
775 warn!("Unknown GPU configuration option: {}", key);
776 }
777 }
778 }
779
780 Ok(())
781 }
782
783 fn shutdown(&self) -> Result<()> {
784 info!("Shutting down GPU execution provider");
785
786 debug!("GPU provider shutdown complete");
790
791 Ok(())
792 }
793}
794
795impl GpuExecutionProvider {
796 pub fn get_device_allocator(&self, device_id: usize) -> Option<Arc<dyn TensorAllocator>> {
798 let device_index = self
799 .config
800 .device_ids
801 .iter()
802 .position(|&id| id == device_id)?;
803 self.allocators.get(device_index).cloned()
804 }
805
806 pub fn get_multi_gpu_stats(&self) -> HashMap<usize, DeviceStats> {
808 let device_manager = self.device_manager.lock().unwrap();
809 device_manager.get_device_stats().clone()
810 }
811
812 pub fn set_multi_gpu_enabled(&mut self, enabled: bool) {
814 self.config.enable_multi_gpu = enabled;
815 info!(
816 "Multi-GPU support {}",
817 if enabled { "enabled" } else { "disabled" }
818 );
819 }
820
821 pub fn set_load_balancing_strategy(&mut self, strategy: LoadBalancingStrategy) {
823 info!("Updated load balancing strategy to {:?}", strategy);
824 self.config.load_balancing = strategy;
825 }
826
827 pub fn device_count(&self) -> usize {
829 self.devices.len()
830 }
831
832 pub fn has_device(&self, device_id: usize) -> bool {
834 self.config.device_ids.contains(&device_id)
835 }
836
837 pub fn has_custom_kernels(&self, device_id: usize) -> bool {
839 if let Some(device_index) = self
840 .config
841 .device_ids
842 .iter()
843 .position(|&id| id == device_id)
844 {
845 self.cuda_kernel_managers
846 .get(device_index)
847 .map(|manager| manager.is_some())
848 .unwrap_or(false)
849 } else {
850 false
851 }
852 }
853
854 pub fn get_custom_kernel_ops(&self, device_id: usize) -> Vec<String> {
856 if self.has_custom_kernels(device_id) {
857 vec![
858 "FusedMatMulBias".to_string(),
859 "OptimizedSoftmax".to_string(),
860 "FusedConvBnRelu".to_string(),
861 "WarpReduceSum".to_string(),
862 "TensorCoreGemm".to_string(),
863 "FastGelu".to_string(),
864 ]
865 } else {
866 vec![]
867 }
868 }
869
870 pub fn try_execute_with_custom_kernel(
872 &self,
873 device_id: usize,
874 op_type: &str,
875 inputs: &[CandleTensor],
876 ) -> Result<Option<Vec<CandleTensor>>> {
877 let device_index = self
878 .config
879 .device_ids
880 .iter()
881 .position(|&id| id == device_id)
882 .ok_or_else(|| anyhow!("Device {} not found", device_id))?;
883
884 if let Some(Some(kernel_manager)) = self.cuda_kernel_managers.get(device_index) {
885 let tensor_shape = inputs
887 .first()
888 .map(|t| t.shape().dims().to_vec())
889 .unwrap_or_else(|| vec![1]);
890
891 match kernel_manager.get_optimized_kernel(op_type, &tensor_shape) {
892 Ok(mut kernel) => {
893 info!("Using custom CUDA kernel for operation: {}", op_type);
894
895 let mut outputs: Vec<CandleTensor> = inputs.iter()
897 .map(|input| input.clone()) .collect();
899
900 kernel_manager.execute_kernel(&mut kernel, inputs, &mut outputs)?;
902
903 Ok(Some(outputs))
904 }
905 Err(_) => {
906 Ok(None)
908 }
909 }
910 } else {
911 Ok(None)
913 }
914 }
915
916 pub fn clear_kernel_caches(&self) {
918 for kernel_manager in self.cuda_kernel_managers.iter().flatten() {
919 kernel_manager.clear_cache();
920 }
921 info!("Cleared all CUDA kernel caches");
922 }
923
924 pub fn get_kernel_cache_stats(&self) -> Vec<super::cuda_kernels::CacheStats> {
926 self.cuda_kernel_managers
927 .iter()
928 .filter_map(|manager| manager.as_ref().map(|km| km.get_cache_stats()))
929 .collect()
930 }
931
932 pub fn transfer_tensor_between_devices(
934 &self,
935 tensor: &CandleTensor,
936 target_device_id: usize,
937 ) -> Result<CandleTensor> {
938 if let Some(ref _memory_manager) = self.memory_manager {
939 info!(
941 "Using multi-GPU memory manager for tensor transfer to device {}",
942 target_device_id
943 );
944
945 let target_device = &self.devices[self
953 .config
954 .device_ids
955 .iter()
956 .position(|&id| id == target_device_id)
957 .unwrap_or(0)];
958 Ok(tensor.to_device(target_device)?)
959 } else {
960 let target_device = &self.devices[self
962 .config
963 .device_ids
964 .iter()
965 .position(|&id| id == target_device_id)
966 .unwrap_or(0)];
967 Ok(tensor.to_device(target_device)?)
968 }
969 }
970
971 pub fn synchronize_memory(&self) -> Result<()> {
973 if let Some(ref memory_manager) = self.memory_manager {
974 memory_manager.synchronize_all()
975 } else {
976 Ok(())
978 }
979 }
980
981 pub fn get_memory_statistics(&self) -> HashMap<usize, super::memory_manager::MemoryPoolStats> {
983 if let Some(ref memory_manager) = self.memory_manager {
984 memory_manager.get_memory_stats()
985 } else {
986 HashMap::new()
987 }
988 }
989
990 pub fn get_global_memory_stats(&self) -> Option<super::memory_manager::GlobalMemoryStats> {
992 self.memory_manager.as_ref().map(|mm| mm.get_global_stats())
993 }
994
995 pub fn get_p2p_connectivity(
997 &self,
998 ) -> HashMap<(usize, usize), super::memory_manager::P2PCapability> {
999 if let Some(ref memory_manager) = self.memory_manager {
1000 memory_manager.get_p2p_info()
1001 } else {
1002 HashMap::new()
1003 }
1004 }
1005
1006 pub fn is_p2p_available(&self, src_device: usize, dst_device: usize) -> bool {
1008 if let Some(ref memory_manager) = self.memory_manager {
1009 let p2p_info = memory_manager.get_p2p_info();
1010 p2p_info
1011 .get(&(src_device, dst_device))
1012 .map(|cap| cap.supported)
1013 .unwrap_or(false)
1014 } else {
1015 false
1016 }
1017 }
1018
1019 pub fn optimize_memory_layout(
1021 &self,
1022 access_pattern: &super::memory_manager::AccessPattern,
1023 ) -> Result<super::memory_manager::MemoryLayout> {
1024 if let Some(ref memory_manager) = self.memory_manager {
1025 memory_manager.optimize_memory_layout(access_pattern)
1026 } else {
1027 Err(anyhow!("Multi-GPU memory manager not available"))
1028 }
1029 }
1030
1031 pub fn get_topology(&self) -> Option<super::topology::GpuTopology> {
1033 self.topology_manager.as_ref().map(|tm| tm.get_topology())
1034 }
1035
1036 pub fn optimize_workload_placement(
1038 &self,
1039 workload: &super::topology::Workload,
1040 strategy: &str,
1041 ) -> Result<super::topology::PlacementPlan> {
1042 if let Some(ref topology_manager) = self.topology_manager {
1043 topology_manager.optimize_placement(workload, strategy)
1044 } else {
1045 Err(anyhow!("GPU topology manager not available"))
1046 }
1047 }
1048
1049 pub fn compare_placement_strategies(
1051 &self,
1052 workload: &super::topology::Workload,
1053 strategies: &[String],
1054 ) -> Result<Vec<(String, super::topology::PlacementPlan)>> {
1055 if let Some(ref topology_manager) = self.topology_manager {
1056 topology_manager.compare_strategies(workload, strategies)
1057 } else {
1058 Err(anyhow!("GPU topology manager not available"))
1059 }
1060 }
1061
1062 pub fn get_available_placement_strategies(&self) -> Vec<String> {
1064 if let Some(ref topology_manager) = self.topology_manager {
1065 topology_manager.get_available_strategies()
1066 } else {
1067 vec![]
1068 }
1069 }
1070
1071 pub fn has_topology_support(&self) -> bool {
1073 self.topology_manager.is_some()
1074 }
1075
1076 pub fn get_detailed_device_info(&self) -> HashMap<usize, super::topology::GpuDeviceInfo> {
1078 if let Some(ref topology_manager) = self.topology_manager {
1079 topology_manager.get_topology().devices
1080 } else {
1081 HashMap::new()
1082 }
1083 }
1084
1085 pub fn get_interconnect_info(
1087 &self,
1088 ) -> HashMap<(usize, usize), super::topology::InterconnectLink> {
1089 if let Some(ref topology_manager) = self.topology_manager {
1090 topology_manager.get_topology().links
1091 } else {
1092 HashMap::new()
1093 }
1094 }
1095
1096 pub fn auto_select_devices(&self, workload: &super::topology::Workload) -> Result<Vec<usize>> {
1098 let plan = self.optimize_workload_placement(workload, "locality_aware")?;
1099 Ok(plan.device_assignments.values().copied().collect())
1100 }
1101}
1102
1103impl GpuKernel {
1104 pub fn new(subgraph: SubGraph, device: Device) -> Result<Self> {
1106 Ok(Self {
1107 subgraph,
1108 device,
1109 stats: std::sync::Mutex::new(GpuKernelStats::default()),
1110 stream_id: 0, kernel_cache: std::sync::Mutex::new(KernelCache {
1112 cached_ops: HashMap::new(),
1113 cache_size: 0,
1114 max_cache_size: 64 * 1024 * 1024, }),
1116 })
1117 }
1118
1119 pub fn with_stream(subgraph: SubGraph, device: Device, stream_id: usize) -> Result<Self> {
1121 Ok(Self {
1122 subgraph,
1123 device,
1124 stats: std::sync::Mutex::new(GpuKernelStats::default()),
1125 stream_id,
1126 kernel_cache: std::sync::Mutex::new(KernelCache {
1127 cached_ops: HashMap::new(),
1128 cache_size: 0,
1129 max_cache_size: 64 * 1024 * 1024, }),
1131 })
1132 }
1133
1134 fn execute_gpu_operation(
1136 &self,
1137 op_type: &str,
1138 inputs: &[CandleTensor],
1139 ) -> Result<Vec<CandleTensor>> {
1140 match op_type {
1141 "Add" => {
1142 if inputs.len() != 2 {
1143 return Err(anyhow!("Add requires exactly 2 inputs"));
1144 }
1145 let result = (&inputs[0] + &inputs[1])?;
1146 Ok(vec![result])
1147 }
1148
1149 "Sub" => {
1150 if inputs.len() != 2 {
1151 return Err(anyhow!("Sub requires exactly 2 inputs"));
1152 }
1153 let result = (&inputs[0] - &inputs[1])?;
1154 Ok(vec![result])
1155 }
1156
1157 "Mul" => {
1158 if inputs.len() != 2 {
1159 return Err(anyhow!("Mul requires exactly 2 inputs"));
1160 }
1161 let result = (&inputs[0] * &inputs[1])?;
1162 Ok(vec![result])
1163 }
1164
1165 "Div" => {
1166 if inputs.len() != 2 {
1167 return Err(anyhow!("Div requires exactly 2 inputs"));
1168 }
1169 let result = (&inputs[0] / &inputs[1])?;
1170 Ok(vec![result])
1171 }
1172
1173 "MatMul" => {
1174 if inputs.len() != 2 {
1175 return Err(anyhow!("MatMul requires exactly 2 inputs"));
1176 }
1177 let result = inputs[0].matmul(&inputs[1])?;
1178 Ok(vec![result])
1179 }
1180
1181 "ReLU" => {
1182 if inputs.len() != 1 {
1183 return Err(anyhow!("ReLU requires exactly 1 input"));
1184 }
1185 let zero = inputs[0].zeros_like()?;
1186 let result = inputs[0].maximum(&zero)?;
1187 Ok(vec![result])
1188 }
1189
1190 "Softmax" => {
1191 if inputs.len() != 1 {
1192 return Err(anyhow!("Softmax requires exactly 1 input"));
1193 }
1194 let result = candle_nn::ops::softmax_last_dim(&inputs[0])?;
1195 Ok(vec![result])
1196 }
1197
1198 "Sigmoid" => {
1199 if inputs.len() != 1 {
1200 return Err(anyhow!("Sigmoid requires exactly 1 input"));
1201 }
1202 let neg_input = inputs[0].neg()?;
1204 let exp_neg = neg_input.exp()?;
1205 let one = inputs[0].ones_like()?;
1206 let denominator = (&one + &exp_neg)?;
1207 let result = one.div(&denominator)?;
1208 Ok(vec![result])
1209 }
1210
1211 "Tanh" => {
1212 if inputs.len() != 1 {
1213 return Err(anyhow!("Tanh requires exactly 1 input"));
1214 }
1215 let result = inputs[0].tanh()?;
1216 Ok(vec![result])
1217 }
1218
1219 "GELU" => {
1220 if inputs.len() != 1 {
1221 return Err(anyhow!("GELU requires exactly 1 input"));
1222 }
1223 let x = &inputs[0];
1225
1226 let x_cubed = x.powf(3.0)?;
1228
1229 let coeff_tensor = x_cubed.affine(0.044715, 0.0)?;
1231
1232 let x_plus_coeff = (x + &coeff_tensor)?;
1234
1235 let sqrt_2_over_pi = (2.0 / std::f64::consts::PI).sqrt() as f64;
1237 let inner = x_plus_coeff.affine(sqrt_2_over_pi, 0.0)?;
1238
1239 let tanh_inner = inner.tanh()?;
1241
1242 let one = x.ones_like()?;
1244 let one_plus_tanh = (&one + &tanh_inner)?;
1245
1246 let half_x = x.affine(0.5, 0.0)?;
1248
1249 let result = (&half_x * &one_plus_tanh)?;
1251 Ok(vec![result])
1252 }
1253
1254 "MaxPool" => {
1255 if inputs.len() != 1 {
1256 return Err(anyhow!("MaxPool requires exactly 1 input"));
1257 }
1258 let input = &inputs[0];
1261 let dims = input.dims();
1262
1263 if dims.len() < 3 {
1264 return Err(anyhow!("MaxPool requires at least 3D input (CHW)"));
1265 }
1266
1267 let result = input.clone(); Ok(vec![result])
1271 }
1272
1273 "AveragePool" => {
1274 if inputs.len() != 1 {
1275 return Err(anyhow!("AveragePool requires exactly 1 input"));
1276 }
1277 let input = &inputs[0];
1280 let dims = input.dims();
1281
1282 if dims.len() < 3 {
1283 return Err(anyhow!("AveragePool requires at least 3D input (CHW)"));
1284 }
1285
1286 let result = input.clone(); Ok(vec![result])
1290 }
1291
1292 "Conv" => {
1293 if inputs.len() < 2 {
1294 return Err(anyhow!(
1295 "Conv requires at least 2 inputs (input and weights)"
1296 ));
1297 }
1298 let input = &inputs[0];
1299 let weights = &inputs[1];
1300
1301 let result = input.conv2d(weights, 1, 1, 1, 1)?; Ok(vec![result])
1305 }
1306
1307 "ConvTranspose" => {
1308 if inputs.len() < 2 {
1309 return Err(anyhow!("ConvTranspose requires at least 2 inputs"));
1310 }
1311 let input = &inputs[0];
1312 let weights = &inputs[1];
1313
1314 let result = input.conv2d(weights, 1, 1, 1, 1)?;
1317 Ok(vec![result])
1318 }
1319
1320 "BatchNormalization" => {
1321 if inputs.len() < 5 {
1322 return Err(anyhow!(
1323 "BatchNormalization requires 5 inputs: input, scale, bias, mean, var"
1324 ));
1325 }
1326 let input = &inputs[0];
1327 let scale = &inputs[1]; let bias = &inputs[2]; let mean = &inputs[3]; let var = &inputs[4]; let eps = 1e-5; let input_dims = input.dims();
1337 let _batch_size = input_dims[0];
1338 let channels = input_dims[1];
1339
1340 let scale_reshaped = if scale.dims().len() == 1 {
1342 scale.reshape(&[1, channels, 1, 1])?
1343 } else {
1344 scale.clone()
1345 };
1346
1347 let bias_reshaped = if bias.dims().len() == 1 {
1348 bias.reshape(&[1, channels, 1, 1])?
1349 } else {
1350 bias.clone()
1351 };
1352
1353 let mean_reshaped = if mean.dims().len() == 1 {
1354 mean.reshape(&[1, channels, 1, 1])?
1355 } else {
1356 mean.clone()
1357 };
1358
1359 let var_reshaped = if var.dims().len() == 1 {
1360 var.reshape(&[1, channels, 1, 1])?
1361 } else {
1362 var.clone()
1363 };
1364
1365 let normalized = (input - &mean_reshaped)?;
1367 let var_plus_eps = (&var_reshaped + eps)?;
1368 let std_dev = var_plus_eps.sqrt()?;
1369 let normalized_scaled = (&normalized / &std_dev)?;
1370 let scaled = (&normalized_scaled * &scale_reshaped)?;
1371 let result = (&scaled + &bias_reshaped)?;
1372
1373 Ok(vec![result])
1374 }
1375
1376 "LayerNormalization" => {
1377 if inputs.len() < 3 {
1378 return Err(anyhow!(
1379 "LayerNormalization requires 3 inputs: input, scale, bias"
1380 ));
1381 }
1382 let input = &inputs[0];
1383 let scale = &inputs[1]; let bias = &inputs[2]; let eps = 1e-5;
1388 let dims = input.dims();
1389 let last_dim = dims.len() - 1;
1390
1391 let mean = input.mean_keepdim(last_dim)?;
1393 let variance = {
1394 let diff = (input - &mean)?;
1395 let squared = (&diff * &diff)?;
1396 squared.mean_keepdim(last_dim)?
1397 };
1398
1399 let normalized = (input - &mean)?;
1401 let var_plus_eps = (&variance + eps)?;
1402 let std_dev = var_plus_eps.sqrt()?;
1403 let normalized_scaled = (&normalized / &std_dev)?;
1404 let scaled = (&normalized_scaled * scale)?;
1405 let result = (&scaled + bias)?;
1406
1407 Ok(vec![result])
1408 }
1409
1410 "GlobalAveragePool" => {
1411 if inputs.len() != 1 {
1412 return Err(anyhow!("GlobalAveragePool requires exactly 1 input"));
1413 }
1414 let input = &inputs[0];
1415 let dims = input.dims();
1416
1417 if dims.len() != 4 {
1418 return Err(anyhow!("GlobalAveragePool expects 4D input (NCHW)"));
1419 }
1420
1421 let result = input.mean_keepdim(2)?.mean_keepdim(3)?;
1423 Ok(vec![result])
1424 }
1425
1426 "Reshape" => {
1427 if inputs.len() != 1 {
1428 return Err(anyhow!("Reshape requires exactly 1 input"));
1429 }
1430 Ok(vec![inputs[0].clone()])
1432 }
1433
1434 _ => Err(anyhow!("Unsupported GPU operation: {}", op_type)),
1435 }
1436 }
1437
1438 fn ronn_to_candle(&self, tensor: &ronn_core::tensor::Tensor) -> Result<CandleTensor> {
1440 let data = tensor.to_vec()?;
1441 let shape = tensor.shape();
1442 let dtype = match tensor.dtype() {
1443 DataType::F32 => candle_core::DType::F32,
1444 DataType::F16 => candle_core::DType::F16,
1445 DataType::F64 => candle_core::DType::F64,
1446 DataType::U8 => candle_core::DType::U8,
1447 DataType::U32 => candle_core::DType::U32,
1448 _ => candle_core::DType::F32, };
1450
1451 let candle_tensor =
1452 CandleTensor::from_vec(data, shape.as_slice(), &self.device)?.to_dtype(dtype)?;
1453
1454 Ok(candle_tensor)
1455 }
1456
1457 fn generate_operation_signature(&self, op_type: &str, inputs: &[CandleTensor]) -> String {
1459 use std::collections::hash_map::DefaultHasher;
1460 use std::hash::{Hash, Hasher};
1461
1462 let mut hasher = DefaultHasher::new();
1463 op_type.hash(&mut hasher);
1464
1465 for input in inputs {
1467 input.dims().hash(&mut hasher);
1468 format!("{:?}", input.dtype()).hash(&mut hasher);
1469 }
1470
1471 format!("{}_{:x}", op_type, hasher.finish())
1472 }
1473
1474 fn can_use_mixed_precision(&self, op_type: &str) -> bool {
1476 matches!(
1477 op_type,
1478 "Add"
1479 | "Sub"
1480 | "Mul"
1481 | "MatMul"
1482 | "Conv"
1483 | "ReLU"
1484 | "Sigmoid"
1485 | "Tanh"
1486 | "GELU"
1487 | "BatchNormalization"
1488 | "LayerNormalization"
1489 )
1490 }
1491
1492 fn apply_mixed_precision(
1494 &self,
1495 inputs: &[CandleTensor],
1496 op_type: &str,
1497 ) -> Result<Vec<CandleTensor>> {
1498 if !self.can_use_mixed_precision(op_type) {
1499 return Ok(inputs.to_vec());
1500 }
1501
1502 let mut converted = Vec::new();
1503 for input in inputs {
1504 let element_count = input.dims().iter().product::<usize>();
1506 if element_count > 1024 && input.dtype() == candle_core::DType::F32 {
1507 let fp16_tensor = input.to_dtype(candle_core::DType::F16)?;
1508 converted.push(fp16_tensor);
1509 debug!("Converted tensor to FP16 for operation: {}", op_type);
1510 } else {
1511 converted.push(input.clone());
1512 }
1513 }
1514 Ok(converted)
1515 }
1516
1517 fn execute_optimized_operation(
1519 &self,
1520 op_type: &str,
1521 inputs: &[CandleTensor],
1522 ) -> Result<Vec<CandleTensor>> {
1523 let signature = self.generate_operation_signature(op_type, inputs);
1524
1525 {
1527 let mut cache = self.kernel_cache.lock().unwrap();
1528 if let Some(cached_op) = cache.cached_ops.get_mut(&signature) {
1529 cached_op.hit_count += 1;
1530 cached_op.last_accessed = std::time::Instant::now();
1531 debug!(
1532 "Cache hit for operation: {} (signature: {})",
1533 op_type, signature
1534 );
1535 }
1536 }
1537
1538 let optimized_inputs = self.apply_mixed_precision(inputs, op_type)?;
1540
1541 let result = self.execute_gpu_operation(op_type, &optimized_inputs)?;
1543
1544 {
1546 let mut cache = self.kernel_cache.lock().unwrap();
1547 let cached_op = CachedOperation {
1548 signature: signature.clone(),
1549 execution_path: OptimizedPath::Single(op_type.to_string()),
1550 hit_count: 1,
1551 last_accessed: std::time::Instant::now(),
1552 };
1553 cache.cached_ops.insert(signature, cached_op);
1554
1555 if cache.cached_ops.len() > 1000 {
1557 self.evict_cache_entries(&mut cache);
1558 }
1559 }
1560
1561 Ok(result)
1562 }
1563
1564 fn evict_cache_entries(&self, cache: &mut KernelCache) {
1566 let current_time = std::time::Instant::now();
1567 let mut to_remove = Vec::new();
1568
1569 for (signature, cached_op) in &cache.cached_ops {
1570 if current_time
1572 .duration_since(cached_op.last_accessed)
1573 .as_secs()
1574 > 300
1575 {
1576 to_remove.push(signature.clone());
1577 }
1578 }
1579
1580 for signature in to_remove {
1581 cache.cached_ops.remove(&signature);
1582 }
1583
1584 debug!("Evicted {} cache entries", cache.cached_ops.len());
1585 }
1586
1587 pub fn get_cache_stats(&self) -> (usize, usize, f64) {
1589 let cache = self.kernel_cache.lock().unwrap();
1590 let total_hits: u64 = cache.cached_ops.values().map(|op| op.hit_count).sum();
1591 let cache_count = cache.cached_ops.len();
1592 let hit_rate = if cache_count > 0 {
1593 total_hits as f64 / cache_count as f64
1594 } else {
1595 0.0
1596 };
1597 (cache_count, cache.cache_size, hit_rate)
1598 }
1599
1600 fn candle_to_ronn(&self, tensor: &CandleTensor) -> Result<ronn_core::tensor::Tensor> {
1602 let shape = tensor.dims().to_vec();
1603 let data: Vec<f32> = tensor.to_vec1()?; let ronn_tensor = Tensor::from_data(
1606 data,
1607 shape,
1608 DataType::F32, TensorLayout::RowMajor,
1610 )?;
1611
1612 Ok(ronn_tensor)
1613 }
1614}
1615
1616impl CompiledKernel for GpuKernel {
1617 fn execute(
1618 &self,
1619 inputs: &[ronn_core::tensor::Tensor],
1620 ) -> Result<Vec<ronn_core::tensor::Tensor>> {
1621 let start_time = std::time::Instant::now();
1622
1623 let mut candle_inputs = Vec::new();
1625 for input in inputs {
1626 let candle_tensor = self.ronn_to_candle(input)?;
1627 candle_inputs.push(candle_tensor);
1628 }
1629
1630 let mut current_tensors = candle_inputs;
1632
1633 for node in &self.subgraph.nodes {
1634 debug!(
1635 "Executing GPU operation: {} on stream {}",
1636 node.op_type, self.stream_id
1637 );
1638 let outputs = self.execute_optimized_operation(&node.op_type, ¤t_tensors)?;
1639 current_tensors = outputs;
1640 }
1641
1642 let mut results = Vec::new();
1644 for candle_tensor in ¤t_tensors {
1645 let ronn_tensor = self.candle_to_ronn(candle_tensor)?;
1646 results.push(ronn_tensor);
1647 }
1648
1649 let execution_time = start_time.elapsed();
1651 {
1652 let mut stats = self.stats.lock().unwrap();
1653 stats.execution_count += 1;
1654 stats.total_time_us += execution_time.as_micros() as u64;
1655
1656 if stats.execution_count == 1 {
1657 stats.min_time_us = execution_time.as_micros() as u64;
1658 stats.max_time_us = execution_time.as_micros() as u64;
1659 } else {
1660 stats.min_time_us = stats.min_time_us.min(execution_time.as_micros() as u64);
1661 stats.max_time_us = stats.max_time_us.max(execution_time.as_micros() as u64);
1662 }
1663 }
1664
1665 debug!("GPU kernel executed in {:?}", execution_time);
1666
1667 Ok(results)
1668 }
1669
1670 fn get_memory_usage(&self) -> MemoryUsage {
1671 let stats = self.stats.lock().unwrap();
1672 MemoryUsage {
1673 peak_bytes: stats.memory_peak,
1674 current_bytes: 0, allocation_count: stats.execution_count as usize,
1676 }
1677 }
1678
1679 fn get_performance_stats(&self) -> KernelStats {
1680 let stats = self.stats.lock().unwrap();
1681
1682 let average_time_us = if stats.execution_count > 0 {
1683 stats.total_time_us as f64 / stats.execution_count as f64
1684 } else {
1685 0.0
1686 };
1687
1688 KernelStats {
1689 execution_count: stats.execution_count,
1690 average_time_us,
1691 min_time_us: stats.min_time_us as f64,
1692 max_time_us: stats.max_time_us as f64,
1693 }
1694 }
1695}
1696
1697pub fn create_gpu_provider() -> Result<Arc<dyn ExecutionProvider>> {
1699 Ok(Arc::new(GpuExecutionProvider::new()?))
1700}
1701
1702pub fn create_gpu_provider_with_config(
1704 config: GpuProviderConfig,
1705) -> Result<Arc<dyn ExecutionProvider>> {
1706 Ok(Arc::new(GpuExecutionProvider::with_config(config)?))
1707}
1708
1709#[cfg(test)]
1710mod tests {
1711 use super::*;
1712 use ronn_core::{AttributeValue, GraphNode};
1713
1714 fn create_test_subgraph() -> SubGraph {
1715 let node = GraphNode {
1716 id: 0,
1717 op_type: "Add".to_string(),
1718 attributes: HashMap::new(),
1719 inputs: vec!["input1".to_string(), "input2".to_string()],
1720 outputs: vec!["output1".to_string()],
1721 name: Some("gpu_add".to_string()),
1722 };
1723
1724 SubGraph {
1725 nodes: vec![node],
1726 edges: vec![],
1727 inputs: vec!["input1".to_string(), "input2".to_string()],
1728 outputs: vec!["output1".to_string()],
1729 }
1730 }
1731
1732 #[test]
1733 fn test_gpu_provider_creation() {
1734 match GpuExecutionProvider::new() {
1736 Ok(provider) => {
1737 assert_eq!(provider.provider_id(), ProviderId::GPU);
1738
1739 let capability = provider.get_capability();
1740 assert_eq!(capability.performance_profile, PerformanceProfile::GPU);
1741 assert!(!capability.supported_ops.is_empty());
1742 assert!(capability.data_types.contains(&DataType::F32));
1743 }
1744 Err(e) => {
1745 println!("GPU not available: {}", e);
1746 }
1748 }
1749 }
1750
1751 #[test]
1752 fn test_gpu_provider_config() {
1753 let config = GpuProviderConfig {
1754 device_ids: vec![0],
1755 enable_mixed_precision: false,
1756 enable_tensor_cores: false,
1757 ..Default::default()
1758 };
1759
1760 match GpuExecutionProvider::with_config(config) {
1761 Ok(provider) => {
1762 assert!(!provider.get_config().enable_mixed_precision);
1763 assert!(!provider.get_config().enable_tensor_cores);
1764 }
1765 Err(_) => {
1766 }
1768 }
1769 }
1770
1771 #[test]
1772 fn test_operation_support() {
1773 match GpuExecutionProvider::new() {
1774 Ok(provider) => {
1775 assert!(provider.supports_operation("Add"));
1777 assert!(provider.supports_operation("MatMul"));
1778 assert!(provider.supports_operation("Conv"));
1779 assert!(provider.supports_operation("ReLU"));
1780 assert!(!provider.supports_operation("NonexistentOp"));
1781
1782 let add_op = OperatorSpec {
1784 op_type: "Add".to_string(),
1785 input_types: vec![DataType::F32],
1786 output_types: vec![DataType::F32],
1787 attributes: HashMap::new(),
1788 };
1789
1790 let conv_op = OperatorSpec {
1791 op_type: "Conv".to_string(),
1792 input_types: vec![DataType::F32],
1793 output_types: vec![DataType::F32],
1794 attributes: HashMap::new(),
1795 };
1796
1797 let add_cost = provider.estimate_cost(&add_op);
1798 let conv_cost = provider.estimate_cost(&conv_op);
1799
1800 assert!(conv_cost > add_cost);
1802 assert!(add_cost < 1.0); }
1804 Err(_) => {
1805 }
1807 }
1808 }
1809
1810 #[test]
1811 fn test_subgraph_compilation() {
1812 match GpuExecutionProvider::new() {
1813 Ok(provider) => {
1814 let subgraph = create_test_subgraph();
1815
1816 match provider.compile_subgraph(subgraph) {
1817 Ok(kernel) => {
1818 let stats = kernel.get_performance_stats();
1819 assert_eq!(stats.execution_count, 0); }
1821 Err(e) => {
1822 println!("Compilation failed: {}", e);
1823 }
1824 }
1825 }
1826 Err(_) => {
1827 }
1829 }
1830 }
1831
1832 #[test]
1833 fn test_factory_functions() {
1834 match create_gpu_provider() {
1836 Ok(provider) => {
1837 assert_eq!(provider.provider_id(), ProviderId::GPU);
1838 }
1839 Err(_) => {
1840 }
1842 }
1843
1844 let config = GpuProviderConfig::default();
1845 match create_gpu_provider_with_config(config) {
1846 Ok(provider) => {
1847 assert_eq!(provider.provider_id(), ProviderId::GPU);
1848 }
1849 Err(_) => {
1850 }
1852 }
1853 }
1854
1855 #[test]
1856 fn test_complex_gpu_operations() {
1857 match GpuExecutionProvider::new() {
1859 Ok(provider) => {
1860 let capability = provider.get_capability();
1861
1862 assert!(capability.supported_ops.contains("Conv"));
1864 assert!(capability.supported_ops.contains("BatchNormalization"));
1865 assert!(capability.supported_ops.contains("LayerNormalization"));
1866 assert!(capability.supported_ops.contains("GlobalAveragePool"));
1867
1868 assert!(capability.supported_ops.contains("Sigmoid"));
1870 assert!(capability.supported_ops.contains("Tanh"));
1871 assert!(capability.supported_ops.contains("GELU"));
1872
1873 println!(
1874 "✅ GPU provider supports {} complex operations",
1875 capability.supported_ops.len()
1876 );
1877 }
1878 Err(e) => {
1879 println!("GPU not available: {}", e);
1880 }
1882 }
1883 }
1884
1885 #[test]
1886 fn test_gpu_benchmarks() {
1887 match GpuExecutionProvider::new() {
1889 Ok(provider) => {
1890 println!("🚀 Running GPU performance benchmarks...");
1891
1892 benchmark_basic_operations(&provider);
1894
1895 benchmark_complex_operations(&provider);
1897
1898 benchmark_mixed_precision(&provider);
1900
1901 benchmark_cache_performance(&provider);
1903
1904 benchmark_memory_throughput(&provider);
1906
1907 println!("✅ GPU benchmarks completed!");
1908 }
1909 Err(e) => {
1910 println!("GPU not available for benchmarks: {}", e);
1911 }
1912 }
1913 }
1914
1915 fn benchmark_basic_operations(provider: &GpuExecutionProvider) {
1916 use std::time::Instant;
1917
1918 println!(" 📊 Basic Operations Benchmark:");
1919
1920 let ops = ["Add", "Mul", "MatMul", "ReLU", "Sigmoid", "Tanh"];
1921
1922 for op in ops {
1923 let subgraph = create_single_op_subgraph(op);
1924 if let Ok(kernel) = provider.compile_subgraph(subgraph) {
1925 let test_input = ronn_core::tensor::Tensor::ones(
1927 vec![1024, 1024],
1928 DataType::F32,
1929 TensorLayout::RowMajor,
1930 )
1931 .unwrap();
1932
1933 let start = Instant::now();
1934 for _ in 0..10 {
1935 let _ = kernel.execute(&[test_input.clone()]);
1936 }
1937 let avg_time = start.elapsed() / 10;
1938
1939 println!(" {} avg: {:?}", op, avg_time);
1940 }
1941 }
1942 }
1943
1944 fn benchmark_complex_operations(provider: &GpuExecutionProvider) {
1945 use std::time::Instant;
1946
1947 println!(" 🧠 Complex Operations Benchmark:");
1948
1949 let complex_ops = [
1950 "Conv",
1951 "BatchNormalization",
1952 "LayerNormalization",
1953 "GlobalAveragePool",
1954 ];
1955
1956 for op in complex_ops {
1957 let subgraph = create_single_op_subgraph(op);
1958 if let Ok(kernel) = provider.compile_subgraph(subgraph) {
1959 let test_input = match op {
1961 "Conv" => ronn_core::tensor::Tensor::ones(
1962 vec![1, 64, 224, 224], DataType::F32,
1964 TensorLayout::RowMajor,
1965 )
1966 .unwrap(),
1967 _ => ronn_core::tensor::Tensor::ones(
1968 vec![32, 512],
1969 DataType::F32,
1970 TensorLayout::RowMajor,
1971 )
1972 .unwrap(),
1973 };
1974
1975 let start = Instant::now();
1976 for _ in 0..5 {
1977 let _ = kernel.execute(&[test_input.clone()]);
1978 }
1979 let avg_time = start.elapsed() / 5;
1980
1981 println!(" {} avg: {:?}", op, avg_time);
1982 }
1983 }
1984 }
1985
1986 fn benchmark_mixed_precision(provider: &GpuExecutionProvider) {
1987 println!(" 🎯 Mixed Precision Benchmark:");
1988
1989 if provider.has_tensor_cores() {
1990 println!(" Tensor cores available - mixed precision enabled");
1991 } else {
1992 println!(" Tensor cores not available - mixed precision simulation");
1993 }
1994
1995 let sizes = [512, 1024, 2048];
1997
1998 for size in sizes {
1999 println!(" Matrix size: {}x{}", size, size);
2000 }
2002 }
2003
2004 fn benchmark_cache_performance(provider: &GpuExecutionProvider) {
2005 use std::time::Instant;
2006
2007 println!(" 💾 Cache Performance Benchmark:");
2008
2009 let subgraph = create_single_op_subgraph("Add");
2010 if let Ok(kernel) = provider.compile_subgraph(subgraph) {
2011 let test_input = ronn_core::tensor::Tensor::ones(
2012 vec![512, 512],
2013 DataType::F32,
2014 TensorLayout::RowMajor,
2015 )
2016 .unwrap();
2017
2018 for _ in 0..5 {
2020 let _ = kernel.execute(&[test_input.clone()]);
2021 }
2022
2023 let start = Instant::now();
2025 for _ in 0..20 {
2026 let _ = kernel.execute(&[test_input.clone()]);
2027 }
2028 let cached_time = start.elapsed() / 20;
2029
2030 println!(" Cached execution avg: {:?}", cached_time);
2036 }
2037 }
2038
2039 fn benchmark_memory_throughput(provider: &GpuExecutionProvider) {
2040 println!(" 🚀 Memory Throughput Benchmark:");
2041
2042 if let Ok((total_memory, _used_memory)) = provider.get_gpu_memory_info() {
2043 println!(
2044 " GPU Memory: {:.2} GB total",
2045 total_memory as f64 / (1024.0 * 1024.0 * 1024.0)
2046 );
2047 }
2048
2049 let allocator = provider.get_allocator();
2050
2051 let start = std::time::Instant::now();
2053 let mut buffers = Vec::new();
2054
2055 for _ in 0..100 {
2056 if let Ok(buffer) = allocator.allocate(&[1024], DataType::F32) {
2057 buffers.push(buffer);
2058 }
2059 }
2060
2061 let alloc_time = start.elapsed();
2062
2063 let start = std::time::Instant::now();
2064 for buffer in buffers {
2065 let _ = allocator.deallocate(buffer);
2066 }
2067 let dealloc_time = start.elapsed();
2068
2069 println!(" 100 allocations: {:?}", alloc_time);
2070 println!(" 100 deallocations: {:?}", dealloc_time);
2071 }
2072
2073 fn create_single_op_subgraph(op_type: &str) -> SubGraph {
2074 let node = GraphNode {
2075 id: 0,
2076 op_type: op_type.to_string(),
2077 attributes: HashMap::new(),
2078 inputs: vec!["input1".to_string()],
2079 outputs: vec!["output1".to_string()],
2080 name: Some(format!("test_{}", op_type)),
2081 };
2082
2083 SubGraph {
2084 nodes: vec![node],
2085 edges: vec![],
2086 inputs: vec!["input1".to_string()],
2087 outputs: vec!["output1".to_string()],
2088 }
2089 }
2090
2091 #[test]
2092 fn test_stream_execution() {
2093 match GpuExecutionProvider::new() {
2095 Ok(provider) => {
2096 if provider.get_config().stream_count > 1 {
2097 println!(
2098 "🌊 Testing stream-based execution with {} streams",
2099 provider.get_config().stream_count
2100 );
2101
2102 let subgraph1 = create_single_op_subgraph("Add");
2104 let subgraph2 = create_single_op_subgraph("Mul");
2105
2106 if let (Ok(kernel1), Ok(kernel2)) = (
2107 GpuKernel::with_stream(subgraph1, provider.device().clone(), 0),
2108 GpuKernel::with_stream(subgraph2, provider.device().clone(), 1),
2109 ) {
2110 println!(" ✅ Successfully created kernels on different streams");
2111
2112 let test_input = ronn_core::tensor::Tensor::ones(
2114 vec![256, 256],
2115 DataType::F32,
2116 TensorLayout::RowMajor,
2117 )
2118 .unwrap();
2119
2120 let start = std::time::Instant::now();
2121 let _result1 = kernel1.execute(&[test_input.clone()]);
2122 let _result2 = kernel2.execute(&[test_input.clone()]);
2123 let concurrent_time = start.elapsed();
2124
2125 println!(" Concurrent execution time: {:?}", concurrent_time);
2126 }
2127 } else {
2128 println!("🌊 Single stream execution (stream_count = 1)");
2129 }
2130 }
2131 Err(_) => {
2132 println!("GPU not available for stream testing");
2133 }
2134 }
2135 }
2136
2137 #[test]
2138 fn test_kernel_cache_operations() {
2139 match GpuExecutionProvider::new() {
2141 Ok(provider) => {
2142 println!("💾 Testing kernel cache operations...");
2143
2144 let subgraph = create_single_op_subgraph("MatMul");
2145 if let Ok(kernel) = provider.compile_subgraph(subgraph) {
2146 let test_input = ronn_core::tensor::Tensor::ones(
2147 vec![128, 128],
2148 DataType::F32,
2149 TensorLayout::RowMajor,
2150 )
2151 .unwrap();
2152
2153 for i in 0..10 {
2155 let _ = kernel.execute(&[test_input.clone()]);
2156
2157 if i == 0 {
2158 println!(" First execution (cold cache)");
2159 } else if i == 9 {
2160 println!(" Tenth execution (warm cache)");
2161 }
2162 }
2163
2164 println!(" ✅ Cache operations test completed");
2165 }
2166 }
2167 Err(_) => {
2168 println!("GPU not available for cache testing");
2169 }
2170 }
2171 }
2172}