1
2
3pub mod kernel {
17 use cjc_repro::{kahan_sum_f64, KahanAccumulatorF64};
18
19 #[inline]
26 pub fn matmul_raw(
27 a: &[f64], b: &[f64], c: &mut [f64],
28 m: usize, k: usize, n: usize,
29 ) {
30 debug_assert_eq!(a.len(), m * k);
31 debug_assert_eq!(b.len(), k * n);
32 debug_assert_eq!(c.len(), m * n);
33 for i in 0..m {
34 for j in 0..n {
35 let mut acc = KahanAccumulatorF64::new();
36 for p in 0..k {
37 acc.add(a[i * k + p] * b[p * n + j]);
38 }
39 c[i * n + j] = acc.finalize();
40 }
41 }
42 }
43
44 #[inline]
49 pub fn softmax_raw(data: &[f64], out: &mut [f64], outer: usize, n: usize) {
50 debug_assert_eq!(data.len(), outer * n);
51 debug_assert_eq!(out.len(), outer * n);
52 for row in 0..outer {
53 let start = row * n;
54 let slice = &data[start..start + n];
55
56 let mut max_val = f64::NEG_INFINITY;
58 for &v in slice {
59 if v > max_val { max_val = v; }
60 }
61
62 let mut sum = 0.0f64;
64 let mut comp = 0.0f64;
65 for i in 0..n {
66 let e = (slice[i] - max_val).exp();
67 out[start + i] = e;
68 let y = e - comp;
69 let t = sum + y;
70 comp = (t - sum) - y;
71 sum = t;
72 }
73
74 if sum == 0.0 {
76 let uniform = 1.0 / n as f64;
77 for i in 0..n {
78 out[start + i] = uniform;
79 }
80 } else {
81 for i in 0..n {
82 out[start + i] /= sum;
83 }
84 }
85 }
86 }
87
88 #[inline]
90 pub fn linear_raw(
91 x: &[f64], w: &[f64], bias: &[f64], out: &mut [f64],
92 outer: usize, in_f: usize, out_f: usize,
93 ) {
94 debug_assert_eq!(x.len(), outer * in_f);
95 debug_assert_eq!(w.len(), out_f * in_f);
96 debug_assert_eq!(bias.len(), out_f);
97 debug_assert_eq!(out.len(), outer * out_f);
98 for row in 0..outer {
99 let x_start = row * in_f;
100 let x_slice = &x[x_start..x_start + in_f];
101 let y_start = row * out_f;
102 for j in 0..out_f {
103 let w_start = j * in_f;
104 let mut acc = KahanAccumulatorF64::new();
105 for p in 0..in_f {
106 acc.add(x_slice[p] * w[w_start + p]);
107 }
108 out[y_start + j] = acc.finalize() + bias[j];
109 }
110 }
111 }
112
113 #[inline]
118 pub fn layer_norm_raw(
119 data: &[f64], gamma: &[f64], beta: &[f64], out: &mut [f64],
120 outer: usize, n: usize, eps: f64,
121 ) {
122 debug_assert_eq!(data.len(), outer * n);
123 debug_assert_eq!(gamma.len(), n);
124 debug_assert_eq!(beta.len(), n);
125 debug_assert_eq!(out.len(), outer * n);
126 for row in 0..outer {
127 let start = row * n;
128 let slice = &data[start..start + n];
129
130 let mean = kahan_sum_f64(slice) / n as f64;
132
133 let diffs: Vec<f64> = slice.iter().map(|&x| (x - mean) * (x - mean)).collect();
135 let var = kahan_sum_f64(&diffs) / n as f64;
136 let inv_std = 1.0 / (var + eps).sqrt();
137
138 for i in 0..n {
139 out[start + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
140 }
141 }
142 }
143
144 #[inline]
146 pub fn relu_raw(data: &[f64], out: &mut [f64]) {
147 debug_assert_eq!(data.len(), out.len());
148 for (o, &x) in out.iter_mut().zip(data.iter()) {
149 *o = if x > 0.0 { x } else { 0.0 };
150 }
151 }
152
153 #[inline]
155 pub fn gelu_raw(data: &[f64], out: &mut [f64]) {
156 debug_assert_eq!(data.len(), out.len());
157 let sqrt_2_over_pi: f64 = (2.0 / std::f64::consts::PI).sqrt();
158 for (o, &x) in out.iter_mut().zip(data.iter()) {
159 let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
160 *o = 0.5 * x * (1.0 + inner.tanh());
161 }
162 }
163
164 pub fn conv1d_raw(
172 signal: &[f64], filters: &[f64], bias: &[f64], out: &mut [f64],
173 signal_len: usize, out_channels: usize, kernel_size: usize,
174 ) {
175 debug_assert!(signal_len >= kernel_size);
176 let out_len = signal_len - kernel_size + 1;
177 debug_assert_eq!(signal.len(), signal_len);
178 debug_assert_eq!(filters.len(), out_channels * kernel_size);
179 debug_assert_eq!(bias.len(), out_channels);
180 debug_assert_eq!(out.len(), out_channels * out_len);
181
182 for ch in 0..out_channels {
183 let filter_start = ch * kernel_size;
184 let filter_slice = &filters[filter_start..filter_start + kernel_size];
185 let out_row_start = ch * out_len;
186 for pos in 0..out_len {
187 let products: Vec<f64> = (0..kernel_size)
188 .map(|k| signal[pos + k] * filter_slice[k])
189 .collect();
190 out[out_row_start + pos] = kahan_sum_f64(&products) + bias[ch];
191 }
192 }
193 }
194
195 pub fn conv1d_circular(
201 buffer: &[f64], write_pos: usize, window_size: usize,
202 window: &mut [f64],
203 filters: &[f64], bias: &[f64], out: &mut [f64],
204 out_channels: usize, kernel_size: usize,
205 ) {
206 let buf_len = buffer.len();
207 debug_assert!(window_size <= buf_len);
208 debug_assert_eq!(window.len(), window_size);
209
210 let start = if write_pos >= window_size {
211 write_pos - window_size
212 } else {
213 buf_len - (window_size - write_pos)
214 };
215 for i in 0..window_size {
216 window[i] = buffer[(start + i) % buf_len];
217 }
218
219 conv1d_raw(window, filters, bias, out, window_size, out_channels, kernel_size);
220 }
221
222 #[allow(clippy::too_many_arguments)]
243 pub fn conv2d_raw(
244 input: &[f64],
245 filters: &[f64],
246 bias: &[f64],
247 out: &mut [f64],
248 n: usize, c_in: usize, h_in: usize, w_in: usize,
249 c_out: usize, kh: usize, kw: usize,
250 stride: usize,
251 ) {
252 use crate::accumulator::BinnedAccumulatorF64;
253
254 let h_out: u64 = ((h_in - kh) / stride + 1) as u64;
255 let w_out: u64 = ((w_in - kw) / stride + 1) as u64;
256
257 let s_n: u64 = (c_in * h_in * w_in) as u64;
259 let s_cin: u64 = (h_in * w_in) as u64;
260 let s_hin: u64 = w_in as u64;
261
262 let f_cout: u64 = (c_in * kh * kw) as u64;
264 let f_cin: u64 = (kh * kw) as u64;
265 let f_kh: u64 = kw as u64;
266
267 let o_n: u64 = c_out as u64 * h_out * w_out;
269 let o_cout: u64 = h_out * w_out;
270
271 debug_assert_eq!(input.len(), n * c_in * h_in * w_in);
272 debug_assert_eq!(filters.len(), c_out * c_in * kh * kw);
273 debug_assert_eq!(bias.len(), c_out);
274 debug_assert_eq!(out.len(), n * c_out * h_out as usize * w_out as usize);
275
276 for bn in 0..n as u64 {
277 for co in 0..c_out as u64 {
278 for oh in 0..h_out {
279 for ow in 0..w_out {
280 let mut acc = BinnedAccumulatorF64::new();
281
282 for ci in 0..c_in as u64 {
284 for ki in 0..kh as u64 {
285 for kj in 0..kw as u64 {
286 let ih: u64 = oh * stride as u64 + ki;
287 let iw: u64 = ow * stride as u64 + kj;
288
289 let inp_idx = (bn * s_n
290 + ci * s_cin
291 + ih * s_hin
292 + iw) as usize;
293 let flt_idx = (co * f_cout
294 + ci * f_cin
295 + ki * f_kh
296 + kj) as usize;
297
298 acc.add(input[inp_idx] * filters[flt_idx]);
299 }
300 }
301 }
302
303 let out_idx = (bn * o_n
304 + co * o_cout
305 + oh * w_out
306 + ow) as usize;
307 out[out_idx] = acc.finalize() + bias[co as usize];
308 }
309 }
310 }
311 }
312 }
313
314 #[allow(clippy::too_many_arguments)]
320 pub fn conv2d_dispatched(
321 input: &[f64],
322 filters: &[f64],
323 bias: &[f64],
324 out: &mut [f64],
325 n: usize, c_in: usize, h_in: usize, w_in: usize,
326 c_out: usize, kh: usize, kw: usize,
327 stride: usize,
328 ctx: &crate::dispatch::ReductionContext,
329 ) {
330 let h_out = (h_in - kh) / stride + 1;
331 let w_out = (w_in - kw) / stride + 1;
332
333 let s_n = c_in * h_in * w_in;
334 let s_cin = h_in * w_in;
335 let s_hin = w_in;
336
337 let f_cout = c_in * kh * kw;
338 let f_cin = kh * kw;
339 let f_kh = kw;
340
341 let o_n = c_out * h_out * w_out;
342 let o_cout = h_out * w_out;
343
344 for bn in 0..n {
345 for co in 0..c_out {
346 for oh in 0..h_out {
347 for ow in 0..w_out {
348 let mut terms = Vec::with_capacity(c_in * kh * kw);
349 for ci in 0..c_in {
350 for ki in 0..kh {
351 for kj in 0..kw {
352 let ih = oh * stride + ki;
353 let iw = ow * stride + kj;
354 let inp_idx = bn * s_n + ci * s_cin + ih * s_hin + iw;
355 let flt_idx = co * f_cout + ci * f_cin + ki * f_kh + kj;
356 terms.push(input[inp_idx] * filters[flt_idx]);
357 }
358 }
359 }
360 let out_idx = bn * o_n + co * o_cout + oh * w_out + ow;
361 out[out_idx] =
362 crate::dispatch::dispatch_sum_f64(&terms, ctx) + bias[co];
363 }
364 }
365 }
366 }
367 }
368
369 pub fn maxpool2d_raw(
376 input: &[f64],
377 out: &mut [f64],
378 n: usize, c: usize, h_in: usize, w_in: usize,
379 ph: usize, pw: usize,
380 ) {
381 let h_out: u64 = (h_in / ph) as u64;
382 let w_out: u64 = (w_in / pw) as u64;
383
384 let s_n: u64 = (c * h_in * w_in) as u64;
385 let s_c: u64 = (h_in * w_in) as u64;
386 let s_hin: u64 = w_in as u64;
387
388 let o_n: u64 = (c as u64) * h_out * w_out;
389 let o_c: u64 = h_out * w_out;
390
391 debug_assert_eq!(input.len(), n * c * h_in * w_in);
392 debug_assert_eq!(out.len(), n * c * h_out as usize * w_out as usize);
393
394 for bn in 0..n as u64 {
395 for ch in 0..c as u64 {
396 for oh in 0..h_out {
397 for ow in 0..w_out {
398 let mut max_val = f64::NEG_INFINITY;
399 for pi in 0..ph as u64 {
400 for pj in 0..pw as u64 {
401 let ih: u64 = oh * ph as u64 + pi;
402 let iw: u64 = ow * pw as u64 + pj;
403 let idx = (bn * s_n + ch * s_c + ih * s_hin + iw) as usize;
404 let v = input[idx];
405 if v > max_val { max_val = v; }
406 }
407 }
408 let o_idx = (bn * o_n + ch * o_c + oh * w_out + ow) as usize;
409 out[o_idx] = max_val;
410 }
411 }
412 }
413 }
414 }
415
416 pub fn maxpool1d_raw(data: &[f64], out: &mut [f64], data_len: usize, pool_size: usize) {
418 debug_assert_eq!(data.len(), data_len);
419 let out_len = data_len / pool_size;
420 debug_assert_eq!(out.len(), out_len);
421 for i in 0..out_len {
422 let start = i * pool_size;
423 let mut max_val = data[start];
424 for j in 1..pool_size {
425 let v = data[start + j];
426 if v > max_val { max_val = v; }
427 }
428 out[i] = max_val;
429 }
430 }
431
432 #[inline]
439 pub fn matmul_dispatched(
440 a: &[f64], b: &[f64], c: &mut [f64],
441 m: usize, k: usize, n: usize,
442 ctx: &crate::dispatch::ReductionContext,
443 ) {
444 debug_assert_eq!(a.len(), m * k);
445 debug_assert_eq!(b.len(), k * n);
446 debug_assert_eq!(c.len(), m * n);
447 for i in 0..m {
448 for j in 0..n {
449 let a_row = &a[i * k..(i + 1) * k];
451 let b_col: Vec<f64> = (0..k).map(|p| b[p * n + j]).collect();
452 c[i * n + j] = crate::dispatch::dispatch_dot_f64(a_row, &b_col, ctx);
453 }
454 }
455 }
456
457 #[inline]
459 pub fn linear_dispatched(
460 x: &[f64], w: &[f64], bias: &[f64], out: &mut [f64],
461 outer: usize, in_f: usize, out_f: usize,
462 ctx: &crate::dispatch::ReductionContext,
463 ) {
464 debug_assert_eq!(x.len(), outer * in_f);
465 debug_assert_eq!(w.len(), out_f * in_f);
466 debug_assert_eq!(bias.len(), out_f);
467 debug_assert_eq!(out.len(), outer * out_f);
468 for row in 0..outer {
469 let x_start = row * in_f;
470 let x_slice = &x[x_start..x_start + in_f];
471 let y_start = row * out_f;
472 for j in 0..out_f {
473 let w_start = j * in_f;
474 let w_slice = &w[w_start..w_start + in_f];
475 out[y_start + j] = crate::dispatch::dispatch_dot_f64(x_slice, w_slice, ctx) + bias[j];
476 }
477 }
478 }
479
480 #[inline]
482 pub fn layer_norm_dispatched(
483 data: &[f64], gamma: &[f64], beta: &[f64], out: &mut [f64],
484 outer: usize, n: usize, eps: f64,
485 ctx: &crate::dispatch::ReductionContext,
486 ) {
487 debug_assert_eq!(data.len(), outer * n);
488 debug_assert_eq!(gamma.len(), n);
489 debug_assert_eq!(beta.len(), n);
490 debug_assert_eq!(out.len(), outer * n);
491 for row in 0..outer {
492 let start = row * n;
493 let slice = &data[start..start + n];
494
495 let mean = crate::dispatch::dispatch_sum_f64(slice, ctx) / n as f64;
496
497 let diffs: Vec<f64> = slice.iter().map(|&x| (x - mean) * (x - mean)).collect();
498 let var = crate::dispatch::dispatch_sum_f64(&diffs, ctx) / n as f64;
499 let inv_std = 1.0 / (var + eps).sqrt();
500
501 for i in 0..n {
502 out[start + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
503 }
504 }
505 }
506
507 pub fn conv1d_dispatched(
509 signal: &[f64], filters: &[f64], bias: &[f64], out: &mut [f64],
510 signal_len: usize, out_channels: usize, kernel_size: usize,
511 ctx: &crate::dispatch::ReductionContext,
512 ) {
513 debug_assert!(signal_len >= kernel_size);
514 let out_len = signal_len - kernel_size + 1;
515 for ch in 0..out_channels {
516 let filter_start = ch * kernel_size;
517 let filter_slice = &filters[filter_start..filter_start + kernel_size];
518 let out_row_start = ch * out_len;
519 for pos in 0..out_len {
520 let sig_slice = &signal[pos..pos + kernel_size];
521 out[out_row_start + pos] =
522 crate::dispatch::dispatch_dot_f64(sig_slice, filter_slice, ctx) + bias[ch];
523 }
524 }
525 }
526}
527