Skip to main content

gam_terms/basis/
implicit_psi_derivative.rs

1use super::*;
2
3/// Implicit representation of ∂X/∂ψ_d that supports matrix-vector products
4/// without materializing the full (n x p) derivative matrices.
5///
6/// For anisotropic Matern / Duchon terms with D axes, the dense path creates
7/// D matrices of size (n x p_smooth) for dX/dpsi_d. At n=400K, p=2000, D=16,
8/// that is ~100 GB.
9///
10/// Two storage modes:
11///
12/// **Materialized** (small-to-medium problems): stores pre-computed arrays
13/// - `phi_values[i*n_knots + j]` = phi(r_{ij})
14/// - `q_values[i*n_knots + j]` = phi'(r_{ij}) / r_{ij}
15/// - `t_values[i*n_knots + j]` = (phi''(r_{ij}) - q_{ij}) / r_{ij}^2
16/// - `axis_components[i*n_knots + j, d]` = exp(2 eta_d) * (x_{id} - c_{jd})^2
17/// Memory: O(n * k * (D + 2)).
18///
19/// **Streaming** (large scale): stores only data/centers/eta/kernel params
20/// and recomputes (q, t, s_a) on the fly during each matvec.
21/// Memory: O(n*d + k*d) -- no per-(data,knot) storage.
22///
23/// The raw-psi chain rule:
24///   shape_a   = q * s_a
25///   shape_ab  = t * s_a * s_b + 2 q s_a 1[a=b]
26///   dphi/dpsi_a         = shape_a + c * phi
27///   d2phi/(dpsi_a dpsi_b) = shape_ab + c (shape_a + shape_b) + c^2 phi
28/// where `c = 0` for Matérn and `c = delta / d` for hybrid Duchon.
29#[derive(Debug, Clone)]
30pub struct ImplicitDesignPsiDerivative {
31    /// Pre-computed kernel values (materialized mode).
32    /// Shape: (n * n_knots,). Empty in streaming mode.
33    pub(crate) phi_values: Array1<f64>,
34
35    /// Pre-computed per (data, knot) pair axis components (materialized mode).
36    /// Shape: (n * n_knots, D) stored in row-major order.
37    /// Empty (0x0) in streaming mode.
38    pub(crate) axis_components: Array2<f64>,
39
40    /// Pre-computed R-operator first scalar (materialized mode).
41    /// Shape: (n * n_knots,). Empty in streaming mode.
42    pub(crate) q_values: Array1<f64>,
43
44    /// Pre-computed R-operator second scalar (materialized mode).
45    /// Shape: (n * n_knots,). Empty in streaming mode.
46    pub(crate) t_values: Array1<f64>,
47
48    /// When set, enables streaming recomputation of q/t/s from raw inputs
49    /// instead of reading from the pre-computed arrays above.
50    pub(crate) streaming: Option<StreamingRadialState>,
51
52    /// Identifiability/constraint transform Z: (n_knots x p_constrained).
53    /// Gauge ownership is upstream; the implicit operator stores this frozen
54    /// section only so forward/transpose matvecs can apply the already-gauged
55    /// chart without materializing derivative matrices. For Duchon this is the
56    /// kernel-constraint nullspace Z_kernel; for Matern with identifiability
57    /// constraints, it is the corresponding Z. `None` means the identity.
58    pub(crate) ident_transform: Option<Array2<f64>>,
59
60    /// Optional full identifiability transform applied after Z_kernel + padding.
61    /// This is likewise replay/application metadata for the matrix-free
62    /// operator, not a second coefficient-coordinate owner. For Duchon terms
63    /// that have an additional global identifiability transform, this is applied
64    /// after the kernel constraint and polynomial padding.
65    /// Shape: (p_constrained + n_poly, p_final).
66    pub(crate) full_ident_transform: Option<Array2<f64>>,
67
68    /// Number of data points.
69    pub(crate) n: usize,
70
71    /// Number of knots (raw basis functions before identifiability transform).
72    pub(crate) n_knots: usize,
73
74    /// Number of polynomial columns appended after the smooth part.
75    /// These have zero derivative with respect to psi_d.
76    pub(crate) n_poly: usize,
77
78    /// Number of axes (dimension D).
79    pub(crate) n_axes: usize,
80
81    /// Isotropic scaling contribution per raw anisotropic psi axis.
82    pub(crate) psi_scale_share: f64,
83
84    /// Optional exposed-axis to raw-axis linear combinations.
85    /// When present, axis `a` represents Σ_i coeff_i * raw_axis_i.
86    pub(crate) axis_combinations: Option<Vec<Vec<(usize, f64)>>>,
87}
88
89/// Streaming design derivative for one per-row latent coordinate `t[n, a]`.
90///
91/// The operator stores the shared latent matrix plus either radial-kernel
92/// ingredients or a precomputed non-radial derivative jet. Individual REML
93/// hyper-directions carry only a flat coordinate index and call
94/// `forward_mul_axis` / `transpose_mul_axis` to expose the corresponding
95/// one-row design derivative on demand.
96pub struct LatentCoordDesignDerivative {
97    pub(crate) provider: Arc<dyn LocalDesignJacobianProvider>,
98}
99
100#[derive(Debug, Clone)]
101pub(crate) struct RadialLatentCoordLocalDesignJacobian {
102    pub(crate) latent: Arc<crate::latent::LatentCoordValues>,
103    pub(crate) centers: Arc<Array2<f64>>,
104    pub(crate) radial_kind: RadialScalarKind,
105    pub(crate) ident_transform: Option<Array2<f64>>,
106    pub(crate) full_ident_transform: Option<Array2<f64>>,
107    pub(crate) n_poly: usize,
108    pub(crate) polynomial_order: Option<DuchonNullspaceOrder>,
109}
110
111#[derive(Debug, Clone)]
112pub(crate) struct JetLatentCoordLocalDesignJacobian {
113    pub(crate) latent: Arc<crate::latent::LatentCoordValues>,
114    pub(crate) jet: Arc<Array3<f64>>,
115    pub(crate) ident_transform: Option<Array2<f64>>,
116}
117
118impl std::fmt::Debug for LatentCoordDesignDerivative {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("LatentCoordDesignDerivative")
121            .field("n_data", &self.n_data())
122            .field("latent_dim", &self.latent_dim())
123            .field("n_axes", &self.n_axes())
124            .field("p_out", &self.p_out())
125            .field("provider", &self.provider)
126            .finish()
127    }
128}
129
130impl Clone for LatentCoordDesignDerivative {
131    fn clone(&self) -> Self {
132        Self {
133            provider: Arc::clone(&self.provider),
134        }
135    }
136}
137
138impl RadialLatentCoordLocalDesignJacobian {
139    pub(crate) fn p_constrained(&self) -> usize {
140        self.ident_transform
141            .as_ref()
142            .map_or(self.centers.nrows(), Array2::ncols)
143    }
144
145    pub(crate) fn p_after_pad(&self) -> usize {
146        self.p_constrained() + self.n_poly
147    }
148
149    pub(crate) fn p_out(&self) -> usize {
150        self.full_ident_transform
151            .as_ref()
152            .map_or(self.p_after_pad(), Array2::ncols)
153    }
154}
155
156impl JetLatentCoordLocalDesignJacobian {
157    pub(crate) fn p_out(&self) -> usize {
158        self.ident_transform
159            .as_ref()
160            .map_or(self.jet.shape()[1], Array2::ncols)
161    }
162}
163
164/// The complete contract a per-row latent / novel-manifold coordinate type must
165/// supply to participate in the REML design-derivative operator surface.
166///
167/// Onboarding a new coordinate type (the SAE / novel-manifold frontier) reduces
168/// to implementing the small set of *required* methods below — the coordinate
169/// geometry (`n_data`, `latent_dim`, `n_axes`) plus the single genuinely-new
170/// payload `local_design_jacobian_row` (the local block ∂(design row)/∂(coord)).
171/// The streaming operator surface consumed by `LatentCoordDerivativeOp` in
172/// `src/solver/reml/mod.rs` — forward matvec, transpose matvec, and dense
173/// materialization, together with the flat-axis → (row, axis) decode — is
174/// inherited as *default* methods and never re-implemented per coordinate type.
175///
176/// This is the close condition for #767: a new coordinate type touches zero
177/// operator-surface code; it provides only its local Jacobian and geometry.
178pub trait LocalDesignJacobianProvider: Send + Sync + std::fmt::Debug {
179    /// Number of data rows `n` the operator spans.
180    fn n_data(&self) -> usize;
181
182    /// Latent coordinate dimension `d` (perturbation axes per row).
183    fn latent_dim(&self) -> usize;
184
185    /// Number of flat hyper-axes `n · d` (one per (row, coordinate-axis) pair).
186    fn n_axes(&self) -> usize;
187
188    /// Number of output-basis columns in each local design-Jacobian row.
189    fn p_out(&self) -> usize;
190
191    /// The only per-coordinate payload: the projected local design-Jacobian row
192    /// ∂(design row `row`)/∂(coordinate axis `axis`) in output-basis columns.
193    fn local_design_jacobian_row(&self, row: usize, axis: usize)
194    -> Result<Array1<f64>, BasisError>;
195
196    /// Decode a flat hyper-axis into its `(row, coordinate axis)`. Row-major over
197    /// `(row, axis)` with stride `latent_dim`; uniform across coordinate types.
198    fn row_axis(&self, flat_axis: usize) -> (usize, usize) {
199        let d = self.latent_dim();
200        (flat_axis / d, flat_axis % d)
201    }
202
203    /// Forward matvec for one flat hyper-axis: place `J_row · u` at `row`.
204    fn forward_mul_axis(
205        &self,
206        flat_axis: usize,
207        u: &ArrayView1<'_, f64>,
208    ) -> Result<Array1<f64>, BasisError> {
209        assert!(
210            flat_axis < self.n_axes(),
211            "latent-coordinate derivative flat axis out of bounds in forward_mul_axis: flat_axis={flat_axis}, n_axes={}",
212            self.n_axes()
213        );
214        let (row, axis) = self.row_axis(flat_axis);
215        let local_jacobian = self.local_design_jacobian_row(row, axis)?;
216        assert_eq!(
217            u.len(),
218            local_jacobian.len(),
219            "latent-coordinate derivative coefficient length mismatch in forward_mul_axis"
220        );
221        let value = local_jacobian.dot(u);
222        let mut out = Array1::<f64>::zeros(self.n_data());
223        out[row] = value;
224        Ok(out)
225    }
226
227    /// Transpose matvec for one flat hyper-axis: scatter `v[row] · J_rowᵀ`.
228    fn transpose_mul_axis(
229        &self,
230        flat_axis: usize,
231        v: &ArrayView1<'_, f64>,
232    ) -> Result<Array1<f64>, BasisError> {
233        assert!(
234            flat_axis < self.n_axes(),
235            "latent-coordinate derivative flat axis out of bounds in transpose_mul_axis: flat_axis={flat_axis}, n_axes={}",
236            self.n_axes()
237        );
238        assert_eq!(
239            v.len(),
240            self.n_data(),
241            "latent-coordinate derivative row-adjoint length mismatch in transpose_mul_axis"
242        );
243        let (row, axis) = self.row_axis(flat_axis);
244        let scale = v[row];
245        Ok(self
246            .local_design_jacobian_row(row, axis)?
247            .mapv(|value| scale * value))
248    }
249
250    /// Dense `(n_data × p_out)` materialization of one flat hyper-axis: the local
251    /// Jacobian row placed at `row`, all other rows zero.
252    fn materialize_axis(&self, flat_axis: usize) -> Result<Array2<f64>, BasisError> {
253        assert!(
254            flat_axis < self.n_axes(),
255            "latent-coordinate derivative flat axis out of bounds in materialize_axis: flat_axis={flat_axis}, n_axes={}",
256            self.n_axes()
257        );
258        let (row, axis) = self.row_axis(flat_axis);
259        let projected = self.local_design_jacobian_row(row, axis)?;
260        let mut out = Array2::<f64>::zeros((self.n_data(), projected.len()));
261        out.row_mut(row).assign(&projected);
262        Ok(out)
263    }
264}
265
266/// The rayon chunk size for parallel implicit matvec operations.
267/// Each chunk processes this many data points before reducing.
268pub(crate) const IMPLICIT_MATVEC_CHUNK_SIZE: usize = 1000;
269
270/// Minimum data size to activate parallel iteration for implicit matvecs.
271pub(crate) const IMPLICIT_MATVEC_PAR_THRESHOLD: usize = 10_000;
272
273/// Number of lower-triangular center rows per tile when assembling dense
274/// ThinPlate penalty ψ-derivative kernel blocks.
275pub(crate) const THIN_PLATE_PENALTY_PSI_TILE_ROWS: usize = 32;
276
277impl LatentCoordDesignDerivative {
278    pub(crate) fn from_local_design_jacobian_provider(
279        provider: Arc<dyn LocalDesignJacobianProvider>,
280    ) -> Self {
281        Self { provider }
282    }
283
284    pub fn new_matern(
285        latent: Arc<crate::latent::LatentCoordValues>,
286        centers: Arc<Array2<f64>>,
287        length_scale: f64,
288        nu: MaternNu,
289        include_intercept: bool,
290        ident_transform: Option<Array2<f64>>,
291    ) -> Result<Self, BasisError> {
292        if latent.latent_dim() != centers.ncols() {
293            crate::bail_dim_basis!(
294                "LatentCoordDesignDerivative Matérn dimension mismatch: latent d={} centers d={}",
295                latent.latent_dim(),
296                centers.ncols()
297            );
298        }
299        Ok(Self::from_local_design_jacobian_provider(Arc::new(
300            RadialLatentCoordLocalDesignJacobian {
301                latent,
302                centers,
303                radial_kind: RadialScalarKind::Matern { length_scale, nu },
304                ident_transform,
305                full_ident_transform: None,
306                n_poly: usize::from(include_intercept),
307                polynomial_order: None,
308            },
309        )))
310    }
311
312    pub fn new_duchon(
313        latent: Arc<crate::latent::LatentCoordValues>,
314        centers: Arc<Array2<f64>>,
315        length_scale: Option<f64>,
316        power: f64,
317        nullspace_order: DuchonNullspaceOrder,
318        full_ident_transform: Option<Array2<f64>>,
319    ) -> Result<Self, BasisError> {
320        if latent.latent_dim() != centers.ncols() {
321            crate::bail_dim_basis!(
322                "LatentCoordDesignDerivative Duchon dimension mismatch: latent d={} centers d={}",
323                latent.latent_dim(),
324                centers.ncols()
325            );
326        }
327        let effective_order = duchon_effective_nullspace_order(centers.view(), nullspace_order);
328        let p_order = duchon_p_from_nullspace_order(effective_order);
329        let s_order = power.max(0.0).round() as usize;
330        let radial_kind = if let Some(length_scale) = length_scale {
331            RadialScalarKind::Duchon {
332                length_scale,
333                p_order,
334                s_order,
335                dim: centers.ncols(),
336                coeffs: duchon_partial_fraction_coeffs(
337                    p_order,
338                    s_order,
339                    1.0 / length_scale.max(1e-300),
340                ),
341            }
342        } else {
343            RadialScalarKind::PureDuchon {
344                block_order: pure_duchon_block_order(p_order, power).max(1.0) as usize,
345                p_order,
346                s_order,
347                dim: centers.ncols(),
348            }
349        };
350        let mut workspace = BasisWorkspace::default();
351        let ident_transform =
352            kernel_constraint_nullspace(centers.view(), effective_order, &mut workspace.cache)?;
353        let n_poly = polynomial_block_from_order(centers.view(), effective_order).ncols();
354        Ok(Self::from_local_design_jacobian_provider(Arc::new(
355            RadialLatentCoordLocalDesignJacobian {
356                latent,
357                centers,
358                radial_kind,
359                ident_transform: Some(ident_transform),
360                full_ident_transform,
361                n_poly,
362                polynomial_order: Some(effective_order),
363            },
364        )))
365    }
366
367    pub fn new_sphere(
368        latent: Arc<crate::latent::LatentCoordValues>,
369        centers: Arc<Array2<f64>>,
370        penalty_order: usize,
371        ident_transform: Option<Array2<f64>>,
372    ) -> Result<Self, BasisError> {
373        if latent.latent_dim() != centers.ncols() {
374            crate::bail_dim_basis!(
375                "LatentCoordDesignDerivative sphere dimension mismatch: latent d={} centers d={}",
376                latent.latent_dim(),
377                centers.ncols()
378            );
379        }
380        let raw_jet = sphere_first_derivative_nd(
381            latent.as_matrix().view(),
382            centers.view(),
383            penalty_order,
384            true,
385        )?;
386        let jet = latent.design_gradient_wrt_t_dispatch(
387            crate::latent::InputLocationDerivative::Jet(raw_jet.view()),
388        )?;
389        Self::from_jet(latent, jet, ident_transform)
390    }
391
392    pub fn new_periodic_bspline(
393        latent: Arc<crate::latent::LatentCoordValues>,
394        data_range: (f64, f64),
395        degree: usize,
396        num_basis: usize,
397        ident_transform: Option<Array2<f64>>,
398    ) -> Result<Self, BasisError> {
399        let raw_jet = periodic_bspline_first_derivative_nd(
400            latent.as_matrix().view(),
401            data_range,
402            degree,
403            num_basis,
404        )?;
405        let jet = latent.design_gradient_wrt_t_dispatch(
406            crate::latent::InputLocationDerivative::Jet(raw_jet.view()),
407        )?;
408        Self::from_jet(latent, jet, ident_transform)
409    }
410
411    pub fn new_tensor_bspline(
412        latent: Arc<crate::latent::LatentCoordValues>,
413        knots_per_axis: Vec<Array1<f64>>,
414        degrees: Vec<usize>,
415        ident_transform: Option<Array2<f64>>,
416    ) -> Result<Self, BasisError> {
417        let knot_views = knots_per_axis
418            .iter()
419            .map(|knots| knots.view())
420            .collect::<Vec<_>>();
421        let raw_jet =
422            bspline_tensor_first_derivative(latent.as_matrix().view(), &knot_views, &degrees)?;
423        let jet = latent.design_gradient_wrt_t_dispatch(
424            crate::latent::InputLocationDerivative::Jet(raw_jet.view()),
425        )?;
426        Self::from_jet(latent, jet, ident_transform)
427    }
428
429    pub fn new_pca(
430        latent: Arc<crate::latent::LatentCoordValues>,
431        basis_matrix: Arc<Array2<f64>>,
432    ) -> Result<Self, BasisError> {
433        if latent.latent_dim() != basis_matrix.nrows() {
434            crate::bail_dim_basis!(
435                "LatentCoordDesignDerivative Pca dimension mismatch: latent d={} basis rows={}",
436                latent.latent_dim(),
437                basis_matrix.nrows()
438            );
439        }
440        let mut jet =
441            Array3::<f64>::zeros((latent.n_obs(), basis_matrix.ncols(), basis_matrix.nrows()));
442        for row in 0..latent.n_obs() {
443            for axis in 0..basis_matrix.nrows() {
444                for col in 0..basis_matrix.ncols() {
445                    jet[[row, col, axis]] = basis_matrix[[axis, col]];
446                }
447            }
448        }
449        Self::from_jet(latent, jet, None)
450    }
451
452    pub fn from_jet(
453        latent: Arc<crate::latent::LatentCoordValues>,
454        jet: Array3<f64>,
455        ident_transform: Option<Array2<f64>>,
456    ) -> Result<Self, BasisError> {
457        if jet.shape()[0] != latent.n_obs() || jet.shape()[2] != latent.latent_dim() {
458            crate::bail_dim_basis!(
459                "LatentCoordDesignDerivative jet shape {:?} does not match latent shape ({}, {}, {})",
460                jet.shape(),
461                latent.n_obs(),
462                jet.shape()[1],
463                latent.latent_dim()
464            );
465        }
466        if let Some(z) = ident_transform.as_ref()
467            && z.nrows() != jet.shape()[1]
468        {
469            crate::bail_dim_basis!(
470                "LatentCoordDesignDerivative identifiability transform has {} rows but derivative jet has {} basis columns",
471                z.nrows(),
472                jet.shape()[1]
473            );
474        }
475        Ok(Self::from_local_design_jacobian_provider(Arc::new(
476            JetLatentCoordLocalDesignJacobian {
477                latent,
478                jet: Arc::new(jet),
479                ident_transform,
480            },
481        )))
482    }
483
484    pub(crate) fn n_data(&self) -> usize {
485        self.provider.n_data()
486    }
487
488    pub(crate) fn latent_dim(&self) -> usize {
489        self.provider.latent_dim()
490    }
491
492    pub fn n_axes(&self) -> usize {
493        self.provider.n_axes()
494    }
495
496    pub fn p_out(&self) -> usize {
497        self.provider.p_out()
498    }
499}
500
501impl RadialLatentCoordLocalDesignJacobian {
502    pub(crate) fn project_and_pad(
503        &self,
504        raw_knot: &Array1<f64>,
505        raw_poly: &Array1<f64>,
506    ) -> Result<Array1<f64>, BasisError> {
507        let constrained = match &self.ident_transform {
508            Some(z) => z.t().dot(raw_knot),
509            None => raw_knot.clone(),
510        };
511        let mut padded = Array1::<f64>::zeros(constrained.len() + self.n_poly);
512        padded
513            .slice_mut(s![..constrained.len()])
514            .assign(&constrained);
515        if self.n_poly > 0 {
516            padded.slice_mut(s![constrained.len()..]).assign(raw_poly);
517        }
518        Ok(match &self.full_ident_transform {
519            Some(zf) => zf.t().dot(&padded),
520            None => padded,
521        })
522    }
523
524    pub(crate) fn kernel_axis_scalar(
525        &self,
526        row: usize,
527        center: usize,
528        axis: usize,
529    ) -> Result<f64, BasisError> {
530        let t_row = self.latent.row(row);
531        let mut r2 = 0.0_f64;
532        for a in 0..self.latent.latent_dim() {
533            let delta = t_row[a] - self.centers[[center, a]];
534            r2 += delta * delta;
535        }
536        let r = r2.sqrt();
537        if r == 0.0 {
538            // At a center collision the axis component s_axis = (t − c)_axis
539            // is exactly zero. The product q · s_axis is therefore 0 for any
540            // kernel whose q has a finite limit; for kernels where q diverges
541            // the value is genuinely indeterminate (0 · ∞) and we must not
542            // pretend it is zero. Defer to the kernel's classification.
543            if self.radial_kind.is_smooth_at_collision() {
544                return Ok(0.0);
545            }
546            return Err(BasisError::DegenerateAtCollision {
547                kernel: "RadialScalarKind (design axis)",
548                dim: self.latent.latent_dim(),
549                m: 0.0,
550                message: "radial scalar q = φ'/r has no finite limit at r = 0; \
551                          the design row axis component is undefined",
552            });
553        }
554        let (_, q, _) = self.radial_kind.eval_design_triplet(r)?;
555        Ok(q * (t_row[axis] - self.centers[[center, axis]]))
556    }
557
558    pub(crate) fn polynomial_axis_values(&self, row: usize, axis: usize) -> Array1<f64> {
559        let Some(order) = self.polynomial_order else {
560            return Array1::<f64>::zeros(self.n_poly);
561        };
562        let max_degree = match order {
563            DuchonNullspaceOrder::Zero => 0usize,
564            DuchonNullspaceOrder::Linear => 1usize,
565            DuchonNullspaceOrder::Degree(k) => k,
566        };
567        let t_row = self.latent.row(row);
568        let exponents = monomial_exponents(self.latent.latent_dim(), max_degree);
569        let mut out = Array1::<f64>::zeros(exponents.len());
570        for (col, alpha) in exponents.iter().enumerate() {
571            let a_axis = alpha[axis];
572            if a_axis == 0 {
573                continue;
574            }
575            let mut value = a_axis as f64;
576            for a in 0..self.latent.latent_dim() {
577                let exp_a = if a == axis { a_axis - 1 } else { alpha[a] };
578                if exp_a != 0 {
579                    value *= t_row[a].powi(exp_a as i32);
580                }
581            }
582            out[col] = value;
583        }
584        out
585    }
586}
587
588impl JetLatentCoordLocalDesignJacobian {
589    pub(crate) fn project_jet(&self, raw_knot: &Array1<f64>) -> Result<Array1<f64>, BasisError> {
590        Ok(match &self.ident_transform {
591            Some(z) => z.t().dot(raw_knot),
592            None => raw_knot.clone(),
593        })
594    }
595}
596
597impl LocalDesignJacobianProvider for LatentCoordDesignDerivative {
598    fn n_data(&self) -> usize {
599        self.provider.n_data()
600    }
601
602    fn latent_dim(&self) -> usize {
603        self.provider.latent_dim()
604    }
605
606    fn n_axes(&self) -> usize {
607        self.provider.n_axes()
608    }
609
610    fn p_out(&self) -> usize {
611        self.provider.p_out()
612    }
613
614    fn local_design_jacobian_row(
615        &self,
616        row: usize,
617        axis: usize,
618    ) -> Result<Array1<f64>, BasisError> {
619        self.provider.local_design_jacobian_row(row, axis)
620    }
621}
622
623impl LocalDesignJacobianProvider for RadialLatentCoordLocalDesignJacobian {
624    fn n_data(&self) -> usize {
625        self.latent.n_obs()
626    }
627
628    fn latent_dim(&self) -> usize {
629        self.latent.latent_dim()
630    }
631
632    fn n_axes(&self) -> usize {
633        self.latent.len()
634    }
635
636    fn p_out(&self) -> usize {
637        Self::p_out(self)
638    }
639
640    fn local_design_jacobian_row(
641        &self,
642        row: usize,
643        axis: usize,
644    ) -> Result<Array1<f64>, BasisError> {
645        let mut raw_knot = Array1::<f64>::zeros(self.centers.nrows());
646        for center in 0..self.centers.nrows() {
647            raw_knot[center] = self.kernel_axis_scalar(row, center, axis)?;
648        }
649        let raw_poly = self.polynomial_axis_values(row, axis);
650        self.project_and_pad(&raw_knot, &raw_poly)
651    }
652}
653
654impl LocalDesignJacobianProvider for JetLatentCoordLocalDesignJacobian {
655    fn n_data(&self) -> usize {
656        self.latent.n_obs()
657    }
658
659    fn latent_dim(&self) -> usize {
660        self.latent.latent_dim()
661    }
662
663    fn n_axes(&self) -> usize {
664        self.latent.len()
665    }
666
667    fn p_out(&self) -> usize {
668        Self::p_out(self)
669    }
670
671    fn local_design_jacobian_row(
672        &self,
673        row: usize,
674        axis: usize,
675    ) -> Result<Array1<f64>, BasisError> {
676        let mut raw_knot = Array1::<f64>::zeros(self.jet.shape()[1]);
677        for basis_col in 0..self.jet.shape()[1] {
678            raw_knot[basis_col] = self.jet[[row, basis_col, axis]];
679        }
680        self.project_jet(&raw_knot)
681    }
682}
683
684impl ImplicitDesignPsiDerivative {
685    /// Construct from pre-computed radial jet scalars.
686    ///
687    /// # Arguments
688    /// - `q_values`: (n * n_knots,) — φ'(r)/r for each (data, knot) pair.
689    /// - `t_values`: (n * n_knots,) — (φ''(r) - q) / r² for each pair.
690    /// - `axis_components`: (n * n_knots, D) — s_{d,ij} = exp(2η_d) · h_d² for each pair/axis.
691    /// - `ident_transform`: optional (n_knots × p_constrained) constraint projection.
692    /// - `full_ident_transform`: optional further projection after padding.
693    /// - `n`, `n_knots`, `n_poly`, `n_axes`: dimensions.
694    /// Construct from pre-computed (materialized) radial jet scalars.
695    /// This is the original path for small-to-medium problems where
696    /// O(n*k*(d+2)) storage is acceptable.
697    pub fn new(
698        phi_values: Array1<f64>,
699        q_values: Array1<f64>,
700        t_values: Array1<f64>,
701        axis_components: Array2<f64>,
702        ident_transform: Option<Array2<f64>>,
703        full_ident_transform: Option<Array2<f64>>,
704        n: usize,
705        n_knots: usize,
706        n_poly: usize,
707        n_axes: usize,
708    ) -> Self {
709        assert_eq!(
710            phi_values.len(),
711            n * n_knots,
712            "implicit psi derivative phi length mismatch: expected n*n_knots={}*{}={}, got {}",
713            n,
714            n_knots,
715            n * n_knots,
716            phi_values.len()
717        );
718        assert_eq!(
719            q_values.len(),
720            n * n_knots,
721            "implicit psi derivative q length mismatch: expected n*n_knots={}*{}={}, got {}",
722            n,
723            n_knots,
724            n * n_knots,
725            q_values.len()
726        );
727        assert_eq!(
728            t_values.len(),
729            n * n_knots,
730            "implicit psi derivative t length mismatch: expected n*n_knots={}*{}={}, got {}",
731            n,
732            n_knots,
733            n * n_knots,
734            t_values.len()
735        );
736        assert_eq!(
737            axis_components.nrows(),
738            n * n_knots,
739            "implicit psi derivative axis-component row mismatch: expected n*n_knots={}*{}={}, got {}",
740            n,
741            n_knots,
742            n * n_knots,
743            axis_components.nrows()
744        );
745        assert_eq!(
746            axis_components.ncols(),
747            n_axes,
748            "implicit psi derivative axis-component column mismatch: expected n_axes={n_axes}, got {}",
749            axis_components.ncols()
750        );
751        Self {
752            phi_values,
753            axis_components,
754            q_values,
755            t_values,
756            streaming: None,
757            ident_transform,
758            full_ident_transform,
759            n,
760            n_knots,
761            n_poly,
762            n_axes,
763            psi_scale_share: 0.0,
764            axis_combinations: None,
765        }
766    }
767
768    pub(crate) fn with_psi_scale_share(mut self, psi_scale_share: f64) -> Self {
769        self.psi_scale_share = psi_scale_share;
770        self
771    }
772
773    /// Construct a streaming operator that recomputes (q, t, s_a) on the fly
774    /// from raw data/centers/eta during each matvec. No O(n*k) arrays are stored.
775    /// This is the large-scale path.
776    ///
777    /// `pub` like the sibling `new_*` constructors: after the engine crate carve
778    /// (#1521) the REML planner tests live in `gam-solve` and build streaming
779    /// operators as fixtures, so this constructor is part of the cross-crate
780    /// surface, not a crate-private helper.
781    pub fn new_streaming(
782        data: Arc<Array2<f64>>,
783        centers: Arc<Array2<f64>>,
784        eta: Vec<f64>,
785        radial_kind: RadialScalarKind,
786        ident_transform: Option<Array2<f64>>,
787        full_ident_transform: Option<Array2<f64>>,
788        n_poly: usize,
789    ) -> Self {
790        let n = data.nrows();
791        let n_knots = centers.nrows();
792        let n_axes = data.ncols();
793        let psi_scale_share = radial_kind.raw_psi_isotropic_share();
794        assert_eq!(eta.len(), n_axes);
795        assert_eq!(
796            centers.ncols(),
797            n_axes,
798            "streaming radial centers have {} columns but data/eta have {n_axes}",
799            centers.ncols()
800        );
801        let metric_weights: Arc<[f64]> = Arc::from(centered_aniso_metric_weights(&eta));
802        Self {
803            // Empty arrays -- not used in streaming mode.
804            phi_values: Array1::<f64>::zeros(0),
805            axis_components: Array2::<f64>::zeros((0, 0)),
806            q_values: Array1::<f64>::zeros(0),
807            t_values: Array1::<f64>::zeros(0),
808            streaming: Some(StreamingRadialState {
809                data,
810                centers,
811                axis_mode: StreamingAxisMode::PerAxis { metric_weights },
812                radial_kind,
813                triplet_cache: Arc::new(std::sync::OnceLock::new()),
814            }),
815            ident_transform,
816            full_ident_transform,
817            n,
818            n_knots,
819            n_poly,
820            n_axes,
821            psi_scale_share,
822            axis_combinations: None,
823        }
824    }
825
826    /// Construct a streaming operator for a scalar ψ derivative. The operator
827    /// exposes a single axis component equal to the full scaled squared
828    /// distance r² under the fixed metric defined by `eta`.
829    pub(crate) fn new_streaming_scalar(
830        data: Arc<Array2<f64>>,
831        centers: Arc<Array2<f64>>,
832        eta: Vec<f64>,
833        radial_kind: RadialScalarKind,
834        ident_transform: Option<Array2<f64>>,
835        full_ident_transform: Option<Array2<f64>>,
836        n_poly: usize,
837    ) -> Self {
838        let n = data.nrows();
839        let n_knots = centers.nrows();
840        let dim = data.ncols();
841        assert_eq!(eta.len(), dim);
842        assert_eq!(
843            centers.ncols(),
844            dim,
845            "streaming scalar radial centers have {} columns but data/eta have {dim}",
846            centers.ncols()
847        );
848        let metric_weights: Arc<[f64]> = Arc::from(centered_aniso_metric_weights(&eta));
849        Self {
850            phi_values: Array1::<f64>::zeros(0),
851            axis_components: Array2::<f64>::zeros((0, 0)),
852            q_values: Array1::<f64>::zeros(0),
853            t_values: Array1::<f64>::zeros(0),
854            streaming: Some(StreamingRadialState {
855                data,
856                centers,
857                axis_mode: StreamingAxisMode::ScalarTotal { metric_weights },
858                radial_kind,
859                triplet_cache: Arc::new(std::sync::OnceLock::new()),
860            }),
861            ident_transform,
862            full_ident_transform,
863            n,
864            n_knots,
865            n_poly,
866            n_axes: 1,
867            psi_scale_share: 0.0,
868            axis_combinations: None,
869        }
870    }
871
872    /// Whether this operator is in streaming (recompute-on-the-fly) mode.
873    #[inline]
874    pub(crate) fn is_streaming(&self) -> bool {
875        self.streaming.is_some()
876    }
877
878    /// Number of data points.
879    pub fn n_data(&self) -> usize {
880        self.n
881    }
882
883    /// Number of axes (D).
884    pub fn n_axes(&self) -> usize {
885        self.axis_combinations
886            .as_ref()
887            .map_or(self.n_axes, Vec::len)
888    }
889
890    pub fn is_duchon_family(&self) -> bool {
891        self.streaming.as_ref().is_some_and(|state| {
892            matches!(
893                state.radial_kind,
894                RadialScalarKind::Duchon { .. } | RadialScalarKind::PureDuchon { .. }
895            )
896        }) || self.psi_scale_share != 0.0
897    }
898
899    /// Whether this operator is wired up by a basis whose large-scale path
900    /// is supposed to stay implicit, so a dense `(n × p)` materialization
901    /// here is a regression rather than a normal compute path. Duchon-family
902    /// terms qualify because they are streaming-only at any scale; ThinPlate
903    /// qualifies because the new scalar-streaming routing relies on the
904    /// implicit operator above the policy threshold and a sneaky
905    /// `materialize_dense()` would silently re-introduce the n × p
906    /// allocation we just removed. The flag is consulted by the
907    /// materialize_first / materialize_second_diag / materialize_second_cross
908    /// guards to fire `assert_no_dense_derivative_materialization` for these
909    /// kinds whenever the resource policy says the materialization would
910    /// exceed budget. Small-n problems still pass the assertion and get the
911    /// dense fast path.
912    pub(crate) fn enforces_dense_materialization_budget(&self) -> bool {
913        if self
914            .streaming
915            .as_ref()
916            .is_some_and(|state| state.radial_kind.enforces_dense_materialization_budget())
917        {
918            return true;
919        }
920        // The materialized-mode path keeps no `radial_kind` to inspect, but
921        // a non-zero psi_scale_share is the unambiguous Duchon-family
922        // signature there (Matern uses 0, ThinPlate uses 0). Materialized
923        // ThinPlate / Matern terms are in the dense fast path and the
924        // guard does not need to fire for them.
925        self.psi_scale_share != 0.0
926    }
927
928    /// Output dimension: total basis columns in the final space.
929    pub fn p_out(&self) -> usize {
930        if let Some(ref zf) = self.full_ident_transform {
931            zf.ncols()
932        } else {
933            self.p_after_pad()
934        }
935    }
936
937    pub fn append_full_transform(
938        mut self,
939        transform: &Array2<f64>,
940    ) -> Result<Self, BasisError> {
941        if transform.nrows() != self.p_out() {
942            crate::bail_dim_basis!(
943                "implicit psi derivative transform has {} rows but operator has {} output columns",
944                transform.nrows(),
945                self.p_out()
946            );
947        }
948        self.full_ident_transform = Some(match self.full_ident_transform.take() {
949            Some(existing) => fast_ab(&existing, transform),
950            None => transform.clone(),
951        });
952        Ok(self)
953    }
954
955    /// Dimension after kernel constraint + polynomial padding (before full ident).
956    pub(crate) fn p_after_pad(&self) -> usize {
957        let p_constrained = self.p_constrained();
958        p_constrained + self.n_poly
959    }
960
961    /// Dimension after kernel constraint projection (before poly padding).
962    pub(crate) fn p_constrained(&self) -> usize {
963        match &self.ident_transform {
964            Some(z) => z.ncols(),
965            None => self.n_knots,
966        }
967    }
968
969    /// Accumulate raw knot-space vector from weighted (data, knot) contributions.
970    /// Returns a vector of length n_knots: Σ_i w_i · scalar_{ij} for each knot j.
971    ///
972    /// This is the core primitive: for each data point i, accumulate
973    /// `v[i] * per_pair_scalar(i,j)` into knot j.
974    pub(crate) fn accumulate_knot_vector<F>(&self, v: &ArrayView1<f64>, per_pair: F) -> Array1<f64>
975    where
976        F: Fn(usize) -> f64 + Send + Sync,
977    {
978        let n = self.n;
979        let k = self.n_knots;
980
981        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
982            // Parallel path: chunk data points and reduce.
983            let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
984            let partial_sums: Vec<Array1<f64>> = (0..n_chunks)
985                .into_par_iter()
986                .map(|chunk_idx| {
987                    let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
988                    let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
989                    let mut local = Array1::<f64>::zeros(k);
990                    for i in start..end {
991                        let vi = v[i];
992                        if vi == 0.0 {
993                            continue;
994                        }
995                        let base = i * k;
996                        for j in 0..k {
997                            local[j] += vi * per_pair(base + j);
998                        }
999                    }
1000                    local
1001                })
1002                .collect();
1003            let mut total = Array1::<f64>::zeros(k);
1004            for p in partial_sums {
1005                total += &p;
1006            }
1007            total
1008        } else {
1009            // Sequential path.
1010            let mut total = Array1::<f64>::zeros(k);
1011            for i in 0..n {
1012                let vi = v[i];
1013                if vi == 0.0 {
1014                    continue;
1015                }
1016                let base = i * k;
1017                for j in 0..k {
1018                    total[j] += vi * per_pair(base + j);
1019                }
1020            }
1021            total
1022        }
1023    }
1024
1025    /// Streaming accumulate knot vector from on-the-fly radial scalars.
1026    pub(crate) fn streaming_accumulate_knot_vector<G>(
1027        &self,
1028        v: &ArrayView1<f64>,
1029        deriv_fn: G,
1030    ) -> Result<Array1<f64>, BasisError>
1031    where
1032        G: Fn(f64, f64, f64, &[f64]) -> f64 + Send + Sync,
1033    {
1034        let st = self.streaming.as_ref().unwrap();
1035        let (n, k, dim) = (self.n, self.n_knots, self.n_axes);
1036        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1037            let err_flag = std::sync::atomic::AtomicBool::new(false);
1038            let nc = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1039            let ps: Vec<Array1<f64>> = (0..nc)
1040                .into_par_iter()
1041                .map(|ci| {
1042                    let s = ci * IMPLICIT_MATVEC_CHUNK_SIZE;
1043                    let e = (s + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1044                    let mut loc = Array1::<f64>::zeros(k);
1045                    let mut sb = vec![0.0; dim];
1046                    for i in s..e {
1047                        let vi = v[i];
1048                        if vi == 0.0 {
1049                            continue;
1050                        }
1051                        for j in 0..k {
1052                            match st.compute_pair(i, j, &mut sb) {
1053                                Ok((phi, q, t)) => {
1054                                    loc[j] += vi * deriv_fn(phi, q, t, &sb);
1055                                }
1056                                Err(_) => {
1057                                    err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
1058                                    return loc;
1059                                }
1060                            }
1061                        }
1062                    }
1063                    loc
1064                })
1065                .collect();
1066            if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
1067                crate::bail_invalid_basis!(
1068                    "radial scalar evaluation failed during streaming accumulate_knot_vector"
1069                        .into(),
1070                );
1071            }
1072            let mut tot = Array1::<f64>::zeros(k);
1073            for p in ps {
1074                tot += &p;
1075            }
1076            Ok(tot)
1077        } else {
1078            let mut tot = Array1::<f64>::zeros(k);
1079            let mut sb = vec![0.0; dim];
1080            for i in 0..n {
1081                let vi = v[i];
1082                if vi == 0.0 {
1083                    continue;
1084                }
1085                for j in 0..k {
1086                    let (phi, q, t) = st.compute_pair(i,j,&mut sb).map_err(|e| BasisError::InvalidInput(
1087                        format!("radial scalar evaluation failed during streaming accumulate_knot_vector: {e}"),
1088                    ))?;
1089                    tot[j] += vi * deriv_fn(phi, q, t, &sb);
1090                }
1091            }
1092            Ok(tot)
1093        }
1094    }
1095    /// Streaming forward multiply.
1096    pub(crate) fn streaming_forward_mul<G>(
1097        &self,
1098        u_knot: &Array1<f64>,
1099        deriv_fn: G,
1100    ) -> Result<Array1<f64>, BasisError>
1101    where
1102        G: Fn(f64, f64, f64, &[f64]) -> f64 + Send + Sync,
1103    {
1104        let st = self.streaming.as_ref().unwrap();
1105        let (n, k, dim) = (self.n, self.n_knots, self.n_axes);
1106        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1107            let err_flag = std::sync::atomic::AtomicBool::new(false);
1108            let nc = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1109            let cr: Vec<(usize, Vec<f64>)> = (0..nc)
1110                .into_par_iter()
1111                .map(|ci| {
1112                    let s = ci * IMPLICIT_MATVEC_CHUNK_SIZE;
1113                    let e = (s + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1114                    let mut loc = vec![0.0; e - s];
1115                    let mut sb = vec![0.0; dim];
1116                    for i in s..e {
1117                        let mut val = 0.0;
1118                        for j in 0..k {
1119                            match st.compute_pair(i, j, &mut sb) {
1120                                Ok((phi, q, t)) => {
1121                                    val += deriv_fn(phi, q, t, &sb) * u_knot[j];
1122                                }
1123                                Err(_) => {
1124                                    err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
1125                                    break;
1126                                }
1127                            }
1128                        }
1129                        loc[i - s] = val;
1130                    }
1131                    (s, loc)
1132                })
1133                .collect();
1134            if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
1135                crate::bail_invalid_basis!(
1136                    "radial scalar evaluation failed during streaming forward_mul".into(),
1137                );
1138            }
1139            let mut res = Array1::<f64>::zeros(n);
1140            for (s, vs) in cr {
1141                for (o, &v) in vs.iter().enumerate() {
1142                    res[s + o] = v;
1143                }
1144            }
1145            Ok(res)
1146        } else {
1147            let mut res = Array1::<f64>::zeros(n);
1148            let mut sb = vec![0.0; dim];
1149            for i in 0..n {
1150                let mut val = 0.0;
1151                for j in 0..k {
1152                    let (phi, q, t) = st.compute_pair(i, j, &mut sb).map_err(|e| {
1153                        BasisError::InvalidInput(format!(
1154                            "radial scalar evaluation failed during streaming forward_mul: {e}"
1155                        ))
1156                    })?;
1157                    val += deriv_fn(phi, q, t, &sb) * u_knot[j];
1158                }
1159                res[i] = val;
1160            }
1161            Ok(res)
1162        }
1163    }
1164    /// Streaming materialization: build (n x k) raw matrix then project.
1165    pub(crate) fn streaming_materialize<G>(&self, deriv_fn: G) -> Result<Array2<f64>, BasisError>
1166    where
1167        G: Fn(f64, f64, f64, &[f64]) -> f64 + Send + Sync,
1168    {
1169        let st = self.streaming.as_ref().unwrap();
1170        let (n, k, dim) = (self.n, self.n_knots, self.n_axes);
1171        let mut raw = Array2::<f64>::zeros((n, k));
1172        let cs = IMPLICIT_MATVEC_CHUNK_SIZE;
1173        let nc = n.div_ceil(cs);
1174        let err_flag = std::sync::atomic::AtomicBool::new(false);
1175        {
1176            let rp = SendPtr(raw.as_mut_ptr());
1177            let ef = &err_flag;
1178            (0..nc).into_par_iter().for_each(move |ci| {
1179                let s = ci * cs;
1180                let e = (s + cs).min(n);
1181                let mut sb = vec![0.0; dim];
1182                for i in s..e {
1183                    for j in 0..k {
1184                        match st.compute_pair(i, j, &mut sb) {
1185                            // SAFETY: chunk ci owns rows [s..e) of the raw n×k buffer,
1186                            // so offsets i*k+j for i ∈ [s,e), j ∈ [0,k) are pairwise
1187                            // disjoint across workers and stay within n*k = raw.len().
1188                            Ok((phi, q, t)) => unsafe {
1189                                *rp.add(i * k + j) = deriv_fn(phi, q, t, &sb);
1190                            },
1191                            Err(_) => {
1192                                ef.store(true, std::sync::atomic::Ordering::Relaxed);
1193                                return;
1194                            }
1195                        }
1196                    }
1197                }
1198            });
1199        }
1200        if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
1201            crate::bail_invalid_basis!(
1202                "radial scalar evaluation failed during streaming materialize".into(),
1203            );
1204        }
1205        Ok(self.project_matrix(raw))
1206    }
1207
1208    /// Project a raw knot-space vector through the identifiability transform
1209    /// and pad with zeros for polynomial columns.
1210    pub(crate) fn project_and_pad(&self, raw_knot_vec: &Array1<f64>) -> Array1<f64> {
1211        // Step 1: apply kernel constraint Z (if present).
1212        let constrained = match &self.ident_transform {
1213            Some(z) => z.t().dot(raw_knot_vec),
1214            None => raw_knot_vec.clone(),
1215        };
1216
1217        // Step 2: pad with polynomial zeros.
1218        let p_padded = constrained.len() + self.n_poly;
1219        let mut padded = Array1::<f64>::zeros(p_padded);
1220        padded
1221            .slice_mut(s![..constrained.len()])
1222            .assign(&constrained);
1223
1224        // Step 3: apply full identifiability transform (if present).
1225        match &self.full_ident_transform {
1226            Some(zf) => zf.t().dot(&padded),
1227            None => padded,
1228        }
1229    }
1230
1231    /// Expand a coefficient vector from the final space back to raw knot space.
1232    /// This is the transpose path: p_out → (padded) → (constrained) → n_knots.
1233    pub(crate) fn unproject(&self, u: &ArrayView1<f64>) -> Array1<f64> {
1234        // Step 1: undo full identifiability transform.
1235        let after_full = match &self.full_ident_transform {
1236            Some(zf) => zf.dot(u),
1237            None => u.to_owned(),
1238        };
1239
1240        // Step 2: extract smooth part (drop polynomial padding).
1241        let p_constrained = self.p_constrained();
1242        let smooth_part = after_full.slice(s![..p_constrained]);
1243
1244        // Step 3: undo kernel constraint Z.
1245        match &self.ident_transform {
1246            Some(z) => z.dot(&smooth_part),
1247            None => smooth_part.to_owned(),
1248        }
1249    }
1250
1251    /// Batched `unproject` for a (p_out × rank) coefficient matrix.
1252    /// Returns (n_knots × rank) via two BLAS3 matmuls — the same algebra as
1253    /// `unproject`, but amortized across all rank columns of `u`. Used by
1254    /// `forward_mul_matrix` so per-axis trace evaluations can be a single
1255    /// chunked GEMM rather than rank-many `forward_mul` calls.
1256    pub fn unproject_matrix(&self, u: &ArrayView2<f64>) -> Array2<f64> {
1257        assert_eq!(u.nrows(), self.p_out());
1258        // Step 1: undo full identifiability transform → (p_after_pad, rank).
1259        let after_full = match &self.full_ident_transform {
1260            Some(zf) => fast_ab(zf, u),
1261            None => u.to_owned(),
1262        };
1263        // Step 2: drop polynomial padding rows → (p_constrained, rank).
1264        let p_constrained = self.p_constrained();
1265        let smooth_part = after_full.slice(s![..p_constrained, ..]);
1266        // Step 3: undo kernel constraint Z → (n_knots, rank).
1267        match &self.ident_transform {
1268            Some(z) => fast_ab(z, &smooth_part),
1269            None => smooth_part.to_owned(),
1270        }
1271    }
1272
1273    /// Compute (∂X/∂ψ_d)^T v for a given axis d and vector v of length n.
1274    ///
1275    /// Returns a vector of length p_out (total basis dimension after all transforms).
1276    ///
1277    /// Formula in raw knot space:
1278    ///   [raw]_j = Σ_i v_i · q_{ij} · s_{d,ij}
1279    /// then project through Z and pad.
1280    ///
1281    /// Note: q = φ_r/r and s_d = exp(2ψ_d)·h_d² are UNNORMALIZED axis components.
1282    /// With this convention, q·s_d = (φ_r/r)·(exp(2ψ_d)·h_d²) = φ_r·(s_d/r),
1283    /// which equals the correct ∂φ/∂ψ_d = φ_r·∂r/∂ψ_d = φ_r·s_d/r.
1284    /// No r² correction is needed — that would be required only if s_d were
1285    /// the fractional quantity s_d/r².
1286    pub fn transpose_mul(
1287        &self,
1288        axis: usize,
1289        v: &ArrayView1<f64>,
1290    ) -> Result<Array1<f64>, BasisError> {
1291        assert!(
1292            axis < self.n_axes(),
1293            "implicit psi first transpose axis out of bounds: axis={axis}, n_axes={}",
1294            self.n_axes()
1295        );
1296        assert_eq!(
1297            v.len(),
1298            self.n,
1299            "implicit psi first transpose row-adjoint length mismatch"
1300        );
1301        if self.axis_combinations.is_some() {
1302            let combo = self.transformed_axis_combination(axis);
1303            let combo_sum = Self::transformed_combo_sum(combo);
1304            if self.is_streaming() {
1305                let c = self.psi_scale_share;
1306                let raw = self.streaming_accumulate_knot_vector(v, |phi, q, _, sb| {
1307                    let s_combo = combo
1308                        .iter()
1309                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1310                        .sum();
1311                    Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
1312                })?;
1313                return Ok(self.project_and_pad(&raw));
1314            }
1315            let c = self.psi_scale_share;
1316            let raw = self.accumulate_knot_vector(v, |idx| {
1317                let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1318                Self::transformed_first_kernel_value(
1319                    self.phi_values[idx],
1320                    self.q_values[idx],
1321                    s_combo,
1322                    combo_sum,
1323                    c,
1324                )
1325            });
1326            return Ok(self.project_and_pad(&raw));
1327        }
1328        if self.is_streaming() {
1329            let c = self.psi_scale_share;
1330            let raw =
1331                self.streaming_accumulate_knot_vector(v, |phi, q, _, sb| q * sb[axis] + c * phi)?;
1332            return Ok(self.project_and_pad(&raw));
1333        }
1334        let c = self.psi_scale_share;
1335        let af = &self.axis_components;
1336        let pv = &self.phi_values;
1337        let qv = &self.q_values;
1338        let raw = self.accumulate_knot_vector(v, |idx| qv[idx] * af[[idx, axis]] + c * pv[idx]);
1339        Ok(self.project_and_pad(&raw))
1340    }
1341
1342    /// Compute (∂X/∂ψ_d) u for a given axis d and vector u of length p_out.
1343    ///
1344    /// Returns a vector of length n.
1345    ///
1346    /// Formula: for each data point i,
1347    ///   result_i = Σ_j q_{ij} · s_{d,ij} · u_knot_j
1348    /// where u_knot = Z · u_smooth (unprojected back to knot space).
1349    pub fn forward_mul(&self, axis: usize, u: &ArrayView1<f64>) -> Result<Array1<f64>, BasisError> {
1350        assert!(
1351            axis < self.n_axes(),
1352            "implicit psi first forward axis out of bounds: axis={axis}, n_axes={}",
1353            self.n_axes()
1354        );
1355        assert_eq!(
1356            u.len(),
1357            self.p_out(),
1358            "implicit psi first forward coefficient length mismatch"
1359        );
1360        let u_knot = self.unproject(u);
1361        if self.axis_combinations.is_some() {
1362            let combo = self.transformed_axis_combination(axis);
1363            let combo_sum = Self::transformed_combo_sum(combo);
1364            if self.is_streaming() {
1365                let c = self.psi_scale_share;
1366                return self.streaming_forward_mul(&u_knot, |phi, q, _, sb| {
1367                    let s_combo = combo
1368                        .iter()
1369                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1370                        .sum();
1371                    Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
1372                });
1373            }
1374            let n = self.n;
1375            let k = self.n_knots;
1376            let c = self.psi_scale_share;
1377            if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1378                let mut result = Array1::<f64>::zeros(n);
1379                let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1380                let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1381                    .into_par_iter()
1382                    .map(|chunk_idx| {
1383                        let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1384                        let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1385                        let mut local = vec![0.0; end - start];
1386                        for i in start..end {
1387                            let base = i * k;
1388                            let mut val = 0.0;
1389                            for j in 0..k {
1390                                let idx = base + j;
1391                                let s_combo =
1392                                    self.transformed_combo_axis_value_materialized(idx, combo);
1393                                val += Self::transformed_first_kernel_value(
1394                                    self.phi_values[idx],
1395                                    self.q_values[idx],
1396                                    s_combo,
1397                                    combo_sum,
1398                                    c,
1399                                ) * u_knot[j];
1400                            }
1401                            local[i - start] = val;
1402                        }
1403                        (start, local)
1404                    })
1405                    .collect();
1406                for (start, vals) in chunk_results {
1407                    for (offset, &v) in vals.iter().enumerate() {
1408                        result[start + offset] = v;
1409                    }
1410                }
1411                return Ok(result);
1412            }
1413            let mut result = Array1::<f64>::zeros(n);
1414            for i in 0..n {
1415                let base = i * k;
1416                let mut val = 0.0;
1417                for j in 0..k {
1418                    let idx = base + j;
1419                    let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1420                    val += Self::transformed_first_kernel_value(
1421                        self.phi_values[idx],
1422                        self.q_values[idx],
1423                        s_combo,
1424                        combo_sum,
1425                        c,
1426                    ) * u_knot[j];
1427                }
1428                result[i] = val;
1429            }
1430            return Ok(result);
1431        }
1432        if self.is_streaming() {
1433            let c = self.psi_scale_share;
1434            return self.streaming_forward_mul(&u_knot, |phi, q, _, sb| q * sb[axis] + c * phi);
1435        }
1436        let n = self.n;
1437        let k = self.n_knots;
1438        let c = self.psi_scale_share;
1439        let af = &self.axis_components;
1440        let pv = &self.phi_values;
1441        let qv = &self.q_values;
1442
1443        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1444            let mut result = Array1::<f64>::zeros(n);
1445            // Parallel over chunks of data points.
1446            let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1447            let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1448                .into_par_iter()
1449                .map(|chunk_idx| {
1450                    let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1451                    let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1452                    let mut local = vec![0.0; end - start];
1453                    for i in start..end {
1454                        let base = i * k;
1455                        let mut val = 0.0;
1456                        for j in 0..k {
1457                            val += (qv[base + j] * af[[base + j, axis]] + c * pv[base + j])
1458                                * u_knot[j];
1459                        }
1460                        local[i - start] = val;
1461                    }
1462                    (start, local)
1463                })
1464                .collect();
1465            for (start, vals) in chunk_results {
1466                for (offset, &v) in vals.iter().enumerate() {
1467                    result[start + offset] = v;
1468                }
1469            }
1470            Ok(result)
1471        } else {
1472            let mut result = Array1::<f64>::zeros(n);
1473            for i in 0..n {
1474                let base = i * k;
1475                let mut val = 0.0;
1476                for j in 0..k {
1477                    val += (qv[base + j] * af[[base + j, axis]] + c * pv[base + j]) * u_knot[j];
1478                }
1479                result[i] = val;
1480            }
1481            Ok(result)
1482        }
1483    }
1484
1485    /// Compute (∂²X/∂ψ_d²)^T v — diagonal second derivative, same axis.
1486    ///
1487    /// Matrix-free variant of `materialize_second_diag`: avoids forming the
1488    /// full (n × p_out) matrix when only a single adjoint matvec is needed.
1489    pub fn transpose_mul_second_diag(
1490        &self,
1491        axis: usize,
1492        v: &ArrayView1<f64>,
1493    ) -> Result<Array1<f64>, BasisError> {
1494        assert!(
1495            axis < self.n_axes(),
1496            "implicit psi second diagonal transpose axis out of bounds: axis={axis}, n_axes={}",
1497            self.n_axes()
1498        );
1499        assert_eq!(
1500            v.len(),
1501            self.n,
1502            "implicit psi second diagonal transpose row-adjoint length mismatch"
1503        );
1504        if self.axis_combinations.is_some() {
1505            let combo = self.transformed_axis_combination(axis);
1506            let combo_sum = Self::transformed_combo_sum(combo);
1507            if self.is_streaming() {
1508                let c = self.psi_scale_share;
1509                let raw = self.streaming_accumulate_knot_vector(v, |phi, q, t, sb| {
1510                    let s_combo = combo
1511                        .iter()
1512                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1513                        .sum();
1514                    let overlap_s = Self::transformed_combo_overlap_streaming(combo, combo, sb);
1515                    Self::transformed_second_kernel_value(
1516                        phi, q, t, s_combo, combo_sum, s_combo, combo_sum, overlap_s, c,
1517                    )
1518                })?;
1519                return Ok(self.project_and_pad(&raw));
1520            }
1521            let c = self.psi_scale_share;
1522            let raw = self.accumulate_knot_vector(v, |idx| {
1523                let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1524                let overlap_s = self.transformed_combo_overlap_materialized(idx, combo, combo);
1525                Self::transformed_second_kernel_value(
1526                    self.phi_values[idx],
1527                    self.q_values[idx],
1528                    self.t_values[idx],
1529                    s_combo,
1530                    combo_sum,
1531                    s_combo,
1532                    combo_sum,
1533                    overlap_s,
1534                    c,
1535                )
1536            });
1537            return Ok(self.project_and_pad(&raw));
1538        }
1539        if self.is_streaming() {
1540            let c = self.psi_scale_share;
1541            let raw = self.streaming_accumulate_knot_vector(v, |phi, q, t, sb| {
1542                let s = sb[axis];
1543                2.0 * q * s + t * s * s + 2.0 * c * q * s + c * c * phi
1544            })?;
1545            return Ok(self.project_and_pad(&raw));
1546        }
1547        let c = self.psi_scale_share;
1548        let af = &self.axis_components;
1549        let pv = &self.phi_values;
1550        let qv = &self.q_values;
1551        let tv = &self.t_values;
1552        let raw = self.accumulate_knot_vector(v, |idx| {
1553            let s = af[[idx, axis]];
1554            2.0 * qv[idx] * s + tv[idx] * s * s + 2.0 * c * qv[idx] * s + c * c * pv[idx]
1555        });
1556        Ok(self.project_and_pad(&raw))
1557    }
1558
1559    /// Compute (∂²X/∂ψ_d∂ψ_e)^T v — cross second derivative (d ≠ e).
1560    pub fn transpose_mul_second_cross(
1561        &self,
1562        axis_d: usize,
1563        axis_e: usize,
1564        v: &ArrayView1<f64>,
1565    ) -> Result<Array1<f64>, BasisError> {
1566        assert!(
1567            axis_d < self.n_axes(),
1568            "implicit psi second cross transpose first axis out of bounds: axis_d={axis_d}, n_axes={}",
1569            self.n_axes()
1570        );
1571        assert!(
1572            axis_e < self.n_axes(),
1573            "implicit psi second cross transpose second axis out of bounds: axis_e={axis_e}, n_axes={}",
1574            self.n_axes()
1575        );
1576        assert_ne!(
1577            axis_d, axis_e,
1578            "implicit psi second cross transpose requires distinct axes: axis_d={axis_d}, axis_e={axis_e}"
1579        );
1580        assert_eq!(
1581            v.len(),
1582            self.n,
1583            "implicit psi second cross transpose row-adjoint length mismatch"
1584        );
1585        if self.axis_combinations.is_some() {
1586            let combo_d = self.transformed_axis_combination(axis_d);
1587            let combo_e = self.transformed_axis_combination(axis_e);
1588            let sum_d = Self::transformed_combo_sum(combo_d);
1589            let sum_e = Self::transformed_combo_sum(combo_e);
1590            if self.is_streaming() {
1591                let c = self.psi_scale_share;
1592                let raw = self.streaming_accumulate_knot_vector(v, |phi, q, t, sb| {
1593                    let s_d = combo_d
1594                        .iter()
1595                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1596                        .sum();
1597                    let s_e = combo_e
1598                        .iter()
1599                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1600                        .sum();
1601                    let overlap_s = Self::transformed_combo_overlap_streaming(combo_d, combo_e, sb);
1602                    Self::transformed_second_kernel_value(
1603                        phi, q, t, s_d, sum_d, s_e, sum_e, overlap_s, c,
1604                    )
1605                })?;
1606                return Ok(self.project_and_pad(&raw));
1607            }
1608            let c = self.psi_scale_share;
1609            let raw = self.accumulate_knot_vector(v, |idx| {
1610                let s_d = self.transformed_combo_axis_value_materialized(idx, combo_d);
1611                let s_e = self.transformed_combo_axis_value_materialized(idx, combo_e);
1612                let overlap_s = self.transformed_combo_overlap_materialized(idx, combo_d, combo_e);
1613                Self::transformed_second_kernel_value(
1614                    self.phi_values[idx],
1615                    self.q_values[idx],
1616                    self.t_values[idx],
1617                    s_d,
1618                    sum_d,
1619                    s_e,
1620                    sum_e,
1621                    overlap_s,
1622                    c,
1623                )
1624            });
1625            return Ok(self.project_and_pad(&raw));
1626        }
1627        if self.is_streaming() {
1628            let c = self.psi_scale_share;
1629            let raw = self.streaming_accumulate_knot_vector(v, |phi, q, t, sb| {
1630                t * sb[axis_d] * sb[axis_e] + c * q * (sb[axis_d] + sb[axis_e]) + c * c * phi
1631            })?;
1632            return Ok(self.project_and_pad(&raw));
1633        }
1634        let c = self.psi_scale_share;
1635        let af = &self.axis_components;
1636        let pv = &self.phi_values;
1637        let qv = &self.q_values;
1638        let tv = &self.t_values;
1639        let raw = self.accumulate_knot_vector(v, |idx| {
1640            tv[idx] * af[[idx, axis_d]] * af[[idx, axis_e]]
1641                + c * qv[idx] * (af[[idx, axis_d]] + af[[idx, axis_e]])
1642                + c * c * pv[idx]
1643        });
1644        Ok(self.project_and_pad(&raw))
1645    }
1646
1647    /// Compute (∂²X/∂ψ_d²) u — forward diagonal second derivative.
1648    pub fn forward_mul_second_diag(
1649        &self,
1650        axis: usize,
1651        u: &ArrayView1<f64>,
1652    ) -> Result<Array1<f64>, BasisError> {
1653        assert!(
1654            axis < self.n_axes(),
1655            "implicit psi second diagonal forward axis out of bounds: axis={axis}, n_axes={}",
1656            self.n_axes()
1657        );
1658        assert_eq!(
1659            u.len(),
1660            self.p_out(),
1661            "implicit psi second diagonal forward coefficient length mismatch"
1662        );
1663        let u_knot = self.unproject(u);
1664        if self.axis_combinations.is_some() {
1665            let combo = self.transformed_axis_combination(axis);
1666            let combo_sum = Self::transformed_combo_sum(combo);
1667            if self.is_streaming() {
1668                let c = self.psi_scale_share;
1669                return self.streaming_forward_mul(&u_knot, |phi, q, t, sb| {
1670                    let s_combo = combo
1671                        .iter()
1672                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1673                        .sum();
1674                    let overlap_s = Self::transformed_combo_overlap_streaming(combo, combo, sb);
1675                    Self::transformed_second_kernel_value(
1676                        phi, q, t, s_combo, combo_sum, s_combo, combo_sum, overlap_s, c,
1677                    )
1678                });
1679            }
1680            let n = self.n;
1681            let k = self.n_knots;
1682            let c = self.psi_scale_share;
1683            let compute_row = |i: usize| -> f64 {
1684                let base = i * k;
1685                let mut val = 0.0;
1686                for j in 0..k {
1687                    let idx = base + j;
1688                    let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1689                    let overlap_s = self.transformed_combo_overlap_materialized(idx, combo, combo);
1690                    val += Self::transformed_second_kernel_value(
1691                        self.phi_values[idx],
1692                        self.q_values[idx],
1693                        self.t_values[idx],
1694                        s_combo,
1695                        combo_sum,
1696                        s_combo,
1697                        combo_sum,
1698                        overlap_s,
1699                        c,
1700                    ) * u_knot[j];
1701                }
1702                val
1703            };
1704            if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1705                let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1706                let mut result = Array1::<f64>::zeros(n);
1707                let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1708                    .into_par_iter()
1709                    .map(|chunk_idx| {
1710                        let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1711                        let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1712                        let local: Vec<f64> = (start..end).map(compute_row).collect();
1713                        (start, local)
1714                    })
1715                    .collect();
1716                for (start, vals) in chunk_results {
1717                    for (offset, &value) in vals.iter().enumerate() {
1718                        result[start + offset] = value;
1719                    }
1720                }
1721                return Ok(result);
1722            }
1723            return Ok(Array1::from_vec((0..n).map(compute_row).collect()));
1724        }
1725        if self.is_streaming() {
1726            let c = self.psi_scale_share;
1727            return self.streaming_forward_mul(&u_knot, |phi, q, t, sb| {
1728                let s = sb[axis];
1729                2.0 * q * s + t * s * s + 2.0 * c * q * s + c * c * phi
1730            });
1731        }
1732        let n = self.n;
1733        let k = self.n_knots;
1734        let c = self.psi_scale_share;
1735        let af = &self.axis_components;
1736        let pv = &self.phi_values;
1737        let qv = &self.q_values;
1738        let tv = &self.t_values;
1739        let compute_row = |i: usize| -> f64 {
1740            let base = i * k;
1741            let mut val = 0.0;
1742            for j in 0..k {
1743                let s = af[[base + j, axis]];
1744                val += (2.0 * qv[base + j] * s
1745                    + tv[base + j] * s * s
1746                    + 2.0 * c * qv[base + j] * s
1747                    + c * c * pv[base + j])
1748                    * u_knot[j];
1749            }
1750            val
1751        };
1752
1753        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1754            let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1755            let mut result = Array1::<f64>::zeros(n);
1756            let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1757                .into_par_iter()
1758                .map(|chunk_idx| {
1759                    let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1760                    let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1761                    let local: Vec<f64> = (start..end).map(compute_row).collect();
1762                    (start, local)
1763                })
1764                .collect();
1765            for (start, vals) in chunk_results {
1766                for (offset, &value) in vals.iter().enumerate() {
1767                    result[start + offset] = value;
1768                }
1769            }
1770            Ok(result)
1771        } else {
1772            Ok(Array1::from_vec((0..n).map(compute_row).collect()))
1773        }
1774    }
1775
1776    /// Compute (∂²X/∂ψ_d∂ψ_e) u — forward cross second derivative.
1777    pub fn forward_mul_second_cross(
1778        &self,
1779        axis_d: usize,
1780        axis_e: usize,
1781        u: &ArrayView1<f64>,
1782    ) -> Result<Array1<f64>, BasisError> {
1783        assert!(
1784            axis_d < self.n_axes(),
1785            "implicit psi second cross forward first axis out of bounds: axis_d={axis_d}, n_axes={}",
1786            self.n_axes()
1787        );
1788        assert!(
1789            axis_e < self.n_axes(),
1790            "implicit psi second cross forward second axis out of bounds: axis_e={axis_e}, n_axes={}",
1791            self.n_axes()
1792        );
1793        assert_ne!(
1794            axis_d, axis_e,
1795            "implicit psi second cross forward requires distinct axes: axis_d={axis_d}, axis_e={axis_e}"
1796        );
1797        assert_eq!(
1798            u.len(),
1799            self.p_out(),
1800            "implicit psi second cross forward coefficient length mismatch"
1801        );
1802        let u_knot = self.unproject(u);
1803        if self.axis_combinations.is_some() {
1804            let combo_d = self.transformed_axis_combination(axis_d);
1805            let combo_e = self.transformed_axis_combination(axis_e);
1806            let sum_d = Self::transformed_combo_sum(combo_d);
1807            let sum_e = Self::transformed_combo_sum(combo_e);
1808            if self.is_streaming() {
1809                let c = self.psi_scale_share;
1810                return self.streaming_forward_mul(&u_knot, |phi, q, t, sb| {
1811                    let s_d = combo_d
1812                        .iter()
1813                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1814                        .sum();
1815                    let s_e = combo_e
1816                        .iter()
1817                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1818                        .sum();
1819                    let overlap_s = Self::transformed_combo_overlap_streaming(combo_d, combo_e, sb);
1820                    Self::transformed_second_kernel_value(
1821                        phi, q, t, s_d, sum_d, s_e, sum_e, overlap_s, c,
1822                    )
1823                });
1824            }
1825            let n = self.n;
1826            let k = self.n_knots;
1827            let c = self.psi_scale_share;
1828            let compute_row = |i: usize| -> f64 {
1829                let base = i * k;
1830                let mut val = 0.0;
1831                for j in 0..k {
1832                    let idx = base + j;
1833                    let s_d = self.transformed_combo_axis_value_materialized(idx, combo_d);
1834                    let s_e = self.transformed_combo_axis_value_materialized(idx, combo_e);
1835                    let overlap_s =
1836                        self.transformed_combo_overlap_materialized(idx, combo_d, combo_e);
1837                    val += Self::transformed_second_kernel_value(
1838                        self.phi_values[idx],
1839                        self.q_values[idx],
1840                        self.t_values[idx],
1841                        s_d,
1842                        sum_d,
1843                        s_e,
1844                        sum_e,
1845                        overlap_s,
1846                        c,
1847                    ) * u_knot[j];
1848                }
1849                val
1850            };
1851            if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1852                let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1853                let mut result = Array1::<f64>::zeros(n);
1854                let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1855                    .into_par_iter()
1856                    .map(|chunk_idx| {
1857                        let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1858                        let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1859                        let local: Vec<f64> = (start..end).map(compute_row).collect();
1860                        (start, local)
1861                    })
1862                    .collect();
1863                for (start, vals) in chunk_results {
1864                    for (offset, &value) in vals.iter().enumerate() {
1865                        result[start + offset] = value;
1866                    }
1867                }
1868                return Ok(result);
1869            }
1870            return Ok(Array1::from_vec((0..n).map(compute_row).collect()));
1871        }
1872        if self.is_streaming() {
1873            let c = self.psi_scale_share;
1874            return self.streaming_forward_mul(&u_knot, |phi, q, t, sb| {
1875                t * sb[axis_d] * sb[axis_e] + c * q * (sb[axis_d] + sb[axis_e]) + c * c * phi
1876            });
1877        }
1878        let n = self.n;
1879        let k = self.n_knots;
1880        let c = self.psi_scale_share;
1881        let af = &self.axis_components;
1882        let pv = &self.phi_values;
1883        let qv = &self.q_values;
1884        let tv = &self.t_values;
1885        let compute_row = |i: usize| -> f64 {
1886            let base = i * k;
1887            let mut val = 0.0;
1888            for j in 0..k {
1889                val += (tv[base + j] * af[[base + j, axis_d]] * af[[base + j, axis_e]]
1890                    + c * qv[base + j] * (af[[base + j, axis_d]] + af[[base + j, axis_e]])
1891                    + c * c * pv[base + j])
1892                    * u_knot[j];
1893            }
1894            val
1895        };
1896
1897        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1898            let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1899            let mut result = Array1::<f64>::zeros(n);
1900            let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1901                .into_par_iter()
1902                .map(|chunk_idx| {
1903                    let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1904                    let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1905                    let local: Vec<f64> = (start..end).map(compute_row).collect();
1906                    (start, local)
1907                })
1908                .collect();
1909            for (start, vals) in chunk_results {
1910                for (offset, &value) in vals.iter().enumerate() {
1911                    result[start + offset] = value;
1912                }
1913            }
1914            Ok(result)
1915        } else {
1916            Ok(Array1::from_vec((0..n).map(compute_row).collect()))
1917        }
1918    }
1919
1920    /// Materialize the full (n × p_out) first-derivative matrix for axis d.
1921    ///
1922    /// Efficient O(n * k) construction: builds the raw (n × k) kernel derivative
1923    /// matrix directly, then projects through identifiability transforms.
1924    /// This is used when the dense matrix is needed temporarily (e.g., for
1925    /// HyperCoord construction) while avoiding simultaneous storage of all D axes.
1926    pub fn materialize_first(&self, axis: usize) -> Result<Array2<f64>, BasisError> {
1927        assert!(
1928            axis < self.n_axes(),
1929            "implicit psi first materialization axis out of bounds: axis={axis}, n_axes={}",
1930            self.n_axes()
1931        );
1932        if self.enforces_dense_materialization_budget() {
1933            assert_no_dense_derivative_materialization(self.n, self.p_out(), self.n_axes());
1934        }
1935        if self.axis_combinations.is_some() {
1936            let combo = self.transformed_axis_combination(axis);
1937            let combo_sum = Self::transformed_combo_sum(combo);
1938            if self.is_streaming() {
1939                let c = self.psi_scale_share;
1940                return self.streaming_materialize(|phi, q, _, sb| {
1941                    let s_combo = combo
1942                        .iter()
1943                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1944                        .sum();
1945                    Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
1946                });
1947            }
1948            let n = self.n;
1949            let k = self.n_knots;
1950            let c = self.psi_scale_share;
1951            let mut raw = Array2::<f64>::zeros((n, k));
1952            for i in 0..n {
1953                let base = i * k;
1954                for j in 0..k {
1955                    let idx = base + j;
1956                    let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1957                    raw[[i, j]] = Self::transformed_first_kernel_value(
1958                        self.phi_values[idx],
1959                        self.q_values[idx],
1960                        s_combo,
1961                        combo_sum,
1962                        c,
1963                    );
1964                }
1965            }
1966            return Ok(self.project_matrix(raw));
1967        }
1968        if self.is_streaming() {
1969            let c = self.psi_scale_share;
1970            return self.streaming_materialize(|phi, q, _, sb| q * sb[axis] + c * phi);
1971        }
1972        let n = self.n;
1973        let k = self.n_knots;
1974        let c = self.psi_scale_share;
1975        let mut raw = Array2::<f64>::zeros((n, k));
1976        for i in 0..n {
1977            let base = i * k;
1978            for j in 0..k {
1979                raw[[i, j]] = self.q_values[base + j] * self.axis_components[[base + j, axis]]
1980                    + c * self.phi_values[base + j];
1981            }
1982        }
1983        Ok(self.project_matrix(raw))
1984    }
1985
1986    /// Materialize the full (n × p_out) second diagonal derivative matrix for axis d.
1987    pub fn materialize_second_diag(&self, axis: usize) -> Result<Array2<f64>, BasisError> {
1988        assert!(
1989            axis < self.n_axes(),
1990            "implicit psi second diagonal materialization axis out of bounds: axis={axis}, n_axes={}",
1991            self.n_axes()
1992        );
1993        if self.enforces_dense_materialization_budget() {
1994            assert_no_dense_derivative_materialization(self.n, self.p_out(), self.n_axes());
1995        }
1996        if self.axis_combinations.is_some() {
1997            let combo = self.transformed_axis_combination(axis);
1998            let combo_sum = Self::transformed_combo_sum(combo);
1999            if self.is_streaming() {
2000                let c = self.psi_scale_share;
2001                return self.streaming_materialize(|phi, q, t, sb| {
2002                    let s_combo = combo
2003                        .iter()
2004                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2005                        .sum();
2006                    let overlap_s = Self::transformed_combo_overlap_streaming(combo, combo, sb);
2007                    Self::transformed_second_kernel_value(
2008                        phi, q, t, s_combo, combo_sum, s_combo, combo_sum, overlap_s, c,
2009                    )
2010                });
2011            }
2012            let n = self.n;
2013            let k = self.n_knots;
2014            let c = self.psi_scale_share;
2015            let mut raw = Array2::<f64>::zeros((n, k));
2016            for i in 0..n {
2017                let base = i * k;
2018                for j in 0..k {
2019                    let idx = base + j;
2020                    let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
2021                    let overlap_s = self.transformed_combo_overlap_materialized(idx, combo, combo);
2022                    raw[[i, j]] = Self::transformed_second_kernel_value(
2023                        self.phi_values[idx],
2024                        self.q_values[idx],
2025                        self.t_values[idx],
2026                        s_combo,
2027                        combo_sum,
2028                        s_combo,
2029                        combo_sum,
2030                        overlap_s,
2031                        c,
2032                    );
2033                }
2034            }
2035            return Ok(self.project_matrix(raw));
2036        }
2037        if self.is_streaming() {
2038            let c = self.psi_scale_share;
2039            return self.streaming_materialize(|phi, q, t, sb| {
2040                let s = sb[axis];
2041                2.0 * q * s + t * s * s + 2.0 * c * q * s + c * c * phi
2042            });
2043        }
2044        let n = self.n;
2045        let k = self.n_knots;
2046        let c = self.psi_scale_share;
2047        let mut raw = Array2::<f64>::zeros((n, k));
2048        for i in 0..n {
2049            let base = i * k;
2050            for j in 0..k {
2051                let s = self.axis_components[[base + j, axis]];
2052                raw[[i, j]] = 2.0 * self.q_values[base + j] * s
2053                    + self.t_values[base + j] * s * s
2054                    + 2.0 * c * self.q_values[base + j] * s
2055                    + c * c * self.phi_values[base + j];
2056            }
2057        }
2058        Ok(self.project_matrix(raw))
2059    }
2060
2061    /// Materialize the full (n × p_out) cross second derivative matrix for axes (d, e).
2062    ///
2063    /// Dense materialization of the t · s_d · s_e cross coupling.
2064    pub fn materialize_second_cross(
2065        &self,
2066        axis_d: usize,
2067        axis_e: usize,
2068    ) -> Result<Array2<f64>, BasisError> {
2069        assert!(
2070            axis_d < self.n_axes(),
2071            "implicit psi second cross materialization first axis out of bounds: axis_d={axis_d}, n_axes={}",
2072            self.n_axes()
2073        );
2074        assert!(
2075            axis_e < self.n_axes(),
2076            "implicit psi second cross materialization second axis out of bounds: axis_e={axis_e}, n_axes={}",
2077            self.n_axes()
2078        );
2079        assert_ne!(
2080            axis_d, axis_e,
2081            "implicit psi second cross materialization requires distinct axes: axis_d={axis_d}, axis_e={axis_e}"
2082        );
2083        if self.enforces_dense_materialization_budget() {
2084            assert_no_dense_derivative_materialization(self.n, self.p_out(), self.n_axes());
2085        }
2086        if self.axis_combinations.is_some() {
2087            let combo_d = self.transformed_axis_combination(axis_d);
2088            let combo_e = self.transformed_axis_combination(axis_e);
2089            let sum_d = Self::transformed_combo_sum(combo_d);
2090            let sum_e = Self::transformed_combo_sum(combo_e);
2091            if self.is_streaming() {
2092                let c = self.psi_scale_share;
2093                return self.streaming_materialize(|phi, q, t, sb| {
2094                    let s_d = combo_d
2095                        .iter()
2096                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2097                        .sum();
2098                    let s_e = combo_e
2099                        .iter()
2100                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2101                        .sum();
2102                    let overlap_s = Self::transformed_combo_overlap_streaming(combo_d, combo_e, sb);
2103                    Self::transformed_second_kernel_value(
2104                        phi, q, t, s_d, sum_d, s_e, sum_e, overlap_s, c,
2105                    )
2106                });
2107            }
2108            let n = self.n;
2109            let k = self.n_knots;
2110            let c = self.psi_scale_share;
2111            let mut raw = Array2::<f64>::zeros((n, k));
2112            for i in 0..n {
2113                let base = i * k;
2114                for j in 0..k {
2115                    let idx = base + j;
2116                    let s_d = self.transformed_combo_axis_value_materialized(idx, combo_d);
2117                    let s_e = self.transformed_combo_axis_value_materialized(idx, combo_e);
2118                    let overlap_s =
2119                        self.transformed_combo_overlap_materialized(idx, combo_d, combo_e);
2120                    raw[[i, j]] = Self::transformed_second_kernel_value(
2121                        self.phi_values[idx],
2122                        self.q_values[idx],
2123                        self.t_values[idx],
2124                        s_d,
2125                        sum_d,
2126                        s_e,
2127                        sum_e,
2128                        overlap_s,
2129                        c,
2130                    );
2131                }
2132            }
2133            return Ok(self.project_matrix(raw));
2134        }
2135        if self.is_streaming() {
2136            let c = self.psi_scale_share;
2137            return self.streaming_materialize(|phi, q, t, sb| {
2138                t * sb[axis_d] * sb[axis_e] + c * q * (sb[axis_d] + sb[axis_e]) + c * c * phi
2139            });
2140        }
2141        let n = self.n;
2142        let k = self.n_knots;
2143        let c = self.psi_scale_share;
2144        let mut raw = Array2::<f64>::zeros((n, k));
2145        for i in 0..n {
2146            let base = i * k;
2147            for j in 0..k {
2148                raw[[i, j]] = self.t_values[base + j]
2149                    * self.axis_components[[base + j, axis_d]]
2150                    * self.axis_components[[base + j, axis_e]]
2151                    + c * self.q_values[base + j]
2152                        * (self.axis_components[[base + j, axis_d]]
2153                            + self.axis_components[[base + j, axis_e]])
2154                    + c * c * self.phi_values[base + j];
2155            }
2156        }
2157        Ok(self.project_matrix(raw))
2158    }
2159
2160    /// Project a raw (n × k) kernel-space matrix through all transforms to
2161    /// produce an (n × p_out) matrix: Z_kernel → pad poly → full ident.
2162    pub(crate) fn project_matrix(&self, raw: Array2<f64>) -> Array2<f64> {
2163        // Step 1: kernel constraint projection.
2164        let constrained = match &self.ident_transform {
2165            Some(z) => fast_ab(&raw, z),
2166            None => raw,
2167        };
2168
2169        // Step 2: polynomial padding.
2170        let padded = if self.n_poly > 0 {
2171            let cols = constrained.ncols();
2172            let mut out = Array2::<f64>::zeros((self.n, cols + self.n_poly));
2173            out.slice_mut(s![.., ..cols]).assign(&constrained);
2174            out
2175        } else {
2176            constrained
2177        };
2178
2179        // Step 3: full identifiability transform.
2180        match &self.full_ident_transform {
2181            Some(zf) => fast_ab(&padded, zf),
2182            None => padded,
2183        }
2184    }
2185
2186    pub(crate) fn project_matrix_rows(&self, raw: Array2<f64>) -> Array2<f64> {
2187        let nrows = raw.nrows();
2188        let constrained = match &self.ident_transform {
2189            Some(z) => fast_ab(&raw, z),
2190            None => raw,
2191        };
2192        let padded = if self.n_poly > 0 {
2193            let cols = constrained.ncols();
2194            let mut out = Array2::<f64>::zeros((nrows, cols + self.n_poly));
2195            out.slice_mut(s![.., ..cols]).assign(&constrained);
2196            out
2197        } else {
2198            constrained
2199        };
2200        match &self.full_ident_transform {
2201            Some(zf) => fast_ab(&padded, zf),
2202            None => padded,
2203        }
2204    }
2205
2206    pub(crate) fn row_chunk_with_kernel<G>(
2207        &self,
2208        rows: std::ops::Range<usize>,
2209        deriv_fn: G,
2210    ) -> Result<Array2<f64>, BasisError>
2211    where
2212        G: Fn(f64, f64, f64, &[f64], usize) -> f64,
2213    {
2214        let raw = self.row_chunk_with_kernel_raw(rows, deriv_fn)?;
2215        Ok(self.project_matrix_rows(raw))
2216    }
2217
2218    /// Like `row_chunk_with_kernel` but returns the raw (chunk × n_knots)
2219    /// kernel scalars without the identifiability/padding projection. Used
2220    /// by `forward_mul_matrix`, which does the projection on the rank side
2221    /// instead (`unproject_matrix(F)`) so the (n × p_out) projected
2222    /// derivative is never materialized for large-scale row counts.
2223    pub(crate) fn row_chunk_with_kernel_raw<G>(
2224        &self,
2225        rows: std::ops::Range<usize>,
2226        deriv_fn: G,
2227    ) -> Result<Array2<f64>, BasisError>
2228    where
2229        G: Fn(f64, f64, f64, &[f64], usize) -> f64,
2230    {
2231        let mut raw = Array2::<f64>::zeros((rows.end - rows.start, self.n_knots));
2232        if let Some(st) = self.streaming.as_ref() {
2233            let mut sb = vec![0.0; self.n_axes];
2234            if let Some(cache) = st.ensure_triplet_cache() {
2235                for (local, i) in rows.enumerate() {
2236                    let base = i * self.n_knots;
2237                    for j in 0..self.n_knots {
2238                        let idx = base + j;
2239                        st.fill_s_buf(i, j, &mut sb);
2240                        raw[[local, j]] =
2241                            deriv_fn(cache.phi[idx], cache.q[idx], cache.t[idx], &sb, idx);
2242                    }
2243                }
2244            } else {
2245                for (local, i) in rows.enumerate() {
2246                    for j in 0..self.n_knots {
2247                        let (phi, q, t) = st.compute_pair(i, j, &mut sb)?;
2248                        raw[[local, j]] = deriv_fn(phi, q, t, &sb, i * self.n_knots + j);
2249                    }
2250                }
2251            }
2252        } else {
2253            for (local, i) in rows.enumerate() {
2254                let base = i * self.n_knots;
2255                for j in 0..self.n_knots {
2256                    let idx = base + j;
2257                    raw[[local, j]] = deriv_fn(
2258                        self.phi_values[idx],
2259                        self.q_values[idx],
2260                        self.t_values[idx],
2261                        &[],
2262                        idx,
2263                    );
2264                }
2265            }
2266        }
2267        Ok(raw)
2268    }
2269
2270    pub fn row_chunk_first(
2271        &self,
2272        axis: usize,
2273        rows: std::ops::Range<usize>,
2274    ) -> Result<Array2<f64>, BasisError> {
2275        assert!(
2276            axis < self.n_axes(),
2277            "implicit psi first row chunk axis out of bounds: axis={axis}, n_axes={}",
2278            self.n_axes()
2279        );
2280        let c = self.psi_scale_share;
2281        if self.axis_combinations.is_some() {
2282            let combo = self.transformed_axis_combination(axis);
2283            let combo_sum = Self::transformed_combo_sum(combo);
2284            return self.row_chunk_with_kernel(rows, |phi, q, _, sb, idx| {
2285                let s_combo = if sb.is_empty() {
2286                    self.transformed_combo_axis_value_materialized(idx, combo)
2287                } else {
2288                    combo
2289                        .iter()
2290                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2291                        .sum()
2292                };
2293                Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
2294            });
2295        }
2296        self.row_chunk_with_kernel(rows, |phi, q, _, sb, idx| {
2297            let s = if sb.is_empty() {
2298                self.axis_components[[idx, axis]]
2299            } else {
2300                sb[axis]
2301            };
2302            q * s + c * phi
2303        })
2304    }
2305
2306    /// Raw (chunk × n_knots) first-order kernel scalars for axis d, without
2307    /// the identifiability/padding projection. Pairs with `unproject_matrix`
2308    /// in `forward_mul_matrix`: the kernel scalars stay in raw knot space
2309    /// while the rank side (F) is unprojected to knot space, so the per-chunk
2310    /// GEMM is (chunk × n_knots) · (n_knots × rank) rather than (chunk × p_out)
2311    /// · (p_out × rank). Saves both flops and a (chunk × p_out) intermediate.
2312    pub fn row_chunk_first_raw(
2313        &self,
2314        axis: usize,
2315        rows: std::ops::Range<usize>,
2316    ) -> Result<Array2<f64>, BasisError> {
2317        assert!(
2318            axis < self.n_axes(),
2319            "implicit psi first raw row chunk axis out of bounds: axis={axis}, n_axes={}",
2320            self.n_axes()
2321        );
2322        let c = self.psi_scale_share;
2323        if self.axis_combinations.is_some() {
2324            let combo = self.transformed_axis_combination(axis);
2325            let combo_sum = Self::transformed_combo_sum(combo);
2326            return self.row_chunk_with_kernel_raw(rows, |phi, q, _, sb, idx| {
2327                let s_combo = if sb.is_empty() {
2328                    self.transformed_combo_axis_value_materialized(idx, combo)
2329                } else {
2330                    combo
2331                        .iter()
2332                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2333                        .sum()
2334                };
2335                Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
2336            });
2337        }
2338        self.row_chunk_with_kernel_raw(rows, |phi, q, _, sb, idx| {
2339            let s = if sb.is_empty() {
2340                self.axis_components[[idx, axis]]
2341            } else {
2342                sb[axis]
2343            };
2344            q * s + c * phi
2345        })
2346    }
2347
2348    pub fn row_chunk_second_diag(
2349        &self,
2350        axis: usize,
2351        rows: std::ops::Range<usize>,
2352    ) -> Result<Array2<f64>, BasisError> {
2353        assert!(
2354            axis < self.n_axes(),
2355            "implicit psi second diagonal row chunk axis out of bounds: axis={axis}, n_axes={}",
2356            self.n_axes()
2357        );
2358        let c = self.psi_scale_share;
2359        if self.axis_combinations.is_some() {
2360            let combo = self.transformed_axis_combination(axis);
2361            let combo_sum = Self::transformed_combo_sum(combo);
2362            return self.row_chunk_with_kernel(rows, |phi, q, t, sb, idx| {
2363                let s_combo = if sb.is_empty() {
2364                    self.transformed_combo_axis_value_materialized(idx, combo)
2365                } else {
2366                    combo
2367                        .iter()
2368                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2369                        .sum()
2370                };
2371                let overlap = if sb.is_empty() {
2372                    self.transformed_combo_overlap_materialized(idx, combo, combo)
2373                } else {
2374                    Self::transformed_combo_overlap_streaming(combo, combo, sb)
2375                };
2376                Self::transformed_second_kernel_value(
2377                    phi, q, t, s_combo, combo_sum, s_combo, combo_sum, overlap, c,
2378                )
2379            });
2380        }
2381        self.row_chunk_with_kernel(rows, |phi, q, t, sb, idx| {
2382            let s = if sb.is_empty() {
2383                self.axis_components[[idx, axis]]
2384            } else {
2385                sb[axis]
2386            };
2387            2.0 * q * s + t * s * s + 2.0 * c * q * s + c * c * phi
2388        })
2389    }
2390
2391    pub fn row_chunk_second_cross(
2392        &self,
2393        axis_d: usize,
2394        axis_e: usize,
2395        rows: std::ops::Range<usize>,
2396    ) -> Result<Array2<f64>, BasisError> {
2397        assert!(
2398            axis_d < self.n_axes(),
2399            "implicit psi second cross row chunk first axis out of bounds: axis_d={axis_d}, n_axes={}",
2400            self.n_axes()
2401        );
2402        assert!(
2403            axis_e < self.n_axes(),
2404            "implicit psi second cross row chunk second axis out of bounds: axis_e={axis_e}, n_axes={}",
2405            self.n_axes()
2406        );
2407        assert_ne!(
2408            axis_d, axis_e,
2409            "implicit psi second cross row chunk requires distinct axes: axis_d={axis_d}, axis_e={axis_e}"
2410        );
2411        let c = self.psi_scale_share;
2412        if self.axis_combinations.is_some() {
2413            let combo_d = self.transformed_axis_combination(axis_d);
2414            let combo_e = self.transformed_axis_combination(axis_e);
2415            let sum_d = Self::transformed_combo_sum(combo_d);
2416            let sum_e = Self::transformed_combo_sum(combo_e);
2417            return self.row_chunk_with_kernel(rows, |phi, q, t, sb, idx| {
2418                let s_d = if sb.is_empty() {
2419                    self.transformed_combo_axis_value_materialized(idx, combo_d)
2420                } else {
2421                    combo_d
2422                        .iter()
2423                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2424                        .sum()
2425                };
2426                let s_e = if sb.is_empty() {
2427                    self.transformed_combo_axis_value_materialized(idx, combo_e)
2428                } else {
2429                    combo_e
2430                        .iter()
2431                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2432                        .sum()
2433                };
2434                let overlap = if sb.is_empty() {
2435                    self.transformed_combo_overlap_materialized(idx, combo_d, combo_e)
2436                } else {
2437                    Self::transformed_combo_overlap_streaming(combo_d, combo_e, sb)
2438                };
2439                Self::transformed_second_kernel_value(phi, q, t, s_d, sum_d, s_e, sum_e, overlap, c)
2440            });
2441        }
2442        self.row_chunk_with_kernel(rows, |phi, q, t, sb, idx| {
2443            let sd = if sb.is_empty() {
2444                self.axis_components[[idx, axis_d]]
2445            } else {
2446                sb[axis_d]
2447            };
2448            let se = if sb.is_empty() {
2449                self.axis_components[[idx, axis_e]]
2450            } else {
2451                sb[axis_e]
2452            };
2453            t * sd * se + c * q * (sd + se) + c * c * phi
2454        })
2455    }
2456
2457    /// Single-row specialization of `row_chunk_first(axis, row..row+1)` that
2458    /// writes the length-`p_out` row directly into the caller-provided buffer.
2459    ///
2460    /// This is the row-local API used by `CustomFamilyPsiLinearMapRef::row_vector`
2461    /// for survival rowwise exact-Hessian paths, which previously applied a
2462    /// unit-vector `transpose_mul` trick (O(n·K) per row) to recover a single
2463    /// row. Avoids allocating a temporary (1 × p_out) matrix per row call.
2464    pub fn row_vector_first_into(
2465        &self,
2466        axis: usize,
2467        row: usize,
2468        mut out: ArrayViewMut1<'_, f64>,
2469    ) -> Result<(), BasisError> {
2470        assert!(
2471            row < self.n,
2472            "implicit psi row-vector request out of bounds: row={row}, n={}",
2473            self.n
2474        );
2475        assert_eq!(
2476            out.len(),
2477            self.p_out(),
2478            "implicit psi row-vector output length mismatch"
2479        );
2480        let chunk = self.row_chunk_first(axis, row..row + 1)?;
2481        out.assign(&chunk.row(0));
2482        Ok(())
2483    }
2484
2485    pub(crate) fn transformed_axis_combination(&self, axis: usize) -> &[(usize, f64)] {
2486        self.axis_combinations
2487            .as_ref()
2488            .expect("transformed axis combinations")
2489            .get(axis)
2490            .map(Vec::as_slice)
2491            .expect("transformed axis index")
2492    }
2493
2494    #[inline]
2495    pub(crate) fn transformed_combo_sum(combo: &[(usize, f64)]) -> f64 {
2496        combo.iter().map(|(_, coeff)| *coeff).sum()
2497    }
2498
2499    #[inline]
2500    pub(crate) fn transformed_combo_axis_value_materialized(
2501        &self,
2502        idx: usize,
2503        combo: &[(usize, f64)],
2504    ) -> f64 {
2505        combo
2506            .iter()
2507            .map(|(raw_axis, coeff)| coeff * self.axis_components[[idx, *raw_axis]])
2508            .sum()
2509    }
2510
2511    #[inline]
2512    pub(crate) fn transformed_combo_overlap_streaming(
2513        combo_left: &[(usize, f64)],
2514        combo_right: &[(usize, f64)],
2515        sb: &[f64],
2516    ) -> f64 {
2517        let mut overlap = 0.0;
2518        for &(left_axis, left_coeff) in combo_left {
2519            for &(right_axis, right_coeff) in combo_right {
2520                if left_axis == right_axis {
2521                    overlap += left_coeff * right_coeff * sb[left_axis];
2522                }
2523            }
2524        }
2525        overlap
2526    }
2527
2528    #[inline]
2529    pub(crate) fn transformed_combo_overlap_materialized(
2530        &self,
2531        idx: usize,
2532        combo_left: &[(usize, f64)],
2533        combo_right: &[(usize, f64)],
2534    ) -> f64 {
2535        let mut overlap = 0.0;
2536        for &(left_axis, left_coeff) in combo_left {
2537            for &(right_axis, right_coeff) in combo_right {
2538                if left_axis == right_axis {
2539                    overlap += left_coeff * right_coeff * self.axis_components[[idx, left_axis]];
2540                }
2541            }
2542        }
2543        overlap
2544    }
2545
2546    #[inline]
2547    pub(crate) fn transformed_first_kernel_value(
2548        phi: f64,
2549        q: f64,
2550        s_combo: f64,
2551        coeff_sum: f64,
2552        psi_scale_share: f64,
2553    ) -> f64 {
2554        q * s_combo + psi_scale_share * coeff_sum * phi
2555    }
2556
2557    #[inline]
2558    pub(crate) fn transformed_second_kernel_value(
2559        phi: f64,
2560        q: f64,
2561        t: f64,
2562        s_left: f64,
2563        left_sum: f64,
2564        s_right: f64,
2565        right_sum: f64,
2566        overlap_s: f64,
2567        psi_scale_share: f64,
2568    ) -> f64 {
2569        t * s_left * s_right
2570            + 2.0 * q * overlap_s
2571            + psi_scale_share * q * (right_sum * s_left + left_sum * s_right)
2572            + psi_scale_share * psi_scale_share * left_sum * right_sum * phi
2573    }
2574}
2575
2576pub(crate) fn build_aniso_design_psi_derivatives_shared(
2577    data: ArrayView2<'_, f64>,
2578    centers: ArrayView2<'_, f64>,
2579    eta: &[f64],
2580    p_final: usize,
2581    ident_transform: Option<Array2<f64>>,
2582    full_ident_transform: Option<Array2<f64>>,
2583    n_poly: usize,
2584    radial_kind: RadialScalarKind,
2585) -> Result<AnisoBasisPsiDerivatives, BasisError> {
2586    let n = data.nrows();
2587    let k = centers.nrows();
2588    let dim = data.ncols();
2589    if eta.len() != dim {
2590        crate::bail_dim_basis!(
2591            "aniso design derivatives: eta.len()={} != data dimension {dim}",
2592            eta.len()
2593        );
2594    }
2595
2596    let policy = gam_runtime::resource::ResourcePolicy::default_library();
2597    let force_operator = radial_kind.is_duchon_family();
2598    let dense_derivatives_exceed_budget =
2599        should_use_implicit_operators_with_policy(n, p_final, dim, &policy);
2600    let operator_only = force_operator || dense_derivatives_exceed_budget;
2601    let cache_radial_components = should_cache_implicit_radial_components(n, k, dim, &policy);
2602    // gam#1376 — the per-axis ψ derivatives this operator produces are ALREADY
2603    // the derivatives w.r.t. the κ-optimizer's raw coordinate, so NO cross-axis
2604    // centering projection is installed (for any family). The optimizer's per-
2605    // axis coordinate `psi_a` is decoded into both the global length scale
2606    // `ℓ = exp(−mean(psi))` and the centered contrast `eta_a = psi_a − mean(psi)`
2607    // simultaneously; in the kernel argument `x² = r²/ℓ² = Σ_a exp(2·psi_a)·h_a²`
2608    // the `mean(psi)` cancels, so the effective per-axis exponent is the raw
2609    // `psi_a` and `∂φ/∂psi_a = q·s_a` is the native per-axis ψ derivative. The
2610    // earlier `with_raw_eta_centering` projection annihilated the all-ones
2611    // (global-scale) direction and broke the analytic↔FD match (rel≈0.85). The
2612    // dense path (`build_matern_basis_log_kappa_aniso_derivatives`) is corrected
2613    // identically — it no longer centers downstream.
2614
2615    // ── Streaming path: large scale ─────────────────────────────────────
2616    // When even the compact radial cache would exceed the operator-cache
2617    // budget, store only data/centers/eta/radial_kind and recompute
2618    // (q, t, s_a) chunkwise during each matvec. Otherwise the operator-only
2619    // path below caches phi/q/t/s_a without materializing dense derivative
2620    // matrices.
2621    if operator_only && !cache_radial_components {
2622        let op = ImplicitDesignPsiDerivative::new_streaming(
2623            shared_owned_data_matrix_from_view(data),
2624            shared_owned_centers_matrix_from_view(centers),
2625            eta.to_vec(),
2626            radial_kind,
2627            ident_transform,
2628            full_ident_transform,
2629            n_poly,
2630        );
2631        return Ok(AnisoBasisPsiDerivatives {
2632            design_first: Vec::new(),
2633            design_second_diag: Vec::new(),
2634            design_second_cross: Vec::new(),
2635            design_second_cross_pairs: Vec::new(),
2636            penalties_first: vec![Vec::new(); dim],
2637            penalties_second_diag: vec![Vec::new(); dim],
2638            penalties_cross_pairs: Vec::new(),
2639            penalties_cross_provider: None,
2640            implicit_operator: Some(op),
2641        });
2642    }
2643
2644    // ── Materialized radial-cache path ────────────────────────────────────
2645    // Allocate O(n*k) arrays up front and fill with parallel chunks that
2646    // write directly into preallocated storage via raw pointers. No
2647    // intermediate Vec<(i, q_row, t_row, s_row)> collection.
2648    let nk = n.checked_mul(k).ok_or_else(|| {
2649        BasisError::InvalidInput("aniso radial cache has too many data-center pairs".to_string())
2650    })?;
2651    if nk.checked_mul(dim).is_none() {
2652        crate::bail_invalid_basis!("aniso radial cache axis component storage is too large");
2653    }
2654    let mut phi_values = Array1::<f64>::zeros(nk);
2655    let mut q_values = Array1::<f64>::zeros(nk);
2656    let mut t_values = Array1::<f64>::zeros(nk);
2657    let mut axis_components = Array2::<f64>::zeros((nk, dim));
2658
2659    let psi_scale_share = radial_kind.raw_psi_isotropic_share();
2660
2661    let cs = IMPLICIT_MATVEC_CHUNK_SIZE;
2662    let nc = n.div_ceil(cs);
2663    // Capture the *first* underlying radial-evaluation error rather than a
2664    // bare boolean: at an extreme trial hyperparameter the anisotropic
2665    // distance `r` can push the Duchon/Matérn radial kernel out of its
2666    // evaluable range, and the caller (the spatial-κ optimizer) needs the
2667    // real cause to decide whether the trial point is merely infeasible
2668    // (retreat) versus a genuine invariant violation (abort). Swallowing it
2669    // as "radial scalar evaluation failed" hid both the cause and the
2670    // recoverability.
2671    let first_err: std::sync::Mutex<Option<BasisError>> = std::sync::Mutex::new(None);
2672    // For large sweeps, replace per-pair exact radial evaluation with a
2673    // certified 1-D Chebyshev profile built once from a distance-only
2674    // pre-pass over the radius range (see `radial_profile`): at the 16-D
2675    // power-9 hybrid Duchon configuration a single exact triplet costs tens
2676    // of microseconds across its partial-fraction blocks, and this n·k
2677    // sweep was the dominant per-κ-trial cost of large-scale fits (#979).
2678    // Out-of-range radii and uncertified builds fall back to the exact
2679    // evaluator per pair.
2680    let profile = if nk >= RADIAL_PROFILE_MIN_PAIRS {
2681        let mut r_lo = f64::INFINITY;
2682        let mut r_hi = 0.0_f64;
2683        let mut drb = vec![0.0; dim];
2684        let mut cb = vec![0.0; dim];
2685        for i in 0..n {
2686            for a in 0..dim {
2687                drb[a] = data[[i, a]];
2688            }
2689            for j in 0..k {
2690                for a in 0..dim {
2691                    cb[a] = centers[[j, a]];
2692                }
2693                let (r, _) = aniso_distance_and_components(&drb, &cb, eta);
2694                if r > 0.0 {
2695                    r_lo = r_lo.min(r);
2696                    r_hi = r_hi.max(r);
2697                }
2698            }
2699        }
2700        if r_lo.is_finite() && r_hi > r_lo {
2701            radial_profile::RadialProfile::build(&radial_kind, r_lo, r_hi)
2702        } else {
2703            None
2704        }
2705    } else {
2706        None
2707    };
2708    {
2709        let pp = SendPtr(phi_values.as_mut_ptr());
2710        let qp = SendPtr(q_values.as_mut_ptr());
2711        let tp = SendPtr(t_values.as_mut_ptr());
2712        let ap = SendPtr(axis_components.as_mut_ptr());
2713        let ferr = &first_err;
2714        let profile_ref = profile.as_ref();
2715        (0..nc).into_par_iter().for_each(move |ci| {
2716            let start = ci * cs;
2717            let end = start.saturating_add(cs).min(n);
2718            let mut drb = vec![0.0; dim];
2719            let mut cb = vec![0.0; dim];
2720            for i in start..end {
2721                for a in 0..dim {
2722                    drb[a] = data[[i, a]];
2723                }
2724                for j in 0..k {
2725                    for a in 0..dim {
2726                        cb[a] = centers[[j, a]];
2727                    }
2728                    let (r, sv) = aniso_distance_and_components(&drb, &cb, eta);
2729                    let triplet = match profile_ref {
2730                        Some(profile) => profile.eval_or_exact(&radial_kind, r),
2731                        None => radial_kind.eval_design_triplet(r),
2732                    };
2733                    let (phi, q, t) = match triplet {
2734                        Ok(p) => p,
2735                        Err(e) => {
2736                            let mut slot = ferr.lock().unwrap_or_else(|p| p.into_inner());
2737                            if slot.is_none() {
2738                                *slot = Some(e);
2739                            }
2740                            return;
2741                        }
2742                    };
2743                    let flat = i * k + j;
2744                    // SAFETY: each Rayon chunk owns a disjoint i-row range,
2745                    // so flat=i*k+j stays in 0..nk for phi/q/t and
2746                    // flat*dim+a stays in 0..nk*dim for axis_components.
2747                    unsafe {
2748                        *pp.add(flat) = phi;
2749                        *qp.add(flat) = q;
2750                        *tp.add(flat) = t;
2751                        for a in 0..dim {
2752                            *ap.add(flat * dim + a) = sv[a];
2753                        }
2754                    }
2755                }
2756            }
2757        });
2758    }
2759    if let Some(cause) = first_err.into_inner().unwrap_or_else(|p| p.into_inner()) {
2760        return Err(BasisError::InvalidInput(format!(
2761            "radial scalar evaluation failed during aniso derivative construction \
2762             (eta={eta:?}): {cause}"
2763        )));
2764    }
2765
2766    let op = ImplicitDesignPsiDerivative::new(
2767        phi_values,
2768        q_values,
2769        t_values,
2770        axis_components,
2771        ident_transform,
2772        full_ident_transform,
2773        n,
2774        k,
2775        n_poly,
2776        dim,
2777    )
2778    .with_psi_scale_share(psi_scale_share);
2779
2780    // gam#1376 — the operator stays in the NATIVE per-axis ψ frame (no
2781    // `with_raw_eta_centering`): the κ-optimizer coordinate `psi_a` already maps
2782    // to the effective per-axis exponent `psi_a` of the kernel argument (the
2783    // `mean(psi)` it injects into the centered contrast is exactly cancelled by
2784    // the `ℓ = exp(−mean(psi))` it injects into the length scale), so the native
2785    // `∂φ/∂psi_a` produced by `materialize_first`/`materialize_second_*` (and by
2786    // the operator matvecs) is the correct raw-coordinate derivative. The
2787    // earlier centering broke the analytic↔FD match — see the comment above.
2788
2789    if operator_only {
2790        return Ok(AnisoBasisPsiDerivatives {
2791            design_first: Vec::new(),
2792            design_second_diag: Vec::new(),
2793            design_second_cross: Vec::new(),
2794            design_second_cross_pairs: Vec::new(),
2795            penalties_first: vec![Vec::new(); dim],
2796            penalties_second_diag: vec![Vec::new(); dim],
2797            penalties_cross_pairs: Vec::new(),
2798            penalties_cross_provider: None,
2799            implicit_operator: Some(op),
2800        });
2801    }
2802
2803    let design_first = (0..dim)
2804        .map(|a| op.materialize_first(a))
2805        .collect::<Result<Vec<_>, _>>()?;
2806    let design_second_diag = (0..dim)
2807        .map(|a| op.materialize_second_diag(a))
2808        .collect::<Result<Vec<_>, _>>()?;
2809
2810    Ok(AnisoBasisPsiDerivatives {
2811        design_first,
2812        design_second_diag,
2813        design_second_cross: Vec::new(),
2814        design_second_cross_pairs: Vec::new(),
2815        penalties_first: vec![Vec::new(); dim],
2816        penalties_second_diag: vec![Vec::new(); dim],
2817        penalties_cross_pairs: Vec::new(),
2818        penalties_cross_provider: None,
2819        implicit_operator: Some(op),
2820    })
2821}
2822
2823#[derive(Debug, Clone)]
2824pub(crate) struct ScalarDesignPsiDerivatives {
2825    pub(crate) design_first: Array2<f64>,
2826    pub(crate) design_second_diag: Array2<f64>,
2827    pub(crate) implicit_operator: Option<ImplicitDesignPsiDerivative>,
2828}
2829
2830pub(crate) fn build_scalar_design_psi_derivatives_shared(
2831    data: ArrayView2<'_, f64>,
2832    centers: ArrayView2<'_, f64>,
2833    fixed_eta: Option<&[f64]>,
2834    p_final: usize,
2835    ident_transform: Option<Array2<f64>>,
2836    full_ident_transform: Option<Array2<f64>>,
2837    n_poly: usize,
2838    radial_kind: RadialScalarKind,
2839    psi_scale_share: f64,
2840) -> Result<ScalarDesignPsiDerivatives, BasisError> {
2841    let n = data.nrows();
2842    let k = centers.nrows();
2843    let dim = data.ncols();
2844    if let Some(eta) = fixed_eta
2845        && eta.len() != dim
2846    {
2847        crate::bail_dim_basis!(
2848            "scalar design derivatives: eta.len()={} != data dimension {dim}",
2849            eta.len()
2850        );
2851    }
2852
2853    let policy = gam_runtime::resource::ResourcePolicy::default_library();
2854    let force_operator = radial_kind.is_duchon_family();
2855    let dense_derivatives_exceed_budget =
2856        should_use_implicit_operators_with_policy(n, p_final, 1, &policy);
2857    let operator_only = force_operator || dense_derivatives_exceed_budget;
2858    let cache_radial_components = should_cache_implicit_radial_components(n, k, 1, &policy);
2859    if operator_only && !cache_radial_components {
2860        let metric_eta = fixed_eta
2861            .map(|eta| eta.to_vec())
2862            .unwrap_or_else(|| vec![0.0; dim]);
2863        let op = ImplicitDesignPsiDerivative::new_streaming_scalar(
2864            shared_owned_data_matrix_from_view(data),
2865            shared_owned_centers_matrix_from_view(centers),
2866            metric_eta,
2867            radial_kind,
2868            ident_transform,
2869            full_ident_transform,
2870            n_poly,
2871        )
2872        .with_psi_scale_share(psi_scale_share);
2873        return Ok(ScalarDesignPsiDerivatives {
2874            design_first: Array2::<f64>::zeros((0, 0)),
2875            design_second_diag: Array2::<f64>::zeros((0, 0)),
2876            implicit_operator: Some(op),
2877        });
2878    }
2879
2880    let nk = n.checked_mul(k).ok_or_else(|| {
2881        BasisError::InvalidInput("scalar radial cache has too many data-center pairs".to_string())
2882    })?;
2883    let mut phi_values = Array1::<f64>::zeros(nk);
2884    let mut q_values = Array1::<f64>::zeros(nk);
2885    let mut t_values = Array1::<f64>::zeros(nk);
2886    let mut axis_components = Array2::<f64>::zeros((nk, 1));
2887
2888    let cs = IMPLICIT_MATVEC_CHUNK_SIZE;
2889    let nc = n.div_ceil(cs);
2890    let first_err: std::sync::Mutex<Option<BasisError>> = std::sync::Mutex::new(None);
2891    // Same certified radial-profile amortization as the per-axis sweep
2892    // above: one distance-only pre-pass for the radius range, one profile
2893    // build, Clenshaw per pair, exact fallback out of range (#979).
2894    let pair_r = |i: usize, j: usize, drb: &mut [f64], cb: &mut [f64]| -> f64 {
2895        if let Some(eta) = fixed_eta {
2896            for a in 0..dim {
2897                drb[a] = data[[i, a]];
2898                cb[a] = centers[[j, a]];
2899            }
2900            aniso_distance_and_components(drb, cb, eta).0
2901        } else {
2902            stable_euclidean_norm((0..dim).map(|a| data[[i, a]] - centers[[j, a]]))
2903        }
2904    };
2905    let profile = if nk >= RADIAL_PROFILE_MIN_PAIRS {
2906        let mut r_lo = f64::INFINITY;
2907        let mut r_hi = 0.0_f64;
2908        let mut drb = vec![0.0; dim];
2909        let mut cb = vec![0.0; dim];
2910        for i in 0..n {
2911            for j in 0..k {
2912                let r = pair_r(i, j, &mut drb, &mut cb);
2913                if r > 0.0 {
2914                    r_lo = r_lo.min(r);
2915                    r_hi = r_hi.max(r);
2916                }
2917            }
2918        }
2919        if r_lo.is_finite() && r_hi > r_lo {
2920            radial_profile::RadialProfile::build(&radial_kind, r_lo, r_hi)
2921        } else {
2922            None
2923        }
2924    } else {
2925        None
2926    };
2927    {
2928        let pp = SendPtr(phi_values.as_mut_ptr());
2929        let qp = SendPtr(q_values.as_mut_ptr());
2930        let tp = SendPtr(t_values.as_mut_ptr());
2931        let ap = SendPtr(axis_components.as_mut_ptr());
2932        let ferr = &first_err;
2933        let profile_ref = profile.as_ref();
2934        (0..nc).into_par_iter().for_each(move |ci| {
2935            let start = ci * cs;
2936            let end = start.saturating_add(cs).min(n);
2937            let mut data_row_buf = vec![0.0; dim];
2938            let mut center_buf = vec![0.0; dim];
2939            for i in start..end {
2940                for a in 0..dim {
2941                    data_row_buf[a] = data[[i, a]];
2942                }
2943                for j in 0..k {
2944                    let (r, scalar_component) = if let Some(eta) = fixed_eta {
2945                        for a in 0..dim {
2946                            center_buf[a] = centers[[j, a]];
2947                        }
2948                        let (r, components) =
2949                            aniso_distance_and_components(&data_row_buf, &center_buf, eta);
2950                        (r, components.into_iter().sum::<f64>())
2951                    } else {
2952                        let r =
2953                            stable_euclidean_norm((0..dim).map(|a| data[[i, a]] - centers[[j, a]]));
2954                        (r, r * r)
2955                    };
2956                    let triplet = match profile_ref {
2957                        Some(profile) => profile.eval_or_exact(&radial_kind, r),
2958                        None => radial_kind.eval_design_triplet(r),
2959                    };
2960                    let (phi, q, t) = match triplet {
2961                        Ok(p) => p,
2962                        Err(e) => {
2963                            let mut slot = ferr.lock().unwrap_or_else(|p| p.into_inner());
2964                            if slot.is_none() {
2965                                *slot = Some(e);
2966                            }
2967                            return;
2968                        }
2969                    };
2970                    let flat = i * k + j;
2971                    // SAFETY: each Rayon chunk owns a disjoint i-row range
2972                    // of the nk-long phi/q/t/axis buffers, so flat=i*k+j is
2973                    // in-bounds for every write and never aliases another worker.
2974                    unsafe {
2975                        *pp.add(flat) = phi;
2976                        *qp.add(flat) = q;
2977                        *tp.add(flat) = t;
2978                        *ap.add(flat) = scalar_component;
2979                    }
2980                }
2981            }
2982        });
2983    }
2984    if let Some(cause) = first_err.into_inner().unwrap_or_else(|p| p.into_inner()) {
2985        return Err(BasisError::InvalidInput(format!(
2986            "radial scalar evaluation failed during scalar derivative construction: {cause}"
2987        )));
2988    }
2989
2990    let op = ImplicitDesignPsiDerivative::new(
2991        phi_values,
2992        q_values,
2993        t_values,
2994        axis_components,
2995        ident_transform,
2996        full_ident_transform,
2997        n,
2998        k,
2999        n_poly,
3000        1,
3001    )
3002    .with_psi_scale_share(psi_scale_share);
3003
3004    if operator_only {
3005        return Ok(ScalarDesignPsiDerivatives {
3006            design_first: Array2::<f64>::zeros((0, 0)),
3007            design_second_diag: Array2::<f64>::zeros((0, 0)),
3008            implicit_operator: Some(op),
3009        });
3010    }
3011
3012    Ok(ScalarDesignPsiDerivatives {
3013        design_first: op.materialize_first(0)?,
3014        design_second_diag: op.materialize_second_diag(0)?,
3015        implicit_operator: Some(op),
3016    })
3017}