use crate::{Result, translation_error};
use crate::parser::ast::*;
use quote::{quote, format_ident};
use proc_macro2::TokenStream;
pub struct KernelTranslator {
block_dims: Option<(u32, u32, u32)>,
grid_dims: Option<(u32, u32, u32)>,
}
impl KernelTranslator {
pub fn new() -> Self {
Self {
block_dims: None,
grid_dims: None,
}
}
pub fn with_block_dims(mut self, x: u32, y: u32, z: u32) -> Self {
self.block_dims = Some((x, y, z));
self
}
pub fn with_grid_dims(mut self, x: u32, y: u32, z: u32) -> Self {
self.grid_dims = Some((x, y, z));
self
}
pub fn translate_vector_add(&self, kernel: &KernelDef) -> Result<TokenStream> {
if kernel.params.len() != 3 {
return Err(translation_error!("Vector addition requires 3 parameters"));
}
let kernel_name = format_ident!("{}", kernel.name);
Ok(quote! {
#[kernel]
pub fn #kernel_name(
a: &[f32],
b: &[f32],
c: &mut [f32],
) {
let idx = thread::index().x + block::index().x * block::dim().x;
if idx < c.len() as u32 {
c[idx as usize] = a[idx as usize] + b[idx as usize];
}
}
})
}
pub fn translate_matrix_mul(&self, kernel: &KernelDef) -> Result<TokenStream> {
if kernel.params.len() < 5 {
return Err(translation_error!("Matrix multiplication requires at least 5 parameters"));
}
let kernel_name = format_ident!("{}", kernel.name);
Ok(quote! {
#[kernel]
pub fn #kernel_name(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: u32,
n: u32,
k: u32,
) {
let row = thread::index().y + block::index().y * block::dim().y;
let col = thread::index().x + block::index().x * block::dim().x;
if row < m && col < n {
let mut sum = 0.0f32;
for i in 0..k {
sum += a[(row * k + i) as usize] * b[(i * n + col) as usize];
}
c[(row * n + col) as usize] = sum;
}
}
})
}
pub fn translate_reduction(&self, kernel: &KernelDef) -> Result<TokenStream> {
let kernel_name = format_ident!("{}", kernel.name);
Ok(quote! {
#[kernel]
pub fn #kernel_name(
input: &[f32],
output: &mut [f32],
n: u32,
) {
#[shared]
static mut PARTIAL_SUMS: [f32; 256] = [0.0; 256];
let tid = thread::index().x;
let gid = block::index().x * block::dim().x + tid;
let block_size = block::dim().x;
let mut sum = 0.0f32;
let mut i = gid;
while i < n {
sum += input[i as usize];
i += grid::dim().x * block_size;
}
unsafe {
PARTIAL_SUMS[tid as usize] = sum;
}
cuda_rust_wasm::runtime::sync_threads();
let mut stride = block_size / 2;
while stride > 0 {
if tid < stride {
unsafe {
PARTIAL_SUMS[tid as usize] += PARTIAL_SUMS[(tid + stride) as usize];
}
}
cuda_rust_wasm::runtime::sync_threads();
stride /= 2;
}
if tid == 0 {
output[block::index().x as usize] = unsafe { PARTIAL_SUMS[0] };
}
}
})
}
pub fn translate_stencil(&self, kernel: &KernelDef) -> Result<TokenStream> {
let kernel_name = format_ident!("{}", kernel.name);
Ok(quote! {
#[kernel]
pub fn #kernel_name(
input: &[f32],
output: &mut [f32],
width: u32,
height: u32,
) {
let x = thread::index().x + block::index().x * block::dim().x;
let y = thread::index().y + block::index().y * block::dim().y;
if x > 0 && x < width - 1 && y > 0 && y < height - 1 {
let idx = (y * width + x) as usize;
let idx_n = ((y - 1) * width + x) as usize;
let idx_s = ((y + 1) * width + x) as usize;
let idx_e = (y * width + (x + 1)) as usize;
let idx_w = (y * width + (x - 1)) as usize;
output[idx] = 0.2 * (
input[idx] +
input[idx_n] +
input[idx_s] +
input[idx_e] +
input[idx_w]
);
}
}
})
}
pub fn detect_pattern(&self, kernel: &KernelDef) -> KernelPattern {
if self.is_matrix_pattern(kernel) {
KernelPattern::MatrixMul
} else if self.is_reduction_pattern(kernel) {
KernelPattern::Reduction
} else if self.is_stencil_pattern(kernel) {
KernelPattern::Stencil
} else if self.is_vector_pattern(kernel) {
KernelPattern::VectorAdd
} else {
KernelPattern::Generic
}
}
fn is_vector_pattern(&self, kernel: &KernelDef) -> bool {
kernel.params.len() >= 3 &&
self.has_linear_indexing(&kernel.body)
}
fn is_matrix_pattern(&self, kernel: &KernelDef) -> bool {
kernel.params.len() >= 5 &&
self.has_2d_indexing(&kernel.body)
}
fn is_reduction_pattern(&self, kernel: &KernelDef) -> bool {
self.has_shared_memory(&kernel.body) &&
self.has_sync_threads(&kernel.body)
}
fn is_stencil_pattern(&self, kernel: &KernelDef) -> bool {
self.has_neighbor_access(&kernel.body)
}
fn has_linear_indexing(&self, block: &Block) -> bool {
block.statements.iter().any(|stmt| {
match stmt {
Statement::VarDecl { init: Some(expr), .. } => {
self.is_linear_index_expr(expr)
},
Statement::Expr(expr) => self.contains_linear_index(expr),
_ => false,
}
})
}
fn has_2d_indexing(&self, block: &Block) -> bool {
let has_x = block.statements.iter().any(|stmt| self.uses_dimension(stmt, &Dimension::X));
let has_y = block.statements.iter().any(|stmt| self.uses_dimension(stmt, &Dimension::Y));
has_x && has_y
}
fn has_shared_memory(&self, block: &Block) -> bool {
block.statements.iter().any(|stmt| {
match stmt {
Statement::VarDecl { storage, .. } => matches!(storage, StorageClass::Shared),
_ => false,
}
})
}
fn has_sync_threads(&self, block: &Block) -> bool {
block.statements.iter().any(|stmt| {
matches!(stmt, Statement::SyncThreads)
})
}
fn has_neighbor_access(&self, block: &Block) -> bool {
block.statements.iter().any(|stmt| {
self.has_offset_access(stmt)
})
}
fn is_linear_index_expr(&self, expr: &Expression) -> bool {
match expr {
Expression::Binary { op: BinaryOp::Add, left, right } => {
matches!(left.as_ref(), Expression::ThreadIdx(Dimension::X)) ||
self.is_block_offset(right)
},
_ => false,
}
}
fn contains_linear_index(&self, expr: &Expression) -> bool {
match expr {
Expression::Binary { left, right, .. } => {
self.contains_linear_index(left) || self.contains_linear_index(right)
},
Expression::Index { index, .. } => self.is_linear_index_expr(index),
_ => false,
}
}
fn is_block_offset(&self, expr: &Expression) -> bool {
match expr {
Expression::Binary { op: BinaryOp::Mul, left, right } => {
matches!(left.as_ref(), Expression::BlockIdx(Dimension::X)) &&
matches!(right.as_ref(), Expression::BlockDim(Dimension::X))
},
_ => false,
}
}
fn uses_dimension(&self, stmt: &Statement, dim: &Dimension) -> bool {
match stmt {
Statement::VarDecl { init: Some(expr), .. } => self.expr_uses_dimension(expr, dim),
Statement::Expr(expr) => self.expr_uses_dimension(expr, dim),
_ => false,
}
}
fn expr_uses_dimension(&self, expr: &Expression, dim: &Dimension) -> bool {
match expr {
Expression::ThreadIdx(d) | Expression::BlockIdx(d) |
Expression::BlockDim(d) | Expression::GridDim(d) => d == dim,
Expression::Binary { left, right, .. } => {
self.expr_uses_dimension(left, dim) || self.expr_uses_dimension(right, dim)
},
_ => false,
}
}
fn has_offset_access(&self, stmt: &Statement) -> bool {
match stmt {
Statement::Expr(expr) => self.expr_has_offset_access(expr),
Statement::VarDecl { init: Some(expr), .. } => self.expr_has_offset_access(expr),
Statement::If { then_branch, else_branch, .. } => {
self.has_offset_access(then_branch) ||
else_branch.as_ref().map_or(false, |e| self.has_offset_access(e))
},
Statement::Block(block) => block.statements.iter().any(|s| self.has_offset_access(s)),
Statement::For { body, .. } | Statement::While { body, .. } => self.has_offset_access(body),
_ => false,
}
}
fn expr_has_offset_access(&self, expr: &Expression) -> bool {
match expr {
Expression::Index { index, .. } => {
self.has_unit_offset(index)
},
Expression::Binary { left, right, .. } => {
self.expr_has_offset_access(left) || self.expr_has_offset_access(right)
},
_ => false,
}
}
fn has_unit_offset(&self, expr: &Expression) -> bool {
match expr {
Expression::Binary { op: BinaryOp::Add | BinaryOp::Sub, left: _, right } => {
matches!(right.as_ref(), Expression::Literal(Literal::Int(1)))
},
_ => false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum KernelPattern {
VectorAdd,
MatrixMul,
Reduction,
Stencil,
Generic,
}
impl Default for KernelTranslator {
fn default() -> Self {
Self::new()
}
}