optirs_gpu/
kernels.rs

1//! GPU Kernel Management and Compilation
2//!
3//! This module provides kernel management, compilation, and execution for GPU-accelerated
4//! optimization algorithms. It supports CUDA, ROCm, Metal, and WebGPU backends.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use thiserror::Error;
10
11use crate::backends::{Backend, CompiledKernel, DeviceMemory, GpuBackend, KernelArg};
12
13/// Kernel manager for compiling and executing GPU kernels
14#[derive(Debug)]
15pub struct KernelManager {
16    /// Backend-specific kernel compiler
17    compiler: Box<dyn KernelCompiler>,
18
19    /// Compiled kernel cache
20    kernel_cache: HashMap<String, Arc<CompiledKernel>>,
21
22    /// Kernel templates
23    templates: HashMap<String, KernelTemplate>,
24
25    /// Backend type
26    backend: GpuBackend,
27
28    /// Compilation options
29    compilation_options: CompilationOptions,
30}
31
32/// Kernel compilation configuration
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct CompilationOptions {
35    /// Optimization level (0-3)
36    pub optimization_level: u32,
37
38    /// Enable fast math optimizations
39    pub fast_math: bool,
40
41    /// Enable debug information
42    pub debug_info: bool,
43
44    /// Target architecture
45    pub target_arch: String,
46
47    /// Additional compiler flags
48    pub extra_flags: Vec<String>,
49
50    /// Maximum register count per thread
51    pub max_registers: Option<u32>,
52
53    /// Shared memory size hint
54    pub shared_memory_hint: Option<usize>,
55}
56
57/// Template for generating GPU kernels
58#[derive(Debug, Clone)]
59pub struct KernelTemplate {
60    /// Template name
61    pub name: String,
62
63    /// Kernel source code template
64    pub source_template: String,
65
66    /// Parameter placeholders
67    pub parameters: Vec<TemplateParameter>,
68
69    /// Supported backends
70    pub supported_backends: Vec<GpuBackend>,
71
72    /// Default block and grid sizes
73    pub default_launch_config: LaunchConfig,
74}
75
76/// Template parameter for kernel generation
77#[derive(Debug, Clone)]
78pub struct TemplateParameter {
79    /// Parameter name
80    pub name: String,
81
82    /// Parameter type
83    pub param_type: ParameterType,
84
85    /// Default value (if any)
86    pub default_value: Option<String>,
87
88    /// Whether this parameter is required
89    pub required: bool,
90}
91
92/// Kernel launch configuration
93#[derive(Debug, Clone)]
94pub struct LaunchConfig {
95    /// Grid dimensions (x, y, z)
96    pub grid_size: (u32, u32, u32),
97
98    /// Block dimensions (x, y, z)
99    pub block_size: (u32, u32, u32),
100
101    /// Shared memory size in bytes
102    pub shared_memory_size: usize,
103
104    /// CUDA stream (if applicable)
105    pub stream: Option<u64>,
106}
107
108/// Parameter types for kernel templates
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum ParameterType {
111    /// Integer type
112    Integer,
113    /// Floating point type
114    Float,
115    /// String/text type
116    String,
117    /// Data type specifier
118    DataType,
119    /// Boolean type
120    Boolean,
121}
122
123/// Kernel types for different optimization algorithms
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
125pub enum KernelType {
126    /// SGD update kernel
127    SGDUpdate,
128    /// Adam update kernel
129    AdamUpdate,
130    /// AdamW update kernel
131    AdamWUpdate,
132    /// RMSprop update kernel
133    RMSpropUpdate,
134    /// AdaGrad update kernel
135    AdaGradUpdate,
136    /// Momentum update kernel
137    MomentumUpdate,
138    /// Gradient clipping kernel
139    GradientClipping,
140    /// Vector operations
141    VectorOps,
142    /// Matrix operations
143    MatrixOps,
144    /// Reduction operations
145    Reduction,
146}
147
148/// Errors that can occur during kernel operations
149#[derive(Debug, Error)]
150pub enum KernelError {
151    #[error("Compilation failed: {reason}")]
152    CompilationFailed { reason: String },
153
154    #[error("Kernel not found: {name}")]
155    KernelNotFound { name: String },
156
157    #[error("Invalid template parameter: {param}")]
158    InvalidTemplateParameter { param: String },
159
160    #[error("Backend not supported: {backend:?}")]
161    BackendNotSupported { backend: GpuBackend },
162
163    #[error("Launch configuration invalid: {reason}")]
164    InvalidLaunchConfig { reason: String },
165
166    #[error("Kernel execution failed: {reason}")]
167    ExecutionFailed { reason: String },
168
169    #[error("Template generation failed: {reason}")]
170    TemplateGenerationFailed { reason: String },
171}
172
173/// Trait for backend-specific kernel compilation
174pub trait KernelCompiler: Send + Sync + std::fmt::Debug {
175    /// Compile kernel source code for the target backend
176    fn compile_kernel(
177        &self,
178        source: &str,
179        options: &CompilationOptions,
180    ) -> Result<CompiledKernel, KernelError>;
181
182    /// Get backend type
183    fn backend_type(&self) -> GpuBackend;
184
185    /// Check if a feature is supported
186    fn supports_feature(&self, feature: KernelFeature) -> bool;
187
188    /// Get compilation capabilities
189    fn get_capabilities(&self) -> CompilerCapabilities;
190}
191
192/// Kernel features that may or may not be supported
193#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194pub enum KernelFeature {
195    /// Half precision (f16) support
196    HalfPrecision,
197    /// Double precision (f64) support
198    DoublePrecision,
199    /// Tensor core operations
200    TensorCores,
201    /// Cooperative groups
202    CooperativeGroups,
203    /// Dynamic parallelism
204    DynamicParallelism,
205    /// Unified memory
206    UnifiedMemory,
207}
208
209/// Compiler capabilities
210#[derive(Debug, Clone)]
211pub struct CompilerCapabilities {
212    /// Maximum threads per block
213    pub max_threads_per_block: u32,
214
215    /// Maximum blocks per grid
216    pub max_blocks_per_grid: (u32, u32, u32),
217
218    /// Maximum shared memory per block
219    pub max_shared_memory: usize,
220
221    /// Supported data types
222    pub supported_data_types: Vec<String>,
223
224    /// Compute capability (for CUDA)
225    pub compute_capability: Option<(u32, u32)>,
226}
227
228impl KernelManager {
229    /// Create a new kernel manager for the specified backend
230    pub fn new(backend: GpuBackend, options: CompilationOptions) -> Result<Self, KernelError> {
231        let compiler = Self::create_compiler(backend)?;
232        let templates = Self::load_builtin_templates();
233
234        Ok(Self {
235            compiler,
236            kernel_cache: HashMap::new(),
237            templates,
238            backend,
239            compilation_options: options,
240        })
241    }
242
243    /// Create a backend-specific compiler
244    fn create_compiler(backend: GpuBackend) -> Result<Box<dyn KernelCompiler>, KernelError> {
245        match backend {
246            GpuBackend::Cuda => Ok(Box::new(CudaCompiler::new())),
247            GpuBackend::Rocm => Ok(Box::new(RocmCompiler::new())),
248            GpuBackend::Metal => Ok(Box::new(MetalCompiler::new())),
249            GpuBackend::Wgpu => Ok(Box::new(WgpuCompiler::new())),
250            GpuBackend::Cpu => Err(KernelError::BackendNotSupported { backend }),
251        }
252    }
253
254    /// Load built-in kernel templates
255    fn load_builtin_templates() -> HashMap<String, KernelTemplate> {
256        let mut templates = HashMap::new();
257
258        // SGD kernel template
259        templates.insert(
260            "sgd_update".to_string(),
261            KernelTemplate {
262                name: "sgd_update".to_string(),
263                source_template: sgd_kernel_template().to_string(),
264                parameters: vec![
265                    TemplateParameter {
266                        name: "dtype".to_string(),
267                        param_type: ParameterType::DataType,
268                        default_value: Some("float".to_string()),
269                        required: true,
270                    },
271                    TemplateParameter {
272                        name: "learning_rate".to_string(),
273                        param_type: ParameterType::Float,
274                        default_value: Some("0.01".to_string()),
275                        required: true,
276                    },
277                ],
278                supported_backends: vec![GpuBackend::Cuda, GpuBackend::Rocm, GpuBackend::Metal],
279                default_launch_config: LaunchConfig {
280                    grid_size: (1, 1, 1),
281                    block_size: (256, 1, 1),
282                    shared_memory_size: 0,
283                    stream: None,
284                },
285            },
286        );
287
288        // Adam kernel template
289        templates.insert(
290            "adam_update".to_string(),
291            KernelTemplate {
292                name: "adam_update".to_string(),
293                source_template: adam_kernel_template().to_string(),
294                parameters: vec![
295                    TemplateParameter {
296                        name: "dtype".to_string(),
297                        param_type: ParameterType::DataType,
298                        default_value: Some("float".to_string()),
299                        required: true,
300                    },
301                    TemplateParameter {
302                        name: "learning_rate".to_string(),
303                        param_type: ParameterType::Float,
304                        default_value: Some("0.001".to_string()),
305                        required: true,
306                    },
307                    TemplateParameter {
308                        name: "beta1".to_string(),
309                        param_type: ParameterType::Float,
310                        default_value: Some("0.9".to_string()),
311                        required: true,
312                    },
313                    TemplateParameter {
314                        name: "beta2".to_string(),
315                        param_type: ParameterType::Float,
316                        default_value: Some("0.999".to_string()),
317                        required: true,
318                    },
319                ],
320                supported_backends: vec![GpuBackend::Cuda, GpuBackend::Rocm, GpuBackend::Metal],
321                default_launch_config: LaunchConfig {
322                    grid_size: (1, 1, 1),
323                    block_size: (256, 1, 1),
324                    shared_memory_size: 0,
325                    stream: None,
326                },
327            },
328        );
329
330        templates
331    }
332
333    /// Generate kernel from template
334    pub fn generate_kernel_from_template(
335        &self,
336        template_name: &str,
337        parameters: &HashMap<String, String>,
338    ) -> Result<String, KernelError> {
339        let template =
340            self.templates
341                .get(template_name)
342                .ok_or_else(|| KernelError::KernelNotFound {
343                    name: template_name.to_string(),
344                })?;
345
346        if !template.supported_backends.contains(&self.backend) {
347            return Err(KernelError::BackendNotSupported {
348                backend: self.backend,
349            });
350        }
351
352        let mut source = template.source_template.clone();
353
354        // Replace template parameters
355        for param in &template.parameters {
356            let value = if let Some(v) = parameters.get(&param.name) {
357                v.clone()
358            } else if let Some(default) = &param.default_value {
359                default.clone()
360            } else if param.required {
361                return Err(KernelError::InvalidTemplateParameter {
362                    param: param.name.clone(),
363                });
364            } else {
365                continue;
366            };
367
368            let placeholder = format!("{{{{{}}}}}", param.name);
369            source = source.replace(&placeholder, &value);
370        }
371
372        Ok(source)
373    }
374
375    /// Compile and cache a kernel
376    pub fn compile_kernel(
377        &mut self,
378        name: String,
379        source: String,
380    ) -> Result<Arc<CompiledKernel>, KernelError> {
381        // Check cache first
382        if let Some(cached) = self.kernel_cache.get(&name) {
383            return Ok(cached.clone());
384        }
385
386        // Compile kernel
387        let compiled = self
388            .compiler
389            .compile_kernel(&source, &self.compilation_options)?;
390        let kernel_arc = Arc::new(compiled);
391
392        // Cache the compiled kernel
393        self.kernel_cache.insert(name, kernel_arc.clone());
394
395        Ok(kernel_arc)
396    }
397
398    /// Get cached kernel
399    pub fn get_kernel(&self, name: &str) -> Option<Arc<CompiledKernel>> {
400        self.kernel_cache.get(name).cloned()
401    }
402
403    /// Clear kernel cache
404    pub fn clear_cache(&mut self) {
405        self.kernel_cache.clear();
406    }
407
408    /// Get compiler capabilities
409    pub fn get_capabilities(&self) -> CompilerCapabilities {
410        self.compiler.get_capabilities()
411    }
412
413    /// Calculate optimal launch configuration
414    pub fn calculate_launch_config(
415        &self,
416        kernel_type: KernelType,
417        data_size: usize,
418    ) -> LaunchConfig {
419        let capabilities = self.get_capabilities();
420
421        match kernel_type {
422            KernelType::VectorOps | KernelType::SGDUpdate | KernelType::AdamUpdate => {
423                let threads_per_block = 256.min(capabilities.max_threads_per_block);
424                let blocks = data_size.div_ceil(threads_per_block as usize);
425
426                LaunchConfig {
427                    grid_size: (blocks as u32, 1, 1),
428                    block_size: (threads_per_block, 1, 1),
429                    shared_memory_size: 0,
430                    stream: None,
431                }
432            }
433            KernelType::MatrixOps => {
434                let block_dim = 16; // 16x16 thread blocks for matrix operations
435                let grid_dim = ((data_size as f64).sqrt().ceil() as u32).div_ceil(block_dim);
436
437                LaunchConfig {
438                    grid_size: (grid_dim, grid_dim, 1),
439                    block_size: (block_dim, block_dim, 1),
440                    shared_memory_size: block_dim as usize * block_dim as usize * 4, // 4 bytes per float
441                    stream: None,
442                }
443            }
444            KernelType::Reduction => {
445                let threads_per_block = 512.min(capabilities.max_threads_per_block);
446                let blocks = data_size.div_ceil(threads_per_block as usize);
447
448                LaunchConfig {
449                    grid_size: (blocks as u32, 1, 1),
450                    block_size: (threads_per_block, 1, 1),
451                    shared_memory_size: threads_per_block as usize * 4, // Shared memory for reduction
452                    stream: None,
453                }
454            }
455            _ => {
456                // Default configuration
457                LaunchConfig {
458                    grid_size: (1, 1, 1),
459                    block_size: (256, 1, 1),
460                    shared_memory_size: 0,
461                    stream: None,
462                }
463            }
464        }
465    }
466}
467
468// Backend-specific compilers
469
470/// CUDA kernel compiler
471#[derive(Debug)]
472pub struct CudaCompiler {
473    nvcc_path: String,
474}
475
476impl Default for CudaCompiler {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482impl CudaCompiler {
483    pub fn new() -> Self {
484        Self {
485            nvcc_path: "nvcc".to_string(),
486        }
487    }
488}
489
490impl KernelCompiler for CudaCompiler {
491    fn compile_kernel(
492        &self,
493        source: &str,
494        options: &CompilationOptions,
495    ) -> Result<CompiledKernel, KernelError> {
496        // In a real implementation, this would call nvcc or use NVRTC
497        Ok(CompiledKernel {
498            name: "cuda_kernel".to_string(),
499            backend: GpuBackend::Cuda,
500            binary: source.as_bytes().to_vec(),
501        })
502    }
503
504    fn backend_type(&self) -> GpuBackend {
505        GpuBackend::Cuda
506    }
507
508    fn supports_feature(&self, feature: KernelFeature) -> bool {
509        match feature {
510            KernelFeature::HalfPrecision => true,
511            KernelFeature::DoublePrecision => true,
512            KernelFeature::TensorCores => true,
513            KernelFeature::CooperativeGroups => true,
514            KernelFeature::DynamicParallelism => true,
515            KernelFeature::UnifiedMemory => true,
516        }
517    }
518
519    fn get_capabilities(&self) -> CompilerCapabilities {
520        CompilerCapabilities {
521            max_threads_per_block: 1024,
522            max_blocks_per_grid: (65535, 65535, 65535),
523            max_shared_memory: 49152,
524            supported_data_types: vec![
525                "float".to_string(),
526                "double".to_string(),
527                "half".to_string(),
528                "int".to_string(),
529            ],
530            compute_capability: Some((8, 6)),
531        }
532    }
533}
534
535/// ROCm kernel compiler
536#[derive(Debug)]
537pub struct RocmCompiler;
538
539impl Default for RocmCompiler {
540    fn default() -> Self {
541        Self::new()
542    }
543}
544
545impl RocmCompiler {
546    pub fn new() -> Self {
547        Self
548    }
549}
550
551impl KernelCompiler for RocmCompiler {
552    fn compile_kernel(
553        &self,
554        source: &str,
555        _options: &CompilationOptions,
556    ) -> Result<CompiledKernel, KernelError> {
557        Ok(CompiledKernel {
558            name: "rocm_kernel".to_string(),
559            backend: GpuBackend::Rocm,
560            binary: source.as_bytes().to_vec(),
561        })
562    }
563
564    fn backend_type(&self) -> GpuBackend {
565        GpuBackend::Rocm
566    }
567
568    fn supports_feature(&self, feature: KernelFeature) -> bool {
569        match feature {
570            KernelFeature::HalfPrecision => true,
571            KernelFeature::DoublePrecision => true,
572            KernelFeature::TensorCores => false,
573            KernelFeature::CooperativeGroups => true,
574            KernelFeature::DynamicParallelism => false,
575            KernelFeature::UnifiedMemory => false,
576        }
577    }
578
579    fn get_capabilities(&self) -> CompilerCapabilities {
580        CompilerCapabilities {
581            max_threads_per_block: 1024,
582            max_blocks_per_grid: (65535, 65535, 65535),
583            max_shared_memory: 65536,
584            supported_data_types: vec![
585                "float".to_string(),
586                "double".to_string(),
587                "half".to_string(),
588                "int".to_string(),
589            ],
590            compute_capability: None,
591        }
592    }
593}
594
595/// Metal kernel compiler
596#[derive(Debug)]
597pub struct MetalCompiler;
598
599impl Default for MetalCompiler {
600    fn default() -> Self {
601        Self::new()
602    }
603}
604
605impl MetalCompiler {
606    pub fn new() -> Self {
607        Self
608    }
609}
610
611impl KernelCompiler for MetalCompiler {
612    fn compile_kernel(
613        &self,
614        source: &str,
615        _options: &CompilationOptions,
616    ) -> Result<CompiledKernel, KernelError> {
617        Ok(CompiledKernel {
618            name: "metal_kernel".to_string(),
619            backend: GpuBackend::Metal,
620            binary: source.as_bytes().to_vec(),
621        })
622    }
623
624    fn backend_type(&self) -> GpuBackend {
625        GpuBackend::Metal
626    }
627
628    fn supports_feature(&self, feature: KernelFeature) -> bool {
629        match feature {
630            KernelFeature::HalfPrecision => true,
631            KernelFeature::DoublePrecision => false,
632            KernelFeature::TensorCores => false,
633            KernelFeature::CooperativeGroups => false,
634            KernelFeature::DynamicParallelism => false,
635            KernelFeature::UnifiedMemory => true,
636        }
637    }
638
639    fn get_capabilities(&self) -> CompilerCapabilities {
640        CompilerCapabilities {
641            max_threads_per_block: 1024,
642            max_blocks_per_grid: (65535, 65535, 65535),
643            max_shared_memory: 32768,
644            supported_data_types: vec!["float".to_string(), "half".to_string(), "int".to_string()],
645            compute_capability: None,
646        }
647    }
648}
649
650/// WebGPU kernel compiler
651#[derive(Debug)]
652pub struct WgpuCompiler;
653
654impl Default for WgpuCompiler {
655    fn default() -> Self {
656        Self::new()
657    }
658}
659
660impl WgpuCompiler {
661    pub fn new() -> Self {
662        Self
663    }
664}
665
666impl KernelCompiler for WgpuCompiler {
667    fn compile_kernel(
668        &self,
669        source: &str,
670        _options: &CompilationOptions,
671    ) -> Result<CompiledKernel, KernelError> {
672        Ok(CompiledKernel {
673            name: "wgpu_kernel".to_string(),
674            backend: GpuBackend::Wgpu,
675            binary: source.as_bytes().to_vec(),
676        })
677    }
678
679    fn backend_type(&self) -> GpuBackend {
680        GpuBackend::Wgpu
681    }
682
683    fn supports_feature(&self, feature: KernelFeature) -> bool {
684        match feature {
685            KernelFeature::HalfPrecision => false,
686            KernelFeature::DoublePrecision => false,
687            KernelFeature::TensorCores => false,
688            KernelFeature::CooperativeGroups => false,
689            KernelFeature::DynamicParallelism => false,
690            KernelFeature::UnifiedMemory => false,
691        }
692    }
693
694    fn get_capabilities(&self) -> CompilerCapabilities {
695        CompilerCapabilities {
696            max_threads_per_block: 256,
697            max_blocks_per_grid: (65535, 65535, 65535),
698            max_shared_memory: 16384,
699            supported_data_types: vec!["float".to_string(), "int".to_string()],
700            compute_capability: None,
701        }
702    }
703}
704
705impl Default for CompilationOptions {
706    fn default() -> Self {
707        Self {
708            optimization_level: 2,
709            fast_math: true,
710            debug_info: false,
711            target_arch: "compute_70".to_string(),
712            extra_flags: Vec::new(),
713            max_registers: None,
714            shared_memory_hint: None,
715        }
716    }
717}
718
719// Placeholder templates (in a real implementation, these would be separate files)
720fn sgd_kernel_template() -> &'static str {
721    r#"
722__global__ void sgd_update_kernel(
723    {{dtype}}* params,
724    const {{dtype}}* gradients,
725    const {{dtype}} learning_rate,
726    const int size
727) {
728    int idx = blockIdx.x * blockDim.x + threadIdx.x;
729    if (idx < size) {
730        params[idx] -= learning_rate * gradients[idx];
731    }
732}
733"#
734}
735
736fn adam_kernel_template() -> &'static str {
737    r#"
738__global__ void adam_update_kernel(
739    {{dtype}}* params,
740    {{dtype}}* m,
741    {{dtype}}* v,
742    const {{dtype}}* gradients,
743    const {{dtype}} learning_rate,
744    const {{dtype}} beta1,
745    const {{dtype}} beta2,
746    const {{dtype}} epsilon,
747    const int step,
748    const int size
749) {
750    int idx = blockIdx.x * blockDim.x + threadIdx.x;
751    if (idx < size) {
752        {{dtype}} grad = gradients[idx];
753        {{dtype}} m_val = beta1 * m[idx] + (1.0 - beta1) * grad;
754        {{dtype}} v_val = beta2 * v[idx] + (1.0 - beta2) * grad * grad;
755
756        {{dtype}} m_hat = m_val / (1.0 - pow(beta1, step));
757        {{dtype}} v_hat = v_val / (1.0 - beta2, step));
758
759        params[idx] -= learning_rate * m_hat / (sqrt(v_hat) + epsilon);
760
761        m[idx] = m_val;
762        v[idx] = v_val;
763    }
764}
765"#
766}
767
768// Template constants are provided by the functions above
769
770// Hack to provide template content when files don't exist
771fn include_str(path: &str) -> &str {
772    match path {
773        "templates/sgd_kernel.template" => sgd_kernel_template(),
774        "templates/adam_kernel.template" => adam_kernel_template(),
775        _ => "",
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782
783    #[test]
784    fn test_kernel_manager_creation() {
785        let options = CompilationOptions::default();
786        let manager = KernelManager::new(GpuBackend::Cuda, options);
787        assert!(manager.is_ok());
788    }
789
790    #[test]
791    fn test_template_generation() {
792        let options = CompilationOptions::default();
793        let manager = KernelManager::new(GpuBackend::Cuda, options).unwrap();
794
795        let mut params = HashMap::new();
796        params.insert("dtype".to_string(), "float".to_string());
797        params.insert("learning_rate".to_string(), "0.01".to_string());
798
799        let source = manager.generate_kernel_from_template("sgd_update", &params);
800        assert!(source.is_ok());
801    }
802
803    #[test]
804    fn test_launch_config_calculation() {
805        let options = CompilationOptions::default();
806        let manager = KernelManager::new(GpuBackend::Cuda, options).unwrap();
807
808        let config = manager.calculate_launch_config(KernelType::SGDUpdate, 10000);
809        assert!(config.grid_size.0 > 0);
810        assert!(config.block_size.0 > 0);
811    }
812
813    #[test]
814    fn test_compiler_capabilities() {
815        let compiler = CudaCompiler::new();
816        let capabilities = compiler.get_capabilities();
817
818        assert!(capabilities.max_threads_per_block > 0);
819        assert!(!capabilities.supported_data_types.is_empty());
820    }
821}