Skip to main content

oxicuda_webgpu/
shader.rs

1//! WGSL shader source generation for common compute kernels.
2//!
3//! Each function returns a complete, self-contained WGSL source string
4//! suitable for passing to `device.create_shader_module()`.
5
6/// Generate WGSL source for a tiled GEMM kernel: `C = alpha * A * B + beta * C`.
7///
8/// Uses `tile_size × tile_size` workgroups.  Both A and B are stored row-major.
9///
10/// # Arguments
11///
12/// * `tile_size` — workgroup tile dimension (e.g. 8, 16, 32).
13pub fn gemm_wgsl(tile_size: u32) -> String {
14    format!(
15        r#"
16struct GemmParams {{
17    m:     u32,
18    n:     u32,
19    k:     u32,
20    alpha: f32,
21    beta:  f32,
22}}
23
24@group(0) @binding(0) var<storage, read>       a:      array<f32>;
25@group(0) @binding(1) var<storage, read>       b:      array<f32>;
26@group(0) @binding(2) var<storage, read_write> c:      array<f32>;
27@group(0) @binding(3) var<uniform>             params: GemmParams;
28
29@compute @workgroup_size({ts}, {ts})
30fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
31    let row = gid.y;
32    let col = gid.x;
33    if (row >= params.m || col >= params.n) {{ return; }}
34
35    var acc: f32 = 0.0;
36    for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
37        acc += a[row * params.k + i] * b[i * params.n + col];
38    }}
39
40    let idx = row * params.n + col;
41    c[idx] = params.alpha * acc + params.beta * c[idx];
42}}
43"#,
44        ts = tile_size
45    )
46}
47
48/// Generate WGSL source for an element-wise unary operation.
49///
50/// The shader reads `n` elements from `input`, applies the operation, and
51/// writes the results to `output`.  Both buffers have `arrayLength` elements.
52///
53/// # Arguments
54///
55/// * `op` — one of: `"relu"`, `"sigmoid"`, `"tanh"`, `"exp"`, `"log"`,
56///   `"sqrt"`, `"abs"`, `"neg"`.  Unknown ops are treated as identity.
57pub fn elementwise_wgsl(op: &str) -> String {
58    let op_expr = match op {
59        "relu" => "max(x, 0.0)",
60        "sigmoid" => "1.0 / (1.0 + exp(-x))",
61        "tanh" => "tanh(x)",
62        "exp" => "exp(x)",
63        "log" => "log(x)",
64        "sqrt" => "sqrt(x)",
65        "abs" => "abs(x)",
66        "neg" => "-x",
67        _ => "x",
68    };
69
70    format!(
71        r#"
72@group(0) @binding(0) var<storage, read>       input:  array<f32>;
73@group(0) @binding(1) var<storage, read_write> output: array<f32>;
74
75@compute @workgroup_size(256)
76fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
77    let i = gid.x;
78    if (i >= arrayLength(&input)) {{ return; }}
79    let x = input[i];
80    output[i] = {op};
81}}
82"#,
83        op = op_expr
84    )
85}
86
87/// Generate WGSL source for a parallel workgroup-level reduction.
88///
89/// Performs a two-pass approach: each workgroup of 256 threads reduces its
90/// tile to a single value in shared memory, then the results are written to
91/// a partial-sums buffer.  A second dispatch (with a single workgroup) then
92/// reduces the partial-sums to the final scalar.
93///
94/// # Arguments
95///
96/// * `op` — one of: `"sum"`, `"max"`, `"min"`, `"mean"`.  `"mean"` behaves
97///   like `"sum"` in the shader; the CPU is responsible for dividing by N.
98///   Unknown ops fall back to `"sum"`.
99pub fn reduction_wgsl(op: &str) -> String {
100    // Neutral elements and combine expressions for each operation.
101    let (neutral, combine) = match op {
102        "max" => ("f32(-1e38)", "max(acc, val)"),
103        "min" => ("f32(1e38)", "min(acc, val)"),
104        // "sum" and "mean" use the same reduction body.
105        _ => ("f32(0.0)", "acc + val"),
106    };
107
108    format!(
109        r#"
110// Reduction params: total element count.
111struct ReduceParams {{
112    n: u32,
113}}
114
115@group(0) @binding(0) var<storage, read>       input:        array<f32>;
116@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
117@group(0) @binding(2) var<uniform>             params:       ReduceParams;
118
119var<workgroup> shared_data: array<f32, 256>;
120
121@compute @workgroup_size(256)
122fn main(
123    @builtin(global_invocation_id) gid:  vec3<u32>,
124    @builtin(local_invocation_id)  lid:  vec3<u32>,
125    @builtin(workgroup_id)         wgid: vec3<u32>,
126) {{
127    let tid         = lid.x;
128    let global_idx  = gid.x;
129
130    // Load or use neutral element when out of range.
131    if (global_idx < params.n) {{
132        shared_data[tid] = input[global_idx];
133    }} else {{
134        shared_data[tid] = {neutral};
135    }}
136    workgroupBarrier();
137
138    // Parallel tree reduction within the workgroup.
139    var stride: u32 = 128u;
140    loop {{
141        if (stride == 0u) {{ break; }}
142        if (tid < stride) {{
143            let acc = shared_data[tid];
144            let val = shared_data[tid + stride];
145            shared_data[tid] = {combine};
146        }}
147        workgroupBarrier();
148        stride = stride >> 1u;
149    }}
150
151    // Thread 0 writes the workgroup result to the partial-sums buffer.
152    if (tid == 0u) {{
153        partial_sums[wgid.x] = shared_data[0];
154    }}
155}}
156"#,
157        neutral = neutral,
158        combine = combine,
159    )
160}
161
162/// Generate WGSL for the final scalar reduction of partial sums.
163///
164/// Takes a `partial_sums` array of length `num_groups` and reduces it to a
165/// single value at `output[0]`.  Should be dispatched with a single workgroup
166/// of 256 threads.
167pub fn reduction_final_wgsl(op: &str) -> String {
168    let (neutral, combine) = match op {
169        "max" => ("f32(-1e38)", "max(acc, val)"),
170        "min" => ("f32(1e38)", "min(acc, val)"),
171        _ => ("f32(0.0)", "acc + val"),
172    };
173
174    format!(
175        r#"
176struct FinalReduceParams {{
177    num_groups: u32,
178}}
179
180@group(0) @binding(0) var<storage, read>       partial_sums: array<f32>;
181@group(0) @binding(1) var<storage, read_write> output:       array<f32>;
182@group(0) @binding(2) var<uniform>             params:       FinalReduceParams;
183
184var<workgroup> shared_data: array<f32, 256>;
185
186@compute @workgroup_size(256)
187fn main(
188    @builtin(local_invocation_id) lid: vec3<u32>,
189) {{
190    let tid = lid.x;
191
192    if (tid < params.num_groups) {{
193        shared_data[tid] = partial_sums[tid];
194    }} else {{
195        shared_data[tid] = {neutral};
196    }}
197    workgroupBarrier();
198
199    var stride: u32 = 128u;
200    loop {{
201        if (stride == 0u) {{ break; }}
202        if (tid < stride) {{
203            let acc = shared_data[tid];
204            let val = shared_data[tid + stride];
205            shared_data[tid] = {combine};
206        }}
207        workgroupBarrier();
208        stride = stride >> 1u;
209    }}
210
211    if (tid == 0u) {{
212        output[0] = shared_data[0];
213    }}
214}}
215"#,
216        neutral = neutral,
217        combine = combine,
218    )
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn wgsl_gemm_contains_workgroup() {
227        let src = gemm_wgsl(16);
228        assert!(src.contains("@compute @workgroup_size(16, 16)"));
229        assert!(src.contains("GemmParams"));
230        assert!(src.contains("alpha"));
231        assert!(src.contains("beta"));
232    }
233
234    #[test]
235    fn wgsl_gemm_tile_size_embedded() {
236        let src8 = gemm_wgsl(8);
237        assert!(src8.contains("@workgroup_size(8, 8)"));
238        let src32 = gemm_wgsl(32);
239        assert!(src32.contains("@workgroup_size(32, 32)"));
240    }
241
242    #[test]
243    fn wgsl_elementwise_relu_contains_max() {
244        let src = elementwise_wgsl("relu");
245        assert!(src.contains("max(x, 0.0)"));
246    }
247
248    #[test]
249    fn wgsl_elementwise_all_ops() {
250        assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
251        assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
252        assert!(elementwise_wgsl("exp").contains("exp(x)"));
253        assert!(elementwise_wgsl("log").contains("log(x)"));
254        assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
255        assert!(elementwise_wgsl("abs").contains("abs(x)"));
256        assert!(elementwise_wgsl("neg").contains("-x"));
257        // Unknown op is identity.
258        assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
259    }
260
261    #[test]
262    fn wgsl_reduction_sum_contains_addition() {
263        let src = reduction_wgsl("sum");
264        assert!(src.contains("acc + val"));
265        assert!(src.contains("workgroupBarrier"));
266    }
267
268    #[test]
269    fn wgsl_reduction_max_uses_max_fn() {
270        let src = reduction_wgsl("max");
271        assert!(src.contains("max(acc, val)"));
272    }
273
274    #[test]
275    fn wgsl_reduction_min_uses_min_fn() {
276        let src = reduction_wgsl("min");
277        assert!(src.contains("min(acc, val)"));
278    }
279
280    #[test]
281    fn wgsl_reduction_mean_same_as_sum() {
282        // "mean" divides on the CPU side; the shader is identical to sum.
283        let sum_src = reduction_wgsl("sum");
284        let mean_src = reduction_wgsl("mean");
285        assert_eq!(sum_src, mean_src);
286    }
287
288    #[test]
289    fn wgsl_reduction_final_sum() {
290        let src = reduction_final_wgsl("sum");
291        assert!(src.contains("num_groups"));
292        assert!(src.contains("output[0]"));
293    }
294}