Skip to main content

cuda_rust_wasm/parser/
kernel_extractor.rs

1//! Kernel extraction utilities
2//!
3//! Extracts kernel definitions and metadata from a parsed CUDA AST.
4
5use super::ast::*;
6
7/// Information about an extracted kernel
8#[derive(Debug, Clone)]
9pub struct KernelInfo {
10    /// Kernel name
11    pub name: String,
12    /// Kernel parameters
13    pub params: Vec<Parameter>,
14    /// Kernel attributes (launch bounds, etc.)
15    pub attributes: Vec<KernelAttribute>,
16    /// Whether the kernel uses shared memory
17    pub uses_shared_memory: bool,
18    /// Whether the kernel uses syncthreads
19    pub uses_sync_threads: bool,
20    /// Set of CUDA builtins referenced (threadIdx, blockIdx, etc.)
21    pub referenced_builtins: Vec<String>,
22    /// Names of functions called from within the kernel
23    pub called_functions: Vec<String>,
24}
25
26/// Extract all kernel definitions from an AST
27pub fn extract_kernels(ast: &Ast) -> Vec<KernelInfo> {
28    ast.items
29        .iter()
30        .filter_map(|item| {
31            if let Item::Kernel(kernel) = item {
32                Some(analyze_kernel(kernel))
33            } else {
34                None
35            }
36        })
37        .collect()
38}
39
40/// Extract a single kernel by name
41pub fn extract_kernel_by_name<'a>(ast: &'a Ast, name: &str) -> Option<&'a KernelDef> {
42    ast.items.iter().find_map(|item| {
43        if let Item::Kernel(kernel) = item {
44            if kernel.name == name {
45                return Some(kernel);
46            }
47        }
48        None
49    })
50}
51
52/// Extract all device functions from the AST
53pub fn extract_device_functions(ast: &Ast) -> Vec<&FunctionDef> {
54    ast.items
55        .iter()
56        .filter_map(|item| {
57            if let Item::DeviceFunction(func) = item {
58                Some(func)
59            } else {
60                None
61            }
62        })
63        .collect()
64}
65
66/// Analyze a kernel definition to produce KernelInfo
67fn analyze_kernel(kernel: &KernelDef) -> KernelInfo {
68    let mut info = KernelInfo {
69        name: kernel.name.clone(),
70        params: kernel.params.clone(),
71        attributes: kernel.attributes.clone(),
72        uses_shared_memory: false,
73        uses_sync_threads: false,
74        referenced_builtins: Vec::new(),
75        called_functions: Vec::new(),
76    };
77
78    visit_block(&kernel.body, &mut info);
79
80    // Deduplicate
81    info.referenced_builtins.sort();
82    info.referenced_builtins.dedup();
83    info.called_functions.sort();
84    info.called_functions.dedup();
85
86    info
87}
88
89fn visit_block(block: &Block, info: &mut KernelInfo) {
90    for stmt in &block.statements {
91        visit_statement(stmt, info);
92    }
93}
94
95fn visit_statement(stmt: &Statement, info: &mut KernelInfo) {
96    match stmt {
97        Statement::VarDecl { storage, init, .. } => {
98            if matches!(storage, StorageClass::Shared) {
99                info.uses_shared_memory = true;
100            }
101            if let Some(expr) = init {
102                visit_expression(expr, info);
103            }
104        }
105        Statement::Expr(expr) => {
106            visit_expression(expr, info);
107        }
108        Statement::Block(block) => {
109            visit_block(block, info);
110        }
111        Statement::If { condition, then_branch, else_branch } => {
112            visit_expression(condition, info);
113            visit_statement(then_branch, info);
114            if let Some(else_stmt) = else_branch {
115                visit_statement(else_stmt, info);
116            }
117        }
118        Statement::For { init, condition, update, body } => {
119            if let Some(init_stmt) = init {
120                visit_statement(init_stmt, info);
121            }
122            if let Some(cond) = condition {
123                visit_expression(cond, info);
124            }
125            if let Some(upd) = update {
126                visit_expression(upd, info);
127            }
128            visit_statement(body, info);
129        }
130        Statement::While { condition, body } => {
131            visit_expression(condition, info);
132            visit_statement(body, info);
133        }
134        Statement::Return(Some(expr)) => {
135            visit_expression(expr, info);
136        }
137        Statement::SyncThreads => {
138            info.uses_sync_threads = true;
139        }
140        _ => {}
141    }
142}
143
144fn visit_expression(expr: &Expression, info: &mut KernelInfo) {
145    match expr {
146        Expression::ThreadIdx(dim) => {
147            info.referenced_builtins.push(format!("threadIdx.{}", dim_str(dim)));
148        }
149        Expression::BlockIdx(dim) => {
150            info.referenced_builtins.push(format!("blockIdx.{}", dim_str(dim)));
151        }
152        Expression::BlockDim(dim) => {
153            info.referenced_builtins.push(format!("blockDim.{}", dim_str(dim)));
154        }
155        Expression::GridDim(dim) => {
156            info.referenced_builtins.push(format!("gridDim.{}", dim_str(dim)));
157        }
158        Expression::Binary { left, right, .. } => {
159            visit_expression(left, info);
160            visit_expression(right, info);
161        }
162        Expression::Unary { expr, .. } => {
163            visit_expression(expr, info);
164        }
165        Expression::Call { name, args } => {
166            if name != "__syncthreads" && name != "__ternary__" && name != "sizeof" {
167                info.called_functions.push(name.clone());
168            }
169            if name == "__syncthreads" {
170                info.uses_sync_threads = true;
171            }
172            for arg in args {
173                visit_expression(arg, info);
174            }
175        }
176        Expression::Index { array, index } => {
177            visit_expression(array, info);
178            visit_expression(index, info);
179        }
180        Expression::Member { object, .. } => {
181            visit_expression(object, info);
182        }
183        Expression::Cast { expr, .. } => {
184            visit_expression(expr, info);
185        }
186        Expression::WarpPrimitive { args, .. } => {
187            for arg in args {
188                visit_expression(arg, info);
189            }
190        }
191        _ => {}
192    }
193}
194
195fn dim_str(dim: &Dimension) -> &'static str {
196    match dim {
197        Dimension::X => "x",
198        Dimension::Y => "y",
199        Dimension::Z => "z",
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::parser::CudaParser;
207
208    #[test]
209    fn test_extract_vector_add() {
210        let src = r#"
211__global__ void vectorAdd(const float* a, const float* b, float* c, int n) {
212    int i = blockIdx.x * blockDim.x + threadIdx.x;
213    if (i < n) {
214        c[i] = a[i] + b[i];
215    }
216}
217"#;
218        let parser = CudaParser::new();
219        let ast = parser.parse(src).unwrap();
220        let kernels = extract_kernels(&ast);
221        assert_eq!(kernels.len(), 1);
222        let k = &kernels[0];
223        assert_eq!(k.name, "vectorAdd");
224        assert_eq!(k.params.len(), 4);
225        assert!(!k.uses_shared_memory);
226        assert!(!k.uses_sync_threads);
227        assert!(k.referenced_builtins.contains(&"threadIdx.x".to_string()));
228        assert!(k.referenced_builtins.contains(&"blockIdx.x".to_string()));
229        assert!(k.referenced_builtins.contains(&"blockDim.x".to_string()));
230    }
231
232    #[test]
233    fn test_extract_shared_memory_kernel() {
234        let src = r#"
235__global__ void matMul(float* A, float* B, float* C, int M, int N, int K) {
236    __shared__ float sA[16][16];
237    __shared__ float sB[16][16];
238    int row = blockIdx.y * blockDim.y + threadIdx.y;
239    int col = blockIdx.x * blockDim.x + threadIdx.x;
240    float sum = 0.0f;
241    for (int t = 0; t < (K + 15) / 16; t++) {
242        sA[threadIdx.y][threadIdx.x] = A[row * K + t * 16 + threadIdx.x];
243        sB[threadIdx.y][threadIdx.x] = B[(t * 16 + threadIdx.y) * N + col];
244        __syncthreads();
245        for (int k = 0; k < 16; k++) {
246            sum += sA[threadIdx.y][k] * sB[k][threadIdx.x];
247        }
248        __syncthreads();
249    }
250    C[row * N + col] = sum;
251}
252"#;
253        let parser = CudaParser::new();
254        let ast = parser.parse(src).unwrap();
255        let kernels = extract_kernels(&ast);
256        assert_eq!(kernels.len(), 1);
257        let k = &kernels[0];
258        assert_eq!(k.name, "matMul");
259        assert!(k.uses_shared_memory);
260        assert!(k.uses_sync_threads);
261    }
262
263    #[test]
264    fn test_extract_kernel_by_name() {
265        let src = r#"
266__global__ void kernel1(int* a) { a[threadIdx.x] = 0; }
267__global__ void kernel2(float* b) { b[threadIdx.x] = 1.0f; }
268"#;
269        let parser = CudaParser::new();
270        let ast = parser.parse(src).unwrap();
271        assert!(extract_kernel_by_name(&ast, "kernel1").is_some());
272        assert!(extract_kernel_by_name(&ast, "kernel2").is_some());
273        assert!(extract_kernel_by_name(&ast, "kernel3").is_none());
274    }
275}