1pub fn gemm_wgsl(tile_size: u32) -> String {
14 format!(
15 r#"
16struct GemmParams {{
17 m: u32,
18 n: u32,
19 k: u32,
20 alpha: f32,
21 beta: f32,
22}}
23
24@group(0) @binding(0) var<storage, read> a: array<f32>;
25@group(0) @binding(1) var<storage, read> b: array<f32>;
26@group(0) @binding(2) var<storage, read_write> c: array<f32>;
27@group(0) @binding(3) var<uniform> params: GemmParams;
28
29@compute @workgroup_size({ts}, {ts})
30fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
31 let row = gid.y;
32 let col = gid.x;
33 if (row >= params.m || col >= params.n) {{ return; }}
34
35 var acc: f32 = 0.0;
36 for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
37 acc += a[row * params.k + i] * b[i * params.n + col];
38 }}
39
40 let idx = row * params.n + col;
41 c[idx] = params.alpha * acc + params.beta * c[idx];
42}}
43"#,
44 ts = tile_size
45 )
46}
47
48pub fn elementwise_wgsl(op: &str) -> String {
58 let op_expr = match op {
59 "relu" => "max(x, 0.0)",
60 "sigmoid" => "1.0 / (1.0 + exp(-x))",
61 "tanh" => "tanh(x)",
62 "exp" => "exp(x)",
63 "log" => "log(x)",
64 "sqrt" => "sqrt(x)",
65 "abs" => "abs(x)",
66 "neg" => "-x",
67 _ => "x",
68 };
69
70 format!(
71 r#"
72@group(0) @binding(0) var<storage, read> input: array<f32>;
73@group(0) @binding(1) var<storage, read_write> output: array<f32>;
74
75@compute @workgroup_size(256)
76fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
77 let i = gid.x;
78 if (i >= arrayLength(&input)) {{ return; }}
79 let x = input[i];
80 output[i] = {op};
81}}
82"#,
83 op = op_expr
84 )
85}
86
87pub fn reduction_wgsl(op: &str) -> String {
100 let (neutral, combine) = match op {
102 "max" => ("f32(-1e38)", "max(acc, val)"),
103 "min" => ("f32(1e38)", "min(acc, val)"),
104 _ => ("f32(0.0)", "acc + val"),
106 };
107
108 format!(
109 r#"
110// Reduction params: total element count.
111struct ReduceParams {{
112 n: u32,
113}}
114
115@group(0) @binding(0) var<storage, read> input: array<f32>;
116@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
117@group(0) @binding(2) var<uniform> params: ReduceParams;
118
119var<workgroup> shared_data: array<f32, 256>;
120
121@compute @workgroup_size(256)
122fn main(
123 @builtin(global_invocation_id) gid: vec3<u32>,
124 @builtin(local_invocation_id) lid: vec3<u32>,
125 @builtin(workgroup_id) wgid: vec3<u32>,
126) {{
127 let tid = lid.x;
128 let global_idx = gid.x;
129
130 // Load or use neutral element when out of range.
131 if (global_idx < params.n) {{
132 shared_data[tid] = input[global_idx];
133 }} else {{
134 shared_data[tid] = {neutral};
135 }}
136 workgroupBarrier();
137
138 // Parallel tree reduction within the workgroup.
139 var stride: u32 = 128u;
140 loop {{
141 if (stride == 0u) {{ break; }}
142 if (tid < stride) {{
143 let acc = shared_data[tid];
144 let val = shared_data[tid + stride];
145 shared_data[tid] = {combine};
146 }}
147 workgroupBarrier();
148 stride = stride >> 1u;
149 }}
150
151 // Thread 0 writes the workgroup result to the partial-sums buffer.
152 if (tid == 0u) {{
153 partial_sums[wgid.x] = shared_data[0];
154 }}
155}}
156"#,
157 neutral = neutral,
158 combine = combine,
159 )
160}
161
162pub fn reduction_final_wgsl(op: &str) -> String {
168 let (neutral, combine) = match op {
169 "max" => ("f32(-1e38)", "max(acc, val)"),
170 "min" => ("f32(1e38)", "min(acc, val)"),
171 _ => ("f32(0.0)", "acc + val"),
172 };
173
174 format!(
175 r#"
176struct FinalReduceParams {{
177 num_groups: u32,
178}}
179
180@group(0) @binding(0) var<storage, read> partial_sums: array<f32>;
181@group(0) @binding(1) var<storage, read_write> output: array<f32>;
182@group(0) @binding(2) var<uniform> params: FinalReduceParams;
183
184var<workgroup> shared_data: array<f32, 256>;
185
186@compute @workgroup_size(256)
187fn main(
188 @builtin(local_invocation_id) lid: vec3<u32>,
189) {{
190 let tid = lid.x;
191
192 if (tid < params.num_groups) {{
193 shared_data[tid] = partial_sums[tid];
194 }} else {{
195 shared_data[tid] = {neutral};
196 }}
197 workgroupBarrier();
198
199 var stride: u32 = 128u;
200 loop {{
201 if (stride == 0u) {{ break; }}
202 if (tid < stride) {{
203 let acc = shared_data[tid];
204 let val = shared_data[tid + stride];
205 shared_data[tid] = {combine};
206 }}
207 workgroupBarrier();
208 stride = stride >> 1u;
209 }}
210
211 if (tid == 0u) {{
212 output[0] = shared_data[0];
213 }}
214}}
215"#,
216 neutral = neutral,
217 combine = combine,
218 )
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn wgsl_gemm_contains_workgroup() {
227 let src = gemm_wgsl(16);
228 assert!(src.contains("@compute @workgroup_size(16, 16)"));
229 assert!(src.contains("GemmParams"));
230 assert!(src.contains("alpha"));
231 assert!(src.contains("beta"));
232 }
233
234 #[test]
235 fn wgsl_gemm_tile_size_embedded() {
236 let src8 = gemm_wgsl(8);
237 assert!(src8.contains("@workgroup_size(8, 8)"));
238 let src32 = gemm_wgsl(32);
239 assert!(src32.contains("@workgroup_size(32, 32)"));
240 }
241
242 #[test]
243 fn wgsl_elementwise_relu_contains_max() {
244 let src = elementwise_wgsl("relu");
245 assert!(src.contains("max(x, 0.0)"));
246 }
247
248 #[test]
249 fn wgsl_elementwise_all_ops() {
250 assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
251 assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
252 assert!(elementwise_wgsl("exp").contains("exp(x)"));
253 assert!(elementwise_wgsl("log").contains("log(x)"));
254 assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
255 assert!(elementwise_wgsl("abs").contains("abs(x)"));
256 assert!(elementwise_wgsl("neg").contains("-x"));
257 assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
259 }
260
261 #[test]
262 fn wgsl_reduction_sum_contains_addition() {
263 let src = reduction_wgsl("sum");
264 assert!(src.contains("acc + val"));
265 assert!(src.contains("workgroupBarrier"));
266 }
267
268 #[test]
269 fn wgsl_reduction_max_uses_max_fn() {
270 let src = reduction_wgsl("max");
271 assert!(src.contains("max(acc, val)"));
272 }
273
274 #[test]
275 fn wgsl_reduction_min_uses_min_fn() {
276 let src = reduction_wgsl("min");
277 assert!(src.contains("min(acc, val)"));
278 }
279
280 #[test]
281 fn wgsl_reduction_mean_same_as_sum() {
282 let sum_src = reduction_wgsl("sum");
284 let mean_src = reduction_wgsl("mean");
285 assert_eq!(sum_src, mean_src);
286 }
287
288 #[test]
289 fn wgsl_reduction_final_sum() {
290 let src = reduction_final_wgsl("sum");
291 assert!(src.contains("num_groups"));
292 assert!(src.contains("output[0]"));
293 }
294}