1pub mod kernel {
40 use cjc_repro::{kahan_sum_f64, KahanAccumulatorF64};
41
42 #[inline]
49 pub fn matmul_raw(
50 a: &[f64], b: &[f64], c: &mut [f64],
51 m: usize, k: usize, n: usize,
52 ) {
53 debug_assert_eq!(a.len(), m * k);
54 debug_assert_eq!(b.len(), k * n);
55 debug_assert_eq!(c.len(), m * n);
56 for i in 0..m {
57 for j in 0..n {
58 let mut acc = KahanAccumulatorF64::new();
59 for p in 0..k {
60 acc.add(a[i * k + p] * b[p * n + j]);
61 }
62 c[i * n + j] = acc.finalize();
63 }
64 }
65 }
66
67 #[inline]
72 pub fn softmax_raw(data: &[f64], out: &mut [f64], outer: usize, n: usize) {
73 debug_assert_eq!(data.len(), outer * n);
74 debug_assert_eq!(out.len(), outer * n);
75 for row in 0..outer {
76 let start = row * n;
77 let slice = &data[start..start + n];
78
79 let mut max_val = f64::NEG_INFINITY;
81 for &v in slice {
82 if v > max_val { max_val = v; }
83 }
84
85 let mut sum = 0.0f64;
87 let mut comp = 0.0f64;
88 for i in 0..n {
89 let e = (slice[i] - max_val).exp();
90 out[start + i] = e;
91 let y = e - comp;
92 let t = sum + y;
93 comp = (t - sum) - y;
94 sum = t;
95 }
96
97 if sum == 0.0 {
99 let uniform = 1.0 / n as f64;
100 for i in 0..n {
101 out[start + i] = uniform;
102 }
103 } else {
104 for i in 0..n {
105 out[start + i] /= sum;
106 }
107 }
108 }
109 }
110
111 #[inline]
113 pub fn linear_raw(
114 x: &[f64], w: &[f64], bias: &[f64], out: &mut [f64],
115 outer: usize, in_f: usize, out_f: usize,
116 ) {
117 debug_assert_eq!(x.len(), outer * in_f);
118 debug_assert_eq!(w.len(), out_f * in_f);
119 debug_assert_eq!(bias.len(), out_f);
120 debug_assert_eq!(out.len(), outer * out_f);
121 for row in 0..outer {
122 let x_start = row * in_f;
123 let x_slice = &x[x_start..x_start + in_f];
124 let y_start = row * out_f;
125 for j in 0..out_f {
126 let w_start = j * in_f;
127 let mut acc = KahanAccumulatorF64::new();
128 for p in 0..in_f {
129 acc.add(x_slice[p] * w[w_start + p]);
130 }
131 out[y_start + j] = acc.finalize() + bias[j];
132 }
133 }
134 }
135
136 #[inline]
141 pub fn layer_norm_raw(
142 data: &[f64], gamma: &[f64], beta: &[f64], out: &mut [f64],
143 outer: usize, n: usize, eps: f64,
144 ) {
145 debug_assert_eq!(data.len(), outer * n);
146 debug_assert_eq!(gamma.len(), n);
147 debug_assert_eq!(beta.len(), n);
148 debug_assert_eq!(out.len(), outer * n);
149 for row in 0..outer {
150 let start = row * n;
151 let slice = &data[start..start + n];
152
153 let mean = kahan_sum_f64(slice) / n as f64;
155
156 let diffs: Vec<f64> = slice.iter().map(|&x| (x - mean) * (x - mean)).collect();
158 let var = kahan_sum_f64(&diffs) / n as f64;
159 let inv_std = 1.0 / (var + eps).sqrt();
160
161 for i in 0..n {
162 out[start + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
163 }
164 }
165 }
166
167 #[inline]
169 pub fn relu_raw(data: &[f64], out: &mut [f64]) {
170 debug_assert_eq!(data.len(), out.len());
171 for (o, &x) in out.iter_mut().zip(data.iter()) {
172 *o = if x > 0.0 { x } else { 0.0 };
173 }
174 }
175
176 #[inline]
178 pub fn gelu_raw(data: &[f64], out: &mut [f64]) {
179 debug_assert_eq!(data.len(), out.len());
180 let sqrt_2_over_pi: f64 = (2.0 / std::f64::consts::PI).sqrt();
181 for (o, &x) in out.iter_mut().zip(data.iter()) {
182 let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
183 *o = 0.5 * x * (1.0 + inner.tanh());
184 }
185 }
186
187 pub fn conv1d_raw(
195 signal: &[f64], filters: &[f64], bias: &[f64], out: &mut [f64],
196 signal_len: usize, out_channels: usize, kernel_size: usize,
197 ) {
198 debug_assert!(signal_len >= kernel_size);
199 let out_len = signal_len - kernel_size + 1;
200 debug_assert_eq!(signal.len(), signal_len);
201 debug_assert_eq!(filters.len(), out_channels * kernel_size);
202 debug_assert_eq!(bias.len(), out_channels);
203 debug_assert_eq!(out.len(), out_channels * out_len);
204
205 for ch in 0..out_channels {
206 let filter_start = ch * kernel_size;
207 let filter_slice = &filters[filter_start..filter_start + kernel_size];
208 let out_row_start = ch * out_len;
209 for pos in 0..out_len {
210 let products: Vec<f64> = (0..kernel_size)
211 .map(|k| signal[pos + k] * filter_slice[k])
212 .collect();
213 out[out_row_start + pos] = kahan_sum_f64(&products) + bias[ch];
214 }
215 }
216 }
217
218 pub fn conv1d_circular(
224 buffer: &[f64], write_pos: usize, window_size: usize,
225 window: &mut [f64],
226 filters: &[f64], bias: &[f64], out: &mut [f64],
227 out_channels: usize, kernel_size: usize,
228 ) {
229 let buf_len = buffer.len();
230 debug_assert!(window_size <= buf_len);
231 debug_assert_eq!(window.len(), window_size);
232
233 let start = if write_pos >= window_size {
234 write_pos - window_size
235 } else {
236 buf_len - (window_size - write_pos)
237 };
238 for i in 0..window_size {
239 window[i] = buffer[(start + i) % buf_len];
240 }
241
242 conv1d_raw(window, filters, bias, out, window_size, out_channels, kernel_size);
243 }
244
245 #[allow(clippy::too_many_arguments)]
266 pub fn conv2d_raw(
267 input: &[f64],
268 filters: &[f64],
269 bias: &[f64],
270 out: &mut [f64],
271 n: usize, c_in: usize, h_in: usize, w_in: usize,
272 c_out: usize, kh: usize, kw: usize,
273 stride: usize,
274 ) {
275 use crate::accumulator::BinnedAccumulatorF64;
276
277 let h_out: u64 = ((h_in - kh) / stride + 1) as u64;
278 let w_out: u64 = ((w_in - kw) / stride + 1) as u64;
279
280 let s_n: u64 = (c_in * h_in * w_in) as u64;
282 let s_cin: u64 = (h_in * w_in) as u64;
283 let s_hin: u64 = w_in as u64;
284
285 let f_cout: u64 = (c_in * kh * kw) as u64;
287 let f_cin: u64 = (kh * kw) as u64;
288 let f_kh: u64 = kw as u64;
289
290 let o_n: u64 = c_out as u64 * h_out * w_out;
292 let o_cout: u64 = h_out * w_out;
293
294 debug_assert_eq!(input.len(), n * c_in * h_in * w_in);
295 debug_assert_eq!(filters.len(), c_out * c_in * kh * kw);
296 debug_assert_eq!(bias.len(), c_out);
297 debug_assert_eq!(out.len(), n * c_out * h_out as usize * w_out as usize);
298
299 for bn in 0..n as u64 {
300 for co in 0..c_out as u64 {
301 for oh in 0..h_out {
302 for ow in 0..w_out {
303 let mut acc = BinnedAccumulatorF64::new();
304
305 for ci in 0..c_in as u64 {
307 for ki in 0..kh as u64 {
308 for kj in 0..kw as u64 {
309 let ih: u64 = oh * stride as u64 + ki;
310 let iw: u64 = ow * stride as u64 + kj;
311
312 let inp_idx = (bn * s_n
313 + ci * s_cin
314 + ih * s_hin
315 + iw) as usize;
316 let flt_idx = (co * f_cout
317 + ci * f_cin
318 + ki * f_kh
319 + kj) as usize;
320
321 acc.add(input[inp_idx] * filters[flt_idx]);
322 }
323 }
324 }
325
326 let out_idx = (bn * o_n
327 + co * o_cout
328 + oh * w_out
329 + ow) as usize;
330 out[out_idx] = acc.finalize() + bias[co as usize];
331 }
332 }
333 }
334 }
335 }
336
337 #[allow(clippy::too_many_arguments)]
343 pub fn conv2d_dispatched(
344 input: &[f64],
345 filters: &[f64],
346 bias: &[f64],
347 out: &mut [f64],
348 n: usize, c_in: usize, h_in: usize, w_in: usize,
349 c_out: usize, kh: usize, kw: usize,
350 stride: usize,
351 ctx: &crate::dispatch::ReductionContext,
352 ) {
353 let h_out = (h_in - kh) / stride + 1;
354 let w_out = (w_in - kw) / stride + 1;
355
356 let s_n = c_in * h_in * w_in;
357 let s_cin = h_in * w_in;
358 let s_hin = w_in;
359
360 let f_cout = c_in * kh * kw;
361 let f_cin = kh * kw;
362 let f_kh = kw;
363
364 let o_n = c_out * h_out * w_out;
365 let o_cout = h_out * w_out;
366
367 for bn in 0..n {
368 for co in 0..c_out {
369 for oh in 0..h_out {
370 for ow in 0..w_out {
371 let mut terms = Vec::with_capacity(c_in * kh * kw);
372 for ci in 0..c_in {
373 for ki in 0..kh {
374 for kj in 0..kw {
375 let ih = oh * stride + ki;
376 let iw = ow * stride + kj;
377 let inp_idx = bn * s_n + ci * s_cin + ih * s_hin + iw;
378 let flt_idx = co * f_cout + ci * f_cin + ki * f_kh + kj;
379 terms.push(input[inp_idx] * filters[flt_idx]);
380 }
381 }
382 }
383 let out_idx = bn * o_n + co * o_cout + oh * w_out + ow;
384 out[out_idx] =
385 crate::dispatch::dispatch_sum_f64(&terms, ctx) + bias[co];
386 }
387 }
388 }
389 }
390 }
391
392 pub fn maxpool2d_raw(
399 input: &[f64],
400 out: &mut [f64],
401 n: usize, c: usize, h_in: usize, w_in: usize,
402 ph: usize, pw: usize,
403 ) {
404 let h_out: u64 = (h_in / ph) as u64;
405 let w_out: u64 = (w_in / pw) as u64;
406
407 let s_n: u64 = (c * h_in * w_in) as u64;
408 let s_c: u64 = (h_in * w_in) as u64;
409 let s_hin: u64 = w_in as u64;
410
411 let o_n: u64 = (c as u64) * h_out * w_out;
412 let o_c: u64 = h_out * w_out;
413
414 debug_assert_eq!(input.len(), n * c * h_in * w_in);
415 debug_assert_eq!(out.len(), n * c * h_out as usize * w_out as usize);
416
417 for bn in 0..n as u64 {
418 for ch in 0..c as u64 {
419 for oh in 0..h_out {
420 for ow in 0..w_out {
421 let mut max_val = f64::NEG_INFINITY;
422 for pi in 0..ph as u64 {
423 for pj in 0..pw as u64 {
424 let ih: u64 = oh * ph as u64 + pi;
425 let iw: u64 = ow * pw as u64 + pj;
426 let idx = (bn * s_n + ch * s_c + ih * s_hin + iw) as usize;
427 let v = input[idx];
428 if v > max_val { max_val = v; }
429 }
430 }
431 let o_idx = (bn * o_n + ch * o_c + oh * w_out + ow) as usize;
432 out[o_idx] = max_val;
433 }
434 }
435 }
436 }
437 }
438
439 pub fn maxpool1d_raw(data: &[f64], out: &mut [f64], data_len: usize, pool_size: usize) {
448 debug_assert_eq!(data.len(), data_len);
449 let out_len = data_len / pool_size;
450 debug_assert_eq!(out.len(), out_len);
451 for i in 0..out_len {
452 let start = i * pool_size;
453 let mut max_val = data[start];
454 for j in 1..pool_size {
455 let v = data[start + j];
456 if v > max_val { max_val = v; }
457 }
458 out[i] = max_val;
459 }
460 }
461
462 #[inline]
469 pub fn matmul_dispatched(
470 a: &[f64], b: &[f64], c: &mut [f64],
471 m: usize, k: usize, n: usize,
472 ctx: &crate::dispatch::ReductionContext,
473 ) {
474 debug_assert_eq!(a.len(), m * k);
475 debug_assert_eq!(b.len(), k * n);
476 debug_assert_eq!(c.len(), m * n);
477 for i in 0..m {
478 for j in 0..n {
479 let a_row = &a[i * k..(i + 1) * k];
481 let b_col: Vec<f64> = (0..k).map(|p| b[p * n + j]).collect();
482 c[i * n + j] = crate::dispatch::dispatch_dot_f64(a_row, &b_col, ctx);
483 }
484 }
485 }
486
487 #[inline]
489 pub fn linear_dispatched(
490 x: &[f64], w: &[f64], bias: &[f64], out: &mut [f64],
491 outer: usize, in_f: usize, out_f: usize,
492 ctx: &crate::dispatch::ReductionContext,
493 ) {
494 debug_assert_eq!(x.len(), outer * in_f);
495 debug_assert_eq!(w.len(), out_f * in_f);
496 debug_assert_eq!(bias.len(), out_f);
497 debug_assert_eq!(out.len(), outer * out_f);
498 for row in 0..outer {
499 let x_start = row * in_f;
500 let x_slice = &x[x_start..x_start + in_f];
501 let y_start = row * out_f;
502 for j in 0..out_f {
503 let w_start = j * in_f;
504 let w_slice = &w[w_start..w_start + in_f];
505 out[y_start + j] = crate::dispatch::dispatch_dot_f64(x_slice, w_slice, ctx) + bias[j];
506 }
507 }
508 }
509
510 #[inline]
512 pub fn layer_norm_dispatched(
513 data: &[f64], gamma: &[f64], beta: &[f64], out: &mut [f64],
514 outer: usize, n: usize, eps: f64,
515 ctx: &crate::dispatch::ReductionContext,
516 ) {
517 debug_assert_eq!(data.len(), outer * n);
518 debug_assert_eq!(gamma.len(), n);
519 debug_assert_eq!(beta.len(), n);
520 debug_assert_eq!(out.len(), outer * n);
521 for row in 0..outer {
522 let start = row * n;
523 let slice = &data[start..start + n];
524
525 let mean = crate::dispatch::dispatch_sum_f64(slice, ctx) / n as f64;
526
527 let diffs: Vec<f64> = slice.iter().map(|&x| (x - mean) * (x - mean)).collect();
528 let var = crate::dispatch::dispatch_sum_f64(&diffs, ctx) / n as f64;
529 let inv_std = 1.0 / (var + eps).sqrt();
530
531 for i in 0..n {
532 out[start + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
533 }
534 }
535 }
536
537 pub fn conv1d_dispatched(
539 signal: &[f64], filters: &[f64], bias: &[f64], out: &mut [f64],
540 signal_len: usize, out_channels: usize, kernel_size: usize,
541 ctx: &crate::dispatch::ReductionContext,
542 ) {
543 debug_assert!(signal_len >= kernel_size);
544 let out_len = signal_len - kernel_size + 1;
545 for ch in 0..out_channels {
546 let filter_start = ch * kernel_size;
547 let filter_slice = &filters[filter_start..filter_start + kernel_size];
548 let out_row_start = ch * out_len;
549 for pos in 0..out_len {
550 let sig_slice = &signal[pos..pos + kernel_size];
551 out[out_row_start + pos] =
552 crate::dispatch::dispatch_dot_f64(sig_slice, filter_slice, ctx) + bias[ch];
553 }
554 }
555 }
556}
557