numrs/backend/webgpu/
codegen.rs

1use crate::llo::ElementwiseKind;
2
3/// Generate a simple WGSL elementwise kernel for Add/Mul
4pub fn elementwise_wgsl(kind: ElementwiseKind, _inputs: Vec<usize>, output_shape: Vec<usize>) -> String {
5    let op = match kind {
6        ElementwiseKind::Add => "a + b",
7        ElementwiseKind::Mul => "a * b",
8        ElementwiseKind::Sub => "a - b",
9        ElementwiseKind::Div => "a / b",
10        ElementwiseKind::Sqrt => "sqrt(a)",
11        ElementwiseKind::Sin => "sin(a)",
12        ElementwiseKind::Cos => "cos(a)",
13        ElementwiseKind::Pow => "pow(a, b)",
14        ElementwiseKind::Abs => "abs(a)",
15        ElementwiseKind::Neg => "-a",
16        ElementwiseKind::Exp => "exp(a)",
17        ElementwiseKind::Log => "log(a)",
18        ElementwiseKind::Tan => "tan(a)",
19        ElementwiseKind::Asin => "asin(a)",
20        ElementwiseKind::Acos => "acos(a)",
21        ElementwiseKind::Atan => "atan(a)",
22        ElementwiseKind::Relu => "max(a, 0.0)",
23        ElementwiseKind::LeakyRelu => "select(0.01 * a, a, a > 0.0)",
24        ElementwiseKind::Sigmoid => "1.0 / (1.0 + exp(-a))",
25        ElementwiseKind::Tanh => "tanh(a)",
26        ElementwiseKind::Softplus => "log(1.0 + exp(a))",
27    };
28
29    format!(r#"// WGSL elementwise kernel (prototype)
30fn compute(a: f32, b: f32) -> f32 {{
31    {}
32}}
33// output shape: {:?}
34"#, op, output_shape)
35}
36
37pub fn reduction_wgsl(_axis: Option<usize>, _inputs: Vec<usize>, _output_shape: Vec<usize>) -> String {
38    "// WGSL reduction kernel template (prototype)\nfn compute_reduce(a: array<f32>) -> f32 { /* ... */ }\n".to_string()
39}