cuda_rust_wasm/parser/
kernel_extractor.rs1use super::ast::*;
6
7#[derive(Debug, Clone)]
9pub struct KernelInfo {
10 pub name: String,
12 pub params: Vec<Parameter>,
14 pub attributes: Vec<KernelAttribute>,
16 pub uses_shared_memory: bool,
18 pub uses_sync_threads: bool,
20 pub referenced_builtins: Vec<String>,
22 pub called_functions: Vec<String>,
24}
25
26pub fn extract_kernels(ast: &Ast) -> Vec<KernelInfo> {
28 ast.items
29 .iter()
30 .filter_map(|item| {
31 if let Item::Kernel(kernel) = item {
32 Some(analyze_kernel(kernel))
33 } else {
34 None
35 }
36 })
37 .collect()
38}
39
40pub fn extract_kernel_by_name<'a>(ast: &'a Ast, name: &str) -> Option<&'a KernelDef> {
42 ast.items.iter().find_map(|item| {
43 if let Item::Kernel(kernel) = item {
44 if kernel.name == name {
45 return Some(kernel);
46 }
47 }
48 None
49 })
50}
51
52pub fn extract_device_functions(ast: &Ast) -> Vec<&FunctionDef> {
54 ast.items
55 .iter()
56 .filter_map(|item| {
57 if let Item::DeviceFunction(func) = item {
58 Some(func)
59 } else {
60 None
61 }
62 })
63 .collect()
64}
65
66fn analyze_kernel(kernel: &KernelDef) -> KernelInfo {
68 let mut info = KernelInfo {
69 name: kernel.name.clone(),
70 params: kernel.params.clone(),
71 attributes: kernel.attributes.clone(),
72 uses_shared_memory: false,
73 uses_sync_threads: false,
74 referenced_builtins: Vec::new(),
75 called_functions: Vec::new(),
76 };
77
78 visit_block(&kernel.body, &mut info);
79
80 info.referenced_builtins.sort();
82 info.referenced_builtins.dedup();
83 info.called_functions.sort();
84 info.called_functions.dedup();
85
86 info
87}
88
89fn visit_block(block: &Block, info: &mut KernelInfo) {
90 for stmt in &block.statements {
91 visit_statement(stmt, info);
92 }
93}
94
95fn visit_statement(stmt: &Statement, info: &mut KernelInfo) {
96 match stmt {
97 Statement::VarDecl { storage, init, .. } => {
98 if matches!(storage, StorageClass::Shared) {
99 info.uses_shared_memory = true;
100 }
101 if let Some(expr) = init {
102 visit_expression(expr, info);
103 }
104 }
105 Statement::Expr(expr) => {
106 visit_expression(expr, info);
107 }
108 Statement::Block(block) => {
109 visit_block(block, info);
110 }
111 Statement::If { condition, then_branch, else_branch } => {
112 visit_expression(condition, info);
113 visit_statement(then_branch, info);
114 if let Some(else_stmt) = else_branch {
115 visit_statement(else_stmt, info);
116 }
117 }
118 Statement::For { init, condition, update, body } => {
119 if let Some(init_stmt) = init {
120 visit_statement(init_stmt, info);
121 }
122 if let Some(cond) = condition {
123 visit_expression(cond, info);
124 }
125 if let Some(upd) = update {
126 visit_expression(upd, info);
127 }
128 visit_statement(body, info);
129 }
130 Statement::While { condition, body } => {
131 visit_expression(condition, info);
132 visit_statement(body, info);
133 }
134 Statement::Return(Some(expr)) => {
135 visit_expression(expr, info);
136 }
137 Statement::SyncThreads => {
138 info.uses_sync_threads = true;
139 }
140 _ => {}
141 }
142}
143
144fn visit_expression(expr: &Expression, info: &mut KernelInfo) {
145 match expr {
146 Expression::ThreadIdx(dim) => {
147 info.referenced_builtins.push(format!("threadIdx.{}", dim_str(dim)));
148 }
149 Expression::BlockIdx(dim) => {
150 info.referenced_builtins.push(format!("blockIdx.{}", dim_str(dim)));
151 }
152 Expression::BlockDim(dim) => {
153 info.referenced_builtins.push(format!("blockDim.{}", dim_str(dim)));
154 }
155 Expression::GridDim(dim) => {
156 info.referenced_builtins.push(format!("gridDim.{}", dim_str(dim)));
157 }
158 Expression::Binary { left, right, .. } => {
159 visit_expression(left, info);
160 visit_expression(right, info);
161 }
162 Expression::Unary { expr, .. } => {
163 visit_expression(expr, info);
164 }
165 Expression::Call { name, args } => {
166 if name != "__syncthreads" && name != "__ternary__" && name != "sizeof" {
167 info.called_functions.push(name.clone());
168 }
169 if name == "__syncthreads" {
170 info.uses_sync_threads = true;
171 }
172 for arg in args {
173 visit_expression(arg, info);
174 }
175 }
176 Expression::Index { array, index } => {
177 visit_expression(array, info);
178 visit_expression(index, info);
179 }
180 Expression::Member { object, .. } => {
181 visit_expression(object, info);
182 }
183 Expression::Cast { expr, .. } => {
184 visit_expression(expr, info);
185 }
186 Expression::WarpPrimitive { args, .. } => {
187 for arg in args {
188 visit_expression(arg, info);
189 }
190 }
191 _ => {}
192 }
193}
194
195fn dim_str(dim: &Dimension) -> &'static str {
196 match dim {
197 Dimension::X => "x",
198 Dimension::Y => "y",
199 Dimension::Z => "z",
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::parser::CudaParser;
207
208 #[test]
209 fn test_extract_vector_add() {
210 let src = r#"
211__global__ void vectorAdd(const float* a, const float* b, float* c, int n) {
212 int i = blockIdx.x * blockDim.x + threadIdx.x;
213 if (i < n) {
214 c[i] = a[i] + b[i];
215 }
216}
217"#;
218 let parser = CudaParser::new();
219 let ast = parser.parse(src).unwrap();
220 let kernels = extract_kernels(&ast);
221 assert_eq!(kernels.len(), 1);
222 let k = &kernels[0];
223 assert_eq!(k.name, "vectorAdd");
224 assert_eq!(k.params.len(), 4);
225 assert!(!k.uses_shared_memory);
226 assert!(!k.uses_sync_threads);
227 assert!(k.referenced_builtins.contains(&"threadIdx.x".to_string()));
228 assert!(k.referenced_builtins.contains(&"blockIdx.x".to_string()));
229 assert!(k.referenced_builtins.contains(&"blockDim.x".to_string()));
230 }
231
232 #[test]
233 fn test_extract_shared_memory_kernel() {
234 let src = r#"
235__global__ void matMul(float* A, float* B, float* C, int M, int N, int K) {
236 __shared__ float sA[16][16];
237 __shared__ float sB[16][16];
238 int row = blockIdx.y * blockDim.y + threadIdx.y;
239 int col = blockIdx.x * blockDim.x + threadIdx.x;
240 float sum = 0.0f;
241 for (int t = 0; t < (K + 15) / 16; t++) {
242 sA[threadIdx.y][threadIdx.x] = A[row * K + t * 16 + threadIdx.x];
243 sB[threadIdx.y][threadIdx.x] = B[(t * 16 + threadIdx.y) * N + col];
244 __syncthreads();
245 for (int k = 0; k < 16; k++) {
246 sum += sA[threadIdx.y][k] * sB[k][threadIdx.x];
247 }
248 __syncthreads();
249 }
250 C[row * N + col] = sum;
251}
252"#;
253 let parser = CudaParser::new();
254 let ast = parser.parse(src).unwrap();
255 let kernels = extract_kernels(&ast);
256 assert_eq!(kernels.len(), 1);
257 let k = &kernels[0];
258 assert_eq!(k.name, "matMul");
259 assert!(k.uses_shared_memory);
260 assert!(k.uses_sync_threads);
261 }
262
263 #[test]
264 fn test_extract_kernel_by_name() {
265 let src = r#"
266__global__ void kernel1(int* a) { a[threadIdx.x] = 0; }
267__global__ void kernel2(float* b) { b[threadIdx.x] = 1.0f; }
268"#;
269 let parser = CudaParser::new();
270 let ast = parser.parse(src).unwrap();
271 assert!(extract_kernel_by_name(&ast, "kernel1").is_some());
272 assert!(extract_kernel_by_name(&ast, "kernel2").is_some());
273 assert!(extract_kernel_by_name(&ast, "kernel3").is_none());
274 }
275}