ghostflow_cuda/kernels/
mod.rs

1//! CUDA kernel definitions
2//!
3//! This module contains the Rust-side definitions for CUDA kernels.
4//! The actual CUDA code would be in .cu files compiled by nvcc.
5
6/// Kernel launch configuration
7#[derive(Debug, Clone, Copy)]
8pub struct LaunchConfig {
9    pub grid_dim: (u32, u32, u32),
10    pub block_dim: (u32, u32, u32),
11    pub shared_mem: usize,
12}
13
14impl LaunchConfig {
15    /// Create 1D launch config
16    pub fn linear(n: usize, block_size: u32) -> Self {
17        let grid_size = ((n as u32) + block_size - 1) / block_size;
18        LaunchConfig {
19            grid_dim: (grid_size, 1, 1),
20            block_dim: (block_size, 1, 1),
21            shared_mem: 0,
22        }
23    }
24
25    /// Create 2D launch config
26    pub fn grid_2d(rows: usize, cols: usize, block_x: u32, block_y: u32) -> Self {
27        let grid_x = ((cols as u32) + block_x - 1) / block_x;
28        let grid_y = ((rows as u32) + block_y - 1) / block_y;
29        LaunchConfig {
30            grid_dim: (grid_x, grid_y, 1),
31            block_dim: (block_x, block_y, 1),
32            shared_mem: 0,
33        }
34    }
35
36    /// Set shared memory size
37    pub fn with_shared_mem(mut self, bytes: usize) -> Self {
38        self.shared_mem = bytes;
39        self
40    }
41}
42
43/// Kernel function pointer type
44pub type KernelFn = unsafe extern "C" fn();
45
46/// Elementwise kernel signatures
47pub mod elementwise {
48    /// Add kernel: `c[i] = a[i] + b[i]`
49    pub const ADD_KERNEL: &str = "ghostflow_add_f32";
50    
51    /// Multiply kernel: `c[i] = a[i] * b[i]`
52    pub const MUL_KERNEL: &str = "ghostflow_mul_f32";
53    
54    /// ReLU kernel: `y[i] = max(0, x[i])`
55    pub const RELU_KERNEL: &str = "ghostflow_relu_f32";
56    
57    /// GELU kernel
58    pub const GELU_KERNEL: &str = "ghostflow_gelu_f32";
59    
60    /// Sigmoid kernel
61    pub const SIGMOID_KERNEL: &str = "ghostflow_sigmoid_f32";
62}
63
64/// Matrix multiplication kernel signatures
65pub mod matmul {
66    /// Basic GEMM kernel
67    pub const GEMM_KERNEL: &str = "ghostflow_gemm_f32";
68    
69    /// Tiled GEMM kernel (better cache utilization)
70    pub const GEMM_TILED_KERNEL: &str = "ghostflow_gemm_tiled_f32";
71    
72    /// Tensor Core GEMM (for Volta+)
73    pub const GEMM_WMMA_KERNEL: &str = "ghostflow_gemm_wmma_f16";
74}
75
76/// Attention kernel signatures
77pub mod attention {
78    /// Flash Attention forward kernel
79    pub const FLASH_ATTN_FWD: &str = "ghostflow_flash_attention_fwd";
80    
81    /// Flash Attention backward kernel
82    pub const FLASH_ATTN_BWD: &str = "ghostflow_flash_attention_bwd";
83    
84    /// Standard attention (for comparison)
85    pub const STANDARD_ATTN: &str = "ghostflow_standard_attention";
86}
87
88/// Reduction kernel signatures
89pub mod reduction {
90    /// Sum reduction
91    pub const SUM_KERNEL: &str = "ghostflow_sum_f32";
92    
93    /// Max reduction
94    pub const MAX_KERNEL: &str = "ghostflow_max_f32";
95    
96    /// Softmax kernel
97    pub const SOFTMAX_KERNEL: &str = "ghostflow_softmax_f32";
98}
99
100/// Normalization kernel signatures
101pub mod normalization {
102    /// Layer normalization
103    pub const LAYER_NORM_KERNEL: &str = "ghostflow_layer_norm_f32";
104    
105    /// RMS normalization
106    pub const RMS_NORM_KERNEL: &str = "ghostflow_rms_norm_f32";
107    
108    /// Batch normalization
109    pub const BATCH_NORM_KERNEL: &str = "ghostflow_batch_norm_f32";
110}