1pub mod matmul {
17 pub const TILED_16X16: &str = "\
23struct Dims { M: u32, N: u32, K: u32 }
24var<push_constant> dims: Dims;
25
26@group(0) @binding(0) var<storage, read> A: array<f32>;
27@group(0) @binding(1) var<storage, read> B: array<f32>;
28@group(0) @binding(2) var<storage, read_write> C: array<f32>;
29
30var<workgroup> tile_a: array<f32, 256>;
31var<workgroup> tile_b: array<f32, 256>;
32
33@compute @workgroup_size(16, 16)
34fn main(
35 @builtin(local_invocation_id) lid: vec3<u32>,
36 @builtin(workgroup_id) wid: vec3<u32>,
37) {
38 let row = wid.y * 16u + lid.y;
39 let col = wid.x * 16u + lid.x;
40 let lr = lid.y;
41 let lc = lid.x;
42
43 var sum = 0.0;
44 let num_tiles = (dims.K + 15u) / 16u;
45
46 for (var t = 0u; t < num_tiles; t++) {
47 let a_col = t * 16u + lc;
48 if row < dims.M && a_col < dims.K {
49 tile_a[lr * 16u + lc] = A[row * dims.K + a_col];
50 } else {
51 tile_a[lr * 16u + lc] = 0.0;
52 }
53
54 let b_row = t * 16u + lr;
55 if b_row < dims.K && col < dims.N {
56 tile_b[lr * 16u + lc] = B[b_row * dims.N + col];
57 } else {
58 tile_b[lr * 16u + lc] = 0.0;
59 }
60
61 workgroupBarrier();
62
63 for (var k = 0u; k < 16u; k++) {
64 sum += tile_a[lr * 16u + k] * tile_b[k * 16u + lc];
65 }
66
67 workgroupBarrier();
68 }
69
70 if row < dims.M && col < dims.N {
71 C[row * dims.N + col] = sum;
72 }
73}";
74
75 #[cfg(feature = "cuda")]
80 pub const TILED_16X16_CUDA: &str = "\
81extern \"C\" __global__ void matmul_tiled_16x16(
82 const float* A, const float* B, float* C,
83 unsigned int M, unsigned int N, unsigned int K
84) {
85 __shared__ float tile_a[256];
86 __shared__ float tile_b[256];
87
88 unsigned int row = blockIdx.y * 16 + threadIdx.y;
89 unsigned int col = blockIdx.x * 16 + threadIdx.x;
90 unsigned int lr = threadIdx.y;
91 unsigned int lc = threadIdx.x;
92
93 float sum = 0.0f;
94 unsigned int num_tiles = (K + 15) / 16;
95
96 for (unsigned int t = 0; t < num_tiles; t++) {
97 unsigned int a_col = t * 16 + lc;
98 tile_a[lr * 16 + lc] = (row < M && a_col < K) ? A[row * K + a_col] : 0.0f;
99
100 unsigned int b_row = t * 16 + lr;
101 tile_b[lr * 16 + lc] = (b_row < K && col < N) ? B[b_row * N + col] : 0.0f;
102
103 __syncthreads();
104
105 for (unsigned int k = 0; k < 16; k++) {
106 sum += tile_a[lr * 16 + k] * tile_b[k * 16 + lc];
107 }
108
109 __syncthreads();
110 }
111
112 if (row < M && col < N) {
113 C[row * N + col] = sum;
114 }
115}";
116
117 pub const COARSE_64X64: &str = "\
126struct Dims { M: u32, N: u32, K: u32 }
127var<push_constant> dims: Dims;
128
129@group(0) @binding(0) var<storage, read> A: array<f32>;
130@group(0) @binding(1) var<storage, read> B: array<f32>;
131@group(0) @binding(2) var<storage, read_write> C: array<f32>;
132
133var<workgroup> sa: array<f32, 1088>;
134var<workgroup> sb: array<f32, 1024>;
135
136@compute @workgroup_size(16, 16)
137fn main(
138 @builtin(local_invocation_id) lid: vec3<u32>,
139 @builtin(local_invocation_index) li: u32,
140 @builtin(workgroup_id) wid: vec3<u32>,
141) {
142 let block_row = wid.y * 64u;
143 let block_col = wid.x * 64u;
144 let tr = lid.y * 4u;
145 let tc = lid.x * 4u;
146
147 var acc: array<f32, 16>;
148 for (var i = 0u; i < 16u; i++) { acc[i] = 0.0; }
149
150 let num_k_tiles = (dims.K + 15u) / 16u;
151
152 for (var kt = 0u; kt < num_k_tiles; kt++) {
153 // Load A tile [64x16] into padded layout (stride 17)
154 for (var x = 0u; x < 4u; x++) {
155 let flat = li * 4u + x;
156 let r = flat / 16u;
157 let c = flat % 16u;
158 let gr = block_row + r;
159 let gc = kt * 16u + c;
160 if gr < dims.M && gc < dims.K {
161 sa[r * 17u + c] = A[gr * dims.K + gc];
162 } else {
163 sa[r * 17u + c] = 0.0;
164 }
165 }
166
167 // Load B tile [16x64]
168 for (var x = 0u; x < 4u; x++) {
169 let flat = li * 4u + x;
170 let r = flat / 64u;
171 let c = flat % 64u;
172 let gr = kt * 16u + r;
173 let gc = block_col + c;
174 if gr < dims.K && gc < dims.N {
175 sb[flat] = B[gr * dims.N + gc];
176 } else {
177 sb[flat] = 0.0;
178 }
179 }
180
181 workgroupBarrier();
182
183 for (var k = 0u; k < 16u; k++) {
184 for (var i = 0u; i < 4u; i++) {
185 let a_val = sa[(tr + i) * 17u + k];
186 for (var j = 0u; j < 4u; j++) {
187 acc[i * 4u + j] += a_val * sb[k * 64u + tc + j];
188 }
189 }
190 }
191
192 workgroupBarrier();
193 }
194
195 for (var i = 0u; i < 4u; i++) {
196 for (var j = 0u; j < 4u; j++) {
197 let gr = block_row + tr + i;
198 let gc = block_col + tc + j;
199 if gr < dims.M && gc < dims.N {
200 C[gr * dims.N + gc] = acc[i * 4u + j];
201 }
202 }
203 }
204}";
205 pub const COARSE_8X8: &str = "\
217struct Dims { M: u32, N: u32, K: u32 }
218var<push_constant> dims: Dims;
219
220@group(0) @binding(0) var<storage, read> A: array<f32>;
221@group(0) @binding(1) var<storage, read> B: array<f32>;
222@group(0) @binding(2) var<storage, read_write> C: array<f32>;
223
224var<workgroup> sa: array<f32, 2176>;
225var<workgroup> sb: array<f32, 2048>;
226
227fn store_row(gr: u32, gc: u32, lo: vec4<f32>, hi: vec4<f32>) {
228 if gr >= dims.M { return; }
229 let base = gr * dims.N + gc;
230 if gc < dims.N { C[base] = lo.x; }
231 if gc + 1u < dims.N { C[base + 1u] = lo.y; }
232 if gc + 2u < dims.N { C[base + 2u] = lo.z; }
233 if gc + 3u < dims.N { C[base + 3u] = lo.w; }
234 if gc + 4u < dims.N { C[base + 4u] = hi.x; }
235 if gc + 5u < dims.N { C[base + 5u] = hi.y; }
236 if gc + 6u < dims.N { C[base + 6u] = hi.z; }
237 if gc + 7u < dims.N { C[base + 7u] = hi.w; }
238}
239
240@compute @workgroup_size(16, 16)
241fn main(
242 @builtin(local_invocation_id) lid: vec3<u32>,
243 @builtin(local_invocation_index) li: u32,
244 @builtin(workgroup_id) wid: vec3<u32>,
245) {
246 let block_row = wid.y * 128u;
247 let block_col = wid.x * 128u;
248 let tr = lid.y * 8u;
249 let tc = lid.x * 8u;
250
251 // 16 named vec4 accumulators — avoids array-based register spill.
252 var r0l = vec4<f32>(0.0); var r0h = vec4<f32>(0.0);
253 var r1l = vec4<f32>(0.0); var r1h = vec4<f32>(0.0);
254 var r2l = vec4<f32>(0.0); var r2h = vec4<f32>(0.0);
255 var r3l = vec4<f32>(0.0); var r3h = vec4<f32>(0.0);
256 var r4l = vec4<f32>(0.0); var r4h = vec4<f32>(0.0);
257 var r5l = vec4<f32>(0.0); var r5h = vec4<f32>(0.0);
258 var r6l = vec4<f32>(0.0); var r6h = vec4<f32>(0.0);
259 var r7l = vec4<f32>(0.0); var r7h = vec4<f32>(0.0);
260
261 let num_k_tiles = (dims.K + 15u) / 16u;
262
263 for (var kt = 0u; kt < num_k_tiles; kt++) {
264 // Load A tile [128x16] — 2048 elements, 8 per thread, padded stride 17
265 for (var x = 0u; x < 8u; x++) {
266 let flat = li * 8u + x;
267 let r = flat / 16u;
268 let c = flat % 16u;
269 let gr = block_row + r;
270 let gc = kt * 16u + c;
271 if gr < dims.M && gc < dims.K {
272 sa[r * 17u + c] = A[gr * dims.K + gc];
273 } else {
274 sa[r * 17u + c] = 0.0;
275 }
276 }
277
278 // Load B tile [16x128] — 2048 elements, 8 per thread
279 for (var x = 0u; x < 8u; x++) {
280 let flat = li * 8u + x;
281 let r = flat / 128u;
282 let c = flat % 128u;
283 let gr = kt * 16u + r;
284 let gc = block_col + c;
285 if gr < dims.K && gc < dims.N {
286 sb[flat] = B[gr * dims.N + gc];
287 } else {
288 sb[flat] = 0.0;
289 }
290 }
291
292 workgroupBarrier();
293
294 // Inner loop: 8 a-scalar loads + 2 vec4 b-loads + 16 vec4 FMAs per k
295 for (var k = 0u; k < 16u; k++) {
296 let bk = k * 128u + tc;
297 let bl = vec4<f32>(sb[bk], sb[bk+1u], sb[bk+2u], sb[bk+3u]);
298 let bh = vec4<f32>(sb[bk+4u], sb[bk+5u], sb[bk+6u], sb[bk+7u]);
299
300 let a0 = sa[(tr ) * 17u + k]; r0l += a0 * bl; r0h += a0 * bh;
301 let a1 = sa[(tr+1u) * 17u + k]; r1l += a1 * bl; r1h += a1 * bh;
302 let a2 = sa[(tr+2u) * 17u + k]; r2l += a2 * bl; r2h += a2 * bh;
303 let a3 = sa[(tr+3u) * 17u + k]; r3l += a3 * bl; r3h += a3 * bh;
304 let a4 = sa[(tr+4u) * 17u + k]; r4l += a4 * bl; r4h += a4 * bh;
305 let a5 = sa[(tr+5u) * 17u + k]; r5l += a5 * bl; r5h += a5 * bh;
306 let a6 = sa[(tr+6u) * 17u + k]; r6l += a6 * bl; r6h += a6 * bh;
307 let a7 = sa[(tr+7u) * 17u + k]; r7l += a7 * bl; r7h += a7 * bh;
308 }
309
310 workgroupBarrier();
311 }
312
313 let gc = block_col + tc;
314 store_row(block_row + tr, gc, r0l, r0h);
315 store_row(block_row + tr + 1u, gc, r1l, r1h);
316 store_row(block_row + tr + 2u, gc, r2l, r2h);
317 store_row(block_row + tr + 3u, gc, r3l, r3h);
318 store_row(block_row + tr + 4u, gc, r4l, r4h);
319 store_row(block_row + tr + 5u, gc, r5l, r5h);
320 store_row(block_row + tr + 6u, gc, r6l, r6h);
321 store_row(block_row + tr + 7u, gc, r7l, r7h);
322}";
323}
324
325pub mod elementwise {
330 pub const BIAS_ADD: &str = "\
339struct Dims { N: u32, cols: u32 }
340var<push_constant> dims: Dims;
341
342@group(0) @binding(0) var<storage, read> z: array<f32>;
343@group(0) @binding(1) var<storage, read> bias: array<f32>;
344@group(0) @binding(2) var<storage, read_write> out: array<f32>;
345
346@compute @workgroup_size(256)
347fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
348 let i = gid.x;
349 if i >= dims.N { return; }
350 out[i] = z[i] + bias[i % dims.cols];
351}";
352
353 #[cfg(feature = "cuda")]
355 pub const BIAS_ADD_CUDA: &str = "\
356extern \"C\" __global__ void bias_add(
357 const float* z, const float* bias, float* out,
358 unsigned int N, unsigned int cols
359) {
360 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
361 if (i >= N) return;
362 out[i] = z[i] + bias[i % cols];
363}";
364
365 pub const RELU: &str = "\
373struct Dims { N: u32 }
374var<push_constant> dims: Dims;
375
376@group(0) @binding(0) var<storage, read> input: array<f32>;
377@group(0) @binding(1) var<storage, read_write> out: array<f32>;
378
379@compute @workgroup_size(256)
380fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
381 let i = gid.x;
382 if i >= dims.N { return; }
383 out[i] = max(0.0, input[i]);
384}";
385
386 #[cfg(feature = "cuda")]
388 pub const RELU_CUDA: &str = "\
389extern \"C\" __global__ void relu(
390 const float* input, float* out,
391 unsigned int N
392) {
393 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
394 if (i >= N) return;
395 out[i] = fmaxf(0.0f, input[i]);
396}";
397
398 pub const TANH: &str = "\
406struct Dims { N: u32 }
407var<push_constant> dims: Dims;
408
409@group(0) @binding(0) var<storage, read> input: array<f32>;
410@group(0) @binding(1) var<storage, read_write> out: array<f32>;
411
412@compute @workgroup_size(256)
413fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
414 let i = gid.x;
415 if i >= dims.N { return; }
416 out[i] = tanh(input[i]);
417}";
418
419 #[cfg(feature = "cuda")]
421 pub const TANH_CUDA: &str = "\
422extern \"C\" __global__ void tanh_fwd(
423 const float* input, float* out,
424 unsigned int N
425) {
426 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
427 if (i >= N) return;
428 out[i] = tanhf(input[i]);
429}";
430
431 pub const SIGMOID: &str = "\
439struct Dims { N: u32 }
440var<push_constant> dims: Dims;
441
442@group(0) @binding(0) var<storage, read> input: array<f32>;
443@group(0) @binding(1) var<storage, read_write> out: array<f32>;
444
445@compute @workgroup_size(256)
446fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
447 let i = gid.x;
448 if i >= dims.N { return; }
449 out[i] = 1.0 / (1.0 + exp(-input[i]));
450}";
451
452 #[cfg(feature = "cuda")]
454 pub const SIGMOID_CUDA: &str = "\
455extern \"C\" __global__ void sigmoid(
456 const float* input, float* out,
457 unsigned int N
458) {
459 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
460 if (i >= N) return;
461 out[i] = 1.0f / (1.0f + expf(-input[i]));
462}";
463}
464
465pub mod backward {
470 pub const RELU_BACKWARD: &str = "\
481struct Dims { N: u32 }
482var<push_constant> dims: Dims;
483
484@group(0) @binding(0) var<storage, read> grad: array<f32>;
485@group(0) @binding(1) var<storage, read> z: array<f32>;
486@group(0) @binding(2) var<storage, read_write> out: array<f32>;
487
488@compute @workgroup_size(256)
489fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
490 let i = gid.x;
491 if i >= dims.N { return; }
492 out[i] = select(0.0, grad[i], z[i] > 0.0);
493}";
494
495 #[cfg(feature = "cuda")]
497 pub const RELU_BACKWARD_CUDA: &str = "\
498extern \"C\" __global__ void relu_backward(
499 const float* grad, const float* z, float* out,
500 unsigned int N
501) {
502 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
503 if (i >= N) return;
504 out[i] = z[i] > 0.0f ? grad[i] : 0.0f;
505}";
506
507 pub const SIGMOID_BACKWARD: &str = "\
518struct Dims { N: u32 }
519var<push_constant> dims: Dims;
520
521@group(0) @binding(0) var<storage, read> grad: array<f32>;
522@group(0) @binding(1) var<storage, read> activated: array<f32>;
523@group(0) @binding(2) var<storage, read_write> out: array<f32>;
524
525@compute @workgroup_size(256)
526fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
527 let i = gid.x;
528 if i >= dims.N { return; }
529 let a = activated[i];
530 out[i] = grad[i] * a * (1.0 - a);
531}";
532
533 #[cfg(feature = "cuda")]
535 pub const SIGMOID_BACKWARD_CUDA: &str = "\
536extern \"C\" __global__ void sigmoid_backward(
537 const float* grad, const float* activated, float* out,
538 unsigned int N
539) {
540 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
541 if (i >= N) return;
542 float a = activated[i];
543 out[i] = grad[i] * a * (1.0f - a);
544}";
545
546 pub const TANH_BACKWARD: &str = "\
557struct Dims { N: u32 }
558var<push_constant> dims: Dims;
559
560@group(0) @binding(0) var<storage, read> grad: array<f32>;
561@group(0) @binding(1) var<storage, read> activated: array<f32>;
562@group(0) @binding(2) var<storage, read_write> out: array<f32>;
563
564@compute @workgroup_size(256)
565fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
566 let i = gid.x;
567 if i >= dims.N { return; }
568 let a = activated[i];
569 out[i] = grad[i] * (1.0 - a * a);
570}";
571
572 #[cfg(feature = "cuda")]
574 pub const TANH_BACKWARD_CUDA: &str = "\
575extern \"C\" __global__ void tanh_backward(
576 const float* grad, const float* activated, float* out,
577 unsigned int N
578) {
579 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
580 if (i >= N) return;
581 float a = activated[i];
582 out[i] = grad[i] * (1.0f - a * a);
583}";
584
585 pub const TRANSPOSE: &str = "\
596struct Dims { rows: u32, cols: u32 }
597var<push_constant> dims: Dims;
598
599@group(0) @binding(0) var<storage, read> input: array<f32>;
600@group(0) @binding(1) var<storage, read_write> out: array<f32>;
601
602@compute @workgroup_size(256)
603fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
604 let i = gid.x;
605 let n = dims.rows * dims.cols;
606 if i >= n { return; }
607 let row = i / dims.cols;
608 let col = i % dims.cols;
609 out[col * dims.rows + row] = input[i];
610}";
611
612 #[cfg(feature = "cuda")]
614 pub const TRANSPOSE_CUDA: &str = "\
615extern \"C\" __global__ void transpose_2d(
616 const float* input, float* out,
617 unsigned int rows, unsigned int cols
618) {
619 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
620 if (i >= rows * cols) return;
621 unsigned int row = i / cols;
622 unsigned int col = i % cols;
623 out[col * rows + row] = input[i];
624}";
625
626 pub const SCALE: &str = "\
634struct Dims { N: u32, alpha: f32 }
635var<push_constant> dims: Dims;
636
637@group(0) @binding(0) var<storage, read> input: array<f32>;
638@group(0) @binding(1) var<storage, read_write> out: array<f32>;
639
640@compute @workgroup_size(256)
641fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
642 let i = gid.x;
643 if i >= dims.N { return; }
644 out[i] = input[i] * dims.alpha;
645}";
646
647 #[cfg(feature = "cuda")]
649 pub const SCALE_CUDA: &str = "\
650extern \"C\" __global__ void scale_fwd(
651 const float* input, float* out,
652 unsigned int N, float alpha
653) {
654 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
655 if (i >= N) return;
656 out[i] = input[i] * alpha;
657}";
658
659 pub const REDUCE_COLS: &str = "\
670struct Dims { rows: u32, cols: u32, scale: f32 }
671var<push_constant> dims: Dims;
672
673@group(0) @binding(0) var<storage, read> input: array<f32>;
674@group(0) @binding(1) var<storage, read_write> out: array<f32>;
675
676@compute @workgroup_size(256)
677fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
678 let j = gid.x;
679 if j >= dims.cols { return; }
680 var sum = 0.0;
681 for (var i = 0u; i < dims.rows; i++) {
682 sum += input[i * dims.cols + j];
683 }
684 out[j] = sum * dims.scale;
685}";
686
687 #[cfg(feature = "cuda")]
689 pub const REDUCE_COLS_CUDA: &str = "\
690extern \"C\" __global__ void reduce_cols(
691 const float* input, float* out,
692 unsigned int rows, unsigned int cols, float scale
693) {
694 unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
695 if (j >= cols) return;
696 float sum = 0.0f;
697 for (unsigned int i = 0; i < rows; i++) {
698 sum += input[i * cols + j];
699 }
700 out[j] = sum * scale;
701}";
702}
703
704pub mod distance {
706 pub const PAIRWISE_EUCLIDEAN: &str = "\
721struct Dims { n_q: u32, n_t: u32, dim: u32 }
722var<push_constant> dims: Dims;
723
724@group(0) @binding(0) var<storage, read> queries: array<f32>;
725@group(0) @binding(1) var<storage, read> train: array<f32>;
726@group(0) @binding(2) var<storage, read_write> dists: array<f32>;
727
728@compute @workgroup_size(256)
729fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
730 let idx = gid.x;
731 let total = dims.n_q * dims.n_t;
732 if (idx >= total) {
733 return;
734 }
735
736 let i = idx / dims.n_t;
737 let j = idx % dims.n_t;
738
739 var sum: f32 = 0.0;
740 let q_base = i * dims.dim;
741 let t_base = j * dims.dim;
742
743 for (var d: u32 = 0u; d < dims.dim; d = d + 1u) {
744 let diff = queries[q_base + d] - train[t_base + d];
745 sum = sum + diff * diff;
746 }
747
748 dists[idx] = sum;
749}";
750
751 #[cfg(feature = "cuda")]
756 pub const PAIRWISE_EUCLIDEAN_CUDA: &str = "\
757extern \"C\" __global__ void pairwise_euclidean(
758 const float* queries, const float* train, float* dists,
759 unsigned int n_q, unsigned int n_t, unsigned int dim
760) {
761 unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
762 unsigned int total = n_q * n_t;
763 if (idx >= total) return;
764
765 unsigned int i = idx / n_t;
766 unsigned int j = idx % n_t;
767
768 float sum = 0.0f;
769 unsigned int q_base = i * dim;
770 unsigned int t_base = j * dim;
771
772 for (unsigned int d = 0; d < dim; d++) {
773 float diff = queries[q_base + d] - train[t_base + d];
774 sum += diff * diff;
775 }
776
777 dists[idx] = sum;
778}";
779}