Skip to main content

cjc_runtime/
kernel_bridge.rs

1
2
3// ---------------------------------------------------------------------------
4// 2d. Raw-Pointer Kernel Bridge — bypass interpreter overhead for hot loops
5// ---------------------------------------------------------------------------
6
7/// Raw-pointer kernel functions that operate directly on f64 slices.
8///
9/// These bypass the `Value::Tensor` wrapper, operating on contiguous `&[f64]`
10/// data. The interpreter resolves tensor pointers once at call entry, then
11/// dispatches to these zero-overhead kernels.
12///
13/// All functions here are safe Rust — they accept slices, not raw pointers.
14/// The "raw pointer" concept means: the interpreter does one `to_vec()` or
15/// `buffer.borrow()` at entry, then passes the contiguous slice through.
16pub mod kernel {
17    use cjc_repro::{kahan_sum_f64, KahanAccumulatorF64};
18
19    /// Matrix multiply: C[m,n] = A[m,k] × B[k,n] with Kahan-summed dots.
20    ///
21    /// `a`, `b` are row-major contiguous slices; `c` is the output buffer
22    /// (must be pre-allocated to `m * n`).
23    ///
24    /// Uses in-place `KahanAccumulatorF64` — zero heap allocation per dot product.
25    #[inline]
26    pub fn matmul_raw(
27        a: &[f64], b: &[f64], c: &mut [f64],
28        m: usize, k: usize, n: usize,
29    ) {
30        debug_assert_eq!(a.len(), m * k);
31        debug_assert_eq!(b.len(), k * n);
32        debug_assert_eq!(c.len(), m * n);
33        for i in 0..m {
34            for j in 0..n {
35                let mut acc = KahanAccumulatorF64::new();
36                for p in 0..k {
37                    acc.add(a[i * k + p] * b[p * n + j]);
38                }
39                c[i * n + j] = acc.finalize();
40            }
41        }
42    }
43
44    /// Softmax over the last dimension of a contiguous buffer.
45    ///
46    /// `data` is the input (length = `outer * n`), `out` is the output.
47    /// Applies two-pass stable softmax per row of length `n`.
48    #[inline]
49    pub fn softmax_raw(data: &[f64], out: &mut [f64], outer: usize, n: usize) {
50        debug_assert_eq!(data.len(), outer * n);
51        debug_assert_eq!(out.len(), outer * n);
52        for row in 0..outer {
53            let start = row * n;
54            let slice = &data[start..start + n];
55
56            // Pass 1: max
57            let mut max_val = f64::NEG_INFINITY;
58            for &v in slice {
59                if v > max_val { max_val = v; }
60            }
61
62            // Pass 2: exp + Kahan sum
63            let mut sum = 0.0f64;
64            let mut comp = 0.0f64;
65            for i in 0..n {
66                let e = (slice[i] - max_val).exp();
67                out[start + i] = e;
68                let y = e - comp;
69                let t = sum + y;
70                comp = (t - sum) - y;
71                sum = t;
72            }
73
74            // Normalize
75            if sum == 0.0 {
76                let uniform = 1.0 / n as f64;
77                for i in 0..n {
78                    out[start + i] = uniform;
79                }
80            } else {
81                for i in 0..n {
82                    out[start + i] /= sum;
83                }
84            }
85        }
86    }
87
88    /// Linear projection: Y[outer, out_f] = X[outer, in_f] @ W^T[out_f, in_f] + bias[out_f]
89    #[inline]
90    pub fn linear_raw(
91        x: &[f64], w: &[f64], bias: &[f64], out: &mut [f64],
92        outer: usize, in_f: usize, out_f: usize,
93    ) {
94        debug_assert_eq!(x.len(), outer * in_f);
95        debug_assert_eq!(w.len(), out_f * in_f);
96        debug_assert_eq!(bias.len(), out_f);
97        debug_assert_eq!(out.len(), outer * out_f);
98        for row in 0..outer {
99            let x_start = row * in_f;
100            let x_slice = &x[x_start..x_start + in_f];
101            let y_start = row * out_f;
102            for j in 0..out_f {
103                let w_start = j * in_f;
104                let mut acc = KahanAccumulatorF64::new();
105                for p in 0..in_f {
106                    acc.add(x_slice[p] * w[w_start + p]);
107                }
108                out[y_start + j] = acc.finalize() + bias[j];
109            }
110        }
111    }
112
113    /// Layer normalization over the last dimension.
114    ///
115    /// For each row of length `n`: normalize to mean=0, var=1, then
116    /// scale by gamma and shift by beta.
117    #[inline]
118    pub fn layer_norm_raw(
119        data: &[f64], gamma: &[f64], beta: &[f64], out: &mut [f64],
120        outer: usize, n: usize, eps: f64,
121    ) {
122        debug_assert_eq!(data.len(), outer * n);
123        debug_assert_eq!(gamma.len(), n);
124        debug_assert_eq!(beta.len(), n);
125        debug_assert_eq!(out.len(), outer * n);
126        for row in 0..outer {
127            let start = row * n;
128            let slice = &data[start..start + n];
129
130            // Mean (Kahan)
131            let mean = kahan_sum_f64(slice) / n as f64;
132
133            // Variance (Kahan)
134            let diffs: Vec<f64> = slice.iter().map(|&x| (x - mean) * (x - mean)).collect();
135            let var = kahan_sum_f64(&diffs) / n as f64;
136            let inv_std = 1.0 / (var + eps).sqrt();
137
138            for i in 0..n {
139                out[start + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
140            }
141        }
142    }
143
144    /// ReLU: max(0, x) element-wise.
145    #[inline]
146    pub fn relu_raw(data: &[f64], out: &mut [f64]) {
147        debug_assert_eq!(data.len(), out.len());
148        for (o, &x) in out.iter_mut().zip(data.iter()) {
149            *o = if x > 0.0 { x } else { 0.0 };
150        }
151    }
152
153    /// Approximate GELU: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
154    #[inline]
155    pub fn gelu_raw(data: &[f64], out: &mut [f64]) {
156        debug_assert_eq!(data.len(), out.len());
157        let sqrt_2_over_pi: f64 = (2.0 / std::f64::consts::PI).sqrt();
158        for (o, &x) in out.iter_mut().zip(data.iter()) {
159            let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
160            *o = 0.5 * x * (1.0 + inner.tanh());
161        }
162    }
163
164    /// 1D convolution with stride=1, no padding ("valid" mode).
165    ///
166    /// `signal`: input signal of length `signal_len`.
167    /// `filters`: `[out_channels, kernel_size]` row-major.
168    /// `bias`: per-channel bias of length `out_channels`.
169    /// `out`: output buffer `[out_channels, out_len]`, `out_len = signal_len - kernel_size + 1`.
170    /// Uses Kahan summation for deterministic dot products.
171    pub fn conv1d_raw(
172        signal: &[f64], filters: &[f64], bias: &[f64], out: &mut [f64],
173        signal_len: usize, out_channels: usize, kernel_size: usize,
174    ) {
175        debug_assert!(signal_len >= kernel_size);
176        let out_len = signal_len - kernel_size + 1;
177        debug_assert_eq!(signal.len(), signal_len);
178        debug_assert_eq!(filters.len(), out_channels * kernel_size);
179        debug_assert_eq!(bias.len(), out_channels);
180        debug_assert_eq!(out.len(), out_channels * out_len);
181
182        for ch in 0..out_channels {
183            let filter_start = ch * kernel_size;
184            let filter_slice = &filters[filter_start..filter_start + kernel_size];
185            let out_row_start = ch * out_len;
186            for pos in 0..out_len {
187                let products: Vec<f64> = (0..kernel_size)
188                    .map(|k| signal[pos + k] * filter_slice[k])
189                    .collect();
190                out[out_row_start + pos] = kahan_sum_f64(&products) + bias[ch];
191            }
192        }
193    }
194
195    /// 1D convolution on a sliding window of a circular buffer.
196    ///
197    /// Extracts the most recent `window_size` samples from `buffer`
198    /// (handling wrap-around at `write_pos`) into `window`, then
199    /// delegates to `conv1d_raw`.
200    pub fn conv1d_circular(
201        buffer: &[f64], write_pos: usize, window_size: usize,
202        window: &mut [f64],
203        filters: &[f64], bias: &[f64], out: &mut [f64],
204        out_channels: usize, kernel_size: usize,
205    ) {
206        let buf_len = buffer.len();
207        debug_assert!(window_size <= buf_len);
208        debug_assert_eq!(window.len(), window_size);
209
210        let start = if write_pos >= window_size {
211            write_pos - window_size
212        } else {
213            buf_len - (window_size - write_pos)
214        };
215        for i in 0..window_size {
216            window[i] = buffer[(start + i) % buf_len];
217        }
218
219        conv1d_raw(window, filters, bias, out, window_size, out_channels, kernel_size);
220    }
221
222    // -- Phase 7: 2D Spatial Kernels ------------------------------------------
223
224    /// 2D convolution — NCHW layout, valid mode (no padding), configurable stride.
225    ///
226    /// # Layout
227    /// - `input`:   `[N, C_in, H_in, W_in]`  row-major contiguous
228    /// - `filters`: `[C_out, C_in, kH, kW]`  row-major contiguous
229    /// - `bias`:    `[C_out]`
230    /// - `out`:     `[N, C_out, H_out, W_out]`  pre-allocated by caller
231    ///
232    /// where `H_out = (H_in - kH) / stride + 1` and `W_out = (W_in - kW) / stride + 1`.
233    ///
234    /// # Numerical contract
235    /// Every kernel-to-patch dot product uses `BinnedAccumulatorF64`, guaranteeing
236    /// bit-identical results regardless of stride, batch size, or channel count.
237    ///
238    /// # NoGC guarantee
239    /// All index arithmetic uses `u64` before narrowing to `usize`, preventing
240    /// overflow for high-resolution inputs (e.g., 8192×8192). The output buffer
241    /// is caller-allocated; this function performs zero heap allocations.
242    #[allow(clippy::too_many_arguments)]
243    pub fn conv2d_raw(
244        input:   &[f64],
245        filters: &[f64],
246        bias:    &[f64],
247        out:     &mut [f64],
248        n: usize, c_in: usize, h_in: usize, w_in: usize,
249        c_out: usize, kh: usize, kw: usize,
250        stride: usize,
251    ) {
252        use crate::accumulator::BinnedAccumulatorF64;
253
254        let h_out: u64 = ((h_in - kh) / stride + 1) as u64;
255        let w_out: u64 = ((w_in - kw) / stride + 1) as u64;
256
257        // Strides in the input tensor (NCHW, row-major).
258        let s_n:   u64 = (c_in  * h_in * w_in) as u64;
259        let s_cin: u64 = (h_in  * w_in) as u64;
260        let s_hin: u64 = w_in as u64;
261
262        // Strides in the filter tensor [C_out, C_in, kH, kW].
263        let f_cout: u64 = (c_in * kh * kw) as u64;
264        let f_cin:  u64 = (kh * kw) as u64;
265        let f_kh:   u64 = kw as u64;
266
267        // Strides in the output tensor (NCHW).
268        let o_n:    u64 = c_out as u64 * h_out * w_out;
269        let o_cout: u64 = h_out * w_out;
270
271        debug_assert_eq!(input.len(),   n * c_in  * h_in * w_in);
272        debug_assert_eq!(filters.len(), c_out * c_in * kh * kw);
273        debug_assert_eq!(bias.len(),    c_out);
274        debug_assert_eq!(out.len(),     n * c_out * h_out as usize * w_out as usize);
275
276        for bn in 0..n as u64 {
277            for co in 0..c_out as u64 {
278                for oh in 0..h_out {
279                    for ow in 0..w_out {
280                        let mut acc = BinnedAccumulatorF64::new();
281
282                        // Sum over input channels and kernel spatial extent.
283                        for ci in 0..c_in as u64 {
284                            for ki in 0..kh as u64 {
285                                for kj in 0..kw as u64 {
286                                    let ih: u64 = oh * stride as u64 + ki;
287                                    let iw: u64 = ow * stride as u64 + kj;
288
289                                    let inp_idx = (bn  * s_n
290                                                 + ci  * s_cin
291                                                 + ih  * s_hin
292                                                 + iw) as usize;
293                                    let flt_idx = (co  * f_cout
294                                                 + ci  * f_cin
295                                                 + ki  * f_kh
296                                                 + kj) as usize;
297
298                                    acc.add(input[inp_idx] * filters[flt_idx]);
299                                }
300                            }
301                        }
302
303                        let out_idx = (bn * o_n
304                                     + co * o_cout
305                                     + oh * w_out
306                                     + ow) as usize;
307                        out[out_idx] = acc.finalize() + bias[co as usize];
308                    }
309                }
310            }
311        }
312    }
313
314    /// 2D convolution using dispatched summation strategy.
315    ///
316    /// Identical to `conv2d_raw` but selects Kahan or Binned based on the
317    /// reduction context.  Useful when callers want runtime-configurable
318    /// accumulation precision.
319    #[allow(clippy::too_many_arguments)]
320    pub fn conv2d_dispatched(
321        input:   &[f64],
322        filters: &[f64],
323        bias:    &[f64],
324        out:     &mut [f64],
325        n: usize, c_in: usize, h_in: usize, w_in: usize,
326        c_out: usize, kh: usize, kw: usize,
327        stride: usize,
328        ctx: &crate::dispatch::ReductionContext,
329    ) {
330        let h_out = (h_in - kh) / stride + 1;
331        let w_out = (w_in - kw) / stride + 1;
332
333        let s_n   = c_in  * h_in * w_in;
334        let s_cin = h_in  * w_in;
335        let s_hin = w_in;
336
337        let f_cout = c_in * kh * kw;
338        let f_cin  = kh * kw;
339        let f_kh   = kw;
340
341        let o_n    = c_out * h_out * w_out;
342        let o_cout = h_out * w_out;
343
344        for bn in 0..n {
345            for co in 0..c_out {
346                for oh in 0..h_out {
347                    for ow in 0..w_out {
348                        let mut terms = Vec::with_capacity(c_in * kh * kw);
349                        for ci in 0..c_in {
350                            for ki in 0..kh {
351                                for kj in 0..kw {
352                                    let ih = oh * stride + ki;
353                                    let iw = ow * stride + kj;
354                                    let inp_idx = bn * s_n  + ci * s_cin + ih * s_hin + iw;
355                                    let flt_idx = co * f_cout + ci * f_cin + ki * f_kh + kj;
356                                    terms.push(input[inp_idx] * filters[flt_idx]);
357                                }
358                            }
359                        }
360                        let out_idx = bn * o_n + co * o_cout + oh * w_out + ow;
361                        out[out_idx] =
362                            crate::dispatch::dispatch_sum_f64(&terms, ctx) + bias[co];
363                    }
364                }
365            }
366        }
367    }
368
369    /// 2D max-pooling — NCHW layout, stride = pool_size (non-overlapping).
370    ///
371    /// - `input`: `[N, C, H_in, W_in]`
372    /// - `out`:   `[N, C, H_in/ph, W_in/pw]`  (floor division, pre-allocated)
373    ///
374    /// All index arithmetic uses `u64` to support large spatial extents.
375    pub fn maxpool2d_raw(
376        input:  &[f64],
377        out:    &mut [f64],
378        n: usize, c: usize, h_in: usize, w_in: usize,
379        ph: usize, pw: usize,
380    ) {
381        let h_out: u64 = (h_in / ph) as u64;
382        let w_out: u64 = (w_in / pw) as u64;
383
384        let s_n:   u64 = (c * h_in * w_in) as u64;
385        let s_c:   u64 = (h_in * w_in) as u64;
386        let s_hin: u64 = w_in as u64;
387
388        let o_n:   u64 = (c as u64) * h_out * w_out;
389        let o_c:   u64 = h_out * w_out;
390
391        debug_assert_eq!(input.len(), n * c * h_in * w_in);
392        debug_assert_eq!(out.len(),   n * c * h_out as usize * w_out as usize);
393
394        for bn in 0..n as u64 {
395            for ch in 0..c as u64 {
396                for oh in 0..h_out {
397                    for ow in 0..w_out {
398                        let mut max_val = f64::NEG_INFINITY;
399                        for pi in 0..ph as u64 {
400                            for pj in 0..pw as u64 {
401                                let ih: u64 = oh * ph as u64 + pi;
402                                let iw: u64 = ow * pw as u64 + pj;
403                                let idx = (bn * s_n + ch * s_c + ih * s_hin + iw) as usize;
404                                let v = input[idx];
405                                if v > max_val { max_val = v; }
406                            }
407                        }
408                        let o_idx = (bn * o_n + ch * o_c + oh * w_out + ow) as usize;
409                        out[o_idx] = max_val;
410                    }
411                }
412            }
413        }
414    }
415
416    /// Max-pooling over 1D signal, stride = pool_size.
417    pub fn maxpool1d_raw(data: &[f64], out: &mut [f64], data_len: usize, pool_size: usize) {
418        debug_assert_eq!(data.len(), data_len);
419        let out_len = data_len / pool_size;
420        debug_assert_eq!(out.len(), out_len);
421        for i in 0..out_len {
422            let start = i * pool_size;
423            let mut max_val = data[start];
424            for j in 1..pool_size {
425                let v = data[start + j];
426                if v > max_val { max_val = v; }
427            }
428            out[i] = max_val;
429        }
430    }
431
432    // -- Dispatched kernel variants (Milestone 2.7) ---------------------------
433
434    /// Matrix multiply using dispatched summation strategy.
435    ///
436    /// Identical to `matmul_raw` but uses `dispatch_dot_f64` for each dot product,
437    /// selecting Kahan or Binned based on the reduction context.
438    #[inline]
439    pub fn matmul_dispatched(
440        a: &[f64], b: &[f64], c: &mut [f64],
441        m: usize, k: usize, n: usize,
442        ctx: &crate::dispatch::ReductionContext,
443    ) {
444        debug_assert_eq!(a.len(), m * k);
445        debug_assert_eq!(b.len(), k * n);
446        debug_assert_eq!(c.len(), m * n);
447        for i in 0..m {
448            for j in 0..n {
449                // Collect column from B for the dot product.
450                let a_row = &a[i * k..(i + 1) * k];
451                let b_col: Vec<f64> = (0..k).map(|p| b[p * n + j]).collect();
452                c[i * n + j] = crate::dispatch::dispatch_dot_f64(a_row, &b_col, ctx);
453            }
454        }
455    }
456
457    /// Linear projection using dispatched summation.
458    #[inline]
459    pub fn linear_dispatched(
460        x: &[f64], w: &[f64], bias: &[f64], out: &mut [f64],
461        outer: usize, in_f: usize, out_f: usize,
462        ctx: &crate::dispatch::ReductionContext,
463    ) {
464        debug_assert_eq!(x.len(), outer * in_f);
465        debug_assert_eq!(w.len(), out_f * in_f);
466        debug_assert_eq!(bias.len(), out_f);
467        debug_assert_eq!(out.len(), outer * out_f);
468        for row in 0..outer {
469            let x_start = row * in_f;
470            let x_slice = &x[x_start..x_start + in_f];
471            let y_start = row * out_f;
472            for j in 0..out_f {
473                let w_start = j * in_f;
474                let w_slice = &w[w_start..w_start + in_f];
475                out[y_start + j] = crate::dispatch::dispatch_dot_f64(x_slice, w_slice, ctx) + bias[j];
476            }
477        }
478    }
479
480    /// Layer normalization using dispatched summation for mean/variance.
481    #[inline]
482    pub fn layer_norm_dispatched(
483        data: &[f64], gamma: &[f64], beta: &[f64], out: &mut [f64],
484        outer: usize, n: usize, eps: f64,
485        ctx: &crate::dispatch::ReductionContext,
486    ) {
487        debug_assert_eq!(data.len(), outer * n);
488        debug_assert_eq!(gamma.len(), n);
489        debug_assert_eq!(beta.len(), n);
490        debug_assert_eq!(out.len(), outer * n);
491        for row in 0..outer {
492            let start = row * n;
493            let slice = &data[start..start + n];
494
495            let mean = crate::dispatch::dispatch_sum_f64(slice, ctx) / n as f64;
496
497            let diffs: Vec<f64> = slice.iter().map(|&x| (x - mean) * (x - mean)).collect();
498            let var = crate::dispatch::dispatch_sum_f64(&diffs, ctx) / n as f64;
499            let inv_std = 1.0 / (var + eps).sqrt();
500
501            for i in 0..n {
502                out[start + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
503            }
504        }
505    }
506
507    /// 1D convolution using dispatched summation.
508    pub fn conv1d_dispatched(
509        signal: &[f64], filters: &[f64], bias: &[f64], out: &mut [f64],
510        signal_len: usize, out_channels: usize, kernel_size: usize,
511        ctx: &crate::dispatch::ReductionContext,
512    ) {
513        debug_assert!(signal_len >= kernel_size);
514        let out_len = signal_len - kernel_size + 1;
515        for ch in 0..out_channels {
516            let filter_start = ch * kernel_size;
517            let filter_slice = &filters[filter_start..filter_start + kernel_size];
518            let out_row_start = ch * out_len;
519            for pos in 0..out_len {
520                let sig_slice = &signal[pos..pos + kernel_size];
521                out[out_row_start + pos] =
522                    crate::dispatch::dispatch_dot_f64(sig_slice, filter_slice, ctx) + bias[ch];
523            }
524        }
525    }
526}
527