1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use thiserror::Error;
10
11use crate::backends::{Backend, CompiledKernel, DeviceMemory, GpuBackend, KernelArg};
12
13#[derive(Debug)]
15pub struct KernelManager {
16 compiler: Box<dyn KernelCompiler>,
18
19 kernel_cache: HashMap<String, Arc<CompiledKernel>>,
21
22 templates: HashMap<String, KernelTemplate>,
24
25 backend: GpuBackend,
27
28 compilation_options: CompilationOptions,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct CompilationOptions {
35 pub optimization_level: u32,
37
38 pub fast_math: bool,
40
41 pub debug_info: bool,
43
44 pub target_arch: String,
46
47 pub extra_flags: Vec<String>,
49
50 pub max_registers: Option<u32>,
52
53 pub shared_memory_hint: Option<usize>,
55}
56
57#[derive(Debug, Clone)]
59pub struct KernelTemplate {
60 pub name: String,
62
63 pub source_template: String,
65
66 pub parameters: Vec<TemplateParameter>,
68
69 pub supported_backends: Vec<GpuBackend>,
71
72 pub default_launch_config: LaunchConfig,
74}
75
76#[derive(Debug, Clone)]
78pub struct TemplateParameter {
79 pub name: String,
81
82 pub param_type: ParameterType,
84
85 pub default_value: Option<String>,
87
88 pub required: bool,
90}
91
92#[derive(Debug, Clone)]
94pub struct LaunchConfig {
95 pub grid_size: (u32, u32, u32),
97
98 pub block_size: (u32, u32, u32),
100
101 pub shared_memory_size: usize,
103
104 pub stream: Option<u64>,
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum ParameterType {
111 Integer,
113 Float,
115 String,
117 DataType,
119 Boolean,
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
125pub enum KernelType {
126 SGDUpdate,
128 AdamUpdate,
130 AdamWUpdate,
132 RMSpropUpdate,
134 AdaGradUpdate,
136 MomentumUpdate,
138 GradientClipping,
140 VectorOps,
142 MatrixOps,
144 Reduction,
146}
147
148#[derive(Debug, Error)]
150pub enum KernelError {
151 #[error("Compilation failed: {reason}")]
152 CompilationFailed { reason: String },
153
154 #[error("Kernel not found: {name}")]
155 KernelNotFound { name: String },
156
157 #[error("Invalid template parameter: {param}")]
158 InvalidTemplateParameter { param: String },
159
160 #[error("Backend not supported: {backend:?}")]
161 BackendNotSupported { backend: GpuBackend },
162
163 #[error("Launch configuration invalid: {reason}")]
164 InvalidLaunchConfig { reason: String },
165
166 #[error("Kernel execution failed: {reason}")]
167 ExecutionFailed { reason: String },
168
169 #[error("Template generation failed: {reason}")]
170 TemplateGenerationFailed { reason: String },
171}
172
173pub trait KernelCompiler: Send + Sync + std::fmt::Debug {
175 fn compile_kernel(
177 &self,
178 source: &str,
179 options: &CompilationOptions,
180 ) -> Result<CompiledKernel, KernelError>;
181
182 fn backend_type(&self) -> GpuBackend;
184
185 fn supports_feature(&self, feature: KernelFeature) -> bool;
187
188 fn get_capabilities(&self) -> CompilerCapabilities;
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194pub enum KernelFeature {
195 HalfPrecision,
197 DoublePrecision,
199 TensorCores,
201 CooperativeGroups,
203 DynamicParallelism,
205 UnifiedMemory,
207}
208
209#[derive(Debug, Clone)]
211pub struct CompilerCapabilities {
212 pub max_threads_per_block: u32,
214
215 pub max_blocks_per_grid: (u32, u32, u32),
217
218 pub max_shared_memory: usize,
220
221 pub supported_data_types: Vec<String>,
223
224 pub compute_capability: Option<(u32, u32)>,
226}
227
228impl KernelManager {
229 pub fn new(backend: GpuBackend, options: CompilationOptions) -> Result<Self, KernelError> {
231 let compiler = Self::create_compiler(backend)?;
232 let templates = Self::load_builtin_templates();
233
234 Ok(Self {
235 compiler,
236 kernel_cache: HashMap::new(),
237 templates,
238 backend,
239 compilation_options: options,
240 })
241 }
242
243 fn create_compiler(backend: GpuBackend) -> Result<Box<dyn KernelCompiler>, KernelError> {
245 match backend {
246 GpuBackend::Cuda => Ok(Box::new(CudaCompiler::new())),
247 GpuBackend::Rocm => Ok(Box::new(RocmCompiler::new())),
248 GpuBackend::Metal => Ok(Box::new(MetalCompiler::new())),
249 GpuBackend::Wgpu => Ok(Box::new(WgpuCompiler::new())),
250 GpuBackend::Cpu => Err(KernelError::BackendNotSupported { backend }),
251 }
252 }
253
254 fn load_builtin_templates() -> HashMap<String, KernelTemplate> {
256 let mut templates = HashMap::new();
257
258 templates.insert(
260 "sgd_update".to_string(),
261 KernelTemplate {
262 name: "sgd_update".to_string(),
263 source_template: sgd_kernel_template().to_string(),
264 parameters: vec![
265 TemplateParameter {
266 name: "dtype".to_string(),
267 param_type: ParameterType::DataType,
268 default_value: Some("float".to_string()),
269 required: true,
270 },
271 TemplateParameter {
272 name: "learning_rate".to_string(),
273 param_type: ParameterType::Float,
274 default_value: Some("0.01".to_string()),
275 required: true,
276 },
277 ],
278 supported_backends: vec![GpuBackend::Cuda, GpuBackend::Rocm, GpuBackend::Metal],
279 default_launch_config: LaunchConfig {
280 grid_size: (1, 1, 1),
281 block_size: (256, 1, 1),
282 shared_memory_size: 0,
283 stream: None,
284 },
285 },
286 );
287
288 templates.insert(
290 "adam_update".to_string(),
291 KernelTemplate {
292 name: "adam_update".to_string(),
293 source_template: adam_kernel_template().to_string(),
294 parameters: vec![
295 TemplateParameter {
296 name: "dtype".to_string(),
297 param_type: ParameterType::DataType,
298 default_value: Some("float".to_string()),
299 required: true,
300 },
301 TemplateParameter {
302 name: "learning_rate".to_string(),
303 param_type: ParameterType::Float,
304 default_value: Some("0.001".to_string()),
305 required: true,
306 },
307 TemplateParameter {
308 name: "beta1".to_string(),
309 param_type: ParameterType::Float,
310 default_value: Some("0.9".to_string()),
311 required: true,
312 },
313 TemplateParameter {
314 name: "beta2".to_string(),
315 param_type: ParameterType::Float,
316 default_value: Some("0.999".to_string()),
317 required: true,
318 },
319 ],
320 supported_backends: vec![GpuBackend::Cuda, GpuBackend::Rocm, GpuBackend::Metal],
321 default_launch_config: LaunchConfig {
322 grid_size: (1, 1, 1),
323 block_size: (256, 1, 1),
324 shared_memory_size: 0,
325 stream: None,
326 },
327 },
328 );
329
330 templates
331 }
332
333 pub fn generate_kernel_from_template(
335 &self,
336 template_name: &str,
337 parameters: &HashMap<String, String>,
338 ) -> Result<String, KernelError> {
339 let template =
340 self.templates
341 .get(template_name)
342 .ok_or_else(|| KernelError::KernelNotFound {
343 name: template_name.to_string(),
344 })?;
345
346 if !template.supported_backends.contains(&self.backend) {
347 return Err(KernelError::BackendNotSupported {
348 backend: self.backend,
349 });
350 }
351
352 let mut source = template.source_template.clone();
353
354 for param in &template.parameters {
356 let value = if let Some(v) = parameters.get(¶m.name) {
357 v.clone()
358 } else if let Some(default) = ¶m.default_value {
359 default.clone()
360 } else if param.required {
361 return Err(KernelError::InvalidTemplateParameter {
362 param: param.name.clone(),
363 });
364 } else {
365 continue;
366 };
367
368 let placeholder = format!("{{{{{}}}}}", param.name);
369 source = source.replace(&placeholder, &value);
370 }
371
372 Ok(source)
373 }
374
375 pub fn compile_kernel(
377 &mut self,
378 name: String,
379 source: String,
380 ) -> Result<Arc<CompiledKernel>, KernelError> {
381 if let Some(cached) = self.kernel_cache.get(&name) {
383 return Ok(cached.clone());
384 }
385
386 let compiled = self
388 .compiler
389 .compile_kernel(&source, &self.compilation_options)?;
390 let kernel_arc = Arc::new(compiled);
391
392 self.kernel_cache.insert(name, kernel_arc.clone());
394
395 Ok(kernel_arc)
396 }
397
398 pub fn get_kernel(&self, name: &str) -> Option<Arc<CompiledKernel>> {
400 self.kernel_cache.get(name).cloned()
401 }
402
403 pub fn clear_cache(&mut self) {
405 self.kernel_cache.clear();
406 }
407
408 pub fn get_capabilities(&self) -> CompilerCapabilities {
410 self.compiler.get_capabilities()
411 }
412
413 pub fn calculate_launch_config(
415 &self,
416 kernel_type: KernelType,
417 data_size: usize,
418 ) -> LaunchConfig {
419 let capabilities = self.get_capabilities();
420
421 match kernel_type {
422 KernelType::VectorOps | KernelType::SGDUpdate | KernelType::AdamUpdate => {
423 let threads_per_block = 256.min(capabilities.max_threads_per_block);
424 let blocks = data_size.div_ceil(threads_per_block as usize);
425
426 LaunchConfig {
427 grid_size: (blocks as u32, 1, 1),
428 block_size: (threads_per_block, 1, 1),
429 shared_memory_size: 0,
430 stream: None,
431 }
432 }
433 KernelType::MatrixOps => {
434 let block_dim = 16; let grid_dim = ((data_size as f64).sqrt().ceil() as u32).div_ceil(block_dim);
436
437 LaunchConfig {
438 grid_size: (grid_dim, grid_dim, 1),
439 block_size: (block_dim, block_dim, 1),
440 shared_memory_size: block_dim as usize * block_dim as usize * 4, stream: None,
442 }
443 }
444 KernelType::Reduction => {
445 let threads_per_block = 512.min(capabilities.max_threads_per_block);
446 let blocks = data_size.div_ceil(threads_per_block as usize);
447
448 LaunchConfig {
449 grid_size: (blocks as u32, 1, 1),
450 block_size: (threads_per_block, 1, 1),
451 shared_memory_size: threads_per_block as usize * 4, stream: None,
453 }
454 }
455 _ => {
456 LaunchConfig {
458 grid_size: (1, 1, 1),
459 block_size: (256, 1, 1),
460 shared_memory_size: 0,
461 stream: None,
462 }
463 }
464 }
465 }
466}
467
468#[derive(Debug)]
472pub struct CudaCompiler {
473 nvcc_path: String,
474}
475
476impl Default for CudaCompiler {
477 fn default() -> Self {
478 Self::new()
479 }
480}
481
482impl CudaCompiler {
483 pub fn new() -> Self {
484 Self {
485 nvcc_path: "nvcc".to_string(),
486 }
487 }
488}
489
490impl KernelCompiler for CudaCompiler {
491 fn compile_kernel(
492 &self,
493 source: &str,
494 options: &CompilationOptions,
495 ) -> Result<CompiledKernel, KernelError> {
496 Ok(CompiledKernel {
498 name: "cuda_kernel".to_string(),
499 backend: GpuBackend::Cuda,
500 binary: source.as_bytes().to_vec(),
501 })
502 }
503
504 fn backend_type(&self) -> GpuBackend {
505 GpuBackend::Cuda
506 }
507
508 fn supports_feature(&self, feature: KernelFeature) -> bool {
509 match feature {
510 KernelFeature::HalfPrecision => true,
511 KernelFeature::DoublePrecision => true,
512 KernelFeature::TensorCores => true,
513 KernelFeature::CooperativeGroups => true,
514 KernelFeature::DynamicParallelism => true,
515 KernelFeature::UnifiedMemory => true,
516 }
517 }
518
519 fn get_capabilities(&self) -> CompilerCapabilities {
520 CompilerCapabilities {
521 max_threads_per_block: 1024,
522 max_blocks_per_grid: (65535, 65535, 65535),
523 max_shared_memory: 49152,
524 supported_data_types: vec![
525 "float".to_string(),
526 "double".to_string(),
527 "half".to_string(),
528 "int".to_string(),
529 ],
530 compute_capability: Some((8, 6)),
531 }
532 }
533}
534
535#[derive(Debug)]
537pub struct RocmCompiler;
538
539impl Default for RocmCompiler {
540 fn default() -> Self {
541 Self::new()
542 }
543}
544
545impl RocmCompiler {
546 pub fn new() -> Self {
547 Self
548 }
549}
550
551impl KernelCompiler for RocmCompiler {
552 fn compile_kernel(
553 &self,
554 source: &str,
555 _options: &CompilationOptions,
556 ) -> Result<CompiledKernel, KernelError> {
557 Ok(CompiledKernel {
558 name: "rocm_kernel".to_string(),
559 backend: GpuBackend::Rocm,
560 binary: source.as_bytes().to_vec(),
561 })
562 }
563
564 fn backend_type(&self) -> GpuBackend {
565 GpuBackend::Rocm
566 }
567
568 fn supports_feature(&self, feature: KernelFeature) -> bool {
569 match feature {
570 KernelFeature::HalfPrecision => true,
571 KernelFeature::DoublePrecision => true,
572 KernelFeature::TensorCores => false,
573 KernelFeature::CooperativeGroups => true,
574 KernelFeature::DynamicParallelism => false,
575 KernelFeature::UnifiedMemory => false,
576 }
577 }
578
579 fn get_capabilities(&self) -> CompilerCapabilities {
580 CompilerCapabilities {
581 max_threads_per_block: 1024,
582 max_blocks_per_grid: (65535, 65535, 65535),
583 max_shared_memory: 65536,
584 supported_data_types: vec![
585 "float".to_string(),
586 "double".to_string(),
587 "half".to_string(),
588 "int".to_string(),
589 ],
590 compute_capability: None,
591 }
592 }
593}
594
595#[derive(Debug)]
597pub struct MetalCompiler;
598
599impl Default for MetalCompiler {
600 fn default() -> Self {
601 Self::new()
602 }
603}
604
605impl MetalCompiler {
606 pub fn new() -> Self {
607 Self
608 }
609}
610
611impl KernelCompiler for MetalCompiler {
612 fn compile_kernel(
613 &self,
614 source: &str,
615 _options: &CompilationOptions,
616 ) -> Result<CompiledKernel, KernelError> {
617 Ok(CompiledKernel {
618 name: "metal_kernel".to_string(),
619 backend: GpuBackend::Metal,
620 binary: source.as_bytes().to_vec(),
621 })
622 }
623
624 fn backend_type(&self) -> GpuBackend {
625 GpuBackend::Metal
626 }
627
628 fn supports_feature(&self, feature: KernelFeature) -> bool {
629 match feature {
630 KernelFeature::HalfPrecision => true,
631 KernelFeature::DoublePrecision => false,
632 KernelFeature::TensorCores => false,
633 KernelFeature::CooperativeGroups => false,
634 KernelFeature::DynamicParallelism => false,
635 KernelFeature::UnifiedMemory => true,
636 }
637 }
638
639 fn get_capabilities(&self) -> CompilerCapabilities {
640 CompilerCapabilities {
641 max_threads_per_block: 1024,
642 max_blocks_per_grid: (65535, 65535, 65535),
643 max_shared_memory: 32768,
644 supported_data_types: vec!["float".to_string(), "half".to_string(), "int".to_string()],
645 compute_capability: None,
646 }
647 }
648}
649
650#[derive(Debug)]
652pub struct WgpuCompiler;
653
654impl Default for WgpuCompiler {
655 fn default() -> Self {
656 Self::new()
657 }
658}
659
660impl WgpuCompiler {
661 pub fn new() -> Self {
662 Self
663 }
664}
665
666impl KernelCompiler for WgpuCompiler {
667 fn compile_kernel(
668 &self,
669 source: &str,
670 _options: &CompilationOptions,
671 ) -> Result<CompiledKernel, KernelError> {
672 Ok(CompiledKernel {
673 name: "wgpu_kernel".to_string(),
674 backend: GpuBackend::Wgpu,
675 binary: source.as_bytes().to_vec(),
676 })
677 }
678
679 fn backend_type(&self) -> GpuBackend {
680 GpuBackend::Wgpu
681 }
682
683 fn supports_feature(&self, feature: KernelFeature) -> bool {
684 match feature {
685 KernelFeature::HalfPrecision => false,
686 KernelFeature::DoublePrecision => false,
687 KernelFeature::TensorCores => false,
688 KernelFeature::CooperativeGroups => false,
689 KernelFeature::DynamicParallelism => false,
690 KernelFeature::UnifiedMemory => false,
691 }
692 }
693
694 fn get_capabilities(&self) -> CompilerCapabilities {
695 CompilerCapabilities {
696 max_threads_per_block: 256,
697 max_blocks_per_grid: (65535, 65535, 65535),
698 max_shared_memory: 16384,
699 supported_data_types: vec!["float".to_string(), "int".to_string()],
700 compute_capability: None,
701 }
702 }
703}
704
705impl Default for CompilationOptions {
706 fn default() -> Self {
707 Self {
708 optimization_level: 2,
709 fast_math: true,
710 debug_info: false,
711 target_arch: "compute_70".to_string(),
712 extra_flags: Vec::new(),
713 max_registers: None,
714 shared_memory_hint: None,
715 }
716 }
717}
718
719fn sgd_kernel_template() -> &'static str {
721 r#"
722__global__ void sgd_update_kernel(
723 {{dtype}}* params,
724 const {{dtype}}* gradients,
725 const {{dtype}} learning_rate,
726 const int size
727) {
728 int idx = blockIdx.x * blockDim.x + threadIdx.x;
729 if (idx < size) {
730 params[idx] -= learning_rate * gradients[idx];
731 }
732}
733"#
734}
735
736fn adam_kernel_template() -> &'static str {
737 r#"
738__global__ void adam_update_kernel(
739 {{dtype}}* params,
740 {{dtype}}* m,
741 {{dtype}}* v,
742 const {{dtype}}* gradients,
743 const {{dtype}} learning_rate,
744 const {{dtype}} beta1,
745 const {{dtype}} beta2,
746 const {{dtype}} epsilon,
747 const int step,
748 const int size
749) {
750 int idx = blockIdx.x * blockDim.x + threadIdx.x;
751 if (idx < size) {
752 {{dtype}} grad = gradients[idx];
753 {{dtype}} m_val = beta1 * m[idx] + (1.0 - beta1) * grad;
754 {{dtype}} v_val = beta2 * v[idx] + (1.0 - beta2) * grad * grad;
755
756 {{dtype}} m_hat = m_val / (1.0 - pow(beta1, step));
757 {{dtype}} v_hat = v_val / (1.0 - beta2, step));
758
759 params[idx] -= learning_rate * m_hat / (sqrt(v_hat) + epsilon);
760
761 m[idx] = m_val;
762 v[idx] = v_val;
763 }
764}
765"#
766}
767
768fn include_str(path: &str) -> &str {
772 match path {
773 "templates/sgd_kernel.template" => sgd_kernel_template(),
774 "templates/adam_kernel.template" => adam_kernel_template(),
775 _ => "",
776 }
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782
783 #[test]
784 fn test_kernel_manager_creation() {
785 let options = CompilationOptions::default();
786 let manager = KernelManager::new(GpuBackend::Cuda, options);
787 assert!(manager.is_ok());
788 }
789
790 #[test]
791 fn test_template_generation() {
792 let options = CompilationOptions::default();
793 let manager = KernelManager::new(GpuBackend::Cuda, options).unwrap();
794
795 let mut params = HashMap::new();
796 params.insert("dtype".to_string(), "float".to_string());
797 params.insert("learning_rate".to_string(), "0.01".to_string());
798
799 let source = manager.generate_kernel_from_template("sgd_update", ¶ms);
800 assert!(source.is_ok());
801 }
802
803 #[test]
804 fn test_launch_config_calculation() {
805 let options = CompilationOptions::default();
806 let manager = KernelManager::new(GpuBackend::Cuda, options).unwrap();
807
808 let config = manager.calculate_launch_config(KernelType::SGDUpdate, 10000);
809 assert!(config.grid_size.0 > 0);
810 assert!(config.block_size.0 > 0);
811 }
812
813 #[test]
814 fn test_compiler_capabilities() {
815 let compiler = CudaCompiler::new();
816 let capabilities = compiler.get_capabilities();
817
818 assert!(capabilities.max_threads_per_block > 0);
819 assert!(!capabilities.supported_data_types.is_empty());
820 }
821}