Skip to main content

cjc_runtime/
kernel_bridge.rs

1//! Raw-pointer kernel bridge -- bypass interpreter overhead for hot loops.
2//!
3//! Provides zero-overhead numerical kernels that operate directly on
4//! contiguous `&[f64]` slices, avoiding the per-element dispatch cost of
5//! the `Value::Tensor` wrapper. The interpreter resolves tensor pointers
6//! once at call entry, then dispatches to these kernels for the inner loop.
7//!
8//! # Determinism
9//!
10//! Every reduction (dot product, summation, normalization) uses either
11//! [`KahanAccumulatorF64`] or [`BinnedAccumulatorF64`] to guarantee
12//! deterministic floating-point results regardless of thread count or
13//! input ordering.
14//!
15//! # NoGC guarantee
16//!
17//! All kernels operate on caller-provided buffers with no heap allocation
18//! in the hot path. Output buffers are always pre-allocated by the caller.
19//!
20//! # Dispatched variants
21//!
22//! Each kernel has a `_dispatched` variant that accepts a
23//! [`ReductionContext`](crate::dispatch::ReductionContext) to select
24//! between Kahan and Binned accumulation strategies at runtime.
25
26// ---------------------------------------------------------------------------
27// 2d. Raw-Pointer Kernel Bridge — bypass interpreter overhead for hot loops
28// ---------------------------------------------------------------------------
29
30/// Raw-pointer kernel functions that operate directly on f64 slices.
31///
32/// These bypass the `Value::Tensor` wrapper, operating on contiguous `&[f64]`
33/// data. The interpreter resolves tensor pointers once at call entry, then
34/// dispatches to these zero-overhead kernels.
35///
36/// All functions here are safe Rust -- they accept slices, not raw pointers.
37/// The "raw pointer" concept means: the interpreter does one `to_vec()` or
38/// `buffer.borrow()` at entry, then passes the contiguous slice through.
39pub mod kernel {
40    use cjc_repro::{kahan_sum_f64, KahanAccumulatorF64};
41
42    /// Matrix multiply: C[m,n] = A[m,k] × B[k,n] with Kahan-summed dots.
43    ///
44    /// `a`, `b` are row-major contiguous slices; `c` is the output buffer
45    /// (must be pre-allocated to `m * n`).
46    ///
47    /// Uses in-place `KahanAccumulatorF64` — zero heap allocation per dot product.
48    #[inline]
49    pub fn matmul_raw(
50        a: &[f64], b: &[f64], c: &mut [f64],
51        m: usize, k: usize, n: usize,
52    ) {
53        debug_assert_eq!(a.len(), m * k);
54        debug_assert_eq!(b.len(), k * n);
55        debug_assert_eq!(c.len(), m * n);
56        for i in 0..m {
57            for j in 0..n {
58                let mut acc = KahanAccumulatorF64::new();
59                for p in 0..k {
60                    acc.add(a[i * k + p] * b[p * n + j]);
61                }
62                c[i * n + j] = acc.finalize();
63            }
64        }
65    }
66
67    /// Softmax over the last dimension of a contiguous buffer.
68    ///
69    /// `data` is the input (length = `outer * n`), `out` is the output.
70    /// Applies two-pass stable softmax per row of length `n`.
71    #[inline]
72    pub fn softmax_raw(data: &[f64], out: &mut [f64], outer: usize, n: usize) {
73        debug_assert_eq!(data.len(), outer * n);
74        debug_assert_eq!(out.len(), outer * n);
75        for row in 0..outer {
76            let start = row * n;
77            let slice = &data[start..start + n];
78
79            // Pass 1: max
80            let mut max_val = f64::NEG_INFINITY;
81            for &v in slice {
82                if v > max_val { max_val = v; }
83            }
84
85            // Pass 2: exp + Kahan sum
86            let mut sum = 0.0f64;
87            let mut comp = 0.0f64;
88            for i in 0..n {
89                let e = (slice[i] - max_val).exp();
90                out[start + i] = e;
91                let y = e - comp;
92                let t = sum + y;
93                comp = (t - sum) - y;
94                sum = t;
95            }
96
97            // Normalize
98            if sum == 0.0 {
99                let uniform = 1.0 / n as f64;
100                for i in 0..n {
101                    out[start + i] = uniform;
102                }
103            } else {
104                for i in 0..n {
105                    out[start + i] /= sum;
106                }
107            }
108        }
109    }
110
111    /// Linear projection: Y[outer, out_f] = X[outer, in_f] @ W^T[out_f, in_f] + bias[out_f]
112    #[inline]
113    pub fn linear_raw(
114        x: &[f64], w: &[f64], bias: &[f64], out: &mut [f64],
115        outer: usize, in_f: usize, out_f: usize,
116    ) {
117        debug_assert_eq!(x.len(), outer * in_f);
118        debug_assert_eq!(w.len(), out_f * in_f);
119        debug_assert_eq!(bias.len(), out_f);
120        debug_assert_eq!(out.len(), outer * out_f);
121        for row in 0..outer {
122            let x_start = row * in_f;
123            let x_slice = &x[x_start..x_start + in_f];
124            let y_start = row * out_f;
125            for j in 0..out_f {
126                let w_start = j * in_f;
127                let mut acc = KahanAccumulatorF64::new();
128                for p in 0..in_f {
129                    acc.add(x_slice[p] * w[w_start + p]);
130                }
131                out[y_start + j] = acc.finalize() + bias[j];
132            }
133        }
134    }
135
136    /// Layer normalization over the last dimension.
137    ///
138    /// For each row of length `n`: normalize to mean=0, var=1, then
139    /// scale by gamma and shift by beta.
140    #[inline]
141    pub fn layer_norm_raw(
142        data: &[f64], gamma: &[f64], beta: &[f64], out: &mut [f64],
143        outer: usize, n: usize, eps: f64,
144    ) {
145        debug_assert_eq!(data.len(), outer * n);
146        debug_assert_eq!(gamma.len(), n);
147        debug_assert_eq!(beta.len(), n);
148        debug_assert_eq!(out.len(), outer * n);
149        for row in 0..outer {
150            let start = row * n;
151            let slice = &data[start..start + n];
152
153            // Mean (Kahan)
154            let mean = kahan_sum_f64(slice) / n as f64;
155
156            // Variance (Kahan)
157            let diffs: Vec<f64> = slice.iter().map(|&x| (x - mean) * (x - mean)).collect();
158            let var = kahan_sum_f64(&diffs) / n as f64;
159            let inv_std = 1.0 / (var + eps).sqrt();
160
161            for i in 0..n {
162                out[start + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
163            }
164        }
165    }
166
167    /// ReLU: max(0, x) element-wise.
168    #[inline]
169    pub fn relu_raw(data: &[f64], out: &mut [f64]) {
170        debug_assert_eq!(data.len(), out.len());
171        for (o, &x) in out.iter_mut().zip(data.iter()) {
172            *o = if x > 0.0 { x } else { 0.0 };
173        }
174    }
175
176    /// Approximate GELU: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
177    #[inline]
178    pub fn gelu_raw(data: &[f64], out: &mut [f64]) {
179        debug_assert_eq!(data.len(), out.len());
180        let sqrt_2_over_pi: f64 = (2.0 / std::f64::consts::PI).sqrt();
181        for (o, &x) in out.iter_mut().zip(data.iter()) {
182            let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
183            *o = 0.5 * x * (1.0 + inner.tanh());
184        }
185    }
186
187    /// 1D convolution with stride=1, no padding ("valid" mode).
188    ///
189    /// `signal`: input signal of length `signal_len`.
190    /// `filters`: `[out_channels, kernel_size]` row-major.
191    /// `bias`: per-channel bias of length `out_channels`.
192    /// `out`: output buffer `[out_channels, out_len]`, `out_len = signal_len - kernel_size + 1`.
193    /// Uses Kahan summation for deterministic dot products.
194    pub fn conv1d_raw(
195        signal: &[f64], filters: &[f64], bias: &[f64], out: &mut [f64],
196        signal_len: usize, out_channels: usize, kernel_size: usize,
197    ) {
198        debug_assert!(signal_len >= kernel_size);
199        let out_len = signal_len - kernel_size + 1;
200        debug_assert_eq!(signal.len(), signal_len);
201        debug_assert_eq!(filters.len(), out_channels * kernel_size);
202        debug_assert_eq!(bias.len(), out_channels);
203        debug_assert_eq!(out.len(), out_channels * out_len);
204
205        for ch in 0..out_channels {
206            let filter_start = ch * kernel_size;
207            let filter_slice = &filters[filter_start..filter_start + kernel_size];
208            let out_row_start = ch * out_len;
209            for pos in 0..out_len {
210                let products: Vec<f64> = (0..kernel_size)
211                    .map(|k| signal[pos + k] * filter_slice[k])
212                    .collect();
213                out[out_row_start + pos] = kahan_sum_f64(&products) + bias[ch];
214            }
215        }
216    }
217
218    /// 1D convolution on a sliding window of a circular buffer.
219    ///
220    /// Extracts the most recent `window_size` samples from `buffer`
221    /// (handling wrap-around at `write_pos`) into `window`, then
222    /// delegates to `conv1d_raw`.
223    pub fn conv1d_circular(
224        buffer: &[f64], write_pos: usize, window_size: usize,
225        window: &mut [f64],
226        filters: &[f64], bias: &[f64], out: &mut [f64],
227        out_channels: usize, kernel_size: usize,
228    ) {
229        let buf_len = buffer.len();
230        debug_assert!(window_size <= buf_len);
231        debug_assert_eq!(window.len(), window_size);
232
233        let start = if write_pos >= window_size {
234            write_pos - window_size
235        } else {
236            buf_len - (window_size - write_pos)
237        };
238        for i in 0..window_size {
239            window[i] = buffer[(start + i) % buf_len];
240        }
241
242        conv1d_raw(window, filters, bias, out, window_size, out_channels, kernel_size);
243    }
244
245    // -- Phase 7: 2D Spatial Kernels ------------------------------------------
246
247    /// 2D convolution -- NCHW layout, valid mode (no padding), configurable stride.
248    ///
249    /// # Layout
250    /// - `input`:   `[N, C_in, H_in, W_in]`  row-major contiguous
251    /// - `filters`: `[C_out, C_in, kH, kW]`  row-major contiguous
252    /// - `bias`:    `[C_out]`
253    /// - `out`:     `[N, C_out, H_out, W_out]`  pre-allocated by caller
254    ///
255    /// where `H_out = (H_in - kH) / stride + 1` and `W_out = (W_in - kW) / stride + 1`.
256    ///
257    /// # Numerical contract
258    /// Every kernel-to-patch dot product uses `BinnedAccumulatorF64`, guaranteeing
259    /// bit-identical results regardless of stride, batch size, or channel count.
260    ///
261    /// # NoGC guarantee
262    /// All index arithmetic uses `u64` before narrowing to `usize`, preventing
263    /// overflow for high-resolution inputs (e.g., 8192×8192). The output buffer
264    /// is caller-allocated; this function performs zero heap allocations.
265    #[allow(clippy::too_many_arguments)]
266    pub fn conv2d_raw(
267        input:   &[f64],
268        filters: &[f64],
269        bias:    &[f64],
270        out:     &mut [f64],
271        n: usize, c_in: usize, h_in: usize, w_in: usize,
272        c_out: usize, kh: usize, kw: usize,
273        stride: usize,
274    ) {
275        use crate::accumulator::BinnedAccumulatorF64;
276
277        let h_out: u64 = ((h_in - kh) / stride + 1) as u64;
278        let w_out: u64 = ((w_in - kw) / stride + 1) as u64;
279
280        // Strides in the input tensor (NCHW, row-major).
281        let s_n:   u64 = (c_in  * h_in * w_in) as u64;
282        let s_cin: u64 = (h_in  * w_in) as u64;
283        let s_hin: u64 = w_in as u64;
284
285        // Strides in the filter tensor [C_out, C_in, kH, kW].
286        let f_cout: u64 = (c_in * kh * kw) as u64;
287        let f_cin:  u64 = (kh * kw) as u64;
288        let f_kh:   u64 = kw as u64;
289
290        // Strides in the output tensor (NCHW).
291        let o_n:    u64 = c_out as u64 * h_out * w_out;
292        let o_cout: u64 = h_out * w_out;
293
294        debug_assert_eq!(input.len(),   n * c_in  * h_in * w_in);
295        debug_assert_eq!(filters.len(), c_out * c_in * kh * kw);
296        debug_assert_eq!(bias.len(),    c_out);
297        debug_assert_eq!(out.len(),     n * c_out * h_out as usize * w_out as usize);
298
299        for bn in 0..n as u64 {
300            for co in 0..c_out as u64 {
301                for oh in 0..h_out {
302                    for ow in 0..w_out {
303                        let mut acc = BinnedAccumulatorF64::new();
304
305                        // Sum over input channels and kernel spatial extent.
306                        for ci in 0..c_in as u64 {
307                            for ki in 0..kh as u64 {
308                                for kj in 0..kw as u64 {
309                                    let ih: u64 = oh * stride as u64 + ki;
310                                    let iw: u64 = ow * stride as u64 + kj;
311
312                                    let inp_idx = (bn  * s_n
313                                                 + ci  * s_cin
314                                                 + ih  * s_hin
315                                                 + iw) as usize;
316                                    let flt_idx = (co  * f_cout
317                                                 + ci  * f_cin
318                                                 + ki  * f_kh
319                                                 + kj) as usize;
320
321                                    acc.add(input[inp_idx] * filters[flt_idx]);
322                                }
323                            }
324                        }
325
326                        let out_idx = (bn * o_n
327                                     + co * o_cout
328                                     + oh * w_out
329                                     + ow) as usize;
330                        out[out_idx] = acc.finalize() + bias[co as usize];
331                    }
332                }
333            }
334        }
335    }
336
337    /// 2D convolution using dispatched summation strategy.
338    ///
339    /// Identical to `conv2d_raw` but selects Kahan or Binned based on the
340    /// reduction context.  Useful when callers want runtime-configurable
341    /// accumulation precision.
342    #[allow(clippy::too_many_arguments)]
343    pub fn conv2d_dispatched(
344        input:   &[f64],
345        filters: &[f64],
346        bias:    &[f64],
347        out:     &mut [f64],
348        n: usize, c_in: usize, h_in: usize, w_in: usize,
349        c_out: usize, kh: usize, kw: usize,
350        stride: usize,
351        ctx: &crate::dispatch::ReductionContext,
352    ) {
353        let h_out = (h_in - kh) / stride + 1;
354        let w_out = (w_in - kw) / stride + 1;
355
356        let s_n   = c_in  * h_in * w_in;
357        let s_cin = h_in  * w_in;
358        let s_hin = w_in;
359
360        let f_cout = c_in * kh * kw;
361        let f_cin  = kh * kw;
362        let f_kh   = kw;
363
364        let o_n    = c_out * h_out * w_out;
365        let o_cout = h_out * w_out;
366
367        for bn in 0..n {
368            for co in 0..c_out {
369                for oh in 0..h_out {
370                    for ow in 0..w_out {
371                        let mut terms = Vec::with_capacity(c_in * kh * kw);
372                        for ci in 0..c_in {
373                            for ki in 0..kh {
374                                for kj in 0..kw {
375                                    let ih = oh * stride + ki;
376                                    let iw = ow * stride + kj;
377                                    let inp_idx = bn * s_n  + ci * s_cin + ih * s_hin + iw;
378                                    let flt_idx = co * f_cout + ci * f_cin + ki * f_kh + kj;
379                                    terms.push(input[inp_idx] * filters[flt_idx]);
380                                }
381                            }
382                        }
383                        let out_idx = bn * o_n + co * o_cout + oh * w_out + ow;
384                        out[out_idx] =
385                            crate::dispatch::dispatch_sum_f64(&terms, ctx) + bias[co];
386                    }
387                }
388            }
389        }
390    }
391
392    /// 2D max-pooling — NCHW layout, stride = pool_size (non-overlapping).
393    ///
394    /// - `input`: `[N, C, H_in, W_in]`
395    /// - `out`:   `[N, C, H_in/ph, W_in/pw]`  (floor division, pre-allocated)
396    ///
397    /// All index arithmetic uses `u64` to support large spatial extents.
398    pub fn maxpool2d_raw(
399        input:  &[f64],
400        out:    &mut [f64],
401        n: usize, c: usize, h_in: usize, w_in: usize,
402        ph: usize, pw: usize,
403    ) {
404        let h_out: u64 = (h_in / ph) as u64;
405        let w_out: u64 = (w_in / pw) as u64;
406
407        let s_n:   u64 = (c * h_in * w_in) as u64;
408        let s_c:   u64 = (h_in * w_in) as u64;
409        let s_hin: u64 = w_in as u64;
410
411        let o_n:   u64 = (c as u64) * h_out * w_out;
412        let o_c:   u64 = h_out * w_out;
413
414        debug_assert_eq!(input.len(), n * c * h_in * w_in);
415        debug_assert_eq!(out.len(),   n * c * h_out as usize * w_out as usize);
416
417        for bn in 0..n as u64 {
418            for ch in 0..c as u64 {
419                for oh in 0..h_out {
420                    for ow in 0..w_out {
421                        let mut max_val = f64::NEG_INFINITY;
422                        for pi in 0..ph as u64 {
423                            for pj in 0..pw as u64 {
424                                let ih: u64 = oh * ph as u64 + pi;
425                                let iw: u64 = ow * pw as u64 + pj;
426                                let idx = (bn * s_n + ch * s_c + ih * s_hin + iw) as usize;
427                                let v = input[idx];
428                                if v > max_val { max_val = v; }
429                            }
430                        }
431                        let o_idx = (bn * o_n + ch * o_c + oh * w_out + ow) as usize;
432                        out[o_idx] = max_val;
433                    }
434                }
435            }
436        }
437    }
438
439    /// 1D max-pooling with stride equal to `pool_size` (non-overlapping windows).
440    ///
441    /// # Arguments
442    ///
443    /// * `data` - Input signal of length `data_len`.
444    /// * `out` - Pre-allocated output buffer of length `data_len / pool_size`.
445    /// * `data_len` - Length of the input signal.
446    /// * `pool_size` - Window (and stride) size for pooling.
447    pub fn maxpool1d_raw(data: &[f64], out: &mut [f64], data_len: usize, pool_size: usize) {
448        debug_assert_eq!(data.len(), data_len);
449        let out_len = data_len / pool_size;
450        debug_assert_eq!(out.len(), out_len);
451        for i in 0..out_len {
452            let start = i * pool_size;
453            let mut max_val = data[start];
454            for j in 1..pool_size {
455                let v = data[start + j];
456                if v > max_val { max_val = v; }
457            }
458            out[i] = max_val;
459        }
460    }
461
462    // -- Dispatched kernel variants (Milestone 2.7) ---------------------------
463
464    /// Matrix multiply using dispatched summation strategy.
465    ///
466    /// Identical to `matmul_raw` but uses `dispatch_dot_f64` for each dot product,
467    /// selecting Kahan or Binned based on the reduction context.
468    #[inline]
469    pub fn matmul_dispatched(
470        a: &[f64], b: &[f64], c: &mut [f64],
471        m: usize, k: usize, n: usize,
472        ctx: &crate::dispatch::ReductionContext,
473    ) {
474        debug_assert_eq!(a.len(), m * k);
475        debug_assert_eq!(b.len(), k * n);
476        debug_assert_eq!(c.len(), m * n);
477        for i in 0..m {
478            for j in 0..n {
479                // Collect column from B for the dot product.
480                let a_row = &a[i * k..(i + 1) * k];
481                let b_col: Vec<f64> = (0..k).map(|p| b[p * n + j]).collect();
482                c[i * n + j] = crate::dispatch::dispatch_dot_f64(a_row, &b_col, ctx);
483            }
484        }
485    }
486
487    /// Linear projection using dispatched summation.
488    #[inline]
489    pub fn linear_dispatched(
490        x: &[f64], w: &[f64], bias: &[f64], out: &mut [f64],
491        outer: usize, in_f: usize, out_f: usize,
492        ctx: &crate::dispatch::ReductionContext,
493    ) {
494        debug_assert_eq!(x.len(), outer * in_f);
495        debug_assert_eq!(w.len(), out_f * in_f);
496        debug_assert_eq!(bias.len(), out_f);
497        debug_assert_eq!(out.len(), outer * out_f);
498        for row in 0..outer {
499            let x_start = row * in_f;
500            let x_slice = &x[x_start..x_start + in_f];
501            let y_start = row * out_f;
502            for j in 0..out_f {
503                let w_start = j * in_f;
504                let w_slice = &w[w_start..w_start + in_f];
505                out[y_start + j] = crate::dispatch::dispatch_dot_f64(x_slice, w_slice, ctx) + bias[j];
506            }
507        }
508    }
509
510    /// Layer normalization using dispatched summation for mean/variance.
511    #[inline]
512    pub fn layer_norm_dispatched(
513        data: &[f64], gamma: &[f64], beta: &[f64], out: &mut [f64],
514        outer: usize, n: usize, eps: f64,
515        ctx: &crate::dispatch::ReductionContext,
516    ) {
517        debug_assert_eq!(data.len(), outer * n);
518        debug_assert_eq!(gamma.len(), n);
519        debug_assert_eq!(beta.len(), n);
520        debug_assert_eq!(out.len(), outer * n);
521        for row in 0..outer {
522            let start = row * n;
523            let slice = &data[start..start + n];
524
525            let mean = crate::dispatch::dispatch_sum_f64(slice, ctx) / n as f64;
526
527            let diffs: Vec<f64> = slice.iter().map(|&x| (x - mean) * (x - mean)).collect();
528            let var = crate::dispatch::dispatch_sum_f64(&diffs, ctx) / n as f64;
529            let inv_std = 1.0 / (var + eps).sqrt();
530
531            for i in 0..n {
532                out[start + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
533            }
534        }
535    }
536
537    /// 1D convolution using dispatched summation.
538    pub fn conv1d_dispatched(
539        signal: &[f64], filters: &[f64], bias: &[f64], out: &mut [f64],
540        signal_len: usize, out_channels: usize, kernel_size: usize,
541        ctx: &crate::dispatch::ReductionContext,
542    ) {
543        debug_assert!(signal_len >= kernel_size);
544        let out_len = signal_len - kernel_size + 1;
545        for ch in 0..out_channels {
546            let filter_start = ch * kernel_size;
547            let filter_slice = &filters[filter_start..filter_start + kernel_size];
548            let out_row_start = ch * out_len;
549            for pos in 0..out_len {
550                let sig_slice = &signal[pos..pos + kernel_size];
551                out[out_row_start + pos] =
552                    crate::dispatch::dispatch_dot_f64(sig_slice, filter_slice, ctx) + bias[ch];
553            }
554        }
555    }
556}
557