gam_problem/psi_terms.rs
1//! Exact-Newton joint-ψ term carriers and the joint-ψ workspace trait.
2//!
3//! These ψ-hyperparameter term types and the [`ExactNewtonJointPsiWorkspace`]
4//! trait are neutral carriers in the criterion contract: they reference only
5//! [`HyperOperator`] / [`DriftDerivResult`] (defined in this crate) and ndarray
6//! arrays, so they live in `gam-problem` and are re-exported by the
7//! `custom_family` layer for backward compatibility.
8
9use crate::{DriftDerivResult, HyperOperator};
10use ndarray::{Array1, Array2};
11use std::sync::Arc;
12
13#[derive(Clone)]
14pub struct ExactNewtonJointPsiTerms {
15 pub objective_psi: f64,
16 pub score_psi: Array1<f64>,
17 pub hessian_psi: Array2<f64>,
18 pub hessian_psi_operator: Option<Arc<dyn HyperOperator>>,
19}
20
21impl std::fmt::Debug for ExactNewtonJointPsiTerms {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("ExactNewtonJointPsiTerms")
24 .field("objective_psi", &self.objective_psi)
25 .field("score_psi", &self.score_psi)
26 .field("hessian_psi", &self.hessian_psi)
27 .field(
28 "hessian_psi_operator",
29 &self.hessian_psi_operator.as_ref().map(|_| "<operator>"),
30 )
31 .finish()
32 }
33}
34
35impl ExactNewtonJointPsiTerms {
36 pub fn zeros(total: usize) -> Self {
37 Self {
38 objective_psi: 0.0,
39 score_psi: Array1::zeros(total),
40 hessian_psi: Array2::zeros((total, total)),
41 hessian_psi_operator: None,
42 }
43 }
44}
45
46pub struct ExactNewtonJointPsiSecondOrderTerms {
47 pub objective_psi_psi: f64,
48 pub score_psi_psi: Array1<f64>,
49 pub hessian_psi_psi: Array2<f64>,
50 pub hessian_psi_psi_operator: Option<Box<dyn HyperOperator>>,
51}
52
53/// Direction-contracted second-order ψ terms for the profiled θ-HVP (#740).
54///
55/// The per-pair [`ExactNewtonJointPsiSecondOrderTerms`] are the `(ψ_i, ψ_j)`
56/// entries of the joint hyper-Hessian; assembling the full outer Hessian from
57/// them costs one O(n) family row pass per pair, i.e. `K²·n`. A matrix-free
58/// profiled θ-HVP never needs the individual pairs — it needs, for one applied
59/// outer direction with ψ-weights `α_ψ`, the `α`-contraction of those pairs
60/// against the combined ψ-direction `ψ(α) = Σ_j α_j ψ_j`:
61///
62/// ```text
63/// objective[i] = Σ_j α_j V_{ψ_i ψ_j}
64/// score[i] = Σ_j α_j g_{ψ_i ψ_j} (a p-vector per output row i)
65/// hessian[i] = Σ_j α_j D²_β H_L[ψ_i, ψ_j]
66/// = D²_β H_L[ψ_i, ψ(α)] (bilinearity)
67/// ```
68///
69/// All `psi_dim` output rows share the SAME contracted second leg `ψ(α)`, so a
70/// family that streams its rows once over `ψ(α)` (carrying every fixed first
71/// leg `ψ_i` as a batched factor column) produces every row in a SINGLE n-pass.
72/// That is the cost the profiled θ-HVP turns into `K·n`-to-densify /
73/// `m·n`-in-CG instead of the dense path's `K²·n`.
74///
75/// Indexing is over the flattened ψ coordinates in the same order as
76/// [`ExactNewtonJointPsiWorkspace::second_order_terms`]; `hessian[i]` carries
77/// the `D²_β H_L[ψ_i, ψ(α)]` drift as a [`DriftDerivResult`] (dense or
78/// operator-backed) plus any block-local `S_{ψ_i ψ_j}` penalty motion folded by
79/// the family, exactly mirroring the per-pair `hessian_psi_psi(_operator)`.
80pub struct ExactNewtonJointPsiSecondOrderContracted {
81 /// `objective[i] = Σ_j α_j V_{ψ_i ψ_j}`, one scalar per ψ output row.
82 pub objective: Array1<f64>,
83 /// `score[i] = Σ_j α_j g_{ψ_i ψ_j}`, the `psi_dim × total` matrix whose
84 /// row `i` is the contracted fixed-β score derivative for output row `i`.
85 pub score: Array2<f64>,
86 /// `hessian[i] = D²_β H_L[ψ_i, ψ(α)]` for each ψ output row `i`.
87 pub hessian: Vec<DriftDerivResult>,
88}
89
90pub trait ExactNewtonJointPsiWorkspace: Send + Sync {
91 fn first_order_terms(&self, _: usize) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
92 // Default implementation ignores this parameter.
93 Ok(None)
94 }
95
96 fn first_order_terms_all(&self) -> Result<Option<Vec<ExactNewtonJointPsiTerms>>, String> {
97 Ok(None)
98 }
99
100 fn second_order_terms(
101 &self,
102 psi_i: usize,
103 psi_j: usize,
104 ) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String>;
105
106 /// Direction-contracted second-order ψ terms for the profiled θ-HVP (#740).
107 ///
108 /// Given the ψ-block weights `alpha_psi` (length `psi_dim`, the ψ slice of
109 /// one applied outer direction α), return the `α`-contraction of every
110 /// `(ψ_i, ψ_j)` second-order term against the combined ψ-direction
111 /// `ψ(α) = Σ_j alpha_psi[j] · ψ_j`, as
112 /// [`ExactNewtonJointPsiSecondOrderContracted`]. A family that can stream
113 /// its rows once over `ψ(α)` overrides this so the profiled outer-Hessian
114 /// operator applies one combined-direction n-pass per matvec instead of the
115 /// dense path's `K²` per-pair [`Self::second_order_terms`] passes.
116 ///
117 /// Default returns `None`: the profiled θ-HVP operator is then not built and
118 /// the evaluator keeps the exact per-pair assembly (dense
119 /// `compute_outer_hessian` / `build_outer_hessian_operator`). Overriding
120 /// this method is purely a representation/cost choice — it must produce the
121 /// exact same contraction the per-pair terms would, which the
122 /// `profiled_theta_hvp_outer_hessian_fd` finite-difference cross-check
123 /// guards.
124 fn second_order_terms_contracted(
125 &self,
126 _: &[f64],
127 ) -> Result<Option<ExactNewtonJointPsiSecondOrderContracted>, String> {
128 // Default implementation ignores this parameter.
129 Ok(None)
130 }
131
132 fn hessian_directional_derivative(
133 &self,
134 psi_index: usize,
135 d_beta_flat: &Array1<f64>,
136 ) -> Result<Option<DriftDerivResult>, String>;
137}