ghostflow_cuda/kernels/
mod.rs1#[derive(Debug, Clone, Copy)]
8pub struct LaunchConfig {
9 pub grid_dim: (u32, u32, u32),
10 pub block_dim: (u32, u32, u32),
11 pub shared_mem: usize,
12}
13
14impl LaunchConfig {
15 pub fn linear(n: usize, block_size: u32) -> Self {
17 let grid_size = ((n as u32) + block_size - 1) / block_size;
18 LaunchConfig {
19 grid_dim: (grid_size, 1, 1),
20 block_dim: (block_size, 1, 1),
21 shared_mem: 0,
22 }
23 }
24
25 pub fn grid_2d(rows: usize, cols: usize, block_x: u32, block_y: u32) -> Self {
27 let grid_x = ((cols as u32) + block_x - 1) / block_x;
28 let grid_y = ((rows as u32) + block_y - 1) / block_y;
29 LaunchConfig {
30 grid_dim: (grid_x, grid_y, 1),
31 block_dim: (block_x, block_y, 1),
32 shared_mem: 0,
33 }
34 }
35
36 pub fn with_shared_mem(mut self, bytes: usize) -> Self {
38 self.shared_mem = bytes;
39 self
40 }
41}
42
43pub type KernelFn = unsafe extern "C" fn();
45
46pub mod elementwise {
48 pub const ADD_KERNEL: &str = "ghostflow_add_f32";
50
51 pub const MUL_KERNEL: &str = "ghostflow_mul_f32";
53
54 pub const RELU_KERNEL: &str = "ghostflow_relu_f32";
56
57 pub const GELU_KERNEL: &str = "ghostflow_gelu_f32";
59
60 pub const SIGMOID_KERNEL: &str = "ghostflow_sigmoid_f32";
62}
63
64pub mod matmul {
66 pub const GEMM_KERNEL: &str = "ghostflow_gemm_f32";
68
69 pub const GEMM_TILED_KERNEL: &str = "ghostflow_gemm_tiled_f32";
71
72 pub const GEMM_WMMA_KERNEL: &str = "ghostflow_gemm_wmma_f16";
74}
75
76pub mod attention {
78 pub const FLASH_ATTN_FWD: &str = "ghostflow_flash_attention_fwd";
80
81 pub const FLASH_ATTN_BWD: &str = "ghostflow_flash_attention_bwd";
83
84 pub const STANDARD_ATTN: &str = "ghostflow_standard_attention";
86}
87
88pub mod reduction {
90 pub const SUM_KERNEL: &str = "ghostflow_sum_f32";
92
93 pub const MAX_KERNEL: &str = "ghostflow_max_f32";
95
96 pub const SOFTMAX_KERNEL: &str = "ghostflow_softmax_f32";
98}
99
100pub mod normalization {
102 pub const LAYER_NORM_KERNEL: &str = "ghostflow_layer_norm_f32";
104
105 pub const RMS_NORM_KERNEL: &str = "ghostflow_rms_norm_f32";
107
108 pub const BATCH_NORM_KERNEL: &str = "ghostflow_batch_norm_f32";
110}