Skip to main content

gam_terms/basis/
implicit_psi_derivative.rs

1use super::*;
2
3/// Implicit representation of ∂X/∂ψ_d that supports matrix-vector products
4/// without materializing the full (n x p) derivative matrices.
5///
6/// For anisotropic Matern / Duchon terms with D axes, the dense path creates
7/// D matrices of size (n x p_smooth) for dX/dpsi_d. At n=400K, p=2000, D=16,
8/// that is ~100 GB.
9///
10/// Two storage modes:
11///
12/// **Materialized** (small-to-medium problems): stores pre-computed arrays
13/// - `phi_values[i*n_knots + j]` = phi(r_{ij})
14/// - `q_values[i*n_knots + j]` = phi'(r_{ij}) / r_{ij}
15/// - `t_values[i*n_knots + j]` = (phi''(r_{ij}) - q_{ij}) / r_{ij}^2
16/// - `axis_components[i*n_knots + j, d]` = exp(2 eta_d) * (x_{id} - c_{jd})^2
17/// Memory: O(n * k * (D + 2)).
18///
19/// **Streaming** (large scale): stores only data/centers/eta/kernel params
20/// and recomputes (q, t, s_a) on the fly during each matvec.
21/// Memory: O(n*d + k*d) -- no per-(data,knot) storage.
22///
23/// The raw-psi chain rule:
24///   shape_a   = q * s_a
25///   shape_ab  = t * s_a * s_b + 2 q s_a 1[a=b]
26///   dphi/dpsi_a         = shape_a + c * phi
27///   d2phi/(dpsi_a dpsi_b) = shape_ab + c (shape_a + shape_b) + c^2 phi
28/// where `c = 0` for Matérn and `c = delta / d` for hybrid Duchon.
29#[derive(Debug, Clone)]
30pub struct ImplicitDesignPsiDerivative {
31    /// Pre-computed kernel values (materialized mode).
32    /// Shape: (n * n_knots,). Empty in streaming mode.
33    pub(crate) phi_values: Array1<f64>,
34
35    /// Pre-computed per (data, knot) pair axis components (materialized mode).
36    /// Shape: (n * n_knots, D) stored in row-major order.
37    /// Empty (0x0) in streaming mode.
38    pub(crate) axis_components: Array2<f64>,
39
40    /// Pre-computed R-operator first scalar (materialized mode).
41    /// Shape: (n * n_knots,). Empty in streaming mode.
42    pub(crate) q_values: Array1<f64>,
43
44    /// Pre-computed R-operator second scalar (materialized mode).
45    /// Shape: (n * n_knots,). Empty in streaming mode.
46    pub(crate) t_values: Array1<f64>,
47
48    /// When set, enables streaming recomputation of q/t/s from raw inputs
49    /// instead of reading from the pre-computed arrays above.
50    pub(crate) streaming: Option<StreamingRadialState>,
51
52    /// Identifiability/constraint transform Z: (n_knots x p_constrained).
53    /// Gauge ownership is upstream; the implicit operator stores this frozen
54    /// section only so forward/transpose matvecs can apply the already-gauged
55    /// chart without materializing derivative matrices. For Duchon this is the
56    /// kernel-constraint nullspace Z_kernel; for Matern with identifiability
57    /// constraints, it is the corresponding Z. `None` means the identity.
58    pub(crate) ident_transform: Option<Array2<f64>>,
59
60    /// Optional full identifiability transform applied after Z_kernel + padding.
61    /// This is likewise replay/application metadata for the matrix-free
62    /// operator, not a second coefficient-coordinate owner. For Duchon terms
63    /// that have an additional global identifiability transform, this is applied
64    /// after the kernel constraint and polynomial padding.
65    /// Shape: (p_constrained + n_poly, p_final).
66    pub(crate) full_ident_transform: Option<Array2<f64>>,
67
68    /// Number of data points.
69    pub(crate) n: usize,
70
71    /// Number of knots (raw basis functions before identifiability transform).
72    pub(crate) n_knots: usize,
73
74    /// Number of polynomial columns appended after the smooth part.
75    /// These have zero derivative with respect to psi_d.
76    pub(crate) n_poly: usize,
77
78    /// Number of axes (dimension D).
79    pub(crate) n_axes: usize,
80
81    /// Isotropic scaling contribution per raw anisotropic psi axis.
82    pub(crate) psi_scale_share: f64,
83
84    /// Optional exposed-axis to raw-axis linear combinations.
85    /// When present, axis `a` represents Σ_i coeff_i * raw_axis_i.
86    pub(crate) axis_combinations: Option<Vec<Vec<(usize, f64)>>>,
87}
88
89/// Streaming design derivative for one per-row latent coordinate `t[n, a]`.
90///
91/// The operator stores the shared latent matrix plus either radial-kernel
92/// ingredients or a precomputed non-radial derivative jet. Individual REML
93/// hyper-directions carry only a flat coordinate index and call
94/// `forward_mul_axis` / `transpose_mul_axis` to expose the corresponding
95/// one-row design derivative on demand.
96pub struct LatentCoordDesignDerivative {
97    pub(crate) provider: Arc<dyn LocalDesignJacobianProvider>,
98}
99
100#[derive(Debug, Clone)]
101pub(crate) struct RadialLatentCoordLocalDesignJacobian {
102    pub(crate) latent: Arc<crate::latent::LatentCoordValues>,
103    pub(crate) centers: Arc<Array2<f64>>,
104    pub(crate) radial_kind: RadialScalarKind,
105    pub(crate) ident_transform: Option<Array2<f64>>,
106    pub(crate) full_ident_transform: Option<Array2<f64>>,
107    pub(crate) n_poly: usize,
108    pub(crate) polynomial_order: Option<DuchonNullspaceOrder>,
109}
110
111#[derive(Debug, Clone)]
112pub(crate) struct JetLatentCoordLocalDesignJacobian {
113    pub(crate) latent: Arc<crate::latent::LatentCoordValues>,
114    pub(crate) jet: Arc<Array3<f64>>,
115    pub(crate) ident_transform: Option<Array2<f64>>,
116}
117
118impl std::fmt::Debug for LatentCoordDesignDerivative {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("LatentCoordDesignDerivative")
121            .field("n_data", &self.n_data())
122            .field("latent_dim", &self.latent_dim())
123            .field("n_axes", &self.n_axes())
124            .field("p_out", &self.p_out())
125            .field("provider", &self.provider)
126            .finish()
127    }
128}
129
130impl Clone for LatentCoordDesignDerivative {
131    fn clone(&self) -> Self {
132        Self {
133            provider: Arc::clone(&self.provider),
134        }
135    }
136}
137
138impl RadialLatentCoordLocalDesignJacobian {
139    pub(crate) fn p_constrained(&self) -> usize {
140        self.ident_transform
141            .as_ref()
142            .map_or(self.centers.nrows(), Array2::ncols)
143    }
144
145    pub(crate) fn p_after_pad(&self) -> usize {
146        self.p_constrained() + self.n_poly
147    }
148
149    pub(crate) fn p_out(&self) -> usize {
150        self.full_ident_transform
151            .as_ref()
152            .map_or(self.p_after_pad(), Array2::ncols)
153    }
154}
155
156impl JetLatentCoordLocalDesignJacobian {
157    pub(crate) fn p_out(&self) -> usize {
158        self.ident_transform
159            .as_ref()
160            .map_or(self.jet.shape()[1], Array2::ncols)
161    }
162}
163
164/// The complete contract a per-row latent / novel-manifold coordinate type must
165/// supply to participate in the REML design-derivative operator surface.
166///
167/// Onboarding a new coordinate type (the SAE / novel-manifold frontier) reduces
168/// to implementing the small set of *required* methods below — the coordinate
169/// geometry (`n_data`, `latent_dim`, `n_axes`) plus the single genuinely-new
170/// payload `local_design_jacobian_row` (the local block ∂(design row)/∂(coord)).
171/// The streaming operator surface consumed by `LatentCoordDerivativeOp` in
172/// `src/solver/reml/mod.rs` — forward matvec, transpose matvec, and dense
173/// materialization, together with the flat-axis → (row, axis) decode — is
174/// inherited as *default* methods and never re-implemented per coordinate type.
175///
176/// This is the close condition for #767: a new coordinate type touches zero
177/// operator-surface code; it provides only its local Jacobian and geometry.
178pub trait LocalDesignJacobianProvider: Send + Sync + std::fmt::Debug {
179    /// Number of data rows `n` the operator spans.
180    fn n_data(&self) -> usize;
181
182    /// Latent coordinate dimension `d` (perturbation axes per row).
183    fn latent_dim(&self) -> usize;
184
185    /// Number of flat hyper-axes `n · d` (one per (row, coordinate-axis) pair).
186    fn n_axes(&self) -> usize;
187
188    /// Number of output-basis columns in each local design-Jacobian row.
189    fn p_out(&self) -> usize;
190
191    /// The only per-coordinate payload: the projected local design-Jacobian row
192    /// ∂(design row `row`)/∂(coordinate axis `axis`) in output-basis columns.
193    fn local_design_jacobian_row(&self, row: usize, axis: usize)
194    -> Result<Array1<f64>, BasisError>;
195
196    /// Decode a flat hyper-axis into its `(row, coordinate axis)`. Row-major over
197    /// `(row, axis)` with stride `latent_dim`; uniform across coordinate types.
198    fn row_axis(&self, flat_axis: usize) -> (usize, usize) {
199        let d = self.latent_dim();
200        (flat_axis / d, flat_axis % d)
201    }
202
203    /// Forward matvec for one flat hyper-axis: place `J_row · u` at `row`.
204    fn forward_mul_axis(
205        &self,
206        flat_axis: usize,
207        u: &ArrayView1<'_, f64>,
208    ) -> Result<Array1<f64>, BasisError> {
209        assert!(
210            flat_axis < self.n_axes(),
211            "latent-coordinate derivative flat axis out of bounds in forward_mul_axis: flat_axis={flat_axis}, n_axes={}",
212            self.n_axes()
213        );
214        let (row, axis) = self.row_axis(flat_axis);
215        let local_jacobian = self.local_design_jacobian_row(row, axis)?;
216        assert_eq!(
217            u.len(),
218            local_jacobian.len(),
219            "latent-coordinate derivative coefficient length mismatch in forward_mul_axis"
220        );
221        let value = local_jacobian.dot(u);
222        let mut out = Array1::<f64>::zeros(self.n_data());
223        out[row] = value;
224        Ok(out)
225    }
226
227    /// Transpose matvec for one flat hyper-axis: scatter `v[row] · J_rowᵀ`.
228    fn transpose_mul_axis(
229        &self,
230        flat_axis: usize,
231        v: &ArrayView1<'_, f64>,
232    ) -> Result<Array1<f64>, BasisError> {
233        assert!(
234            flat_axis < self.n_axes(),
235            "latent-coordinate derivative flat axis out of bounds in transpose_mul_axis: flat_axis={flat_axis}, n_axes={}",
236            self.n_axes()
237        );
238        assert_eq!(
239            v.len(),
240            self.n_data(),
241            "latent-coordinate derivative row-adjoint length mismatch in transpose_mul_axis"
242        );
243        let (row, axis) = self.row_axis(flat_axis);
244        let scale = v[row];
245        Ok(self
246            .local_design_jacobian_row(row, axis)?
247            .mapv(|value| scale * value))
248    }
249
250    /// Dense `(n_data × p_out)` materialization of one flat hyper-axis: the local
251    /// Jacobian row placed at `row`, all other rows zero.
252    fn materialize_axis(&self, flat_axis: usize) -> Result<Array2<f64>, BasisError> {
253        assert!(
254            flat_axis < self.n_axes(),
255            "latent-coordinate derivative flat axis out of bounds in materialize_axis: flat_axis={flat_axis}, n_axes={}",
256            self.n_axes()
257        );
258        let (row, axis) = self.row_axis(flat_axis);
259        let projected = self.local_design_jacobian_row(row, axis)?;
260        let mut out = Array2::<f64>::zeros((self.n_data(), projected.len()));
261        out.row_mut(row).assign(&projected);
262        Ok(out)
263    }
264}
265
266/// The rayon chunk size for parallel implicit matvec operations.
267/// Each chunk processes this many data points before reducing.
268pub(crate) const IMPLICIT_MATVEC_CHUNK_SIZE: usize = 1000;
269
270/// Minimum data size to activate parallel iteration for implicit matvecs.
271pub(crate) const IMPLICIT_MATVEC_PAR_THRESHOLD: usize = 10_000;
272
273/// Number of lower-triangular center rows per tile when assembling dense
274/// ThinPlate penalty ψ-derivative kernel blocks.
275pub(crate) const THIN_PLATE_PENALTY_PSI_TILE_ROWS: usize = 32;
276
277impl LatentCoordDesignDerivative {
278    pub(crate) fn from_local_design_jacobian_provider(
279        provider: Arc<dyn LocalDesignJacobianProvider>,
280    ) -> Self {
281        Self { provider }
282    }
283
284    pub fn new_matern(
285        latent: Arc<crate::latent::LatentCoordValues>,
286        centers: Arc<Array2<f64>>,
287        length_scale: f64,
288        nu: MaternNu,
289        include_intercept: bool,
290        ident_transform: Option<Array2<f64>>,
291    ) -> Result<Self, BasisError> {
292        if latent.latent_dim() != centers.ncols() {
293            crate::bail_dim_basis!(
294                "LatentCoordDesignDerivative Matérn dimension mismatch: latent d={} centers d={}",
295                latent.latent_dim(),
296                centers.ncols()
297            );
298        }
299        Ok(Self::from_local_design_jacobian_provider(Arc::new(
300            RadialLatentCoordLocalDesignJacobian {
301                latent,
302                centers,
303                radial_kind: RadialScalarKind::Matern { length_scale, nu },
304                ident_transform,
305                full_ident_transform: None,
306                n_poly: usize::from(include_intercept),
307                polynomial_order: None,
308            },
309        )))
310    }
311
312    pub fn new_duchon(
313        latent: Arc<crate::latent::LatentCoordValues>,
314        centers: Arc<Array2<f64>>,
315        length_scale: Option<f64>,
316        power: f64,
317        nullspace_order: DuchonNullspaceOrder,
318        full_ident_transform: Option<Array2<f64>>,
319    ) -> Result<Self, BasisError> {
320        if latent.latent_dim() != centers.ncols() {
321            crate::bail_dim_basis!(
322                "LatentCoordDesignDerivative Duchon dimension mismatch: latent d={} centers d={}",
323                latent.latent_dim(),
324                centers.ncols()
325            );
326        }
327        let effective_order = duchon_effective_nullspace_order(centers.view(), nullspace_order);
328        let p_order = duchon_p_from_nullspace_order(effective_order);
329        let s_order = power.max(0.0).round() as usize;
330        let radial_kind = if let Some(length_scale) = length_scale {
331            RadialScalarKind::Duchon {
332                length_scale,
333                p_order,
334                s_order,
335                dim: centers.ncols(),
336                coeffs: duchon_partial_fraction_coeffs(
337                    p_order,
338                    s_order,
339                    1.0 / length_scale.max(1e-300),
340                ),
341            }
342        } else {
343            RadialScalarKind::PureDuchon {
344                block_order: pure_duchon_block_order(p_order, power).max(1.0) as usize,
345                p_order,
346                s_order,
347                dim: centers.ncols(),
348            }
349        };
350        let mut workspace = BasisWorkspace::default();
351        let ident_transform =
352            kernel_constraint_nullspace(centers.view(), effective_order, &mut workspace.cache)?;
353        let n_poly = polynomial_block_from_order(centers.view(), effective_order).ncols();
354        Ok(Self::from_local_design_jacobian_provider(Arc::new(
355            RadialLatentCoordLocalDesignJacobian {
356                latent,
357                centers,
358                radial_kind,
359                ident_transform: Some(ident_transform),
360                full_ident_transform,
361                n_poly,
362                polynomial_order: Some(effective_order),
363            },
364        )))
365    }
366
367    pub fn new_sphere(
368        latent: Arc<crate::latent::LatentCoordValues>,
369        centers: Arc<Array2<f64>>,
370        penalty_order: usize,
371        ident_transform: Option<Array2<f64>>,
372    ) -> Result<Self, BasisError> {
373        if latent.latent_dim() != centers.ncols() {
374            crate::bail_dim_basis!(
375                "LatentCoordDesignDerivative sphere dimension mismatch: latent d={} centers d={}",
376                latent.latent_dim(),
377                centers.ncols()
378            );
379        }
380        let raw_jet = sphere_first_derivative_nd(
381            latent.as_matrix().view(),
382            centers.view(),
383            penalty_order,
384            true,
385        )?;
386        let jet = latent.design_gradient_wrt_t_dispatch(
387            crate::latent::InputLocationDerivative::Jet(raw_jet.view()),
388        )?;
389        Self::from_jet(latent, jet, ident_transform)
390    }
391
392    pub fn new_periodic_bspline(
393        latent: Arc<crate::latent::LatentCoordValues>,
394        data_range: (f64, f64),
395        degree: usize,
396        num_basis: usize,
397        ident_transform: Option<Array2<f64>>,
398    ) -> Result<Self, BasisError> {
399        let raw_jet = periodic_bspline_first_derivative_nd(
400            latent.as_matrix().view(),
401            data_range,
402            degree,
403            num_basis,
404        )?;
405        let jet = latent.design_gradient_wrt_t_dispatch(
406            crate::latent::InputLocationDerivative::Jet(raw_jet.view()),
407        )?;
408        Self::from_jet(latent, jet, ident_transform)
409    }
410
411    pub fn new_tensor_bspline(
412        latent: Arc<crate::latent::LatentCoordValues>,
413        knots_per_axis: Vec<Array1<f64>>,
414        degrees: Vec<usize>,
415        ident_transform: Option<Array2<f64>>,
416    ) -> Result<Self, BasisError> {
417        let knot_views = knots_per_axis
418            .iter()
419            .map(|knots| knots.view())
420            .collect::<Vec<_>>();
421        let raw_jet =
422            bspline_tensor_first_derivative(latent.as_matrix().view(), &knot_views, &degrees)?;
423        let jet = latent.design_gradient_wrt_t_dispatch(
424            crate::latent::InputLocationDerivative::Jet(raw_jet.view()),
425        )?;
426        Self::from_jet(latent, jet, ident_transform)
427    }
428
429    pub fn new_pca(
430        latent: Arc<crate::latent::LatentCoordValues>,
431        basis_matrix: Arc<Array2<f64>>,
432    ) -> Result<Self, BasisError> {
433        if latent.latent_dim() != basis_matrix.nrows() {
434            crate::bail_dim_basis!(
435                "LatentCoordDesignDerivative Pca dimension mismatch: latent d={} basis rows={}",
436                latent.latent_dim(),
437                basis_matrix.nrows()
438            );
439        }
440        let mut jet =
441            Array3::<f64>::zeros((latent.n_obs(), basis_matrix.ncols(), basis_matrix.nrows()));
442        for row in 0..latent.n_obs() {
443            for axis in 0..basis_matrix.nrows() {
444                for col in 0..basis_matrix.ncols() {
445                    jet[[row, col, axis]] = basis_matrix[[axis, col]];
446                }
447            }
448        }
449        Self::from_jet(latent, jet, None)
450    }
451
452    pub fn from_jet(
453        latent: Arc<crate::latent::LatentCoordValues>,
454        jet: Array3<f64>,
455        ident_transform: Option<Array2<f64>>,
456    ) -> Result<Self, BasisError> {
457        if jet.shape()[0] != latent.n_obs() || jet.shape()[2] != latent.latent_dim() {
458            crate::bail_dim_basis!(
459                "LatentCoordDesignDerivative jet shape {:?} does not match latent shape ({}, {}, {})",
460                jet.shape(),
461                latent.n_obs(),
462                jet.shape()[1],
463                latent.latent_dim()
464            );
465        }
466        if let Some(z) = ident_transform.as_ref()
467            && z.nrows() != jet.shape()[1]
468        {
469            crate::bail_dim_basis!(
470                "LatentCoordDesignDerivative identifiability transform has {} rows but derivative jet has {} basis columns",
471                z.nrows(),
472                jet.shape()[1]
473            );
474        }
475        Ok(Self::from_local_design_jacobian_provider(Arc::new(
476            JetLatentCoordLocalDesignJacobian {
477                latent,
478                jet: Arc::new(jet),
479                ident_transform,
480            },
481        )))
482    }
483
484    pub(crate) fn n_data(&self) -> usize {
485        self.provider.n_data()
486    }
487
488    pub(crate) fn latent_dim(&self) -> usize {
489        self.provider.latent_dim()
490    }
491
492    pub fn n_axes(&self) -> usize {
493        self.provider.n_axes()
494    }
495
496    pub fn p_out(&self) -> usize {
497        self.provider.p_out()
498    }
499}
500
501impl RadialLatentCoordLocalDesignJacobian {
502    pub(crate) fn project_and_pad(
503        &self,
504        raw_knot: &Array1<f64>,
505        raw_poly: &Array1<f64>,
506    ) -> Result<Array1<f64>, BasisError> {
507        let constrained = match &self.ident_transform {
508            Some(z) => z.t().dot(raw_knot),
509            None => raw_knot.clone(),
510        };
511        let mut padded = Array1::<f64>::zeros(constrained.len() + self.n_poly);
512        padded
513            .slice_mut(s![..constrained.len()])
514            .assign(&constrained);
515        if self.n_poly > 0 {
516            padded.slice_mut(s![constrained.len()..]).assign(raw_poly);
517        }
518        Ok(match &self.full_ident_transform {
519            Some(zf) => zf.t().dot(&padded),
520            None => padded,
521        })
522    }
523
524    pub(crate) fn kernel_axis_scalar(
525        &self,
526        row: usize,
527        center: usize,
528        axis: usize,
529    ) -> Result<f64, BasisError> {
530        let t_row = self.latent.row(row);
531        let mut r2 = 0.0_f64;
532        for a in 0..self.latent.latent_dim() {
533            let delta = t_row[a] - self.centers[[center, a]];
534            r2 += delta * delta;
535        }
536        let r = r2.sqrt();
537        if r == 0.0 {
538            // At a center collision the axis component s_axis = (t − c)_axis
539            // is exactly zero. The product q · s_axis is therefore 0 for any
540            // kernel whose q has a finite limit; for kernels where q diverges
541            // the value is genuinely indeterminate (0 · ∞) and we must not
542            // pretend it is zero. Defer to the kernel's classification.
543            if self.radial_kind.is_smooth_at_collision() {
544                return Ok(0.0);
545            }
546            return Err(BasisError::DegenerateAtCollision {
547                kernel: "RadialScalarKind (design axis)",
548                dim: self.latent.latent_dim(),
549                m: 0.0,
550                message: "radial scalar q = φ'/r has no finite limit at r = 0; \
551                          the design row axis component is undefined",
552            });
553        }
554        let (_, q, _) = self.radial_kind.eval_design_triplet(r)?;
555        Ok(q * (t_row[axis] - self.centers[[center, axis]]))
556    }
557
558    pub(crate) fn polynomial_axis_values(&self, row: usize, axis: usize) -> Array1<f64> {
559        let Some(order) = self.polynomial_order else {
560            return Array1::<f64>::zeros(self.n_poly);
561        };
562        let max_degree = match order {
563            DuchonNullspaceOrder::Zero => 0usize,
564            DuchonNullspaceOrder::Linear => 1usize,
565            DuchonNullspaceOrder::Degree(k) => k,
566        };
567        let t_row = self.latent.row(row);
568        let exponents = monomial_exponents(self.latent.latent_dim(), max_degree);
569        let mut out = Array1::<f64>::zeros(exponents.len());
570        for (col, alpha) in exponents.iter().enumerate() {
571            let a_axis = alpha[axis];
572            if a_axis == 0 {
573                continue;
574            }
575            let mut value = a_axis as f64;
576            for a in 0..self.latent.latent_dim() {
577                let exp_a = if a == axis { a_axis - 1 } else { alpha[a] };
578                if exp_a != 0 {
579                    value *= t_row[a].powi(exp_a as i32);
580                }
581            }
582            out[col] = value;
583        }
584        out
585    }
586}
587
588impl JetLatentCoordLocalDesignJacobian {
589    pub(crate) fn project_jet(&self, raw_knot: &Array1<f64>) -> Result<Array1<f64>, BasisError> {
590        Ok(match &self.ident_transform {
591            Some(z) => z.t().dot(raw_knot),
592            None => raw_knot.clone(),
593        })
594    }
595}
596
597impl LocalDesignJacobianProvider for LatentCoordDesignDerivative {
598    fn n_data(&self) -> usize {
599        self.provider.n_data()
600    }
601
602    fn latent_dim(&self) -> usize {
603        self.provider.latent_dim()
604    }
605
606    fn n_axes(&self) -> usize {
607        self.provider.n_axes()
608    }
609
610    fn p_out(&self) -> usize {
611        self.provider.p_out()
612    }
613
614    fn local_design_jacobian_row(
615        &self,
616        row: usize,
617        axis: usize,
618    ) -> Result<Array1<f64>, BasisError> {
619        self.provider.local_design_jacobian_row(row, axis)
620    }
621}
622
623impl LocalDesignJacobianProvider for RadialLatentCoordLocalDesignJacobian {
624    fn n_data(&self) -> usize {
625        self.latent.n_obs()
626    }
627
628    fn latent_dim(&self) -> usize {
629        self.latent.latent_dim()
630    }
631
632    fn n_axes(&self) -> usize {
633        self.latent.len()
634    }
635
636    fn p_out(&self) -> usize {
637        Self::p_out(self)
638    }
639
640    fn local_design_jacobian_row(
641        &self,
642        row: usize,
643        axis: usize,
644    ) -> Result<Array1<f64>, BasisError> {
645        let mut raw_knot = Array1::<f64>::zeros(self.centers.nrows());
646        for center in 0..self.centers.nrows() {
647            raw_knot[center] = self.kernel_axis_scalar(row, center, axis)?;
648        }
649        let raw_poly = self.polynomial_axis_values(row, axis);
650        self.project_and_pad(&raw_knot, &raw_poly)
651    }
652}
653
654impl LocalDesignJacobianProvider for JetLatentCoordLocalDesignJacobian {
655    fn n_data(&self) -> usize {
656        self.latent.n_obs()
657    }
658
659    fn latent_dim(&self) -> usize {
660        self.latent.latent_dim()
661    }
662
663    fn n_axes(&self) -> usize {
664        self.latent.len()
665    }
666
667    fn p_out(&self) -> usize {
668        Self::p_out(self)
669    }
670
671    fn local_design_jacobian_row(
672        &self,
673        row: usize,
674        axis: usize,
675    ) -> Result<Array1<f64>, BasisError> {
676        let mut raw_knot = Array1::<f64>::zeros(self.jet.shape()[1]);
677        for basis_col in 0..self.jet.shape()[1] {
678            raw_knot[basis_col] = self.jet[[row, basis_col, axis]];
679        }
680        self.project_jet(&raw_knot)
681    }
682}
683
684impl ImplicitDesignPsiDerivative {
685    /// Construct from pre-computed radial jet scalars.
686    ///
687    /// # Arguments
688    /// - `q_values`: (n * n_knots,) — φ'(r)/r for each (data, knot) pair.
689    /// - `t_values`: (n * n_knots,) — (φ''(r) - q) / r² for each pair.
690    /// - `axis_components`: (n * n_knots, D) — s_{d,ij} = exp(2η_d) · h_d² for each pair/axis.
691    /// - `ident_transform`: optional (n_knots × p_constrained) constraint projection.
692    /// - `full_ident_transform`: optional further projection after padding.
693    /// - `n`, `n_knots`, `n_poly`, `n_axes`: dimensions.
694    /// Construct from pre-computed (materialized) radial jet scalars.
695    /// This is the original path for small-to-medium problems where
696    /// O(n*k*(d+2)) storage is acceptable.
697    pub fn new(
698        phi_values: Array1<f64>,
699        q_values: Array1<f64>,
700        t_values: Array1<f64>,
701        axis_components: Array2<f64>,
702        ident_transform: Option<Array2<f64>>,
703        full_ident_transform: Option<Array2<f64>>,
704        n: usize,
705        n_knots: usize,
706        n_poly: usize,
707        n_axes: usize,
708    ) -> Self {
709        assert_eq!(
710            phi_values.len(),
711            n * n_knots,
712            "implicit psi derivative phi length mismatch: expected n*n_knots={}*{}={}, got {}",
713            n,
714            n_knots,
715            n * n_knots,
716            phi_values.len()
717        );
718        assert_eq!(
719            q_values.len(),
720            n * n_knots,
721            "implicit psi derivative q length mismatch: expected n*n_knots={}*{}={}, got {}",
722            n,
723            n_knots,
724            n * n_knots,
725            q_values.len()
726        );
727        assert_eq!(
728            t_values.len(),
729            n * n_knots,
730            "implicit psi derivative t length mismatch: expected n*n_knots={}*{}={}, got {}",
731            n,
732            n_knots,
733            n * n_knots,
734            t_values.len()
735        );
736        assert_eq!(
737            axis_components.nrows(),
738            n * n_knots,
739            "implicit psi derivative axis-component row mismatch: expected n*n_knots={}*{}={}, got {}",
740            n,
741            n_knots,
742            n * n_knots,
743            axis_components.nrows()
744        );
745        assert_eq!(
746            axis_components.ncols(),
747            n_axes,
748            "implicit psi derivative axis-component column mismatch: expected n_axes={n_axes}, got {}",
749            axis_components.ncols()
750        );
751        Self {
752            phi_values,
753            axis_components,
754            q_values,
755            t_values,
756            streaming: None,
757            ident_transform,
758            full_ident_transform,
759            n,
760            n_knots,
761            n_poly,
762            n_axes,
763            psi_scale_share: 0.0,
764            axis_combinations: None,
765        }
766    }
767
768    pub(crate) fn with_psi_scale_share(mut self, psi_scale_share: f64) -> Self {
769        self.psi_scale_share = psi_scale_share;
770        self
771    }
772
773    /// Construct a streaming operator that recomputes (q, t, s_a) on the fly
774    /// from raw data/centers/eta during each matvec. No O(n*k) arrays are stored.
775    /// This is the large-scale path.
776    ///
777    /// `pub` like the sibling `new_*` constructors: after the engine crate carve
778    /// (#1521) the REML planner tests live in `gam-solve` and build streaming
779    /// operators as fixtures, so this constructor is part of the cross-crate
780    /// surface, not a crate-private helper.
781    pub fn new_streaming(
782        data: Arc<Array2<f64>>,
783        centers: Arc<Array2<f64>>,
784        eta: Vec<f64>,
785        radial_kind: RadialScalarKind,
786        ident_transform: Option<Array2<f64>>,
787        full_ident_transform: Option<Array2<f64>>,
788        n_poly: usize,
789    ) -> Self {
790        let n = data.nrows();
791        let n_knots = centers.nrows();
792        let n_axes = data.ncols();
793        let psi_scale_share = radial_kind.raw_psi_isotropic_share();
794        assert_eq!(eta.len(), n_axes);
795        assert_eq!(
796            centers.ncols(),
797            n_axes,
798            "streaming radial centers have {} columns but data/eta have {n_axes}",
799            centers.ncols()
800        );
801        let metric_weights: Arc<[f64]> = Arc::from(centered_aniso_metric_weights(&eta));
802        Self {
803            // Empty arrays -- not used in streaming mode.
804            phi_values: Array1::<f64>::zeros(0),
805            axis_components: Array2::<f64>::zeros((0, 0)),
806            q_values: Array1::<f64>::zeros(0),
807            t_values: Array1::<f64>::zeros(0),
808            streaming: Some(StreamingRadialState {
809                data,
810                centers,
811                axis_mode: StreamingAxisMode::PerAxis { metric_weights },
812                radial_kind,
813                triplet_cache: Arc::new(std::sync::OnceLock::new()),
814            }),
815            ident_transform,
816            full_ident_transform,
817            n,
818            n_knots,
819            n_poly,
820            n_axes,
821            psi_scale_share,
822            axis_combinations: None,
823        }
824    }
825
826    /// Construct a streaming operator for a scalar ψ derivative. The operator
827    /// exposes a single axis component equal to the full scaled squared
828    /// distance r² under the fixed metric defined by `eta`.
829    pub(crate) fn new_streaming_scalar(
830        data: Arc<Array2<f64>>,
831        centers: Arc<Array2<f64>>,
832        eta: Vec<f64>,
833        radial_kind: RadialScalarKind,
834        ident_transform: Option<Array2<f64>>,
835        full_ident_transform: Option<Array2<f64>>,
836        n_poly: usize,
837    ) -> Self {
838        let n = data.nrows();
839        let n_knots = centers.nrows();
840        let dim = data.ncols();
841        assert_eq!(eta.len(), dim);
842        assert_eq!(
843            centers.ncols(),
844            dim,
845            "streaming scalar radial centers have {} columns but data/eta have {dim}",
846            centers.ncols()
847        );
848        let metric_weights: Arc<[f64]> = Arc::from(centered_aniso_metric_weights(&eta));
849        Self {
850            phi_values: Array1::<f64>::zeros(0),
851            axis_components: Array2::<f64>::zeros((0, 0)),
852            q_values: Array1::<f64>::zeros(0),
853            t_values: Array1::<f64>::zeros(0),
854            streaming: Some(StreamingRadialState {
855                data,
856                centers,
857                axis_mode: StreamingAxisMode::ScalarTotal { metric_weights },
858                radial_kind,
859                triplet_cache: Arc::new(std::sync::OnceLock::new()),
860            }),
861            ident_transform,
862            full_ident_transform,
863            n,
864            n_knots,
865            n_poly,
866            n_axes: 1,
867            psi_scale_share: 0.0,
868            axis_combinations: None,
869        }
870    }
871
872    /// Whether this operator is in streaming (recompute-on-the-fly) mode.
873    #[inline]
874    pub(crate) fn is_streaming(&self) -> bool {
875        self.streaming.is_some()
876    }
877
878    /// Number of data points.
879    pub fn n_data(&self) -> usize {
880        self.n
881    }
882
883    /// Number of axes (D).
884    pub fn n_axes(&self) -> usize {
885        self.axis_combinations
886            .as_ref()
887            .map_or(self.n_axes, Vec::len)
888    }
889
890    pub fn is_duchon_family(&self) -> bool {
891        self.streaming.as_ref().is_some_and(|state| {
892            matches!(
893                state.radial_kind,
894                RadialScalarKind::Duchon { .. } | RadialScalarKind::PureDuchon { .. }
895            )
896        }) || self.psi_scale_share != 0.0
897    }
898
899    /// Whether this operator is wired up by a basis whose large-scale path
900    /// is supposed to stay implicit, so a dense `(n × p)` materialization
901    /// here is a regression rather than a normal compute path. Duchon-family
902    /// terms qualify because they are streaming-only at any scale; ThinPlate
903    /// qualifies because the new scalar-streaming routing relies on the
904    /// implicit operator above the policy threshold and a sneaky
905    /// `materialize_dense()` would silently re-introduce the n × p
906    /// allocation we just removed. The flag is consulted by the
907    /// materialize_first / materialize_second_diag / materialize_second_cross
908    /// guards to fire `assert_no_dense_derivative_materialization` for these
909    /// kinds whenever the resource policy says the materialization would
910    /// exceed budget. Small-n problems still pass the assertion and get the
911    /// dense fast path.
912    pub(crate) fn enforces_dense_materialization_budget(&self) -> bool {
913        if self
914            .streaming
915            .as_ref()
916            .is_some_and(|state| state.radial_kind.enforces_dense_materialization_budget())
917        {
918            return true;
919        }
920        // The materialized-mode path keeps no `radial_kind` to inspect, but
921        // a non-zero psi_scale_share is the unambiguous Duchon-family
922        // signature there (Matern uses 0, ThinPlate uses 0). Materialized
923        // ThinPlate / Matern terms are in the dense fast path and the
924        // guard does not need to fire for them.
925        self.psi_scale_share != 0.0
926    }
927
928    /// Output dimension: total basis columns in the final space.
929    pub fn p_out(&self) -> usize {
930        if let Some(ref zf) = self.full_ident_transform {
931            zf.ncols()
932        } else {
933            self.p_after_pad()
934        }
935    }
936
937    pub fn append_full_transform(mut self, transform: &Array2<f64>) -> Result<Self, BasisError> {
938        if transform.nrows() != self.p_out() {
939            crate::bail_dim_basis!(
940                "implicit psi derivative transform has {} rows but operator has {} output columns",
941                transform.nrows(),
942                self.p_out()
943            );
944        }
945        self.full_ident_transform = Some(match self.full_ident_transform.take() {
946            Some(existing) => fast_ab(&existing, transform),
947            None => transform.clone(),
948        });
949        Ok(self)
950    }
951
952    /// Dimension after kernel constraint + polynomial padding (before full ident).
953    pub(crate) fn p_after_pad(&self) -> usize {
954        let p_constrained = self.p_constrained();
955        p_constrained + self.n_poly
956    }
957
958    /// Dimension after kernel constraint projection (before poly padding).
959    pub(crate) fn p_constrained(&self) -> usize {
960        match &self.ident_transform {
961            Some(z) => z.ncols(),
962            None => self.n_knots,
963        }
964    }
965
966    /// Accumulate raw knot-space vector from weighted (data, knot) contributions.
967    /// Returns a vector of length n_knots: Σ_i w_i · scalar_{ij} for each knot j.
968    ///
969    /// This is the core primitive: for each data point i, accumulate
970    /// `v[i] * per_pair_scalar(i,j)` into knot j.
971    pub(crate) fn accumulate_knot_vector<F>(&self, v: &ArrayView1<f64>, per_pair: F) -> Array1<f64>
972    where
973        F: Fn(usize) -> f64 + Send + Sync,
974    {
975        let n = self.n;
976        let k = self.n_knots;
977
978        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
979            // Parallel path: chunk data points and reduce.
980            let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
981            let partial_sums: Vec<Array1<f64>> = (0..n_chunks)
982                .into_par_iter()
983                .map(|chunk_idx| {
984                    let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
985                    let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
986                    let mut local = Array1::<f64>::zeros(k);
987                    for i in start..end {
988                        let vi = v[i];
989                        if vi == 0.0 {
990                            continue;
991                        }
992                        let base = i * k;
993                        for j in 0..k {
994                            local[j] += vi * per_pair(base + j);
995                        }
996                    }
997                    local
998                })
999                .collect();
1000            let mut total = Array1::<f64>::zeros(k);
1001            for p in partial_sums {
1002                total += &p;
1003            }
1004            total
1005        } else {
1006            // Sequential path.
1007            let mut total = Array1::<f64>::zeros(k);
1008            for i in 0..n {
1009                let vi = v[i];
1010                if vi == 0.0 {
1011                    continue;
1012                }
1013                let base = i * k;
1014                for j in 0..k {
1015                    total[j] += vi * per_pair(base + j);
1016                }
1017            }
1018            total
1019        }
1020    }
1021
1022    /// Streaming accumulate knot vector from on-the-fly radial scalars.
1023    pub(crate) fn streaming_accumulate_knot_vector<G>(
1024        &self,
1025        v: &ArrayView1<f64>,
1026        deriv_fn: G,
1027    ) -> Result<Array1<f64>, BasisError>
1028    where
1029        G: Fn(f64, f64, f64, &[f64]) -> f64 + Send + Sync,
1030    {
1031        let st = self.streaming.as_ref().unwrap();
1032        let (n, k, dim) = (self.n, self.n_knots, self.n_axes);
1033        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1034            let err_flag = std::sync::atomic::AtomicBool::new(false);
1035            let nc = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1036            let ps: Vec<Array1<f64>> = (0..nc)
1037                .into_par_iter()
1038                .map(|ci| {
1039                    let s = ci * IMPLICIT_MATVEC_CHUNK_SIZE;
1040                    let e = (s + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1041                    let mut loc = Array1::<f64>::zeros(k);
1042                    let mut sb = vec![0.0; dim];
1043                    for i in s..e {
1044                        let vi = v[i];
1045                        if vi == 0.0 {
1046                            continue;
1047                        }
1048                        for j in 0..k {
1049                            match st.compute_pair(i, j, &mut sb) {
1050                                Ok((phi, q, t)) => {
1051                                    loc[j] += vi * deriv_fn(phi, q, t, &sb);
1052                                }
1053                                Err(_) => {
1054                                    err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
1055                                    return loc;
1056                                }
1057                            }
1058                        }
1059                    }
1060                    loc
1061                })
1062                .collect();
1063            if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
1064                crate::bail_invalid_basis!(
1065                    "radial scalar evaluation failed during streaming accumulate_knot_vector"
1066                        .into(),
1067                );
1068            }
1069            let mut tot = Array1::<f64>::zeros(k);
1070            for p in ps {
1071                tot += &p;
1072            }
1073            Ok(tot)
1074        } else {
1075            let mut tot = Array1::<f64>::zeros(k);
1076            let mut sb = vec![0.0; dim];
1077            for i in 0..n {
1078                let vi = v[i];
1079                if vi == 0.0 {
1080                    continue;
1081                }
1082                for j in 0..k {
1083                    let (phi, q, t) = st.compute_pair(i,j,&mut sb).map_err(|e| BasisError::InvalidInput(
1084                        format!("radial scalar evaluation failed during streaming accumulate_knot_vector: {e}"),
1085                    ))?;
1086                    tot[j] += vi * deriv_fn(phi, q, t, &sb);
1087                }
1088            }
1089            Ok(tot)
1090        }
1091    }
1092    /// Streaming forward multiply.
1093    pub(crate) fn streaming_forward_mul<G>(
1094        &self,
1095        u_knot: &Array1<f64>,
1096        deriv_fn: G,
1097    ) -> Result<Array1<f64>, BasisError>
1098    where
1099        G: Fn(f64, f64, f64, &[f64]) -> f64 + Send + Sync,
1100    {
1101        let st = self.streaming.as_ref().unwrap();
1102        let (n, k, dim) = (self.n, self.n_knots, self.n_axes);
1103        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1104            let err_flag = std::sync::atomic::AtomicBool::new(false);
1105            let nc = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1106            let cr: Vec<(usize, Vec<f64>)> = (0..nc)
1107                .into_par_iter()
1108                .map(|ci| {
1109                    let s = ci * IMPLICIT_MATVEC_CHUNK_SIZE;
1110                    let e = (s + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1111                    let mut loc = vec![0.0; e - s];
1112                    let mut sb = vec![0.0; dim];
1113                    for i in s..e {
1114                        let mut val = 0.0;
1115                        for j in 0..k {
1116                            match st.compute_pair(i, j, &mut sb) {
1117                                Ok((phi, q, t)) => {
1118                                    val += deriv_fn(phi, q, t, &sb) * u_knot[j];
1119                                }
1120                                Err(_) => {
1121                                    err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
1122                                    break;
1123                                }
1124                            }
1125                        }
1126                        loc[i - s] = val;
1127                    }
1128                    (s, loc)
1129                })
1130                .collect();
1131            if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
1132                crate::bail_invalid_basis!(
1133                    "radial scalar evaluation failed during streaming forward_mul".into(),
1134                );
1135            }
1136            let mut res = Array1::<f64>::zeros(n);
1137            for (s, vs) in cr {
1138                for (o, &v) in vs.iter().enumerate() {
1139                    res[s + o] = v;
1140                }
1141            }
1142            Ok(res)
1143        } else {
1144            let mut res = Array1::<f64>::zeros(n);
1145            let mut sb = vec![0.0; dim];
1146            for i in 0..n {
1147                let mut val = 0.0;
1148                for j in 0..k {
1149                    let (phi, q, t) = st.compute_pair(i, j, &mut sb).map_err(|e| {
1150                        BasisError::InvalidInput(format!(
1151                            "radial scalar evaluation failed during streaming forward_mul: {e}"
1152                        ))
1153                    })?;
1154                    val += deriv_fn(phi, q, t, &sb) * u_knot[j];
1155                }
1156                res[i] = val;
1157            }
1158            Ok(res)
1159        }
1160    }
1161    /// Streaming materialization: build (n x k) raw matrix then project.
1162    pub(crate) fn streaming_materialize<G>(&self, deriv_fn: G) -> Result<Array2<f64>, BasisError>
1163    where
1164        G: Fn(f64, f64, f64, &[f64]) -> f64 + Send + Sync,
1165    {
1166        let st = self.streaming.as_ref().unwrap();
1167        let (n, k, dim) = (self.n, self.n_knots, self.n_axes);
1168        let mut raw = Array2::<f64>::zeros((n, k));
1169        let cs = IMPLICIT_MATVEC_CHUNK_SIZE;
1170        let nc = n.div_ceil(cs);
1171        let err_flag = std::sync::atomic::AtomicBool::new(false);
1172        {
1173            let rp = SendPtr(raw.as_mut_ptr());
1174            let ef = &err_flag;
1175            (0..nc).into_par_iter().for_each(move |ci| {
1176                let s = ci * cs;
1177                let e = (s + cs).min(n);
1178                let mut sb = vec![0.0; dim];
1179                for i in s..e {
1180                    for j in 0..k {
1181                        match st.compute_pair(i, j, &mut sb) {
1182                            // SAFETY: chunk ci owns rows [s..e) of the raw n×k buffer,
1183                            // so offsets i*k+j for i ∈ [s,e), j ∈ [0,k) are pairwise
1184                            // disjoint across workers and stay within n*k = raw.len().
1185                            Ok((phi, q, t)) => unsafe {
1186                                *rp.add(i * k + j) = deriv_fn(phi, q, t, &sb);
1187                            },
1188                            Err(_) => {
1189                                ef.store(true, std::sync::atomic::Ordering::Relaxed);
1190                                return;
1191                            }
1192                        }
1193                    }
1194                }
1195            });
1196        }
1197        if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
1198            crate::bail_invalid_basis!(
1199                "radial scalar evaluation failed during streaming materialize".into(),
1200            );
1201        }
1202        Ok(self.project_matrix(raw))
1203    }
1204
1205    /// Project a raw knot-space vector through the identifiability transform
1206    /// and pad with zeros for polynomial columns.
1207    pub(crate) fn project_and_pad(&self, raw_knot_vec: &Array1<f64>) -> Array1<f64> {
1208        // Step 1: apply kernel constraint Z (if present).
1209        let constrained = match &self.ident_transform {
1210            Some(z) => z.t().dot(raw_knot_vec),
1211            None => raw_knot_vec.clone(),
1212        };
1213
1214        // Step 2: pad with polynomial zeros.
1215        let p_padded = constrained.len() + self.n_poly;
1216        let mut padded = Array1::<f64>::zeros(p_padded);
1217        padded
1218            .slice_mut(s![..constrained.len()])
1219            .assign(&constrained);
1220
1221        // Step 3: apply full identifiability transform (if present).
1222        match &self.full_ident_transform {
1223            Some(zf) => zf.t().dot(&padded),
1224            None => padded,
1225        }
1226    }
1227
1228    /// Expand a coefficient vector from the final space back to raw knot space.
1229    /// This is the transpose path: p_out → (padded) → (constrained) → n_knots.
1230    pub(crate) fn unproject(&self, u: &ArrayView1<f64>) -> Array1<f64> {
1231        // Step 1: undo full identifiability transform.
1232        let after_full = match &self.full_ident_transform {
1233            Some(zf) => zf.dot(u),
1234            None => u.to_owned(),
1235        };
1236
1237        // Step 2: extract smooth part (drop polynomial padding).
1238        let p_constrained = self.p_constrained();
1239        let smooth_part = after_full.slice(s![..p_constrained]);
1240
1241        // Step 3: undo kernel constraint Z.
1242        match &self.ident_transform {
1243            Some(z) => z.dot(&smooth_part),
1244            None => smooth_part.to_owned(),
1245        }
1246    }
1247
1248    /// Batched `unproject` for a (p_out × rank) coefficient matrix.
1249    /// Returns (n_knots × rank) via two BLAS3 matmuls — the same algebra as
1250    /// `unproject`, but amortized across all rank columns of `u`. Used by
1251    /// `forward_mul_matrix` so per-axis trace evaluations can be a single
1252    /// chunked GEMM rather than rank-many `forward_mul` calls.
1253    pub fn unproject_matrix(&self, u: &ArrayView2<f64>) -> Array2<f64> {
1254        assert_eq!(u.nrows(), self.p_out());
1255        // Step 1: undo full identifiability transform → (p_after_pad, rank).
1256        let after_full = match &self.full_ident_transform {
1257            Some(zf) => fast_ab(zf, u),
1258            None => u.to_owned(),
1259        };
1260        // Step 2: drop polynomial padding rows → (p_constrained, rank).
1261        let p_constrained = self.p_constrained();
1262        let smooth_part = after_full.slice(s![..p_constrained, ..]);
1263        // Step 3: undo kernel constraint Z → (n_knots, rank).
1264        match &self.ident_transform {
1265            Some(z) => fast_ab(z, &smooth_part),
1266            None => smooth_part.to_owned(),
1267        }
1268    }
1269
1270    /// Compute (∂X/∂ψ_d)^T v for a given axis d and vector v of length n.
1271    ///
1272    /// Returns a vector of length p_out (total basis dimension after all transforms).
1273    ///
1274    /// Formula in raw knot space:
1275    ///   [raw]_j = Σ_i v_i · q_{ij} · s_{d,ij}
1276    /// then project through Z and pad.
1277    ///
1278    /// Note: q = φ_r/r and s_d = exp(2ψ_d)·h_d² are UNNORMALIZED axis components.
1279    /// With this convention, q·s_d = (φ_r/r)·(exp(2ψ_d)·h_d²) = φ_r·(s_d/r),
1280    /// which equals the correct ∂φ/∂ψ_d = φ_r·∂r/∂ψ_d = φ_r·s_d/r.
1281    /// No r² correction is needed — that would be required only if s_d were
1282    /// the fractional quantity s_d/r².
1283    pub fn transpose_mul(
1284        &self,
1285        axis: usize,
1286        v: &ArrayView1<f64>,
1287    ) -> Result<Array1<f64>, BasisError> {
1288        assert!(
1289            axis < self.n_axes(),
1290            "implicit psi first transpose axis out of bounds: axis={axis}, n_axes={}",
1291            self.n_axes()
1292        );
1293        assert_eq!(
1294            v.len(),
1295            self.n,
1296            "implicit psi first transpose row-adjoint length mismatch"
1297        );
1298        if self.axis_combinations.is_some() {
1299            let combo = self.transformed_axis_combination(axis);
1300            let combo_sum = Self::transformed_combo_sum(combo);
1301            if self.is_streaming() {
1302                let c = self.psi_scale_share;
1303                let raw = self.streaming_accumulate_knot_vector(v, |phi, q, _, sb| {
1304                    let s_combo = combo
1305                        .iter()
1306                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1307                        .sum();
1308                    Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
1309                })?;
1310                return Ok(self.project_and_pad(&raw));
1311            }
1312            let c = self.psi_scale_share;
1313            let raw = self.accumulate_knot_vector(v, |idx| {
1314                let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1315                Self::transformed_first_kernel_value(
1316                    self.phi_values[idx],
1317                    self.q_values[idx],
1318                    s_combo,
1319                    combo_sum,
1320                    c,
1321                )
1322            });
1323            return Ok(self.project_and_pad(&raw));
1324        }
1325        if self.is_streaming() {
1326            let c = self.psi_scale_share;
1327            let raw =
1328                self.streaming_accumulate_knot_vector(v, |phi, q, _, sb| q * sb[axis] + c * phi)?;
1329            return Ok(self.project_and_pad(&raw));
1330        }
1331        let c = self.psi_scale_share;
1332        let af = &self.axis_components;
1333        let pv = &self.phi_values;
1334        let qv = &self.q_values;
1335        let raw = self.accumulate_knot_vector(v, |idx| qv[idx] * af[[idx, axis]] + c * pv[idx]);
1336        Ok(self.project_and_pad(&raw))
1337    }
1338
1339    /// Compute (∂X/∂ψ_d) u for a given axis d and vector u of length p_out.
1340    ///
1341    /// Returns a vector of length n.
1342    ///
1343    /// Formula: for each data point i,
1344    ///   result_i = Σ_j q_{ij} · s_{d,ij} · u_knot_j
1345    /// where u_knot = Z · u_smooth (unprojected back to knot space).
1346    pub fn forward_mul(&self, axis: usize, u: &ArrayView1<f64>) -> Result<Array1<f64>, BasisError> {
1347        assert!(
1348            axis < self.n_axes(),
1349            "implicit psi first forward axis out of bounds: axis={axis}, n_axes={}",
1350            self.n_axes()
1351        );
1352        assert_eq!(
1353            u.len(),
1354            self.p_out(),
1355            "implicit psi first forward coefficient length mismatch"
1356        );
1357        let u_knot = self.unproject(u);
1358        if self.axis_combinations.is_some() {
1359            let combo = self.transformed_axis_combination(axis);
1360            let combo_sum = Self::transformed_combo_sum(combo);
1361            if self.is_streaming() {
1362                let c = self.psi_scale_share;
1363                return self.streaming_forward_mul(&u_knot, |phi, q, _, sb| {
1364                    let s_combo = combo
1365                        .iter()
1366                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1367                        .sum();
1368                    Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
1369                });
1370            }
1371            let n = self.n;
1372            let k = self.n_knots;
1373            let c = self.psi_scale_share;
1374            if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1375                let mut result = Array1::<f64>::zeros(n);
1376                let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1377                let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1378                    .into_par_iter()
1379                    .map(|chunk_idx| {
1380                        let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1381                        let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1382                        let mut local = vec![0.0; end - start];
1383                        for i in start..end {
1384                            let base = i * k;
1385                            let mut val = 0.0;
1386                            for j in 0..k {
1387                                let idx = base + j;
1388                                let s_combo =
1389                                    self.transformed_combo_axis_value_materialized(idx, combo);
1390                                val += Self::transformed_first_kernel_value(
1391                                    self.phi_values[idx],
1392                                    self.q_values[idx],
1393                                    s_combo,
1394                                    combo_sum,
1395                                    c,
1396                                ) * u_knot[j];
1397                            }
1398                            local[i - start] = val;
1399                        }
1400                        (start, local)
1401                    })
1402                    .collect();
1403                for (start, vals) in chunk_results {
1404                    for (offset, &v) in vals.iter().enumerate() {
1405                        result[start + offset] = v;
1406                    }
1407                }
1408                return Ok(result);
1409            }
1410            let mut result = Array1::<f64>::zeros(n);
1411            for i in 0..n {
1412                let base = i * k;
1413                let mut val = 0.0;
1414                for j in 0..k {
1415                    let idx = base + j;
1416                    let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1417                    val += Self::transformed_first_kernel_value(
1418                        self.phi_values[idx],
1419                        self.q_values[idx],
1420                        s_combo,
1421                        combo_sum,
1422                        c,
1423                    ) * u_knot[j];
1424                }
1425                result[i] = val;
1426            }
1427            return Ok(result);
1428        }
1429        if self.is_streaming() {
1430            let c = self.psi_scale_share;
1431            return self.streaming_forward_mul(&u_knot, |phi, q, _, sb| q * sb[axis] + c * phi);
1432        }
1433        let n = self.n;
1434        let k = self.n_knots;
1435        let c = self.psi_scale_share;
1436        let af = &self.axis_components;
1437        let pv = &self.phi_values;
1438        let qv = &self.q_values;
1439
1440        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1441            let mut result = Array1::<f64>::zeros(n);
1442            // Parallel over chunks of data points.
1443            let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1444            let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1445                .into_par_iter()
1446                .map(|chunk_idx| {
1447                    let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1448                    let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1449                    let mut local = vec![0.0; end - start];
1450                    for i in start..end {
1451                        let base = i * k;
1452                        let mut val = 0.0;
1453                        for j in 0..k {
1454                            val += (qv[base + j] * af[[base + j, axis]] + c * pv[base + j])
1455                                * u_knot[j];
1456                        }
1457                        local[i - start] = val;
1458                    }
1459                    (start, local)
1460                })
1461                .collect();
1462            for (start, vals) in chunk_results {
1463                for (offset, &v) in vals.iter().enumerate() {
1464                    result[start + offset] = v;
1465                }
1466            }
1467            Ok(result)
1468        } else {
1469            let mut result = Array1::<f64>::zeros(n);
1470            for i in 0..n {
1471                let base = i * k;
1472                let mut val = 0.0;
1473                for j in 0..k {
1474                    val += (qv[base + j] * af[[base + j, axis]] + c * pv[base + j]) * u_knot[j];
1475                }
1476                result[i] = val;
1477            }
1478            Ok(result)
1479        }
1480    }
1481
1482    /// Compute (∂²X/∂ψ_d²)^T v — diagonal second derivative, same axis.
1483    ///
1484    /// Matrix-free variant of `materialize_second_diag`: avoids forming the
1485    /// full (n × p_out) matrix when only a single adjoint matvec is needed.
1486    pub fn transpose_mul_second_diag(
1487        &self,
1488        axis: usize,
1489        v: &ArrayView1<f64>,
1490    ) -> Result<Array1<f64>, BasisError> {
1491        assert!(
1492            axis < self.n_axes(),
1493            "implicit psi second diagonal transpose axis out of bounds: axis={axis}, n_axes={}",
1494            self.n_axes()
1495        );
1496        assert_eq!(
1497            v.len(),
1498            self.n,
1499            "implicit psi second diagonal transpose row-adjoint length mismatch"
1500        );
1501        if self.axis_combinations.is_some() {
1502            let combo = self.transformed_axis_combination(axis);
1503            let combo_sum = Self::transformed_combo_sum(combo);
1504            if self.is_streaming() {
1505                let c = self.psi_scale_share;
1506                let raw = self.streaming_accumulate_knot_vector(v, |phi, q, t, sb| {
1507                    let s_combo = combo
1508                        .iter()
1509                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1510                        .sum();
1511                    let overlap_s = Self::transformed_combo_overlap_streaming(combo, combo, sb);
1512                    Self::transformed_second_kernel_value(
1513                        phi, q, t, s_combo, combo_sum, s_combo, combo_sum, overlap_s, c,
1514                    )
1515                })?;
1516                return Ok(self.project_and_pad(&raw));
1517            }
1518            let c = self.psi_scale_share;
1519            let raw = self.accumulate_knot_vector(v, |idx| {
1520                let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1521                let overlap_s = self.transformed_combo_overlap_materialized(idx, combo, combo);
1522                Self::transformed_second_kernel_value(
1523                    self.phi_values[idx],
1524                    self.q_values[idx],
1525                    self.t_values[idx],
1526                    s_combo,
1527                    combo_sum,
1528                    s_combo,
1529                    combo_sum,
1530                    overlap_s,
1531                    c,
1532                )
1533            });
1534            return Ok(self.project_and_pad(&raw));
1535        }
1536        if self.is_streaming() {
1537            let c = self.psi_scale_share;
1538            let raw = self.streaming_accumulate_knot_vector(v, |phi, q, t, sb| {
1539                let s = sb[axis];
1540                2.0 * q * s + t * s * s + 2.0 * c * q * s + c * c * phi
1541            })?;
1542            return Ok(self.project_and_pad(&raw));
1543        }
1544        let c = self.psi_scale_share;
1545        let af = &self.axis_components;
1546        let pv = &self.phi_values;
1547        let qv = &self.q_values;
1548        let tv = &self.t_values;
1549        let raw = self.accumulate_knot_vector(v, |idx| {
1550            let s = af[[idx, axis]];
1551            2.0 * qv[idx] * s + tv[idx] * s * s + 2.0 * c * qv[idx] * s + c * c * pv[idx]
1552        });
1553        Ok(self.project_and_pad(&raw))
1554    }
1555
1556    /// Compute (∂²X/∂ψ_d∂ψ_e)^T v — cross second derivative (d ≠ e).
1557    pub fn transpose_mul_second_cross(
1558        &self,
1559        axis_d: usize,
1560        axis_e: usize,
1561        v: &ArrayView1<f64>,
1562    ) -> Result<Array1<f64>, BasisError> {
1563        assert!(
1564            axis_d < self.n_axes(),
1565            "implicit psi second cross transpose first axis out of bounds: axis_d={axis_d}, n_axes={}",
1566            self.n_axes()
1567        );
1568        assert!(
1569            axis_e < self.n_axes(),
1570            "implicit psi second cross transpose second axis out of bounds: axis_e={axis_e}, n_axes={}",
1571            self.n_axes()
1572        );
1573        assert_ne!(
1574            axis_d, axis_e,
1575            "implicit psi second cross transpose requires distinct axes: axis_d={axis_d}, axis_e={axis_e}"
1576        );
1577        assert_eq!(
1578            v.len(),
1579            self.n,
1580            "implicit psi second cross transpose row-adjoint length mismatch"
1581        );
1582        if self.axis_combinations.is_some() {
1583            let combo_d = self.transformed_axis_combination(axis_d);
1584            let combo_e = self.transformed_axis_combination(axis_e);
1585            let sum_d = Self::transformed_combo_sum(combo_d);
1586            let sum_e = Self::transformed_combo_sum(combo_e);
1587            if self.is_streaming() {
1588                let c = self.psi_scale_share;
1589                let raw = self.streaming_accumulate_knot_vector(v, |phi, q, t, sb| {
1590                    let s_d = combo_d
1591                        .iter()
1592                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1593                        .sum();
1594                    let s_e = combo_e
1595                        .iter()
1596                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1597                        .sum();
1598                    let overlap_s = Self::transformed_combo_overlap_streaming(combo_d, combo_e, sb);
1599                    Self::transformed_second_kernel_value(
1600                        phi, q, t, s_d, sum_d, s_e, sum_e, overlap_s, c,
1601                    )
1602                })?;
1603                return Ok(self.project_and_pad(&raw));
1604            }
1605            let c = self.psi_scale_share;
1606            let raw = self.accumulate_knot_vector(v, |idx| {
1607                let s_d = self.transformed_combo_axis_value_materialized(idx, combo_d);
1608                let s_e = self.transformed_combo_axis_value_materialized(idx, combo_e);
1609                let overlap_s = self.transformed_combo_overlap_materialized(idx, combo_d, combo_e);
1610                Self::transformed_second_kernel_value(
1611                    self.phi_values[idx],
1612                    self.q_values[idx],
1613                    self.t_values[idx],
1614                    s_d,
1615                    sum_d,
1616                    s_e,
1617                    sum_e,
1618                    overlap_s,
1619                    c,
1620                )
1621            });
1622            return Ok(self.project_and_pad(&raw));
1623        }
1624        if self.is_streaming() {
1625            let c = self.psi_scale_share;
1626            let raw = self.streaming_accumulate_knot_vector(v, |phi, q, t, sb| {
1627                t * sb[axis_d] * sb[axis_e] + c * q * (sb[axis_d] + sb[axis_e]) + c * c * phi
1628            })?;
1629            return Ok(self.project_and_pad(&raw));
1630        }
1631        let c = self.psi_scale_share;
1632        let af = &self.axis_components;
1633        let pv = &self.phi_values;
1634        let qv = &self.q_values;
1635        let tv = &self.t_values;
1636        let raw = self.accumulate_knot_vector(v, |idx| {
1637            tv[idx] * af[[idx, axis_d]] * af[[idx, axis_e]]
1638                + c * qv[idx] * (af[[idx, axis_d]] + af[[idx, axis_e]])
1639                + c * c * pv[idx]
1640        });
1641        Ok(self.project_and_pad(&raw))
1642    }
1643
1644    /// Compute (∂²X/∂ψ_d²) u — forward diagonal second derivative.
1645    pub fn forward_mul_second_diag(
1646        &self,
1647        axis: usize,
1648        u: &ArrayView1<f64>,
1649    ) -> Result<Array1<f64>, BasisError> {
1650        assert!(
1651            axis < self.n_axes(),
1652            "implicit psi second diagonal forward axis out of bounds: axis={axis}, n_axes={}",
1653            self.n_axes()
1654        );
1655        assert_eq!(
1656            u.len(),
1657            self.p_out(),
1658            "implicit psi second diagonal forward coefficient length mismatch"
1659        );
1660        let u_knot = self.unproject(u);
1661        if self.axis_combinations.is_some() {
1662            let combo = self.transformed_axis_combination(axis);
1663            let combo_sum = Self::transformed_combo_sum(combo);
1664            if self.is_streaming() {
1665                let c = self.psi_scale_share;
1666                return self.streaming_forward_mul(&u_knot, |phi, q, t, sb| {
1667                    let s_combo = combo
1668                        .iter()
1669                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1670                        .sum();
1671                    let overlap_s = Self::transformed_combo_overlap_streaming(combo, combo, sb);
1672                    Self::transformed_second_kernel_value(
1673                        phi, q, t, s_combo, combo_sum, s_combo, combo_sum, overlap_s, c,
1674                    )
1675                });
1676            }
1677            let n = self.n;
1678            let k = self.n_knots;
1679            let c = self.psi_scale_share;
1680            let compute_row = |i: usize| -> f64 {
1681                let base = i * k;
1682                let mut val = 0.0;
1683                for j in 0..k {
1684                    let idx = base + j;
1685                    let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1686                    let overlap_s = self.transformed_combo_overlap_materialized(idx, combo, combo);
1687                    val += Self::transformed_second_kernel_value(
1688                        self.phi_values[idx],
1689                        self.q_values[idx],
1690                        self.t_values[idx],
1691                        s_combo,
1692                        combo_sum,
1693                        s_combo,
1694                        combo_sum,
1695                        overlap_s,
1696                        c,
1697                    ) * u_knot[j];
1698                }
1699                val
1700            };
1701            if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1702                let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1703                let mut result = Array1::<f64>::zeros(n);
1704                let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1705                    .into_par_iter()
1706                    .map(|chunk_idx| {
1707                        let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1708                        let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1709                        let local: Vec<f64> = (start..end).map(compute_row).collect();
1710                        (start, local)
1711                    })
1712                    .collect();
1713                for (start, vals) in chunk_results {
1714                    for (offset, &value) in vals.iter().enumerate() {
1715                        result[start + offset] = value;
1716                    }
1717                }
1718                return Ok(result);
1719            }
1720            return Ok(Array1::from_vec((0..n).map(compute_row).collect()));
1721        }
1722        if self.is_streaming() {
1723            let c = self.psi_scale_share;
1724            return self.streaming_forward_mul(&u_knot, |phi, q, t, sb| {
1725                let s = sb[axis];
1726                2.0 * q * s + t * s * s + 2.0 * c * q * s + c * c * phi
1727            });
1728        }
1729        let n = self.n;
1730        let k = self.n_knots;
1731        let c = self.psi_scale_share;
1732        let af = &self.axis_components;
1733        let pv = &self.phi_values;
1734        let qv = &self.q_values;
1735        let tv = &self.t_values;
1736        let compute_row = |i: usize| -> f64 {
1737            let base = i * k;
1738            let mut val = 0.0;
1739            for j in 0..k {
1740                let s = af[[base + j, axis]];
1741                val += (2.0 * qv[base + j] * s
1742                    + tv[base + j] * s * s
1743                    + 2.0 * c * qv[base + j] * s
1744                    + c * c * pv[base + j])
1745                    * u_knot[j];
1746            }
1747            val
1748        };
1749
1750        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1751            let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1752            let mut result = Array1::<f64>::zeros(n);
1753            let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1754                .into_par_iter()
1755                .map(|chunk_idx| {
1756                    let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1757                    let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1758                    let local: Vec<f64> = (start..end).map(compute_row).collect();
1759                    (start, local)
1760                })
1761                .collect();
1762            for (start, vals) in chunk_results {
1763                for (offset, &value) in vals.iter().enumerate() {
1764                    result[start + offset] = value;
1765                }
1766            }
1767            Ok(result)
1768        } else {
1769            Ok(Array1::from_vec((0..n).map(compute_row).collect()))
1770        }
1771    }
1772
1773    /// Compute (∂²X/∂ψ_d∂ψ_e) u — forward cross second derivative.
1774    pub fn forward_mul_second_cross(
1775        &self,
1776        axis_d: usize,
1777        axis_e: usize,
1778        u: &ArrayView1<f64>,
1779    ) -> Result<Array1<f64>, BasisError> {
1780        assert!(
1781            axis_d < self.n_axes(),
1782            "implicit psi second cross forward first axis out of bounds: axis_d={axis_d}, n_axes={}",
1783            self.n_axes()
1784        );
1785        assert!(
1786            axis_e < self.n_axes(),
1787            "implicit psi second cross forward second axis out of bounds: axis_e={axis_e}, n_axes={}",
1788            self.n_axes()
1789        );
1790        assert_ne!(
1791            axis_d, axis_e,
1792            "implicit psi second cross forward requires distinct axes: axis_d={axis_d}, axis_e={axis_e}"
1793        );
1794        assert_eq!(
1795            u.len(),
1796            self.p_out(),
1797            "implicit psi second cross forward coefficient length mismatch"
1798        );
1799        let u_knot = self.unproject(u);
1800        if self.axis_combinations.is_some() {
1801            let combo_d = self.transformed_axis_combination(axis_d);
1802            let combo_e = self.transformed_axis_combination(axis_e);
1803            let sum_d = Self::transformed_combo_sum(combo_d);
1804            let sum_e = Self::transformed_combo_sum(combo_e);
1805            if self.is_streaming() {
1806                let c = self.psi_scale_share;
1807                return self.streaming_forward_mul(&u_knot, |phi, q, t, sb| {
1808                    let s_d = combo_d
1809                        .iter()
1810                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1811                        .sum();
1812                    let s_e = combo_e
1813                        .iter()
1814                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1815                        .sum();
1816                    let overlap_s = Self::transformed_combo_overlap_streaming(combo_d, combo_e, sb);
1817                    Self::transformed_second_kernel_value(
1818                        phi, q, t, s_d, sum_d, s_e, sum_e, overlap_s, c,
1819                    )
1820                });
1821            }
1822            let n = self.n;
1823            let k = self.n_knots;
1824            let c = self.psi_scale_share;
1825            let compute_row = |i: usize| -> f64 {
1826                let base = i * k;
1827                let mut val = 0.0;
1828                for j in 0..k {
1829                    let idx = base + j;
1830                    let s_d = self.transformed_combo_axis_value_materialized(idx, combo_d);
1831                    let s_e = self.transformed_combo_axis_value_materialized(idx, combo_e);
1832                    let overlap_s =
1833                        self.transformed_combo_overlap_materialized(idx, combo_d, combo_e);
1834                    val += Self::transformed_second_kernel_value(
1835                        self.phi_values[idx],
1836                        self.q_values[idx],
1837                        self.t_values[idx],
1838                        s_d,
1839                        sum_d,
1840                        s_e,
1841                        sum_e,
1842                        overlap_s,
1843                        c,
1844                    ) * u_knot[j];
1845                }
1846                val
1847            };
1848            if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1849                let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1850                let mut result = Array1::<f64>::zeros(n);
1851                let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1852                    .into_par_iter()
1853                    .map(|chunk_idx| {
1854                        let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1855                        let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1856                        let local: Vec<f64> = (start..end).map(compute_row).collect();
1857                        (start, local)
1858                    })
1859                    .collect();
1860                for (start, vals) in chunk_results {
1861                    for (offset, &value) in vals.iter().enumerate() {
1862                        result[start + offset] = value;
1863                    }
1864                }
1865                return Ok(result);
1866            }
1867            return Ok(Array1::from_vec((0..n).map(compute_row).collect()));
1868        }
1869        if self.is_streaming() {
1870            let c = self.psi_scale_share;
1871            return self.streaming_forward_mul(&u_knot, |phi, q, t, sb| {
1872                t * sb[axis_d] * sb[axis_e] + c * q * (sb[axis_d] + sb[axis_e]) + c * c * phi
1873            });
1874        }
1875        let n = self.n;
1876        let k = self.n_knots;
1877        let c = self.psi_scale_share;
1878        let af = &self.axis_components;
1879        let pv = &self.phi_values;
1880        let qv = &self.q_values;
1881        let tv = &self.t_values;
1882        let compute_row = |i: usize| -> f64 {
1883            let base = i * k;
1884            let mut val = 0.0;
1885            for j in 0..k {
1886                val += (tv[base + j] * af[[base + j, axis_d]] * af[[base + j, axis_e]]
1887                    + c * qv[base + j] * (af[[base + j, axis_d]] + af[[base + j, axis_e]])
1888                    + c * c * pv[base + j])
1889                    * u_knot[j];
1890            }
1891            val
1892        };
1893
1894        if n >= IMPLICIT_MATVEC_PAR_THRESHOLD {
1895            let n_chunks = n.div_ceil(IMPLICIT_MATVEC_CHUNK_SIZE);
1896            let mut result = Array1::<f64>::zeros(n);
1897            let chunk_results: Vec<(usize, Vec<f64>)> = (0..n_chunks)
1898                .into_par_iter()
1899                .map(|chunk_idx| {
1900                    let start = chunk_idx * IMPLICIT_MATVEC_CHUNK_SIZE;
1901                    let end = (start + IMPLICIT_MATVEC_CHUNK_SIZE).min(n);
1902                    let local: Vec<f64> = (start..end).map(compute_row).collect();
1903                    (start, local)
1904                })
1905                .collect();
1906            for (start, vals) in chunk_results {
1907                for (offset, &value) in vals.iter().enumerate() {
1908                    result[start + offset] = value;
1909                }
1910            }
1911            Ok(result)
1912        } else {
1913            Ok(Array1::from_vec((0..n).map(compute_row).collect()))
1914        }
1915    }
1916
1917    /// Materialize the full (n × p_out) first-derivative matrix for axis d.
1918    ///
1919    /// Efficient O(n * k) construction: builds the raw (n × k) kernel derivative
1920    /// matrix directly, then projects through identifiability transforms.
1921    /// This is used when the dense matrix is needed temporarily (e.g., for
1922    /// HyperCoord construction) while avoiding simultaneous storage of all D axes.
1923    pub fn materialize_first(&self, axis: usize) -> Result<Array2<f64>, BasisError> {
1924        assert!(
1925            axis < self.n_axes(),
1926            "implicit psi first materialization axis out of bounds: axis={axis}, n_axes={}",
1927            self.n_axes()
1928        );
1929        if self.enforces_dense_materialization_budget() {
1930            assert_no_dense_derivative_materialization(self.n, self.p_out(), self.n_axes());
1931        }
1932        if self.axis_combinations.is_some() {
1933            let combo = self.transformed_axis_combination(axis);
1934            let combo_sum = Self::transformed_combo_sum(combo);
1935            if self.is_streaming() {
1936                let c = self.psi_scale_share;
1937                return self.streaming_materialize(|phi, q, _, sb| {
1938                    let s_combo = combo
1939                        .iter()
1940                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
1941                        .sum();
1942                    Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
1943                });
1944            }
1945            let n = self.n;
1946            let k = self.n_knots;
1947            let c = self.psi_scale_share;
1948            let mut raw = Array2::<f64>::zeros((n, k));
1949            for i in 0..n {
1950                let base = i * k;
1951                for j in 0..k {
1952                    let idx = base + j;
1953                    let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
1954                    raw[[i, j]] = Self::transformed_first_kernel_value(
1955                        self.phi_values[idx],
1956                        self.q_values[idx],
1957                        s_combo,
1958                        combo_sum,
1959                        c,
1960                    );
1961                }
1962            }
1963            return Ok(self.project_matrix(raw));
1964        }
1965        if self.is_streaming() {
1966            let c = self.psi_scale_share;
1967            return self.streaming_materialize(|phi, q, _, sb| q * sb[axis] + c * phi);
1968        }
1969        let n = self.n;
1970        let k = self.n_knots;
1971        let c = self.psi_scale_share;
1972        let mut raw = Array2::<f64>::zeros((n, k));
1973        for i in 0..n {
1974            let base = i * k;
1975            for j in 0..k {
1976                raw[[i, j]] = self.q_values[base + j] * self.axis_components[[base + j, axis]]
1977                    + c * self.phi_values[base + j];
1978            }
1979        }
1980        Ok(self.project_matrix(raw))
1981    }
1982
1983    /// Materialize the full (n × p_out) second diagonal derivative matrix for axis d.
1984    pub fn materialize_second_diag(&self, axis: usize) -> Result<Array2<f64>, BasisError> {
1985        assert!(
1986            axis < self.n_axes(),
1987            "implicit psi second diagonal materialization axis out of bounds: axis={axis}, n_axes={}",
1988            self.n_axes()
1989        );
1990        if self.enforces_dense_materialization_budget() {
1991            assert_no_dense_derivative_materialization(self.n, self.p_out(), self.n_axes());
1992        }
1993        if self.axis_combinations.is_some() {
1994            let combo = self.transformed_axis_combination(axis);
1995            let combo_sum = Self::transformed_combo_sum(combo);
1996            if self.is_streaming() {
1997                let c = self.psi_scale_share;
1998                return self.streaming_materialize(|phi, q, t, sb| {
1999                    let s_combo = combo
2000                        .iter()
2001                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2002                        .sum();
2003                    let overlap_s = Self::transformed_combo_overlap_streaming(combo, combo, sb);
2004                    Self::transformed_second_kernel_value(
2005                        phi, q, t, s_combo, combo_sum, s_combo, combo_sum, overlap_s, c,
2006                    )
2007                });
2008            }
2009            let n = self.n;
2010            let k = self.n_knots;
2011            let c = self.psi_scale_share;
2012            let mut raw = Array2::<f64>::zeros((n, k));
2013            for i in 0..n {
2014                let base = i * k;
2015                for j in 0..k {
2016                    let idx = base + j;
2017                    let s_combo = self.transformed_combo_axis_value_materialized(idx, combo);
2018                    let overlap_s = self.transformed_combo_overlap_materialized(idx, combo, combo);
2019                    raw[[i, j]] = Self::transformed_second_kernel_value(
2020                        self.phi_values[idx],
2021                        self.q_values[idx],
2022                        self.t_values[idx],
2023                        s_combo,
2024                        combo_sum,
2025                        s_combo,
2026                        combo_sum,
2027                        overlap_s,
2028                        c,
2029                    );
2030                }
2031            }
2032            return Ok(self.project_matrix(raw));
2033        }
2034        if self.is_streaming() {
2035            let c = self.psi_scale_share;
2036            return self.streaming_materialize(|phi, q, t, sb| {
2037                let s = sb[axis];
2038                2.0 * q * s + t * s * s + 2.0 * c * q * s + c * c * phi
2039            });
2040        }
2041        let n = self.n;
2042        let k = self.n_knots;
2043        let c = self.psi_scale_share;
2044        let mut raw = Array2::<f64>::zeros((n, k));
2045        for i in 0..n {
2046            let base = i * k;
2047            for j in 0..k {
2048                let s = self.axis_components[[base + j, axis]];
2049                raw[[i, j]] = 2.0 * self.q_values[base + j] * s
2050                    + self.t_values[base + j] * s * s
2051                    + 2.0 * c * self.q_values[base + j] * s
2052                    + c * c * self.phi_values[base + j];
2053            }
2054        }
2055        Ok(self.project_matrix(raw))
2056    }
2057
2058    /// Materialize the full (n × p_out) cross second derivative matrix for axes (d, e).
2059    ///
2060    /// Dense materialization of the t · s_d · s_e cross coupling.
2061    pub fn materialize_second_cross(
2062        &self,
2063        axis_d: usize,
2064        axis_e: usize,
2065    ) -> Result<Array2<f64>, BasisError> {
2066        assert!(
2067            axis_d < self.n_axes(),
2068            "implicit psi second cross materialization first axis out of bounds: axis_d={axis_d}, n_axes={}",
2069            self.n_axes()
2070        );
2071        assert!(
2072            axis_e < self.n_axes(),
2073            "implicit psi second cross materialization second axis out of bounds: axis_e={axis_e}, n_axes={}",
2074            self.n_axes()
2075        );
2076        assert_ne!(
2077            axis_d, axis_e,
2078            "implicit psi second cross materialization requires distinct axes: axis_d={axis_d}, axis_e={axis_e}"
2079        );
2080        if self.enforces_dense_materialization_budget() {
2081            assert_no_dense_derivative_materialization(self.n, self.p_out(), self.n_axes());
2082        }
2083        if self.axis_combinations.is_some() {
2084            let combo_d = self.transformed_axis_combination(axis_d);
2085            let combo_e = self.transformed_axis_combination(axis_e);
2086            let sum_d = Self::transformed_combo_sum(combo_d);
2087            let sum_e = Self::transformed_combo_sum(combo_e);
2088            if self.is_streaming() {
2089                let c = self.psi_scale_share;
2090                return self.streaming_materialize(|phi, q, t, sb| {
2091                    let s_d = combo_d
2092                        .iter()
2093                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2094                        .sum();
2095                    let s_e = combo_e
2096                        .iter()
2097                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2098                        .sum();
2099                    let overlap_s = Self::transformed_combo_overlap_streaming(combo_d, combo_e, sb);
2100                    Self::transformed_second_kernel_value(
2101                        phi, q, t, s_d, sum_d, s_e, sum_e, overlap_s, c,
2102                    )
2103                });
2104            }
2105            let n = self.n;
2106            let k = self.n_knots;
2107            let c = self.psi_scale_share;
2108            let mut raw = Array2::<f64>::zeros((n, k));
2109            for i in 0..n {
2110                let base = i * k;
2111                for j in 0..k {
2112                    let idx = base + j;
2113                    let s_d = self.transformed_combo_axis_value_materialized(idx, combo_d);
2114                    let s_e = self.transformed_combo_axis_value_materialized(idx, combo_e);
2115                    let overlap_s =
2116                        self.transformed_combo_overlap_materialized(idx, combo_d, combo_e);
2117                    raw[[i, j]] = Self::transformed_second_kernel_value(
2118                        self.phi_values[idx],
2119                        self.q_values[idx],
2120                        self.t_values[idx],
2121                        s_d,
2122                        sum_d,
2123                        s_e,
2124                        sum_e,
2125                        overlap_s,
2126                        c,
2127                    );
2128                }
2129            }
2130            return Ok(self.project_matrix(raw));
2131        }
2132        if self.is_streaming() {
2133            let c = self.psi_scale_share;
2134            return self.streaming_materialize(|phi, q, t, sb| {
2135                t * sb[axis_d] * sb[axis_e] + c * q * (sb[axis_d] + sb[axis_e]) + c * c * phi
2136            });
2137        }
2138        let n = self.n;
2139        let k = self.n_knots;
2140        let c = self.psi_scale_share;
2141        let mut raw = Array2::<f64>::zeros((n, k));
2142        for i in 0..n {
2143            let base = i * k;
2144            for j in 0..k {
2145                raw[[i, j]] = self.t_values[base + j]
2146                    * self.axis_components[[base + j, axis_d]]
2147                    * self.axis_components[[base + j, axis_e]]
2148                    + c * self.q_values[base + j]
2149                        * (self.axis_components[[base + j, axis_d]]
2150                            + self.axis_components[[base + j, axis_e]])
2151                    + c * c * self.phi_values[base + j];
2152            }
2153        }
2154        Ok(self.project_matrix(raw))
2155    }
2156
2157    /// Project a raw (n × k) kernel-space matrix through all transforms to
2158    /// produce an (n × p_out) matrix: Z_kernel → pad poly → full ident.
2159    pub(crate) fn project_matrix(&self, raw: Array2<f64>) -> Array2<f64> {
2160        // Step 1: kernel constraint projection.
2161        let constrained = match &self.ident_transform {
2162            Some(z) => fast_ab(&raw, z),
2163            None => raw,
2164        };
2165
2166        // Step 2: polynomial padding.
2167        let padded = if self.n_poly > 0 {
2168            let cols = constrained.ncols();
2169            let mut out = Array2::<f64>::zeros((self.n, cols + self.n_poly));
2170            out.slice_mut(s![.., ..cols]).assign(&constrained);
2171            out
2172        } else {
2173            constrained
2174        };
2175
2176        // Step 3: full identifiability transform.
2177        match &self.full_ident_transform {
2178            Some(zf) => fast_ab(&padded, zf),
2179            None => padded,
2180        }
2181    }
2182
2183    pub(crate) fn project_matrix_rows(&self, raw: Array2<f64>) -> Array2<f64> {
2184        let nrows = raw.nrows();
2185        let constrained = match &self.ident_transform {
2186            Some(z) => fast_ab(&raw, z),
2187            None => raw,
2188        };
2189        let padded = if self.n_poly > 0 {
2190            let cols = constrained.ncols();
2191            let mut out = Array2::<f64>::zeros((nrows, cols + self.n_poly));
2192            out.slice_mut(s![.., ..cols]).assign(&constrained);
2193            out
2194        } else {
2195            constrained
2196        };
2197        match &self.full_ident_transform {
2198            Some(zf) => fast_ab(&padded, zf),
2199            None => padded,
2200        }
2201    }
2202
2203    pub(crate) fn row_chunk_with_kernel<G>(
2204        &self,
2205        rows: std::ops::Range<usize>,
2206        deriv_fn: G,
2207    ) -> Result<Array2<f64>, BasisError>
2208    where
2209        G: Fn(f64, f64, f64, &[f64], usize) -> f64,
2210    {
2211        let raw = self.row_chunk_with_kernel_raw(rows, deriv_fn)?;
2212        Ok(self.project_matrix_rows(raw))
2213    }
2214
2215    /// Like `row_chunk_with_kernel` but returns the raw (chunk × n_knots)
2216    /// kernel scalars without the identifiability/padding projection. Used
2217    /// by `forward_mul_matrix`, which does the projection on the rank side
2218    /// instead (`unproject_matrix(F)`) so the (n × p_out) projected
2219    /// derivative is never materialized for large-scale row counts.
2220    pub(crate) fn row_chunk_with_kernel_raw<G>(
2221        &self,
2222        rows: std::ops::Range<usize>,
2223        deriv_fn: G,
2224    ) -> Result<Array2<f64>, BasisError>
2225    where
2226        G: Fn(f64, f64, f64, &[f64], usize) -> f64,
2227    {
2228        let mut raw = Array2::<f64>::zeros((rows.end - rows.start, self.n_knots));
2229        if let Some(st) = self.streaming.as_ref() {
2230            let mut sb = vec![0.0; self.n_axes];
2231            if let Some(cache) = st.ensure_triplet_cache() {
2232                for (local, i) in rows.enumerate() {
2233                    let base = i * self.n_knots;
2234                    for j in 0..self.n_knots {
2235                        let idx = base + j;
2236                        st.fill_s_buf(i, j, &mut sb);
2237                        raw[[local, j]] =
2238                            deriv_fn(cache.phi[idx], cache.q[idx], cache.t[idx], &sb, idx);
2239                    }
2240                }
2241            } else {
2242                for (local, i) in rows.enumerate() {
2243                    for j in 0..self.n_knots {
2244                        let (phi, q, t) = st.compute_pair(i, j, &mut sb)?;
2245                        raw[[local, j]] = deriv_fn(phi, q, t, &sb, i * self.n_knots + j);
2246                    }
2247                }
2248            }
2249        } else {
2250            for (local, i) in rows.enumerate() {
2251                let base = i * self.n_knots;
2252                for j in 0..self.n_knots {
2253                    let idx = base + j;
2254                    raw[[local, j]] = deriv_fn(
2255                        self.phi_values[idx],
2256                        self.q_values[idx],
2257                        self.t_values[idx],
2258                        &[],
2259                        idx,
2260                    );
2261                }
2262            }
2263        }
2264        Ok(raw)
2265    }
2266
2267    pub fn row_chunk_first(
2268        &self,
2269        axis: usize,
2270        rows: std::ops::Range<usize>,
2271    ) -> Result<Array2<f64>, BasisError> {
2272        assert!(
2273            axis < self.n_axes(),
2274            "implicit psi first row chunk axis out of bounds: axis={axis}, n_axes={}",
2275            self.n_axes()
2276        );
2277        let c = self.psi_scale_share;
2278        if self.axis_combinations.is_some() {
2279            let combo = self.transformed_axis_combination(axis);
2280            let combo_sum = Self::transformed_combo_sum(combo);
2281            return self.row_chunk_with_kernel(rows, |phi, q, _, sb, idx| {
2282                let s_combo = if sb.is_empty() {
2283                    self.transformed_combo_axis_value_materialized(idx, combo)
2284                } else {
2285                    combo
2286                        .iter()
2287                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2288                        .sum()
2289                };
2290                Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
2291            });
2292        }
2293        self.row_chunk_with_kernel(rows, |phi, q, _, sb, idx| {
2294            let s = if sb.is_empty() {
2295                self.axis_components[[idx, axis]]
2296            } else {
2297                sb[axis]
2298            };
2299            q * s + c * phi
2300        })
2301    }
2302
2303    /// Raw (chunk × n_knots) first-order kernel scalars for axis d, without
2304    /// the identifiability/padding projection. Pairs with `unproject_matrix`
2305    /// in `forward_mul_matrix`: the kernel scalars stay in raw knot space
2306    /// while the rank side (F) is unprojected to knot space, so the per-chunk
2307    /// GEMM is (chunk × n_knots) · (n_knots × rank) rather than (chunk × p_out)
2308    /// · (p_out × rank). Saves both flops and a (chunk × p_out) intermediate.
2309    pub fn row_chunk_first_raw(
2310        &self,
2311        axis: usize,
2312        rows: std::ops::Range<usize>,
2313    ) -> Result<Array2<f64>, BasisError> {
2314        assert!(
2315            axis < self.n_axes(),
2316            "implicit psi first raw row chunk axis out of bounds: axis={axis}, n_axes={}",
2317            self.n_axes()
2318        );
2319        let c = self.psi_scale_share;
2320        if self.axis_combinations.is_some() {
2321            let combo = self.transformed_axis_combination(axis);
2322            let combo_sum = Self::transformed_combo_sum(combo);
2323            return self.row_chunk_with_kernel_raw(rows, |phi, q, _, sb, idx| {
2324                let s_combo = if sb.is_empty() {
2325                    self.transformed_combo_axis_value_materialized(idx, combo)
2326                } else {
2327                    combo
2328                        .iter()
2329                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2330                        .sum()
2331                };
2332                Self::transformed_first_kernel_value(phi, q, s_combo, combo_sum, c)
2333            });
2334        }
2335        self.row_chunk_with_kernel_raw(rows, |phi, q, _, sb, idx| {
2336            let s = if sb.is_empty() {
2337                self.axis_components[[idx, axis]]
2338            } else {
2339                sb[axis]
2340            };
2341            q * s + c * phi
2342        })
2343    }
2344
2345    pub fn row_chunk_second_diag(
2346        &self,
2347        axis: usize,
2348        rows: std::ops::Range<usize>,
2349    ) -> Result<Array2<f64>, BasisError> {
2350        assert!(
2351            axis < self.n_axes(),
2352            "implicit psi second diagonal row chunk axis out of bounds: axis={axis}, n_axes={}",
2353            self.n_axes()
2354        );
2355        let c = self.psi_scale_share;
2356        if self.axis_combinations.is_some() {
2357            let combo = self.transformed_axis_combination(axis);
2358            let combo_sum = Self::transformed_combo_sum(combo);
2359            return self.row_chunk_with_kernel(rows, |phi, q, t, sb, idx| {
2360                let s_combo = if sb.is_empty() {
2361                    self.transformed_combo_axis_value_materialized(idx, combo)
2362                } else {
2363                    combo
2364                        .iter()
2365                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2366                        .sum()
2367                };
2368                let overlap = if sb.is_empty() {
2369                    self.transformed_combo_overlap_materialized(idx, combo, combo)
2370                } else {
2371                    Self::transformed_combo_overlap_streaming(combo, combo, sb)
2372                };
2373                Self::transformed_second_kernel_value(
2374                    phi, q, t, s_combo, combo_sum, s_combo, combo_sum, overlap, c,
2375                )
2376            });
2377        }
2378        self.row_chunk_with_kernel(rows, |phi, q, t, sb, idx| {
2379            let s = if sb.is_empty() {
2380                self.axis_components[[idx, axis]]
2381            } else {
2382                sb[axis]
2383            };
2384            2.0 * q * s + t * s * s + 2.0 * c * q * s + c * c * phi
2385        })
2386    }
2387
2388    pub fn row_chunk_second_cross(
2389        &self,
2390        axis_d: usize,
2391        axis_e: usize,
2392        rows: std::ops::Range<usize>,
2393    ) -> Result<Array2<f64>, BasisError> {
2394        assert!(
2395            axis_d < self.n_axes(),
2396            "implicit psi second cross row chunk first axis out of bounds: axis_d={axis_d}, n_axes={}",
2397            self.n_axes()
2398        );
2399        assert!(
2400            axis_e < self.n_axes(),
2401            "implicit psi second cross row chunk second axis out of bounds: axis_e={axis_e}, n_axes={}",
2402            self.n_axes()
2403        );
2404        assert_ne!(
2405            axis_d, axis_e,
2406            "implicit psi second cross row chunk requires distinct axes: axis_d={axis_d}, axis_e={axis_e}"
2407        );
2408        let c = self.psi_scale_share;
2409        if self.axis_combinations.is_some() {
2410            let combo_d = self.transformed_axis_combination(axis_d);
2411            let combo_e = self.transformed_axis_combination(axis_e);
2412            let sum_d = Self::transformed_combo_sum(combo_d);
2413            let sum_e = Self::transformed_combo_sum(combo_e);
2414            return self.row_chunk_with_kernel(rows, |phi, q, t, sb, idx| {
2415                let s_d = if sb.is_empty() {
2416                    self.transformed_combo_axis_value_materialized(idx, combo_d)
2417                } else {
2418                    combo_d
2419                        .iter()
2420                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2421                        .sum()
2422                };
2423                let s_e = if sb.is_empty() {
2424                    self.transformed_combo_axis_value_materialized(idx, combo_e)
2425                } else {
2426                    combo_e
2427                        .iter()
2428                        .map(|(raw_axis, coeff)| coeff * sb[*raw_axis])
2429                        .sum()
2430                };
2431                let overlap = if sb.is_empty() {
2432                    self.transformed_combo_overlap_materialized(idx, combo_d, combo_e)
2433                } else {
2434                    Self::transformed_combo_overlap_streaming(combo_d, combo_e, sb)
2435                };
2436                Self::transformed_second_kernel_value(phi, q, t, s_d, sum_d, s_e, sum_e, overlap, c)
2437            });
2438        }
2439        self.row_chunk_with_kernel(rows, |phi, q, t, sb, idx| {
2440            let sd = if sb.is_empty() {
2441                self.axis_components[[idx, axis_d]]
2442            } else {
2443                sb[axis_d]
2444            };
2445            let se = if sb.is_empty() {
2446                self.axis_components[[idx, axis_e]]
2447            } else {
2448                sb[axis_e]
2449            };
2450            t * sd * se + c * q * (sd + se) + c * c * phi
2451        })
2452    }
2453
2454    /// Single-row specialization of `row_chunk_first(axis, row..row+1)` that
2455    /// writes the length-`p_out` row directly into the caller-provided buffer.
2456    ///
2457    /// This is the row-local API used by `CustomFamilyPsiLinearMapRef::row_vector`
2458    /// for survival rowwise exact-Hessian paths, which previously applied a
2459    /// unit-vector `transpose_mul` trick (O(n·K) per row) to recover a single
2460    /// row. Avoids allocating a temporary (1 × p_out) matrix per row call.
2461    pub fn row_vector_first_into(
2462        &self,
2463        axis: usize,
2464        row: usize,
2465        mut out: ArrayViewMut1<'_, f64>,
2466    ) -> Result<(), BasisError> {
2467        assert!(
2468            row < self.n,
2469            "implicit psi row-vector request out of bounds: row={row}, n={}",
2470            self.n
2471        );
2472        assert_eq!(
2473            out.len(),
2474            self.p_out(),
2475            "implicit psi row-vector output length mismatch"
2476        );
2477        let chunk = self.row_chunk_first(axis, row..row + 1)?;
2478        out.assign(&chunk.row(0));
2479        Ok(())
2480    }
2481
2482    pub(crate) fn transformed_axis_combination(&self, axis: usize) -> &[(usize, f64)] {
2483        self.axis_combinations
2484            .as_ref()
2485            .expect("transformed axis combinations")
2486            .get(axis)
2487            .map(Vec::as_slice)
2488            .expect("transformed axis index")
2489    }
2490
2491    #[inline]
2492    pub(crate) fn transformed_combo_sum(combo: &[(usize, f64)]) -> f64 {
2493        combo.iter().map(|(_, coeff)| *coeff).sum()
2494    }
2495
2496    #[inline]
2497    pub(crate) fn transformed_combo_axis_value_materialized(
2498        &self,
2499        idx: usize,
2500        combo: &[(usize, f64)],
2501    ) -> f64 {
2502        combo
2503            .iter()
2504            .map(|(raw_axis, coeff)| coeff * self.axis_components[[idx, *raw_axis]])
2505            .sum()
2506    }
2507
2508    #[inline]
2509    pub(crate) fn transformed_combo_overlap_streaming(
2510        combo_left: &[(usize, f64)],
2511        combo_right: &[(usize, f64)],
2512        sb: &[f64],
2513    ) -> f64 {
2514        let mut overlap = 0.0;
2515        for &(left_axis, left_coeff) in combo_left {
2516            for &(right_axis, right_coeff) in combo_right {
2517                if left_axis == right_axis {
2518                    overlap += left_coeff * right_coeff * sb[left_axis];
2519                }
2520            }
2521        }
2522        overlap
2523    }
2524
2525    #[inline]
2526    pub(crate) fn transformed_combo_overlap_materialized(
2527        &self,
2528        idx: usize,
2529        combo_left: &[(usize, f64)],
2530        combo_right: &[(usize, f64)],
2531    ) -> f64 {
2532        let mut overlap = 0.0;
2533        for &(left_axis, left_coeff) in combo_left {
2534            for &(right_axis, right_coeff) in combo_right {
2535                if left_axis == right_axis {
2536                    overlap += left_coeff * right_coeff * self.axis_components[[idx, left_axis]];
2537                }
2538            }
2539        }
2540        overlap
2541    }
2542
2543    #[inline]
2544    pub(crate) fn transformed_first_kernel_value(
2545        phi: f64,
2546        q: f64,
2547        s_combo: f64,
2548        coeff_sum: f64,
2549        psi_scale_share: f64,
2550    ) -> f64 {
2551        q * s_combo + psi_scale_share * coeff_sum * phi
2552    }
2553
2554    #[inline]
2555    pub(crate) fn transformed_second_kernel_value(
2556        phi: f64,
2557        q: f64,
2558        t: f64,
2559        s_left: f64,
2560        left_sum: f64,
2561        s_right: f64,
2562        right_sum: f64,
2563        overlap_s: f64,
2564        psi_scale_share: f64,
2565    ) -> f64 {
2566        t * s_left * s_right
2567            + 2.0 * q * overlap_s
2568            + psi_scale_share * q * (right_sum * s_left + left_sum * s_right)
2569            + psi_scale_share * psi_scale_share * left_sum * right_sum * phi
2570    }
2571}
2572
2573pub(crate) fn build_aniso_design_psi_derivatives_shared(
2574    data: ArrayView2<'_, f64>,
2575    centers: ArrayView2<'_, f64>,
2576    eta: &[f64],
2577    p_final: usize,
2578    ident_transform: Option<Array2<f64>>,
2579    full_ident_transform: Option<Array2<f64>>,
2580    n_poly: usize,
2581    radial_kind: RadialScalarKind,
2582) -> Result<AnisoBasisPsiDerivatives, BasisError> {
2583    let n = data.nrows();
2584    let k = centers.nrows();
2585    let dim = data.ncols();
2586    if eta.len() != dim {
2587        crate::bail_dim_basis!(
2588            "aniso design derivatives: eta.len()={} != data dimension {dim}",
2589            eta.len()
2590        );
2591    }
2592
2593    let policy = gam_runtime::resource::ResourcePolicy::default_library();
2594    let force_operator = radial_kind.is_duchon_family();
2595    let dense_derivatives_exceed_budget =
2596        should_use_implicit_operators_with_policy(n, p_final, dim, &policy);
2597    let operator_only = force_operator || dense_derivatives_exceed_budget;
2598    let cache_radial_components = should_cache_implicit_radial_components(n, k, dim, &policy);
2599    // gam#1376 — the per-axis ψ derivatives this operator produces are ALREADY
2600    // the derivatives w.r.t. the κ-optimizer's raw coordinate, so NO cross-axis
2601    // centering projection is installed (for any family). The optimizer's per-
2602    // axis coordinate `psi_a` is decoded into both the global length scale
2603    // `ℓ = exp(−mean(psi))` and the centered contrast `eta_a = psi_a − mean(psi)`
2604    // simultaneously; in the kernel argument `x² = r²/ℓ² = Σ_a exp(2·psi_a)·h_a²`
2605    // the `mean(psi)` cancels, so the effective per-axis exponent is the raw
2606    // `psi_a` and `∂φ/∂psi_a = q·s_a` is the native per-axis ψ derivative. The
2607    // earlier `with_raw_eta_centering` projection annihilated the all-ones
2608    // (global-scale) direction and broke the analytic↔FD match (rel≈0.85). The
2609    // dense path (`build_matern_basis_log_kappa_aniso_derivatives`) is corrected
2610    // identically — it no longer centers downstream.
2611
2612    // ── Streaming path: large scale ─────────────────────────────────────
2613    // When even the compact radial cache would exceed the operator-cache
2614    // budget, store only data/centers/eta/radial_kind and recompute
2615    // (q, t, s_a) chunkwise during each matvec. Otherwise the operator-only
2616    // path below caches phi/q/t/s_a without materializing dense derivative
2617    // matrices.
2618    if operator_only && !cache_radial_components {
2619        let op = ImplicitDesignPsiDerivative::new_streaming(
2620            shared_owned_data_matrix_from_view(data),
2621            shared_owned_centers_matrix_from_view(centers),
2622            eta.to_vec(),
2623            radial_kind,
2624            ident_transform,
2625            full_ident_transform,
2626            n_poly,
2627        );
2628        return Ok(AnisoBasisPsiDerivatives {
2629            design_first: Vec::new(),
2630            design_second_diag: Vec::new(),
2631            design_second_cross: Vec::new(),
2632            design_second_cross_pairs: Vec::new(),
2633            penalties_first: vec![Vec::new(); dim],
2634            penalties_second_diag: vec![Vec::new(); dim],
2635            penalties_cross_pairs: Vec::new(),
2636            penalties_cross_provider: None,
2637            implicit_operator: Some(op),
2638        });
2639    }
2640
2641    // ── Materialized radial-cache path ────────────────────────────────────
2642    // Allocate O(n*k) arrays up front and fill with parallel chunks that
2643    // write directly into preallocated storage via raw pointers. No
2644    // intermediate Vec<(i, q_row, t_row, s_row)> collection.
2645    let nk = n.checked_mul(k).ok_or_else(|| {
2646        BasisError::InvalidInput("aniso radial cache has too many data-center pairs".to_string())
2647    })?;
2648    if nk.checked_mul(dim).is_none() {
2649        crate::bail_invalid_basis!("aniso radial cache axis component storage is too large");
2650    }
2651    let mut phi_values = Array1::<f64>::zeros(nk);
2652    let mut q_values = Array1::<f64>::zeros(nk);
2653    let mut t_values = Array1::<f64>::zeros(nk);
2654    let mut axis_components = Array2::<f64>::zeros((nk, dim));
2655
2656    let psi_scale_share = radial_kind.raw_psi_isotropic_share();
2657
2658    let cs = IMPLICIT_MATVEC_CHUNK_SIZE;
2659    let nc = n.div_ceil(cs);
2660    // Capture the *first* underlying radial-evaluation error rather than a
2661    // bare boolean: at an extreme trial hyperparameter the anisotropic
2662    // distance `r` can push the Duchon/Matérn radial kernel out of its
2663    // evaluable range, and the caller (the spatial-κ optimizer) needs the
2664    // real cause to decide whether the trial point is merely infeasible
2665    // (retreat) versus a genuine invariant violation (abort). Swallowing it
2666    // as "radial scalar evaluation failed" hid both the cause and the
2667    // recoverability.
2668    let first_err: std::sync::Mutex<Option<BasisError>> = std::sync::Mutex::new(None);
2669    // For large sweeps, replace per-pair exact radial evaluation with a
2670    // certified 1-D Chebyshev profile built once from a distance-only
2671    // pre-pass over the radius range (see `radial_profile`): at the 16-D
2672    // power-9 hybrid Duchon configuration a single exact triplet costs tens
2673    // of microseconds across its partial-fraction blocks, and this n·k
2674    // sweep was the dominant per-κ-trial cost of large-scale fits (#979).
2675    // Out-of-range radii and uncertified builds fall back to the exact
2676    // evaluator per pair.
2677    let profile = if nk >= RADIAL_PROFILE_MIN_PAIRS {
2678        let mut r_lo = f64::INFINITY;
2679        let mut r_hi = 0.0_f64;
2680        let mut drb = vec![0.0; dim];
2681        let mut cb = vec![0.0; dim];
2682        for i in 0..n {
2683            for a in 0..dim {
2684                drb[a] = data[[i, a]];
2685            }
2686            for j in 0..k {
2687                for a in 0..dim {
2688                    cb[a] = centers[[j, a]];
2689                }
2690                let (r, _) = aniso_distance_and_components(&drb, &cb, eta);
2691                if r > 0.0 {
2692                    r_lo = r_lo.min(r);
2693                    r_hi = r_hi.max(r);
2694                }
2695            }
2696        }
2697        if r_lo.is_finite() && r_hi > r_lo {
2698            radial_profile::RadialProfile::build(&radial_kind, r_lo, r_hi)
2699        } else {
2700            None
2701        }
2702    } else {
2703        None
2704    };
2705    {
2706        let pp = SendPtr(phi_values.as_mut_ptr());
2707        let qp = SendPtr(q_values.as_mut_ptr());
2708        let tp = SendPtr(t_values.as_mut_ptr());
2709        let ap = SendPtr(axis_components.as_mut_ptr());
2710        let ferr = &first_err;
2711        let profile_ref = profile.as_ref();
2712        (0..nc).into_par_iter().for_each(move |ci| {
2713            let start = ci * cs;
2714            let end = start.saturating_add(cs).min(n);
2715            let mut drb = vec![0.0; dim];
2716            let mut cb = vec![0.0; dim];
2717            for i in start..end {
2718                for a in 0..dim {
2719                    drb[a] = data[[i, a]];
2720                }
2721                for j in 0..k {
2722                    for a in 0..dim {
2723                        cb[a] = centers[[j, a]];
2724                    }
2725                    let (r, sv) = aniso_distance_and_components(&drb, &cb, eta);
2726                    let triplet = match profile_ref {
2727                        Some(profile) => profile.eval_or_exact(&radial_kind, r),
2728                        None => radial_kind.eval_design_triplet(r),
2729                    };
2730                    let (phi, q, t) = match triplet {
2731                        Ok(p) => p,
2732                        Err(e) => {
2733                            let mut slot = ferr.lock().unwrap_or_else(|p| p.into_inner());
2734                            if slot.is_none() {
2735                                *slot = Some(e);
2736                            }
2737                            return;
2738                        }
2739                    };
2740                    let flat = i * k + j;
2741                    // SAFETY: each Rayon chunk owns a disjoint i-row range,
2742                    // so flat=i*k+j stays in 0..nk for phi/q/t and
2743                    // flat*dim+a stays in 0..nk*dim for axis_components.
2744                    unsafe {
2745                        *pp.add(flat) = phi;
2746                        *qp.add(flat) = q;
2747                        *tp.add(flat) = t;
2748                        for a in 0..dim {
2749                            *ap.add(flat * dim + a) = sv[a];
2750                        }
2751                    }
2752                }
2753            }
2754        });
2755    }
2756    if let Some(cause) = first_err.into_inner().unwrap_or_else(|p| p.into_inner()) {
2757        return Err(BasisError::InvalidInput(format!(
2758            "radial scalar evaluation failed during aniso derivative construction \
2759             (eta={eta:?}): {cause}"
2760        )));
2761    }
2762
2763    let op = ImplicitDesignPsiDerivative::new(
2764        phi_values,
2765        q_values,
2766        t_values,
2767        axis_components,
2768        ident_transform,
2769        full_ident_transform,
2770        n,
2771        k,
2772        n_poly,
2773        dim,
2774    )
2775    .with_psi_scale_share(psi_scale_share);
2776
2777    // gam#1376 — the operator stays in the NATIVE per-axis ψ frame (no
2778    // `with_raw_eta_centering`): the κ-optimizer coordinate `psi_a` already maps
2779    // to the effective per-axis exponent `psi_a` of the kernel argument (the
2780    // `mean(psi)` it injects into the centered contrast is exactly cancelled by
2781    // the `ℓ = exp(−mean(psi))` it injects into the length scale), so the native
2782    // `∂φ/∂psi_a` produced by `materialize_first`/`materialize_second_*` (and by
2783    // the operator matvecs) is the correct raw-coordinate derivative. The
2784    // earlier centering broke the analytic↔FD match — see the comment above.
2785
2786    if operator_only {
2787        return Ok(AnisoBasisPsiDerivatives {
2788            design_first: Vec::new(),
2789            design_second_diag: Vec::new(),
2790            design_second_cross: Vec::new(),
2791            design_second_cross_pairs: Vec::new(),
2792            penalties_first: vec![Vec::new(); dim],
2793            penalties_second_diag: vec![Vec::new(); dim],
2794            penalties_cross_pairs: Vec::new(),
2795            penalties_cross_provider: None,
2796            implicit_operator: Some(op),
2797        });
2798    }
2799
2800    let design_first = (0..dim)
2801        .map(|a| op.materialize_first(a))
2802        .collect::<Result<Vec<_>, _>>()?;
2803    let design_second_diag = (0..dim)
2804        .map(|a| op.materialize_second_diag(a))
2805        .collect::<Result<Vec<_>, _>>()?;
2806
2807    Ok(AnisoBasisPsiDerivatives {
2808        design_first,
2809        design_second_diag,
2810        design_second_cross: Vec::new(),
2811        design_second_cross_pairs: Vec::new(),
2812        penalties_first: vec![Vec::new(); dim],
2813        penalties_second_diag: vec![Vec::new(); dim],
2814        penalties_cross_pairs: Vec::new(),
2815        penalties_cross_provider: None,
2816        implicit_operator: Some(op),
2817    })
2818}
2819
2820#[derive(Debug, Clone)]
2821pub(crate) struct ScalarDesignPsiDerivatives {
2822    pub(crate) design_first: Array2<f64>,
2823    pub(crate) design_second_diag: Array2<f64>,
2824    pub(crate) implicit_operator: Option<ImplicitDesignPsiDerivative>,
2825}
2826
2827pub(crate) fn build_scalar_design_psi_derivatives_shared(
2828    data: ArrayView2<'_, f64>,
2829    centers: ArrayView2<'_, f64>,
2830    fixed_eta: Option<&[f64]>,
2831    p_final: usize,
2832    ident_transform: Option<Array2<f64>>,
2833    full_ident_transform: Option<Array2<f64>>,
2834    n_poly: usize,
2835    radial_kind: RadialScalarKind,
2836    psi_scale_share: f64,
2837) -> Result<ScalarDesignPsiDerivatives, BasisError> {
2838    let n = data.nrows();
2839    let k = centers.nrows();
2840    let dim = data.ncols();
2841    if let Some(eta) = fixed_eta
2842        && eta.len() != dim
2843    {
2844        crate::bail_dim_basis!(
2845            "scalar design derivatives: eta.len()={} != data dimension {dim}",
2846            eta.len()
2847        );
2848    }
2849
2850    let policy = gam_runtime::resource::ResourcePolicy::default_library();
2851    let force_operator = radial_kind.is_duchon_family();
2852    let dense_derivatives_exceed_budget =
2853        should_use_implicit_operators_with_policy(n, p_final, 1, &policy);
2854    let operator_only = force_operator || dense_derivatives_exceed_budget;
2855    let cache_radial_components = should_cache_implicit_radial_components(n, k, 1, &policy);
2856    if operator_only && !cache_radial_components {
2857        let metric_eta = fixed_eta
2858            .map(|eta| eta.to_vec())
2859            .unwrap_or_else(|| vec![0.0; dim]);
2860        let op = ImplicitDesignPsiDerivative::new_streaming_scalar(
2861            shared_owned_data_matrix_from_view(data),
2862            shared_owned_centers_matrix_from_view(centers),
2863            metric_eta,
2864            radial_kind,
2865            ident_transform,
2866            full_ident_transform,
2867            n_poly,
2868        )
2869        .with_psi_scale_share(psi_scale_share);
2870        return Ok(ScalarDesignPsiDerivatives {
2871            design_first: Array2::<f64>::zeros((0, 0)),
2872            design_second_diag: Array2::<f64>::zeros((0, 0)),
2873            implicit_operator: Some(op),
2874        });
2875    }
2876
2877    let nk = n.checked_mul(k).ok_or_else(|| {
2878        BasisError::InvalidInput("scalar radial cache has too many data-center pairs".to_string())
2879    })?;
2880    let mut phi_values = Array1::<f64>::zeros(nk);
2881    let mut q_values = Array1::<f64>::zeros(nk);
2882    let mut t_values = Array1::<f64>::zeros(nk);
2883    let mut axis_components = Array2::<f64>::zeros((nk, 1));
2884
2885    let cs = IMPLICIT_MATVEC_CHUNK_SIZE;
2886    let nc = n.div_ceil(cs);
2887    let first_err: std::sync::Mutex<Option<BasisError>> = std::sync::Mutex::new(None);
2888    // Same certified radial-profile amortization as the per-axis sweep
2889    // above: one distance-only pre-pass for the radius range, one profile
2890    // build, Clenshaw per pair, exact fallback out of range (#979).
2891    let pair_r = |i: usize, j: usize, drb: &mut [f64], cb: &mut [f64]| -> f64 {
2892        if let Some(eta) = fixed_eta {
2893            for a in 0..dim {
2894                drb[a] = data[[i, a]];
2895                cb[a] = centers[[j, a]];
2896            }
2897            aniso_distance_and_components(drb, cb, eta).0
2898        } else {
2899            stable_euclidean_norm((0..dim).map(|a| data[[i, a]] - centers[[j, a]]))
2900        }
2901    };
2902    let profile = if nk >= RADIAL_PROFILE_MIN_PAIRS {
2903        let mut r_lo = f64::INFINITY;
2904        let mut r_hi = 0.0_f64;
2905        let mut drb = vec![0.0; dim];
2906        let mut cb = vec![0.0; dim];
2907        for i in 0..n {
2908            for j in 0..k {
2909                let r = pair_r(i, j, &mut drb, &mut cb);
2910                if r > 0.0 {
2911                    r_lo = r_lo.min(r);
2912                    r_hi = r_hi.max(r);
2913                }
2914            }
2915        }
2916        if r_lo.is_finite() && r_hi > r_lo {
2917            radial_profile::RadialProfile::build(&radial_kind, r_lo, r_hi)
2918        } else {
2919            None
2920        }
2921    } else {
2922        None
2923    };
2924    {
2925        let pp = SendPtr(phi_values.as_mut_ptr());
2926        let qp = SendPtr(q_values.as_mut_ptr());
2927        let tp = SendPtr(t_values.as_mut_ptr());
2928        let ap = SendPtr(axis_components.as_mut_ptr());
2929        let ferr = &first_err;
2930        let profile_ref = profile.as_ref();
2931        (0..nc).into_par_iter().for_each(move |ci| {
2932            let start = ci * cs;
2933            let end = start.saturating_add(cs).min(n);
2934            let mut data_row_buf = vec![0.0; dim];
2935            let mut center_buf = vec![0.0; dim];
2936            for i in start..end {
2937                for a in 0..dim {
2938                    data_row_buf[a] = data[[i, a]];
2939                }
2940                for j in 0..k {
2941                    let (r, scalar_component) = if let Some(eta) = fixed_eta {
2942                        for a in 0..dim {
2943                            center_buf[a] = centers[[j, a]];
2944                        }
2945                        let (r, components) =
2946                            aniso_distance_and_components(&data_row_buf, &center_buf, eta);
2947                        (r, components.into_iter().sum::<f64>())
2948                    } else {
2949                        let r =
2950                            stable_euclidean_norm((0..dim).map(|a| data[[i, a]] - centers[[j, a]]));
2951                        (r, r * r)
2952                    };
2953                    let triplet = match profile_ref {
2954                        Some(profile) => profile.eval_or_exact(&radial_kind, r),
2955                        None => radial_kind.eval_design_triplet(r),
2956                    };
2957                    let (phi, q, t) = match triplet {
2958                        Ok(p) => p,
2959                        Err(e) => {
2960                            let mut slot = ferr.lock().unwrap_or_else(|p| p.into_inner());
2961                            if slot.is_none() {
2962                                *slot = Some(e);
2963                            }
2964                            return;
2965                        }
2966                    };
2967                    let flat = i * k + j;
2968                    // SAFETY: each Rayon chunk owns a disjoint i-row range
2969                    // of the nk-long phi/q/t/axis buffers, so flat=i*k+j is
2970                    // in-bounds for every write and never aliases another worker.
2971                    unsafe {
2972                        *pp.add(flat) = phi;
2973                        *qp.add(flat) = q;
2974                        *tp.add(flat) = t;
2975                        *ap.add(flat) = scalar_component;
2976                    }
2977                }
2978            }
2979        });
2980    }
2981    if let Some(cause) = first_err.into_inner().unwrap_or_else(|p| p.into_inner()) {
2982        return Err(BasisError::InvalidInput(format!(
2983            "radial scalar evaluation failed during scalar derivative construction: {cause}"
2984        )));
2985    }
2986
2987    let op = ImplicitDesignPsiDerivative::new(
2988        phi_values,
2989        q_values,
2990        t_values,
2991        axis_components,
2992        ident_transform,
2993        full_ident_transform,
2994        n,
2995        k,
2996        n_poly,
2997        1,
2998    )
2999    .with_psi_scale_share(psi_scale_share);
3000
3001    if operator_only {
3002        return Ok(ScalarDesignPsiDerivatives {
3003            design_first: Array2::<f64>::zeros((0, 0)),
3004            design_second_diag: Array2::<f64>::zeros((0, 0)),
3005            implicit_operator: Some(op),
3006        });
3007    }
3008
3009    Ok(ScalarDesignPsiDerivatives {
3010        design_first: op.materialize_first(0)?,
3011        design_second_diag: op.materialize_second_diag(0)?,
3012        implicit_operator: Some(op),
3013    })
3014}