cuda_rust_wasm/transpiler/
kernel_translator.rs1use crate::{Result, translation_error};
4use crate::parser::ast::*;
5use quote::{quote, format_ident};
6use proc_macro2::TokenStream;
7
8pub struct KernelTranslator {
10 block_dims: Option<(u32, u32, u32)>,
12 grid_dims: Option<(u32, u32, u32)>,
14}
15
16impl KernelTranslator {
17 pub fn new() -> Self {
19 Self {
20 block_dims: None,
21 grid_dims: None,
22 }
23 }
24
25 pub fn with_block_dims(mut self, x: u32, y: u32, z: u32) -> Self {
27 self.block_dims = Some((x, y, z));
28 self
29 }
30
31 pub fn with_grid_dims(mut self, x: u32, y: u32, z: u32) -> Self {
33 self.grid_dims = Some((x, y, z));
34 self
35 }
36
37 pub fn translate_vector_add(&self, kernel: &KernelDef) -> Result<TokenStream> {
39 if kernel.params.len() != 3 {
41 return Err(translation_error!("Vector addition requires 3 parameters"));
42 }
43
44 let kernel_name = format_ident!("{}", kernel.name);
45
46 Ok(quote! {
47 #[kernel]
48 pub fn #kernel_name(
49 a: &[f32],
50 b: &[f32],
51 c: &mut [f32],
52 ) {
53 let idx = thread::index().x + block::index().x * block::dim().x;
54 if idx < c.len() as u32 {
55 c[idx as usize] = a[idx as usize] + b[idx as usize];
56 }
57 }
58 })
59 }
60
61 pub fn translate_matrix_mul(&self, kernel: &KernelDef) -> Result<TokenStream> {
63 if kernel.params.len() < 5 {
65 return Err(translation_error!("Matrix multiplication requires at least 5 parameters"));
66 }
67
68 let kernel_name = format_ident!("{}", kernel.name);
69
70 Ok(quote! {
71 #[kernel]
72 pub fn #kernel_name(
73 a: &[f32],
74 b: &[f32],
75 c: &mut [f32],
76 m: u32,
77 n: u32,
78 k: u32,
79 ) {
80 let row = thread::index().y + block::index().y * block::dim().y;
81 let col = thread::index().x + block::index().x * block::dim().x;
82
83 if row < m && col < n {
84 let mut sum = 0.0f32;
85 for i in 0..k {
86 sum += a[(row * k + i) as usize] * b[(i * n + col) as usize];
87 }
88 c[(row * n + col) as usize] = sum;
89 }
90 }
91 })
92 }
93
94 pub fn translate_reduction(&self, kernel: &KernelDef) -> Result<TokenStream> {
96 let kernel_name = format_ident!("{}", kernel.name);
97
98 Ok(quote! {
99 #[kernel]
100 pub fn #kernel_name(
101 input: &[f32],
102 output: &mut [f32],
103 n: u32,
104 ) {
105 #[shared]
107 static mut PARTIAL_SUMS: [f32; 256] = [0.0; 256];
108
109 let tid = thread::index().x;
110 let gid = block::index().x * block::dim().x + tid;
111 let block_size = block::dim().x;
112
113 let mut sum = 0.0f32;
115 let mut i = gid;
116 while i < n {
117 sum += input[i as usize];
118 i += grid::dim().x * block_size;
119 }
120
121 unsafe {
123 PARTIAL_SUMS[tid as usize] = sum;
124 }
125
126 cuda_rust_wasm::runtime::sync_threads();
128
129 let mut stride = block_size / 2;
131 while stride > 0 {
132 if tid < stride {
133 unsafe {
134 PARTIAL_SUMS[tid as usize] += PARTIAL_SUMS[(tid + stride) as usize];
135 }
136 }
137 cuda_rust_wasm::runtime::sync_threads();
138 stride /= 2;
139 }
140
141 if tid == 0 {
143 output[block::index().x as usize] = unsafe { PARTIAL_SUMS[0] };
144 }
145 }
146 })
147 }
148
149 pub fn translate_stencil(&self, kernel: &KernelDef) -> Result<TokenStream> {
151 let kernel_name = format_ident!("{}", kernel.name);
152
153 Ok(quote! {
154 #[kernel]
155 pub fn #kernel_name(
156 input: &[f32],
157 output: &mut [f32],
158 width: u32,
159 height: u32,
160 ) {
161 let x = thread::index().x + block::index().x * block::dim().x;
162 let y = thread::index().y + block::index().y * block::dim().y;
163
164 if x > 0 && x < width - 1 && y > 0 && y < height - 1 {
165 let idx = (y * width + x) as usize;
166 let idx_n = ((y - 1) * width + x) as usize;
167 let idx_s = ((y + 1) * width + x) as usize;
168 let idx_e = (y * width + (x + 1)) as usize;
169 let idx_w = (y * width + (x - 1)) as usize;
170
171 output[idx] = 0.2 * (
173 input[idx] +
174 input[idx_n] +
175 input[idx_s] +
176 input[idx_e] +
177 input[idx_w]
178 );
179 }
180 }
181 })
182 }
183
184 pub fn detect_pattern(&self, kernel: &KernelDef) -> KernelPattern {
186 if self.is_vector_pattern(kernel) {
188 KernelPattern::VectorAdd
189 } else if self.is_matrix_pattern(kernel) {
190 KernelPattern::MatrixMul
191 } else if self.is_reduction_pattern(kernel) {
192 KernelPattern::Reduction
193 } else if self.is_stencil_pattern(kernel) {
194 KernelPattern::Stencil
195 } else {
196 KernelPattern::Generic
197 }
198 }
199
200 fn is_vector_pattern(&self, kernel: &KernelDef) -> bool {
202 kernel.params.len() >= 3 &&
204 self.has_linear_indexing(&kernel.body)
205 }
206
207 fn is_matrix_pattern(&self, kernel: &KernelDef) -> bool {
209 kernel.params.len() >= 5 &&
211 self.has_2d_indexing(&kernel.body)
212 }
213
214 fn is_reduction_pattern(&self, kernel: &KernelDef) -> bool {
216 self.has_shared_memory(&kernel.body) &&
218 self.has_sync_threads(&kernel.body)
219 }
220
221 fn is_stencil_pattern(&self, kernel: &KernelDef) -> bool {
223 self.has_neighbor_access(&kernel.body)
225 }
226
227 fn has_linear_indexing(&self, block: &Block) -> bool {
229 block.statements.iter().any(|stmt| {
231 match stmt {
232 Statement::VarDecl { init: Some(expr), .. } => {
233 self.is_linear_index_expr(expr)
234 },
235 Statement::Expr(expr) => self.contains_linear_index(expr),
236 _ => false,
237 }
238 })
239 }
240
241 fn has_2d_indexing(&self, block: &Block) -> bool {
243 let has_x = block.statements.iter().any(|stmt| self.uses_dimension(stmt, &Dimension::X));
245 let has_y = block.statements.iter().any(|stmt| self.uses_dimension(stmt, &Dimension::Y));
246 has_x && has_y
247 }
248
249 fn has_shared_memory(&self, block: &Block) -> bool {
251 block.statements.iter().any(|stmt| {
252 match stmt {
253 Statement::VarDecl { storage, .. } => matches!(storage, StorageClass::Shared),
254 _ => false,
255 }
256 })
257 }
258
259 fn has_sync_threads(&self, block: &Block) -> bool {
261 block.statements.iter().any(|stmt| {
262 matches!(stmt, Statement::SyncThreads)
263 })
264 }
265
266 fn has_neighbor_access(&self, block: &Block) -> bool {
268 block.statements.iter().any(|stmt| {
270 self.has_offset_access(stmt)
271 })
272 }
273
274 fn is_linear_index_expr(&self, expr: &Expression) -> bool {
276 match expr {
277 Expression::Binary { op: BinaryOp::Add, left, right } => {
278 matches!(left.as_ref(), Expression::ThreadIdx(Dimension::X)) ||
280 self.is_block_offset(right)
281 },
282 _ => false,
283 }
284 }
285
286 fn contains_linear_index(&self, expr: &Expression) -> bool {
288 match expr {
289 Expression::Binary { left, right, .. } => {
290 self.contains_linear_index(left) || self.contains_linear_index(right)
291 },
292 Expression::Index { index, .. } => self.is_linear_index_expr(index),
293 _ => false,
294 }
295 }
296
297 fn is_block_offset(&self, expr: &Expression) -> bool {
299 match expr {
300 Expression::Binary { op: BinaryOp::Mul, left, right } => {
301 matches!(left.as_ref(), Expression::BlockIdx(Dimension::X)) &&
302 matches!(right.as_ref(), Expression::BlockDim(Dimension::X))
303 },
304 _ => false,
305 }
306 }
307
308 fn uses_dimension(&self, stmt: &Statement, dim: &Dimension) -> bool {
310 match stmt {
311 Statement::VarDecl { init: Some(expr), .. } => self.expr_uses_dimension(expr, dim),
312 Statement::Expr(expr) => self.expr_uses_dimension(expr, dim),
313 _ => false,
314 }
315 }
316
317 fn expr_uses_dimension(&self, expr: &Expression, dim: &Dimension) -> bool {
319 match expr {
320 Expression::ThreadIdx(d) | Expression::BlockIdx(d) |
321 Expression::BlockDim(d) | Expression::GridDim(d) => d == dim,
322 Expression::Binary { left, right, .. } => {
323 self.expr_uses_dimension(left, dim) || self.expr_uses_dimension(right, dim)
324 },
325 _ => false,
326 }
327 }
328
329 fn has_offset_access(&self, stmt: &Statement) -> bool {
331 match stmt {
332 Statement::Expr(expr) => self.expr_has_offset_access(expr),
333 Statement::VarDecl { init: Some(expr), .. } => self.expr_has_offset_access(expr),
334 _ => false,
335 }
336 }
337
338 fn expr_has_offset_access(&self, expr: &Expression) -> bool {
340 match expr {
341 Expression::Index { index, .. } => {
342 self.has_unit_offset(index)
344 },
345 Expression::Binary { left, right, .. } => {
346 self.expr_has_offset_access(left) || self.expr_has_offset_access(right)
347 },
348 _ => false,
349 }
350 }
351
352 fn has_unit_offset(&self, expr: &Expression) -> bool {
354 match expr {
355 Expression::Binary { op: BinaryOp::Add | BinaryOp::Sub, left: _, right } => {
356 matches!(right.as_ref(), Expression::Literal(Literal::Int(1)))
357 },
358 _ => false,
359 }
360 }
361}
362
363#[derive(Debug, Clone, PartialEq)]
365pub enum KernelPattern {
366 VectorAdd,
367 MatrixMul,
368 Reduction,
369 Stencil,
370 Generic,
371}
372
373impl Default for KernelTranslator {
374 fn default() -> Self {
375 Self::new()
376 }
377}