Skip to main content

gam_models/transformation_normal/
custom_family.rs

1use super::*;
2
3// ---------------------------------------------------------------------------
4// CustomFamily implementation
5// ---------------------------------------------------------------------------
6
7impl CustomFamily for TransformationNormalFamily {
8    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
9        crate::block_layout::block_count::validate_block_count::<
10            TransformationNormalError,
11        >("TransformationNormalFamily", 1, block_states.len())?;
12        let evaluate_start = std::time::Instant::now();
13        let beta = &block_states[0].beta;
14        let row_q_start = std::time::Instant::now();
15        let row_quantities = self.row_quantities(beta)?;
16        log::info!(
17            "[STAGE] CTN row_quantities (h, h', 1/h', powers) n={} elapsed={:.3}s",
18            row_quantities.h.len(),
19            row_q_start.elapsed().as_secs_f64(),
20        );
21        let h = row_quantities.h.as_ref();
22        let n = h.len();
23
24        let log_likelihood = row_quantities.log_likelihood;
25        // SCOP gradient and exact negative Hessian. Response column 0 is the
26        // linear location component b(x); response columns >=1 are squared
27        // γ_k(x)^2 shape components.
28        let grad_start = std::time::Instant::now();
29        let (grad, hessian) = self.scop_gradient_and_negative_hessian(beta, &row_quantities)?;
30        log::info!(
31            "[STAGE] CTN gradient terms n={} p={} elapsed={:.3}s",
32            n,
33            grad.len(),
34            grad_start.elapsed().as_secs_f64(),
35        );
36
37        let hess_start = std::time::Instant::now();
38        let p_dim = hessian.nrows() as u64;
39        let n_u64 = n as u64;
40        log::info!(
41            "[STAGE] CTN hessian terms (SCOP exact dense) n={} p={} flops~{} elapsed={:.3}s",
42            n,
43            p_dim,
44            n_u64.saturating_mul(p_dim).saturating_mul(p_dim),
45            hess_start.elapsed().as_secs_f64(),
46        );
47        log::info!(
48            "[STAGE] CTN evaluate end n={} p={} elapsed={:.3}s",
49            n,
50            p_dim,
51            evaluate_start.elapsed().as_secs_f64(),
52        );
53
54        Ok(FamilyEvaluation {
55            log_likelihood,
56            blockworking_sets: vec![BlockWorkingSet::ExactNewton {
57                gradient: grad,
58                hessian: SymmetricMatrix::Dense(hessian),
59            }],
60        })
61    }
62
63    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
64        crate::block_layout::block_count::validate_block_count::<
65            TransformationNormalError,
66        >("TransformationNormalFamily", 1, block_states.len())?;
67        // The line search uses NEG_INFINITY as the barrier-violation signal,
68        // so we can't propagate the row_quantities Err here. Translate any
69        // h' validation failure back into the NEG_INFINITY rejection contract.
70        let row_quantities = match self.row_quantities(&block_states[0].beta) {
71            Ok(rq) => rq,
72            Err(_) => return Ok(f64::NEG_INFINITY),
73        };
74        Ok(row_quantities.log_likelihood)
75    }
76
77    fn log_likelihood_only_with_options(
78        &self,
79        block_states: &[ParameterBlockState],
80        options: &BlockwiseFitOptions,
81    ) -> Result<f64, String> {
82        // When an outer-score subsample is installed, route through a
83        // mask-aware family clone whose `effective_weights()` returns the
84        // HT-weighted per-row weights. Because every term inside
85        // `build_transformation_row_derived` is linear in `wᵢ`, the row-LL
86        // accumulator yields `Σᵢ (mᵢ · wᵢ) · row_ll_i` — the unbiased
87        // Horvitz-Thompson estimator of the full-data LL.
88        match self.maybe_with_outer_subsample_from_options(options) {
89            Ok(Some(masked)) => masked.log_likelihood_only(block_states),
90            Ok(None) => self.log_likelihood_only(block_states),
91            Err(e) => Err(e.into()),
92        }
93    }
94
95    /// Log-likelihood + flat joint gradient without building the dense Hessian.
96    ///
97    /// The default trait implementation returns `None`, so the joint-Newton
98    /// inner solver falls back to `evaluate()` to obtain the gradient — and
99    /// that side-effects a full `Θ(n p²)` `weighted_gram` Hessian build at
100    /// every inner iteration. CTN's gradient is structurally
101    ///
102    ///   `∇ℓ = -X_val^T (w·h) + X_deriv^T (w/h')`,
103    ///
104    /// which is two `transpose_mul`s through the existing Khatri-Rao operators
105    /// and one `Θ(n)` row reduction — `Θ(n p)` total. At large scale that is
106    /// ~10⁷ FLOPs per call versus ~3·10¹⁰ for the full `evaluate`, so wiring
107    /// this override is the gating condition for routing CTN's inner solve
108    /// through the matrix-free joint-Newton path without paying the dense H
109    /// tax on every gradient refresh.
110    fn exact_newton_joint_gradient_evaluation(
111        &self,
112        block_states: &[ParameterBlockState],
113        _: &[ParameterBlockSpec],
114    ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
115        crate::block_layout::block_count::validate_block_count::<
116            TransformationNormalError,
117        >("TransformationNormalFamily", 1, block_states.len())?;
118        let beta = &block_states[0].beta;
119        let row_quantities = self.row_quantities(beta)?;
120        let log_likelihood = row_quantities.log_likelihood;
121        let gradient = self.scop_gradient(beta, &row_quantities)?;
122        Ok(Some(ExactNewtonJointGradientEvaluation {
123            log_likelihood,
124            gradient,
125        }))
126    }
127
128    fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
129        // The Hessian depends on β through 1/h'² where h' = X_deriv · β.
130        true
131    }
132
133    fn joint_jeffreys_term_required(&self) -> bool {
134        // CTN models a continuous response through a monotone transformation
135        // `h(Y|x) ~ N(0,1)`; there is no separation/under-identification
136        // regime to bound. The Fisher information is `O(n)` on every
137        // identified direction at every working point, so the conditioning
138        // gate inside `joint_jeffreys_term` smooth-steps the contribution to
139        // zero as soon as `λ_min ≥ 16`. The construction up to that gate is
140        // not free though: each evaluation runs `p` SCOP directional
141        // derivatives of the joint Hessian, called three times per inner
142        // cycle (head-KKT gradient, joint Newton step RHS, post-step KKT
143        // residual) and once per outer evaluation. At large scale —
144        // `bench/large_scale` `rust_margslope_aniso_duchon16d_*` with
145        // `p=144`, `n=20000` — that single source dominates each inner
146        // cycle (~230 s/cycle observed in CI; ~5 700 cycles × 5.7 min
147        // ⇒ multi-hour hang) and exhausts the 40-minute CI budget before
148        // the inner solve converges. Disabling the term here keeps the
149        // un-augmented inner Newton path (still consistent with the
150        // outer LAML logdet, which also drops the `H_Φ` contribution
151        // through the same family gate).
152        false
153    }
154
155    fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
156        // Khatri–Rao tensor design: the coefficient block is X = R ⊙ C with
157        // rows length p_resp · p_cov. Two regimes:
158        //
159        // * **Dense regime** (small enough that the unified evaluator builds
160        //   `weighted_gram` directly): per-evaluation cost is the dense
161        //   `n · (p_resp · p_cov)²` Khatri–Rao gram build.
162        //
163        // * **Matrix-free regime** (large enough that
164        //   `use_joint_matrix_free_path` returns true and the evaluator
165        //   factors `H v` through `forward_mul` / `transpose_mul` on the
166        //   Khatri–Rao operands): per-`Hv` matvec cost is just
167        //   `n · (p_resp + p_cov)` flops — see `ctn_matrix_free_workspace`.
168        //   This is only an inner coefficient-space cost estimate; outer
169        //   θθ Hessian availability is declared separately.
170        let n_usize = self.response_val_basis.nrows();
171        let p_resp = self.response_val_basis.ncols() as u64;
172        let p_cov = self.covariate_design.ncols() as u64;
173        let expected_p_total = p_resp.saturating_mul(p_cov);
174        // Block-spec preview is optional. Callers without an assembled
175        // ParameterBlockSpec — cost estimators, planners, the
176        // BlockwiseFitOptions screen, every code path that asks "how
177        // expensive would the Hessian be on *this* family?" — pass `&[]`.
178        // The Khatri–Rao layout is fully determined by `p_resp · p_cov`
179        // from the family state, so fall back to `expected_p_total` for
180        // the empty-specs preview rather than returning the `u64::MAX`
181        // unreachable sentinel that would dominate every cost comparison.
182        // When specs IS supplied we still enforce the structural
183        // expectation `spec.design.ncols() == p_resp · p_cov`; a mismatch
184        // is the only condition that legitimately surfaces the sentinel.
185        let p_total = match specs {
186            [] => expected_p_total,
187            [spec] if spec.design.ncols() as u64 == expected_p_total => spec.design.ncols() as u64,
188            _ => return u64::MAX,
189        };
190        let n = n_usize as u64;
191        // Shared operator-aware gate (see `coefficient_cost`): matrix-free Hv
192        // streams the Khatri–Rao operands at `n · (p_resp + p_cov)`; the dense
193        // fallback is the `n · p_total²` Khatri–Rao gram build. The dense count
194        // is supplied inline rather than via `joint_coupled_coefficient_hessian_cost`
195        // because the empty-specs preview must still report `n · p_total²` from
196        // the family-derived `p_total`, not the `n · 0²` an empty `specs` sum yields.
197        crate::coefficient_cost::operator_aware_hessian_cost(
198            p_total,
199            n,
200            n.saturating_mul(p_resp.saturating_add(p_cov)),
201            n.saturating_mul(p_total.saturating_mul(p_total)),
202        )
203    }
204
205    fn coefficient_gradient_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
206        // One row-quantity pass plus two transpose products. The SCOP derivative
207        // is structurally positive, so coefficient line searches no longer run a
208        // full derivative-grid fraction-to-boundary scan on every attempt.
209        self.coefficient_hessian_cost(specs) / 2
210    }
211
212    fn outer_derivative_policy(
213        &self,
214        specs: &[crate::custom_family::ParameterBlockSpec],
215        psi_dim: usize,
216        options: &crate::custom_family::BlockwiseFitOptions,
217    ) -> crate::custom_family::OuterDerivativePolicy {
218        // The generic default model in `CustomFamily::outer_derivative_policy`
219        // uses `coefficient_hessian_cost × (rho_dim + psi_dim)`, which
220        // overstates CTN's actual per-eval Hessian work because the SCOP
221        // joint-Hessian path is row-streaming through the Khatri-Rao jet
222        // (its `O(n · p)` matrix-free HVP, not `O(n · p²)` dense build).
223        // Use a CTN-specific shape:
224        //
225        // * gradient ≈ `n · (rho_dim + psi_dim) · p_total`
226        //   (one directional jet sweep per outer coordinate, row-streamed)
227        // * Hessian  ≈ min(dense build, matrix-free HVP loop)
228        //   * dense  ≈ `n · (rho_dim + psi_dim) · p_total^2`
229        //   * mfree  ≈ `n · (rho_dim + psi_dim) · p_total · rho_dim`
230        let capability = self.exact_outer_derivative_order(specs, options);
231        let n = specs.first().map_or(0u128, |s| s.design.nrows() as u128);
232        let p_total: u128 = specs
233            .iter()
234            .map(|s| s.design.ncols() as u128)
235            .fold(0u128, |acc, x| acc.saturating_add(x));
236        let rho_dim: u128 = specs
237            .iter()
238            .map(|s| s.penalties.len() as u128)
239            .fold(0u128, |acc, x| acc.saturating_add(x));
240        let k = rho_dim.saturating_add(psi_dim as u128).max(1);
241        let p_eff = p_total.max(1);
242        // Gradient work: one row sweep per outer coordinate.
243        let work_grad = n.saturating_mul(k).saturating_mul(p_eff);
244        // Hessian work: pick whichever access shape would dominate. The
245        // amortization gate in `should_build_dense` (P2.2) picks the
246        // cheaper path at execution time; the policy budget mirrors that
247        // by taking the min so that genuinely Hessian-prohibitive
248        // problems still downgrade through the budget ceiling.
249        let dense_hess = work_grad.saturating_mul(p_eff);
250        let mfree_hess = work_grad.saturating_mul(rho_dim.max(1));
251        let work_hess = dense_hess.min(mfree_hess);
252        crate::custom_family::OuterDerivativePolicy {
253            capability,
254            predicted_hessian_work: work_hess,
255            predicted_gradient_work: work_grad,
256            // CTN's outer-score reductions are mathematically per-row
257            // sums whose contributions are linear in `wᵢ` at every assembly
258            // site (gradient, joint Hessian dense / matvec / diagonal, ψ,
259            // ψ-ψ, log-likelihood). The `_with_options` overrides install a
260            // mask-aware family clone whose `effective_weights()` returns
261            // `wᵢ · mᵢ` (HT-weighted), yielding an unbiased estimator
262            // `E[score_subsample] = score_full`. The persistent
263            // dense-Hessian cache is keyed on the mask hash so subsampled
264            // and full-data builds at the same β do not alias.
265            subsample_capable: true,
266        }
267    }
268
269    fn outer_seed_config(&self, n_params: usize) -> gam_solve::seeding::SeedConfig {
270        gam_solve::seeding::SeedConfig {
271            bounds: (-12.0, 12.0),
272            max_seeds: if n_params <= 8 { 1 } else { 2 },
273            seed_budget: 1,
274            screen_max_inner_iterations: 2,
275            risk_profile: gam_solve::seeding::SeedRiskProfile::Gaussian,
276            num_auxiliary_trailing: 0,
277            over_smoothing_probe_rho: None,
278        }
279    }
280
281    fn max_feasible_step_size(
282        &self,
283        block_states: &[ParameterBlockState],
284        block_index: usize,
285        delta: &Array1<f64>,
286    ) -> Result<Option<f64>, String> {
287        if block_index != 0 {
288            return Ok(None);
289        }
290        crate::block_layout::block_count::validate_block_count::<
291            TransformationNormalError,
292        >("TransformationNormalFamily", 1, block_states.len())?;
293        if delta.len() != block_states[0].beta.len() {
294            return Err(TransformationNormalError::InvalidInput {
295                reason: format!(
296                    "CTN line-search step length {} != beta length {}",
297                    delta.len(),
298                    block_states[0].beta.len()
299                ),
300            }
301            .into());
302        }
303        // SCOP encodes monotonicity as
304        //   h'(y, x) = epsilon + sum_k M_k(y) * gamma_k(x)^2.
305        // With nonnegative M-spline derivative basis rows, every finite beta is
306        // interior-feasible. A derivative-grid fraction-to-boundary scan is pure
307        // overhead and was the dominant CTN large-scale line-search cost.
308        Ok(None)
309    }
310
311    fn block_linear_constraints(
312        &self,
313        _: &[ParameterBlockState],
314        block_index: usize,
315        block_spec: &ParameterBlockSpec,
316    ) -> Result<Option<LinearInequalityConstraints>, String> {
317        assert!(!block_spec.name.is_empty());
318        if block_index != 0 {
319            return Ok(None);
320        }
321        // The CTN tensor design is intentionally factored. Strict monotonicity
322        // is encoded structurally as `h' = ε + Σ M_r γ_r²`, so there are no
323        // dense active-set constraints to expose here.
324        Ok(None)
325    }
326
327    fn exact_newton_hessian_directional_derivative(
328        &self,
329        block_states: &[ParameterBlockState],
330        block_index: usize,
331        d_beta: &Array1<f64>,
332    ) -> Result<Option<Array2<f64>>, String> {
333        if block_index != 0 {
334            return Ok(None);
335        }
336        let beta = &block_states[0].beta;
337        let row_quantities = self.row_quantities(beta)?;
338        let dd = self.scop_hessian_directional_derivative(beta, d_beta, &row_quantities)?;
339        Ok(Some(dd))
340    }
341
342    fn exact_newton_joint_hessian(
343        &self,
344        block_states: &[ParameterBlockState],
345    ) -> Result<Option<Array2<f64>>, String> {
346        // Single block: joint Hessian = block Hessian.
347        let beta = &block_states[0].beta;
348        let row_quantities = self.row_quantities(beta)?;
349        let (_, hessian) = self.scop_gradient_and_negative_hessian(beta, &row_quantities)?;
350        Ok(Some(hessian))
351    }
352
353    fn exact_newton_joint_hessian_directional_derivative(
354        &self,
355        block_states: &[ParameterBlockState],
356        d_beta_flat: &Array1<f64>,
357    ) -> Result<Option<Array2<f64>>, String> {
358        self.exact_newton_hessian_directional_derivative(block_states, 0, d_beta_flat)
359    }
360
361    fn exact_newton_joint_hessiansecond_directional_derivative(
362        &self,
363        block_states: &[ParameterBlockState],
364        d_beta_u_flat: &Array1<f64>,
365        d_beta_v_flat: &Array1<f64>,
366    ) -> Result<Option<Array2<f64>>, String> {
367        let beta = &block_states[0].beta;
368        let row_quantities = self.row_quantities(beta)?;
369        let d2 = self.scop_hessian_second_directional_derivative(
370            beta,
371            d_beta_u_flat,
372            d_beta_v_flat,
373            &row_quantities,
374        )?;
375        Ok(Some(d2))
376    }
377
378    fn exact_newton_joint_psi_terms(
379        &self,
380        block_states: &[ParameterBlockState],
381        _: &[ParameterBlockSpec],
382        psi_derivs: &[Vec<CustomFamilyBlockPsiDerivative>],
383        psi_index: usize,
384    ) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
385        if psi_derivs.is_empty() || psi_index >= psi_derivs[0].len() {
386            return Ok(None);
387        }
388        let psi_first_start = std::time::Instant::now();
389        let deriv = &psi_derivs[0][psi_index];
390        let beta = &block_states[0].beta;
391        let row = self.row_quantities(beta)?;
392        let op = deriv
393            .implicit_operator
394            .as_ref()
395            .and_then(|op| op.as_any().downcast_ref::<TensorKroneckerPsiOperator>())
396            .ok_or_else(|| {
397                "TransformationNormalFamily requires tensor psi derivatives to remain operator-backed"
398                    .to_string()
399            })?;
400        let axis = deriv.implicit_axis;
401        let op_arc = Arc::clone(
402            deriv
403                .implicit_operator
404                .as_ref()
405                .expect("validated CTN psi derivative operator disappeared"),
406        );
407        let terms = self.scop_psi_terms(beta, &row, op, op_arc, axis)?;
408
409        log::info!(
410            "[STAGE] CTN psi first-order terms axis={} psi_index={} elapsed={:.3}s",
411            deriv.implicit_axis,
412            psi_index,
413            psi_first_start.elapsed().as_secs_f64(),
414        );
415
416        Ok(Some(terms))
417    }
418
419    fn exact_newton_joint_psisecond_order_terms(
420        &self,
421        block_states: &[ParameterBlockState],
422        _: &[ParameterBlockSpec],
423        psi_derivs: &[Vec<CustomFamilyBlockPsiDerivative>],
424        psi_i: usize,
425        psi_j: usize,
426    ) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
427        if psi_derivs.is_empty() || psi_i >= psi_derivs[0].len() || psi_j >= psi_derivs[0].len() {
428            return Ok(None);
429        }
430        let psi_pair_start = std::time::Instant::now();
431        let deriv_i = &psi_derivs[0][psi_i];
432        let deriv_j = &psi_derivs[0][psi_j];
433        let beta = &block_states[0].beta;
434        let row = self.row_quantities(beta)?;
435        let p_resp = self.response_val_basis.ncols();
436        let p_cov = self.covariate_design.ncols();
437        let p_total = p_resp * p_cov;
438        if beta.len() != p_total {
439            return Err(TransformationNormalError::InvalidInput {
440                reason: format!(
441                    "SCOP psi-psi terms beta length {} != p_resp({p_resp}) * p_cov({p_cov})",
442                    beta.len()
443                ),
444            }
445            .into());
446        }
447
448        let op = deriv_i
449            .implicit_operator
450            .as_ref()
451            .and_then(|op| op.as_any().downcast_ref::<TensorKroneckerPsiOperator>())
452            .ok_or_else(|| {
453                "TransformationNormalFamily requires tensor psi derivatives to remain operator-backed"
454                    .to_string()
455            })?;
456        let axis_i = deriv_i.implicit_axis;
457        let axis_j = deriv_j.implicit_axis;
458
459        let (objective_psi_psi, score_psi_psi, _) = self
460            .scop_psi_psi_value_score_hvp_from_operator(
461                beta,
462                op,
463                axis_i,
464                axis_j,
465                row.gamma.view(),
466                row.h.view(),
467                row.h_prime.view(),
468                row.endpoint_q.as_slice(),
469                None,
470            )?;
471        let hessian_psi_psi_operator: Box<dyn HyperOperator> =
472            Box::new(TransformationNormalPsiPsiHessianOperator::new(
473                Arc::new(self.clone()),
474                beta.clone(),
475                Arc::clone(
476                    deriv_i
477                        .implicit_operator
478                        .as_ref()
479                        .expect("validated CTN psi derivative has an implicit operator"),
480                ),
481                axis_i,
482                axis_j,
483                Arc::clone(&row.gamma),
484                Arc::clone(&row.h),
485                Arc::clone(&row.h_prime),
486                Arc::clone(&row.endpoint_q),
487            ));
488
489        // Result-validation gate. A trial point can still make the SCOP row
490        // terms non-finite through an invalid h' or an exploding ψ second
491        // derivative in the covariate basis. Surface that as an infeasible
492        // exact-Newton evaluation instead of passing NaNs into the unified
493        // outer evaluator.
494        if !objective_psi_psi.is_finite() || !score_psi_psi.iter().all(|v| v.is_finite()) {
495            return Err(TransformationNormalError::NonFinite {
496                reason: format!(
497                    "TransformationNormalFamily exact ψ-ψ second-order terms produced \
498                 non-finite values at psi_i={psi_i}, psi_j={psi_j}: \
499                 obj_finite={}, score_all_finite={}. \
500                 The outer evaluator should retreat from this trial point.",
501                    objective_psi_psi.is_finite(),
502                    score_psi_psi.iter().all(|v| v.is_finite()),
503                ),
504            }
505            .into());
506        }
507
508        log::info!(
509            "[STAGE] CTN psi-psi pair (psi_i={}, psi_j={}, axes={},{}) elapsed={:.3}s",
510            psi_i,
511            psi_j,
512            deriv_i.implicit_axis,
513            deriv_j.implicit_axis,
514            psi_pair_start.elapsed().as_secs_f64(),
515        );
516
517        Ok(Some(ExactNewtonJointPsiSecondOrderTerms {
518            objective_psi_psi,
519            score_psi_psi,
520            hessian_psi_psi: Array2::zeros((0, 0)),
521            hessian_psi_psi_operator: Some(hessian_psi_psi_operator),
522        }))
523    }
524
525    fn exact_newton_joint_psihessian_directional_derivative(
526        &self,
527        block_states: &[ParameterBlockState],
528        _: &[ParameterBlockSpec],
529        psi_derivs: &[Vec<CustomFamilyBlockPsiDerivative>],
530        psi_index: usize,
531        d_beta_flat: &Array1<f64>,
532    ) -> Result<Option<Array2<f64>>, String> {
533        if psi_derivs.is_empty() || psi_index >= psi_derivs[0].len() {
534            return Ok(None);
535        }
536        let deriv = &psi_derivs[0][psi_index];
537        let beta = &block_states[0].beta;
538        let op = deriv
539            .implicit_operator
540            .as_ref()
541            .and_then(|op| op.as_any().downcast_ref::<TensorKroneckerPsiOperator>())
542            .ok_or_else(|| {
543                "TransformationNormalFamily requires tensor psi derivatives to remain operator-backed"
544                    .to_string()
545            })?;
546        let axis = deriv.implicit_axis;
547        let row = self.row_quantities(beta)?;
548        let hess =
549            self.scop_psi_hessian_directional_derivative(beta, d_beta_flat, &row, op, axis)?;
550        Ok(Some(hess))
551    }
552
553    fn exact_newton_joint_hessian_workspace(
554        &self,
555        block_states: &[ParameterBlockState],
556        specs: &[ParameterBlockSpec],
557    ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
558        crate::block_layout::block_count::validate_block_count::<
559            TransformationNormalError,
560        >("TransformationNormalFamily", 1, block_states.len())?;
561        if !self.inner_coefficient_hessian_hvp_available(specs) {
562            return Err(TransformationNormalError::InvalidInput {
563                reason: "TransformationNormalFamily joint Hessian workspace received incompatible block specs"
564                    .to_string(),
565            }
566            .into());
567        }
568        let beta = &block_states[0].beta;
569        let row_quantities = self.row_quantities(beta)?;
570        // Expected HVP reuse this workspace will service before its
571        // `(β, row_quantities)` key advances. The outer-eval trace path
572        // performs ~`2·rho_dim` HVPs plus one diagonal call against the
573        let workspace = TransformationNormalJointHessianWorkspace::new(
574            Arc::new(self.clone()),
575            beta.clone(),
576            row_quantities.clone(),
577        )?;
578        Ok(Some(
579            Arc::new(workspace) as Arc<dyn ExactNewtonJointHessianWorkspace>
580        ))
581    }
582
583    fn exact_newton_joint_psi_workspace(
584        &self,
585        block_states: &[ParameterBlockState],
586        specs: &[ParameterBlockSpec],
587        derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
588    ) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
589        if !self.inner_coefficient_hessian_hvp_available(specs) {
590            return Err(TransformationNormalError::InvalidInput {
591                reason: "TransformationNormalFamily joint psi workspace received incompatible block specs"
592                    .to_string(),
593            }
594            .into());
595        }
596        Ok(Some(Arc::new(TransformationNormalPsiWorkspace::new(
597            self.clone(),
598            block_states.to_vec(),
599            derivative_blocks.to_vec(),
600        ))))
601    }
602
603    fn exact_newton_joint_hessian_workspace_with_options(
604        &self,
605        block_states: &[ParameterBlockState],
606        specs: &[ParameterBlockSpec],
607        options: &BlockwiseFitOptions,
608    ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
609        // Route through a mask-aware family clone when an outer-score
610        // subsample is active. The cloned family's `effective_weights()`
611        // returns `wᵢ · mᵢ` (`mᵢ = 1/πᵢ` on sampled rows, `0` elsewhere),
612        // and every CTN assembly site reads weights through that accessor.
613        // Each per-row contribution is linear in `wᵢ`, so the workspace's
614        // gradient / dense Hessian / matrix-free HVP / diagonal are exact
615        // Horvitz-Thompson estimators of the full-data quantities.
616        match self.maybe_with_outer_subsample_from_options(options)? {
617            Some(masked) => masked.exact_newton_joint_hessian_workspace(block_states, specs),
618            None => self.exact_newton_joint_hessian_workspace(block_states, specs),
619        }
620    }
621
622    fn exact_newton_joint_psi_workspace_with_options(
623        &self,
624        block_states: &[ParameterBlockState],
625        specs: &[ParameterBlockSpec],
626        derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
627        options: &BlockwiseFitOptions,
628    ) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
629        if !self.inner_coefficient_hessian_hvp_available(specs) {
630            return Err(TransformationNormalError::InvalidInput {
631                reason: "TransformationNormalFamily joint psi workspace received incompatible block specs"
632                    .to_string(),
633            }
634            .into());
635        }
636        // Route through a mask-aware family clone when an outer-score
637        // subsample is active. Every CTN ψ assembly site — including the
638        // workspace's `compute_all_axes` (per-row reduction near line ~13916)
639        // and `compute_pair_cache` (per-row reduction near line ~14263) —
640        // reads its row weight via `self.family.effective_weights()`, which
641        // on the cloned family returns `wᵢ · mᵢ`. Because each per-row
642        // contribution is linear in `wᵢ`, the workspace's per-axis ψ and
643        // per-axis-pair ψ-ψ outputs are exact Horvitz-Thompson estimators
644        // of the full-data quantities. The persistent dense-Hessian cache
645        // and `row_quantity_cache` on the cloned family are fresh, so
646        // subsampled builds cannot alias a later full-data probe at the
647        // same β.
648        let family = match self.maybe_with_outer_subsample_from_options(options)? {
649            Some(masked) => masked,
650            None => self.clone(),
651        };
652        Ok(Some(Arc::new(TransformationNormalPsiWorkspace::new(
653            family,
654            block_states.to_vec(),
655            derivative_blocks.to_vec(),
656        ))))
657    }
658
659    fn exact_newton_joint_psi_workspace_for_first_order_terms(&self) -> bool {
660        // CTN's per-axis [`scop_psi_terms`] kernel walks all `n` rows serially
661        // and is invoked once per ψ axis. Opting in here amortizes the per-row
662        // state load across axes and parallelizes the row walk via the
663        // workspace's [`compute_all_axes`] kernel — the dominant outer
664        // gradient-evaluation cost at large scale.
665        true
666    }
667
668    fn inner_coefficient_hessian_hvp_available(&self, specs: &[ParameterBlockSpec]) -> bool {
669        // CTN's SCOP coefficient-space joint Hessian is supplied as a
670        // row-streaming matrix-free Hv operator.
671        matches!(specs, [spec] if spec.design.ncols()
672            == self.response_val_basis.ncols().saturating_mul(self.covariate_design.ncols()))
673    }
674
675    fn outer_hyper_hessian_hvp_available(&self, specs: &[ParameterBlockSpec]) -> bool {
676        self.inner_coefficient_hessian_hvp_available(specs)
677    }
678
679    fn outer_hyper_hessian_dense_available(&self, specs: &[ParameterBlockSpec]) -> bool {
680        // Dense materialization remains mathematically available through the
681        // outer-HVP operator, but SCOP's primary production path is the
682        // matrix-free θθ operator above.
683        self.inner_coefficient_hessian_hvp_available(specs)
684    }
685}