numrs/backend/webgpu/
codegen.rs1use crate::llo::ElementwiseKind;
2
3pub fn elementwise_wgsl(kind: ElementwiseKind, _inputs: Vec<usize>, output_shape: Vec<usize>) -> String {
5 let op = match kind {
6 ElementwiseKind::Add => "a + b",
7 ElementwiseKind::Mul => "a * b",
8 ElementwiseKind::Sub => "a - b",
9 ElementwiseKind::Div => "a / b",
10 ElementwiseKind::Sqrt => "sqrt(a)",
11 ElementwiseKind::Sin => "sin(a)",
12 ElementwiseKind::Cos => "cos(a)",
13 ElementwiseKind::Pow => "pow(a, b)",
14 ElementwiseKind::Abs => "abs(a)",
15 ElementwiseKind::Neg => "-a",
16 ElementwiseKind::Exp => "exp(a)",
17 ElementwiseKind::Log => "log(a)",
18 ElementwiseKind::Tan => "tan(a)",
19 ElementwiseKind::Asin => "asin(a)",
20 ElementwiseKind::Acos => "acos(a)",
21 ElementwiseKind::Atan => "atan(a)",
22 ElementwiseKind::Relu => "max(a, 0.0)",
23 ElementwiseKind::LeakyRelu => "select(0.01 * a, a, a > 0.0)",
24 ElementwiseKind::Sigmoid => "1.0 / (1.0 + exp(-a))",
25 ElementwiseKind::Tanh => "tanh(a)",
26 ElementwiseKind::Softplus => "log(1.0 + exp(a))",
27 };
28
29 format!(r#"// WGSL elementwise kernel (prototype)
30fn compute(a: f32, b: f32) -> f32 {{
31 {}
32}}
33// output shape: {:?}
34"#, op, output_shape)
35}
36
37pub fn reduction_wgsl(_axis: Option<usize>, _inputs: Vec<usize>, _output_shape: Vec<usize>) -> String {
38 "// WGSL reduction kernel template (prototype)\nfn compute_reduce(a: array<f32>) -> f32 { /* ... */ }\n".to_string()
39}