Skip to main content

gam_terms/basis/
streaming_design.rs

1use super::*;
2
3/// Which radial kernel family is being used. Stored in the streaming operator
4/// so that (q, t) scalars can be recomputed on the fly without a closure.
5#[derive(Debug, Clone)]
6pub enum RadialScalarKind {
7    /// Matern kernel: (length_scale, nu).
8    Matern { length_scale: f64, nu: MaternNu },
9    /// Hybrid Duchon kernel: parameters needed for `duchon_radial_jets`.
10    Duchon {
11        length_scale: f64,
12        p_order: usize,
13        s_order: usize,
14        dim: usize,
15        coeffs: DuchonPartialFractionCoeffs,
16    },
17    /// Pure Duchon kernel: a single intrinsic polyharmonic block.
18    PureDuchon {
19        block_order: usize,
20        p_order: usize,
21        s_order: usize,
22        dim: usize,
23    },
24    /// Thin-Plate Spline kernel: isotropic with a scalar length-scale, used
25    /// only with `n_axes = 1` (`ScalarTotal` streaming mode). The chain rule
26    /// for ψ = log κ = -log(length_scale) and r̃ = ‖x − c‖ gives
27    /// ∂φ/∂ψ   = φ'(z)·z
28    /// ∂²φ/∂ψ² = φ''(z)·z² + φ'(z)·z
29    /// where z = r̃ / length_scale. With the operator's c=0, s_0 = r̃²,
30    /// `shape_0 = q·s_0` and `shape_00 = t·s_0² + 2 q s_0` reproduce both
31    /// derivatives exactly when q = φ'(r̃)/r̃ and t = (φ''(r̃) − q)/r̃².
32    ThinPlate { length_scale: f64, dim: usize },
33}
34
35impl RadialScalarKind {
36    /// Evaluate the `(phi, q, t)` radial scalars for a given distance `r`.
37    ///
38    /// `q = φ'(r)/r` and `t = (φ''(r) - q)/r²` (with the appropriate
39    /// finite limits at `r → 0`). This is exactly the scalar pair needed
40    /// to assemble the first and second derivatives of `Φ(t) = φ(‖t − c‖)`
41    /// with respect to the input location `t`:
42    ///
43    /// ```text
44    /// ∂Φ/∂t_a       = q · (t − c)_a
45    /// ∂²Φ/∂t_a∂t_b  = q · δ_ab + t · (t − c)_a (t − c)_b
46    /// ```
47    ///
48    /// Re-pointing the existing ψ-derivative machinery at the first kernel
49    /// argument t (see `crate::basis::input_loc_derivatives`).
50    ///
51    /// Returns `true` iff both `q = φ'(r)/r` and `t = (φ''(r) − q)/r²` have
52    /// finite limits as `r → 0+` for this kernel. When this returns `false`
53    /// the design-row gradient/Hessian at a center collision (`r = 0`) is not
54    /// defined by a single finite value; callers must either move off the
55    /// collision or surface a `BasisError::DegenerateAtCollision`.
56    ///
57    /// Smoothness criteria used here (matching the analytic limits derived
58    /// in this file and the comments on `eval_design_triplet`):
59    ///   - Matérn ν = 1/2: `q = -s·E/r → -∞`, not smooth.
60    ///   - Matérn ν = 3/2: `q` finite but `t = s³E/r → ∞`, not smooth.
61    ///   - Matérn ν = 5/2, 7/2, 9/2: both finite, smooth.
62    ///   - Duchon hybrid (`Duchon`): finite via the hybrid PFD identity;
63    ///     the radial-jets routine produces a finite limit, so smooth.
64    ///   - PureDuchon (raw polyharmonic block, exponent α = 2m − d):
65    ///       non-log case and α ≥ 4 ⇒ both `q` and `t` vanish (smooth);
66    ///       log case at any α, or α < 4 ⇒ at least one derivative diverges.
67    ///   - ThinPlate dim = 1: φ = r³, `q = 3r → 0`, but `t = 3/r → ∞`. The
68    ///     1-D Hessian formula `q·δ + t·s·s` at r = 0 has the only diagonal
69    ///     entry contracted by `s_a = 0`, but the bare scalar limit is still
70    ///     not finite, so we report it as non-smooth and let callers in 1-D
71    ///     (where `s_a` literally vanishes) opt in by handling the error.
72    ///     Dim 2 (log r), Dim 3 (-r) both diverge.
73    #[inline]
74    pub(crate) fn is_smooth_at_collision(&self) -> bool {
75        match self {
76            RadialScalarKind::Matern { nu, .. } => matches!(
77                nu,
78                MaternNu::FiveHalves | MaternNu::SevenHalves | MaternNu::NineHalves
79            ),
80            RadialScalarKind::Duchon { .. } => true,
81            RadialScalarKind::PureDuchon {
82                p_order,
83                s_order,
84                dim,
85                ..
86            } => {
87                let alpha = duchon_scaling_exponent(*p_order, *s_order, *dim);
88                let is_log = (*dim) % 2 == 0 && {
89                    let half = (alpha / 2.0).round();
90                    half >= 0.0 && (half * 2.0 - alpha).abs() < 1e-12
91                };
92                !is_log && alpha >= 4.0
93            }
94            RadialScalarKind::ThinPlate { .. } => false,
95        }
96    }
97
98    pub fn eval_design_triplet(&self, r: f64) -> Result<(f64, f64, f64), BasisError> {
99        match self {
100            RadialScalarKind::Matern { length_scale, nu } => {
101                let (phi, q, t, _, _) =
102                    matern_aniso_extended_radial_scalars(r, *length_scale, *nu)?;
103                Ok((phi, q, t))
104            }
105            RadialScalarKind::Duchon {
106                length_scale,
107                p_order,
108                s_order,
109                dim,
110                coeffs,
111            } => {
112                let jets = duchon_radial_jets(r, *length_scale, *p_order, *s_order, *dim, coeffs)?;
113                Ok((jets.phi, jets.q, jets.t))
114            }
115            RadialScalarKind::PureDuchon {
116                block_order, dim, ..
117            } => {
118                let phi = polyharmonic_kernel(r, (*block_order) as f64, *dim);
119                if r < 1e-14 {
120                    // Collision: q = φ'/r and t = (φ'' − q)/r² generally
121                    // diverge here. Only the non-log, α = 2m − d ≥ 4 case
122                    // gives finite limits (both 0). Otherwise the design
123                    // gradient/Hessian at r = 0 is undefined: surface a
124                    // `DegenerateAtCollision` so callers can detect it.
125                    if !self.is_smooth_at_collision() {
126                        return Err(BasisError::DegenerateAtCollision {
127                            kernel: "PureDuchon (polyharmonic)",
128                            dim: *dim,
129                            m: *block_order as f64,
130                            message: "raw polyharmonic block φ(r) = c r^α (log r) is \
131                                      not C² at r = 0 for α = 2m − d < 4 or for log \
132                                      cases; first/second radial derivatives diverge",
133                        });
134                    }
135                    return Ok((phi, 0.0, 0.0));
136                }
137                let (q, t, _, _) =
138                    duchon_polyharmonic_operator_block_jets(r, *block_order as f64, *dim)?;
139                Ok((phi, q, t))
140            }
141            RadialScalarKind::ThinPlate { length_scale, dim } => {
142                // (q, t) individually diverge at r = 0 for ThinPlate
143                // (q = 2 log r + 1 → −∞ in dim 2, q = −1/r → −∞ in dim 3,
144                // t = 3/r → ∞ in dim 1, …) but the chain-rule coefficient
145                // `c = raw_psi_isotropic_share` is 0 for ThinPlate, so every
146                // consumer multiplies q by a squared displacement s_a and t
147                // by s_a · s_b before use (design row uses φ alone, and
148                // φ(0) = 0). The products
149                //   q · s_a = (φ'(r) · r) · (s_a / r²),
150                //   t · s_a · s_b = (φ''(r) · r² − φ'(r) · r) · (s_a/r²)·(s_b/r²)
151                // both vanish as r → 0+, since r · φ'(r) → 0 and r² · φ''(r) → 0
152                // for every standard TPS kernel (φ = r³ in dim 1, r² log r in
153                // dim 2, −r in dim 3, and the general polyharmonic case for
154                // d ≥ 4) and the ratios s_a/r², s_b/r² are bounded. The
155                // closed-form ψ-derivative limit at the collision is
156                // therefore (0, 0, 0).
157                if r < 1e-14 {
158                    return Ok((0.0, 0.0, 0.0));
159                }
160                let scaled_r = r / *length_scale;
161                let (phi, phi_kernel_first, phi_kernel_second) =
162                    thin_plate_kernel_triplet_from_scaled_distance(scaled_r, *dim)?;
163                // The implicit operator uses derivatives w.r.t. the unscaled r
164                // (the operator's chain rule will rescale them to ψ-derivatives
165                // via s_0 = r²). Convert φ'(z), φ''(z) → φ'(r), φ''(r) by the
166                // length-scale chain rule:
167                //   φ'(r)  = φ'(z) / length_scale
168                //   φ''(r) = φ''(z) / length_scale²
169                let phi_r = phi_kernel_first / *length_scale;
170                let phi_rr = phi_kernel_second / (*length_scale * *length_scale);
171                let q = phi_r / r;
172                let t = (phi_rr - q) / (r * r);
173                Ok((phi, q, t))
174            }
175        }
176    }
177
178    #[inline]
179    pub(crate) fn raw_psi_isotropic_share(&self) -> f64 {
180        match self {
181            RadialScalarKind::Matern { .. } => 0.0,
182            RadialScalarKind::Duchon {
183                p_order,
184                s_order,
185                dim,
186                ..
187            } => duchon_scaling_exponent(*p_order, *s_order, *dim) / *dim as f64,
188            RadialScalarKind::PureDuchon {
189                p_order,
190                s_order,
191                dim,
192                ..
193            } => duchon_scaling_exponent(*p_order, *s_order, *dim) / *dim as f64,
194            // ThinPlate is a pure radial kernel φ(z) with no κ^δ prefactor;
195            // the chain rule has no isotropic share term.
196            RadialScalarKind::ThinPlate { .. } => 0.0,
197        }
198    }
199
200    #[inline]
201    pub(crate) fn is_duchon_family(&self) -> bool {
202        matches!(
203            self,
204            RadialScalarKind::Duchon { .. } | RadialScalarKind::PureDuchon { .. }
205        )
206    }
207
208    /// Whether the radial-kind enforces a hard guard against accidental
209    /// dense `(n × p)` ψ-derivative materialization. Duchon-family terms
210    /// always do (they are streaming-only at any scale). ThinPlate joins
211    /// the guard list because the new scalar-streaming routing makes it
212    /// genuine to rely on the implicit operator at large scale, and a
213    /// downstream consumer that sneaks in a `materialize_dense()` call
214    /// would silently re-introduce the same `n × p` allocation we wired
215    /// streaming to avoid. The guard panics only when the resource
216    /// policy says the materialization would exceed budget — small `n`
217    /// problems still get the dense fast path.
218    #[inline]
219    pub(crate) fn enforces_dense_materialization_budget(&self) -> bool {
220        matches!(
221            self,
222            RadialScalarKind::Duchon { .. }
223                | RadialScalarKind::PureDuchon { .. }
224                | RadialScalarKind::ThinPlate { .. }
225        )
226    }
227}
228
229/// Shared chunked-operator machinery for the streaming basis evaluators.
230///
231/// `StreamingMaternEvaluator` and `StreamingBSplineEvaluator` differ only in
232/// how a single row chunk of the design is materialized (`for_row_chunk`) and
233/// the chunk size policy (`chunk_rows`); every other operator method — the
234/// chunked matvec, transpose-matvec, weighted Gram and dense materialization —
235/// is identical boilerplate over that one primitive. This trait carries those
236/// shared methods as defaults keyed on `for_row_chunk`, so each evaluator
237/// implements only the per-basis pieces. The `NAME` const bakes the struct
238/// name into the panic/error strings so diagnostics stay per-evaluator.
239trait ChunkedDesign {
240    /// Struct name used in assertion / error messages.
241    const NAME: &'static str;
242
243    /// Number of design rows (observations).
244    fn op_nrows(&self) -> usize;
245
246    /// Number of design columns (basis functions after any transform).
247    fn op_ncols(&self) -> usize;
248
249    /// Row-block size used to bound the per-chunk working set.
250    fn chunk_rows(&self) -> usize;
251
252    /// Materialize the dense design rows `[start, end)` — the only genuinely
253    /// per-evaluator computation.
254    fn for_row_chunk(&self, start: usize, end: usize) -> Array2<f64>;
255
256    /// Chunked matvec `output = X · theta`.
257    fn chunked_gradient_into(&self, theta: ArrayView1<'_, f64>, output: &mut Array1<f64>) {
258        assert_eq!(
259            theta.len(),
260            self.op_ncols(),
261            "{} theta width mismatch",
262            Self::NAME
263        );
264        assert_eq!(
265            output.len(),
266            self.op_nrows(),
267            "{} output length mismatch",
268            Self::NAME
269        );
270        output.fill(0.0);
271        let nrows = self.op_nrows();
272        for start in (0..nrows).step_by(self.chunk_rows()) {
273            let end = (start + self.chunk_rows()).min(nrows);
274            let chunk = self.for_row_chunk(start, end);
275            let values = chunk.dot(&theta);
276            output.slice_mut(s![start..end]).assign(&values);
277        }
278    }
279
280    /// Chunked matvec returning a fresh vector (`LinearOperator::apply`).
281    fn chunked_apply(&self, vector: &Array1<f64>) -> Array1<f64> {
282        let mut out = Array1::<f64>::zeros(self.op_nrows());
283        self.chunked_gradient_into(vector.view(), &mut out);
284        out
285    }
286
287    /// Chunked transpose-matvec `out = Xᵀ · vector`
288    /// (`LinearOperator::apply_transpose`).
289    fn chunked_apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
290        assert_eq!(
291            vector.len(),
292            self.op_nrows(),
293            "{} transpose vector length mismatch",
294            Self::NAME
295        );
296        let nrows = self.op_nrows();
297        let mut out = Array1::<f64>::zeros(self.op_ncols());
298        for start in (0..nrows).step_by(self.chunk_rows()) {
299            let end = (start + self.chunk_rows()).min(nrows);
300            let chunk = self.for_row_chunk(start, end);
301            let partial = chunk.t().dot(&vector.slice(s![start..end]));
302            out += &partial;
303        }
304        out
305    }
306
307    /// Chunked weighted Gram `XᵀWX` (`LinearOperator::diag_xtw_x`).
308    fn chunked_diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String>
309    where
310        Self: Sync,
311    {
312        let nrows = self.op_nrows();
313        if weights.len() != nrows {
314            return Err(format!(
315                "{} diag_xtw_x weight length mismatch: weights={}, nrows={}",
316                Self::NAME,
317                weights.len(),
318                nrows
319            ));
320        }
321        let p = self.op_ncols();
322        let chunk_rows = self.chunk_rows();
323        let starts = (0..nrows).step_by(chunk_rows).collect::<Vec<_>>();
324        Ok(starts
325            .into_par_iter()
326            .fold(
327                || Array2::<f64>::zeros((p, p)),
328                |mut acc, start| {
329                    let end = (start + chunk_rows).min(nrows);
330                    let chunk = self.for_row_chunk(start, end);
331                    let mut weighted = chunk.clone();
332                    for local in 0..(end - start) {
333                        let w = weights[start + local];
334                        weighted.row_mut(local).mapv_inplace(|v| v * w);
335                    }
336                    acc += &chunk.t().dot(&weighted);
337                    acc
338                },
339            )
340            .reduce(
341                || Array2::<f64>::zeros((p, p)),
342                |mut a, b| {
343                    a += &b;
344                    a
345                },
346            ))
347    }
348
349    /// Chunked dense row fill (`DenseDesignOperator::row_chunk_into`).
350    fn chunked_row_chunk_into(
351        &self,
352        rows: Range<usize>,
353        mut out: ArrayViewMut2<'_, f64>,
354    ) -> Result<(), MatrixMaterializationError> {
355        if rows.end > self.op_nrows() || rows.start > rows.end {
356            return Err(MatrixMaterializationError::MissingRowChunk {
357                context: Self::ROW_RANGE_OOB,
358            });
359        }
360        if out.nrows() != rows.end - rows.start || out.ncols() != self.op_ncols() {
361            return Err(MatrixMaterializationError::MissingRowChunk {
362                context: Self::ROW_CHUNK_SHAPE_MISMATCH,
363            });
364        }
365        out.assign(&self.for_row_chunk(rows.start, rows.end));
366        Ok(())
367    }
368
369    /// Full dense materialization (`DenseDesignOperator::to_dense`).
370    fn chunked_to_dense(&self) -> Array2<f64> {
371        self.for_row_chunk(0, self.op_nrows())
372    }
373
374    /// Static `&str` context strings for the row-chunk errors — kept as
375    /// associated consts because `MatrixMaterializationError::MissingRowChunk`
376    /// stores `&'static str`, so a runtime-formatted name cannot be used.
377    const ROW_RANGE_OOB: &'static str;
378    const ROW_CHUNK_SHAPE_MISMATCH: &'static str;
379}
380
381#[derive(Debug, Clone)]
382pub(crate) struct StreamingMaternEvaluator {
383    pub(crate) data: Arc<Array2<f64>>,
384    pub(crate) centers: Arc<Array2<f64>>,
385    pub(crate) length_scale: f64,
386    pub(crate) nu: MaternNu,
387    pub(crate) metric_weights: Arc<[f64]>,
388    pub(crate) ident_transform: Option<Arc<Array2<f64>>>,
389    pub(crate) include_intercept: bool,
390    pub(crate) chunk_size: usize,
391    pub(crate) total_cols: usize,
392}
393
394impl StreamingMaternEvaluator {
395    pub(crate) fn new(
396        data: Arc<Array2<f64>>,
397        centers: Arc<Array2<f64>>,
398        length_scale: f64,
399        nu: MaternNu,
400        aniso_log_scales: Option<Vec<f64>>,
401        ident_transform: Option<Arc<Array2<f64>>>,
402        include_intercept: bool,
403        chunk_size: Option<usize>,
404    ) -> Result<Self, String> {
405        if data.ncols() != centers.ncols() {
406            return Err(format!(
407                "StreamingMaternEvaluator: data dim {} != centers dim {}",
408                data.ncols(),
409                centers.ncols()
410            ));
411        }
412        let metric_weights = match aniso_log_scales {
413            Some(eta) => {
414                if eta.len() != data.ncols() {
415                    return Err(format!(
416                        "StreamingMaternEvaluator: aniso_log_scales len {} != data dim {}",
417                        eta.len(),
418                        data.ncols()
419                    ));
420                }
421                eta.into_iter().map(|v| (2.0 * v).exp()).collect::<Vec<_>>()
422            }
423            None => vec![1.0; data.ncols()],
424        };
425        if let Some(z) = ident_transform.as_ref()
426            && z.nrows() != centers.nrows()
427        {
428            return Err(format!(
429                "StreamingMaternEvaluator: identifiability transform rows {} != centers {}",
430                z.nrows(),
431                centers.nrows()
432            ));
433        }
434        let kernel_cols = ident_transform
435            .as_ref()
436            .map_or(centers.nrows(), |z| z.ncols());
437        Ok(Self {
438            data: Arc::new(data.as_standard_layout().to_owned()),
439            centers: Arc::new(centers.as_standard_layout().to_owned()),
440            length_scale,
441            nu,
442            metric_weights: Arc::from(metric_weights),
443            ident_transform,
444            include_intercept,
445            chunk_size: chunk_size.unwrap_or(DEFAULT_STREAMING_CHUNK_ROWS).max(1),
446            total_cols: kernel_cols + usize::from(include_intercept),
447        })
448    }
449
450    pub(crate) fn raw_kernel_chunk(&self, rows: Range<usize>) -> Array2<f64> {
451        let chunk_n = rows.end - rows.start;
452        let k_raw = self.centers.nrows();
453        let dim = self.data.ncols();
454        let data = self
455            .data
456            .as_slice()
457            .expect("StreamingMaternEvaluator stores standard-layout data");
458        let centers = self
459            .centers
460            .as_slice()
461            .expect("StreamingMaternEvaluator stores standard-layout centers");
462        let mut values = vec![0.0_f64; chunk_n * k_raw];
463        values
464            .par_chunks_mut(k_raw)
465            .enumerate()
466            .for_each(|(local, out_row)| {
467                let global = rows.start + local;
468                let x = &data[global * dim..(global + 1) * dim];
469                for j in 0..k_raw {
470                    let c = &centers[j * dim..(j + 1) * dim];
471                    let mut r2 = 0.0_f64;
472                    for axis in 0..dim {
473                        let h = x[axis] - c[axis];
474                        r2 += self.metric_weights[axis] * h * h;
475                    }
476                    out_row[j] = matern_kernel_from_distance(r2.sqrt(), self.length_scale, self.nu)
477                        .expect("validated Matérn inputs should not fail");
478                }
479            });
480        Array2::from_shape_vec((chunk_n, k_raw), values)
481            .expect("StreamingMaternEvaluator chunk shape should match generated values")
482    }
483
484    pub(crate) fn for_row_chunk_impl(&self, start: usize, end: usize) -> Array2<f64> {
485        let raw = self.raw_kernel_chunk(start..end);
486        let kernel = match self.ident_transform.as_ref() {
487            Some(z) => fast_ab(&raw, z),
488            None => raw,
489        };
490        if !self.include_intercept {
491            return kernel;
492        }
493        let mut out = Array2::<f64>::ones((end - start, kernel.ncols() + 1));
494        out.slice_mut(s![.., ..kernel.ncols()]).assign(&kernel);
495        out
496    }
497}
498
499impl ChunkedDesign for StreamingMaternEvaluator {
500    const NAME: &'static str = "StreamingMaternEvaluator";
501    const ROW_RANGE_OOB: &'static str = "StreamingMaternEvaluator row range out of bounds";
502    const ROW_CHUNK_SHAPE_MISMATCH: &'static str =
503        "StreamingMaternEvaluator row_chunk_into shape mismatch";
504
505    fn op_nrows(&self) -> usize {
506        self.data.nrows()
507    }
508
509    fn op_ncols(&self) -> usize {
510        self.total_cols
511    }
512
513    fn chunk_rows(&self) -> usize {
514        self.chunk_size.min(self.data.nrows().max(1))
515    }
516
517    fn for_row_chunk(&self, start: usize, end: usize) -> Array2<f64> {
518        assert!(
519            start <= end && end <= self.data.nrows(),
520            "StreamingMaternEvaluator row chunk out of bounds"
521        );
522        self.for_row_chunk_impl(start, end)
523    }
524}
525
526impl LinearOperator for StreamingMaternEvaluator {
527    fn nrows(&self) -> usize {
528        self.op_nrows()
529    }
530
531    fn ncols(&self) -> usize {
532        self.op_ncols()
533    }
534
535    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
536        self.chunked_apply(vector)
537    }
538
539    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
540        self.chunked_apply_transpose(vector)
541    }
542
543    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
544        self.chunked_diag_xtw_x(weights)
545    }
546}
547
548impl DenseDesignOperator for StreamingMaternEvaluator {
549    fn row_chunk_into(
550        &self,
551        rows: Range<usize>,
552        out: ArrayViewMut2<'_, f64>,
553    ) -> Result<(), MatrixMaterializationError> {
554        self.chunked_row_chunk_into(rows, out)
555    }
556
557    fn to_dense(&self) -> Array2<f64> {
558        self.chunked_to_dense()
559    }
560}
561
562#[derive(Debug, Clone)]
563pub(crate) struct StreamingBSplineEvaluator {
564    pub(crate) data: Arc<Array1<f64>>,
565    pub(crate) knots: Arc<Array1<f64>>,
566    pub(crate) degree: usize,
567    pub(crate) periodic: Option<(f64, f64, usize)>,
568    pub(crate) transform: Option<Arc<Array2<f64>>>,
569    pub(crate) chunk_size: usize,
570    pub(crate) total_cols: usize,
571}
572
573impl StreamingBSplineEvaluator {
574    pub(crate) fn new(
575        data: Arc<Array1<f64>>,
576        knots: Arc<Array1<f64>>,
577        degree: usize,
578        periodic: Option<(f64, f64, usize)>,
579        transform: Option<Arc<Array2<f64>>>,
580        chunk_size: Option<usize>,
581    ) -> Result<Self, String> {
582        let raw_cols = bspline_raw_column_count(knots.as_ref(), degree, periodic)?;
583        if let Some(z) = transform.as_ref()
584            && z.nrows() != raw_cols
585        {
586            return Err(format!(
587                "StreamingBSplineEvaluator: transform rows {} != raw basis columns {}",
588                z.nrows(),
589                raw_cols
590            ));
591        }
592        Ok(Self {
593            data: Arc::new(data.as_standard_layout().to_owned()),
594            knots: Arc::new(knots.as_standard_layout().to_owned()),
595            degree,
596            periodic,
597            total_cols: transform.as_ref().map_or(raw_cols, |z| z.ncols()),
598            transform,
599            chunk_size: chunk_size.unwrap_or(DEFAULT_STREAMING_CHUNK_ROWS).max(1),
600        })
601    }
602
603    pub(crate) fn raw_chunk(&self, start: usize, end: usize) -> Array2<f64> {
604        bspline_raw_row_chunk(
605            self.data.view(),
606            self.knots.view(),
607            self.degree,
608            self.periodic,
609            start,
610            end,
611        )
612        .expect("StreamingBSplineEvaluator validated inputs should build row chunks")
613    }
614
615    pub(crate) fn for_row_chunk_impl(&self, start: usize, end: usize) -> Array2<f64> {
616        let raw = self.raw_chunk(start, end);
617        match self.transform.as_ref() {
618            Some(z) => fast_ab(&raw, z),
619            None => raw,
620        }
621    }
622}
623
624impl ChunkedDesign for StreamingBSplineEvaluator {
625    const NAME: &'static str = "StreamingBSplineEvaluator";
626    const ROW_RANGE_OOB: &'static str = "StreamingBSplineEvaluator row range out of bounds";
627    const ROW_CHUNK_SHAPE_MISMATCH: &'static str =
628        "StreamingBSplineEvaluator row_chunk_into shape mismatch";
629
630    fn op_nrows(&self) -> usize {
631        self.data.len()
632    }
633
634    fn op_ncols(&self) -> usize {
635        self.total_cols
636    }
637
638    fn chunk_rows(&self) -> usize {
639        self.chunk_size.min(self.data.len().max(1))
640    }
641
642    fn for_row_chunk(&self, start: usize, end: usize) -> Array2<f64> {
643        assert!(
644            start <= end && end <= self.data.len(),
645            "StreamingBSplineEvaluator row chunk out of bounds"
646        );
647        self.for_row_chunk_impl(start, end)
648    }
649}
650
651impl LinearOperator for StreamingBSplineEvaluator {
652    fn nrows(&self) -> usize {
653        self.op_nrows()
654    }
655
656    fn ncols(&self) -> usize {
657        self.op_ncols()
658    }
659
660    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
661        self.chunked_apply(vector)
662    }
663
664    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
665        self.chunked_apply_transpose(vector)
666    }
667
668    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
669        self.chunked_diag_xtw_x(weights)
670    }
671}
672
673impl DenseDesignOperator for StreamingBSplineEvaluator {
674    fn row_chunk_into(
675        &self,
676        rows: Range<usize>,
677        out: ArrayViewMut2<'_, f64>,
678    ) -> Result<(), MatrixMaterializationError> {
679        self.chunked_row_chunk_into(rows, out)
680    }
681
682    fn to_dense(&self) -> Array2<f64> {
683        self.chunked_to_dense()
684    }
685}
686
687/// Data stored for streaming (on-the-fly) recomputation of radial jet scalars.
688/// Instead of persisting O(n*k*(d+2)) arrays, the operator stores the original
689/// data/centers/eta and recomputes q/t/s per chunk during matvec operations.
690#[derive(Debug, Clone)]
691pub(crate) enum StreamingAxisMode {
692    /// Per-axis anisotropic ψ_a derivatives: expose one `s_a` component per axis.
693    PerAxis { metric_weights: Arc<[f64]> },
694    /// Scalar ψ derivative: expose a single component equal to the total
695    /// scaled squared radius r² = Σ_a exp(2η_a) h_a².
696    ScalarTotal { metric_weights: Arc<[f64]> },
697}
698
699#[derive(Debug, Clone)]
700pub(crate) struct StreamingRadialState {
701    /// Data matrix, shape (n, d).
702    pub(crate) data: Arc<Array2<f64>>,
703    /// Center matrix, shape (k, d).
704    pub(crate) centers: Arc<Array2<f64>>,
705    /// How per-pair axis components are exposed to the derivative operator.
706    pub(crate) axis_mode: StreamingAxisMode,
707    /// Which radial kernel family to use for recomputation.
708    pub(crate) radial_kind: RadialScalarKind,
709    /// Lazily materialized radial-scalar cache. (phi, q, t) per (i, j) pair
710    /// — independent of axis, identical across every per-axis chunk loop —
711    /// so collapses (axes × calls × chunks × n × n_knots) streaming radial
712    /// evaluations into a single O(n × n_knots) sweep per operator. The
713    /// inner `Option` is `None` when the parallel fill encountered a radial
714    /// evaluation error (e.g. a non-finite r); callers fall back to the
715    /// streaming path which propagates the error through `compute_pair`.
716    pub(crate) triplet_cache: Arc<std::sync::OnceLock<Option<StreamingTripletCache>>>,
717}
718
719#[derive(Debug)]
720pub(crate) struct StreamingTripletCache {
721    pub(crate) phi: Vec<f64>,
722    pub(crate) q: Vec<f64>,
723    pub(crate) t: Vec<f64>,
724}
725
726/// Memory cap (bytes) above which we keep streaming the radial scalars
727/// instead of materializing the (phi, q, t) triplet cache. Three `Vec<f64>`
728/// arrays of length `n × n_knots` consume `24 × n × n_knots` bytes; the cap
729/// keeps the resident footprint bounded for designs that would blow past a
730/// few hundred MiB.
731pub(crate) const STREAMING_TRIPLET_CACHE_BYTE_BUDGET: usize = 1 << 30;
732
733impl StreamingRadialState {
734    pub(crate) fn cache_fits_budget(&self) -> bool {
735        let total = self
736            .data
737            .nrows()
738            .saturating_mul(self.centers.nrows())
739            .saturating_mul(std::mem::size_of::<f64>())
740            .saturating_mul(3);
741        total <= STREAMING_TRIPLET_CACHE_BYTE_BUDGET
742    }
743
744    pub(crate) fn ensure_triplet_cache(&self) -> Option<&StreamingTripletCache> {
745        if !self.cache_fits_budget() {
746            return None;
747        }
748        let n = self.data.nrows();
749        let n_knots = self.centers.nrows();
750        if n == 0 || n_knots == 0 {
751            return None;
752        }
753        // The OnceLock holds `Option<StreamingTripletCache>` so a fill that
754        // hits an invalid `eval_design_triplet` (e.g. a non-finite r) does
755        // not poison the cache silently — consumers see `None` and fall back
756        // to the streaming `compute_pair` path that propagates the error
757        // through `Result<…, BasisError>`.
758        self.triplet_cache
759            .get_or_init(|| self.materialize_triplet_cache())
760            .as_ref()
761    }
762
763    pub(crate) fn materialize_triplet_cache(&self) -> Option<StreamingTripletCache> {
764        let n = self.data.nrows();
765        let n_knots = self.centers.nrows();
766        let total = n * n_knots;
767        let mut phi = vec![0.0_f64; total];
768        let mut q = vec![0.0_f64; total];
769        let mut t = vec![0.0_f64; total];
770
771        let metric_weights: &[f64] = match &self.axis_mode {
772            StreamingAxisMode::PerAxis { metric_weights }
773            | StreamingAxisMode::ScalarTotal { metric_weights } => metric_weights,
774        };
775        let dim = metric_weights.len();
776        assert_eq!(dim, self.data.ncols());
777        assert_eq!(dim, self.centers.ncols());
778
779        // SERIAL fill: `ensure_triplet_cache` is called from inside outer
780        // `into_par_iter` workers (e.g. the per-axis cross-trace sweep at
781        // `projected_operator_terms_batched`). A nested `par_chunks_mut`
782        // inside this `OnceLock::get_or_init` closure would deadlock the
783        // global rayon pool — every outer worker blocks on the OnceLock
784        // while the one that won the race tries to schedule child tasks no
785        // worker is free to pick up (see `feedback_oncelock_rayon_deadlock`).
786        //
787        // The serial sweep is only affordable when the per-pair radial
788        // evaluation is cheap. For the 16-D power-9 hybrid Duchon kernel a
789        // single exact `eval_design_triplet` costs tens of microseconds
790        // across its partial-fraction blocks, and at the large-scale
791        // conditional-PGS shape (n·k ≈ 480k pairs) this loop was ~15–20 s
792        // of single-threaded work per κ-trial — the dominant cost of the
793        // whole CTN stage-1 fit (#979; the cost model in the previous
794        // version of this comment assumed a cheap kernel). For large sweeps
795        // we therefore build a certified 1-D Chebyshev radial profile once
796        // (a few hundred exact evaluations, see `radial_profile`) from a
797        // distance-only pre-pass over the radius range, and answer per-pair
798        // queries with a Clenshaw contraction; out-of-range or uncertified
799        // cases fall back to the exact evaluator per pair.
800        let pair_radius = |i: usize, j: usize| -> f64 {
801            let mut r2 = 0.0_f64;
802            for a in 0..dim {
803                // Streaming constructors set n=data.nrows(), n_knots=centers.nrows(),
804                // and require dim=data.ncols()=centers.ncols(); the loop ranges
805                // therefore keep both uget reads in-bounds.
806                let h = unsafe { self.data.uget((i, a)) - self.centers.uget((j, a)) }; // SAFETY: bounds per the comment immediately above
807                r2 += metric_weights[a] * h * h;
808            }
809            r2.sqrt()
810        };
811        let profile = if total >= RADIAL_PROFILE_MIN_PAIRS {
812            let mut r_lo = f64::INFINITY;
813            let mut r_hi = 0.0_f64;
814            for i in 0..n {
815                for j in 0..n_knots {
816                    let r = pair_radius(i, j);
817                    if r > 0.0 {
818                        r_lo = r_lo.min(r);
819                        r_hi = r_hi.max(r);
820                    }
821                }
822            }
823            if r_lo.is_finite() && r_hi > r_lo {
824                radial_profile::RadialProfile::build(&self.radial_kind, r_lo, r_hi)
825            } else {
826                None
827            }
828        } else {
829            None
830        };
831        for i in 0..n {
832            let row_off = i * n_knots;
833            for j in 0..n_knots {
834                let r = pair_radius(i, j);
835                let triplet = match profile.as_ref() {
836                    Some(profile) => profile.eval_or_exact(&self.radial_kind, r),
837                    None => self.radial_kind.eval_design_triplet(r),
838                };
839                match triplet {
840                    Ok((pv, qv, tv)) => {
841                        phi[row_off + j] = pv;
842                        q[row_off + j] = qv;
843                        t[row_off + j] = tv;
844                    }
845                    Err(_) => return None,
846                }
847            }
848        }
849        Some(StreamingTripletCache { phi, q, t })
850    }
851
852    #[inline]
853    pub(crate) fn fill_s_buf(&self, i: usize, j: usize, s_buf: &mut [f64]) {
854        match &self.axis_mode {
855            StreamingAxisMode::PerAxis { metric_weights } => {
856                let dim = metric_weights.len();
857                assert_eq!(s_buf.len(), dim);
858                for a in 0..dim {
859                    // SAFETY: compute_pair/ensure_triplet_cache callers pass i <
860                    // data.nrows() and j < centers.nrows(); streaming constructors
861                    // require dim=data.ncols()=centers.ncols(), and this loop has a < dim.
862                    let h = unsafe { self.data.uget((i, a)) - self.centers.uget((j, a)) };
863                    s_buf[a] = metric_weights[a] * h * h;
864                }
865            }
866            StreamingAxisMode::ScalarTotal { metric_weights } => {
867                assert_eq!(s_buf.len(), 1);
868                let dim = metric_weights.len();
869                let mut r2 = 0.0;
870                for a in 0..dim {
871                    // SAFETY: compute_pair/ensure_triplet_cache callers pass i <
872                    // data.nrows() and j < centers.nrows(); streaming constructors
873                    // require dim=data.ncols()=centers.ncols(), and this loop has a < dim.
874                    let h = unsafe { self.data.uget((i, a)) - self.centers.uget((j, a)) };
875                    r2 += metric_weights[a] * h * h;
876                }
877                s_buf[0] = r2;
878            }
879        }
880    }
881
882    /// Compute `(phi, q, t, s_a[0..d])` for a single `(data_row i, center j)` pair.
883    ///
884    /// Returns `(phi, q, t)` and writes per-axis components into `s_buf` (length d).
885    #[inline]
886    pub(crate) fn compute_pair(
887        &self,
888        i: usize,
889        j: usize,
890        s_buf: &mut [f64],
891    ) -> Result<(f64, f64, f64), BasisError> {
892        assert!(i < self.data.nrows() && j < self.centers.nrows());
893        self.fill_s_buf(i, j, s_buf);
894        match &self.axis_mode {
895            StreamingAxisMode::PerAxis { metric_weights } => {
896                let r2: f64 = (0..metric_weights.len()).map(|a| s_buf[a]).sum();
897                self.radial_kind.eval_design_triplet(r2.sqrt())
898            }
899            StreamingAxisMode::ScalarTotal { .. } => {
900                let r2 = s_buf[0];
901                self.radial_kind.eval_design_triplet(r2.sqrt())
902            }
903        }
904    }
905}