Skip to main content

gam_problem/
psi_terms.rs

1//! Exact-Newton joint-ψ term carriers and the joint-ψ workspace trait.
2//!
3//! These ψ-hyperparameter term types and the [`ExactNewtonJointPsiWorkspace`]
4//! trait are neutral carriers in the criterion contract: they reference only
5//! [`HyperOperator`] / [`DriftDerivResult`] (defined in this crate) and ndarray
6//! arrays, so they live in `gam-problem` and are re-exported by the
7//! `custom_family` layer for backward compatibility.
8
9use crate::{DriftDerivResult, HyperOperator};
10use ndarray::{Array1, Array2};
11use std::sync::Arc;
12
13#[derive(Clone)]
14pub struct ExactNewtonJointPsiTerms {
15    pub objective_psi: f64,
16    pub score_psi: Array1<f64>,
17    pub hessian_psi: Array2<f64>,
18    pub hessian_psi_operator: Option<Arc<dyn HyperOperator>>,
19}
20
21impl std::fmt::Debug for ExactNewtonJointPsiTerms {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("ExactNewtonJointPsiTerms")
24            .field("objective_psi", &self.objective_psi)
25            .field("score_psi", &self.score_psi)
26            .field("hessian_psi", &self.hessian_psi)
27            .field(
28                "hessian_psi_operator",
29                &self.hessian_psi_operator.as_ref().map(|_| "<operator>"),
30            )
31            .finish()
32    }
33}
34
35impl ExactNewtonJointPsiTerms {
36    pub fn zeros(total: usize) -> Self {
37        Self {
38            objective_psi: 0.0,
39            score_psi: Array1::zeros(total),
40            hessian_psi: Array2::zeros((total, total)),
41            hessian_psi_operator: None,
42        }
43    }
44}
45
46pub struct ExactNewtonJointPsiSecondOrderTerms {
47    pub objective_psi_psi: f64,
48    pub score_psi_psi: Array1<f64>,
49    pub hessian_psi_psi: Array2<f64>,
50    pub hessian_psi_psi_operator: Option<Box<dyn HyperOperator>>,
51}
52
53/// Direction-contracted second-order ψ terms for the profiled θ-HVP (#740).
54///
55/// The per-pair [`ExactNewtonJointPsiSecondOrderTerms`] are the `(ψ_i, ψ_j)`
56/// entries of the joint hyper-Hessian; assembling the full outer Hessian from
57/// them costs one O(n) family row pass per pair, i.e. `K²·n`. A matrix-free
58/// profiled θ-HVP never needs the individual pairs — it needs, for one applied
59/// outer direction with ψ-weights `α_ψ`, the `α`-contraction of those pairs
60/// against the combined ψ-direction `ψ(α) = Σ_j α_j ψ_j`:
61///
62/// ```text
63///   objective[i] = Σ_j α_j V_{ψ_i ψ_j}
64///   score[i]     = Σ_j α_j g_{ψ_i ψ_j}          (a p-vector per output row i)
65///   hessian[i]   = Σ_j α_j D²_β H_L[ψ_i, ψ_j]
66///                = D²_β H_L[ψ_i, ψ(α)]            (bilinearity)
67/// ```
68///
69/// All `psi_dim` output rows share the SAME contracted second leg `ψ(α)`, so a
70/// family that streams its rows once over `ψ(α)` (carrying every fixed first
71/// leg `ψ_i` as a batched factor column) produces every row in a SINGLE n-pass.
72/// That is the cost the profiled θ-HVP turns into `K·n`-to-densify /
73/// `m·n`-in-CG instead of the dense path's `K²·n`.
74///
75/// Indexing is over the flattened ψ coordinates in the same order as
76/// [`ExactNewtonJointPsiWorkspace::second_order_terms`]; `hessian[i]` carries
77/// the `D²_β H_L[ψ_i, ψ(α)]` drift as a [`DriftDerivResult`] (dense or
78/// operator-backed) plus any block-local `S_{ψ_i ψ_j}` penalty motion folded by
79/// the family, exactly mirroring the per-pair `hessian_psi_psi(_operator)`.
80pub struct ExactNewtonJointPsiSecondOrderContracted {
81    /// `objective[i] = Σ_j α_j V_{ψ_i ψ_j}`, one scalar per ψ output row.
82    pub objective: Array1<f64>,
83    /// `score[i] = Σ_j α_j g_{ψ_i ψ_j}`, the `psi_dim × total` matrix whose
84    /// row `i` is the contracted fixed-β score derivative for output row `i`.
85    pub score: Array2<f64>,
86    /// `hessian[i] = D²_β H_L[ψ_i, ψ(α)]` for each ψ output row `i`.
87    pub hessian: Vec<DriftDerivResult>,
88}
89
90pub trait ExactNewtonJointPsiWorkspace: Send + Sync {
91    fn first_order_terms(&self, _: usize) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
92        // Default implementation ignores this parameter.
93        Ok(None)
94    }
95
96    fn first_order_terms_all(&self) -> Result<Option<Vec<ExactNewtonJointPsiTerms>>, String> {
97        Ok(None)
98    }
99
100    fn second_order_terms(
101        &self,
102        psi_i: usize,
103        psi_j: usize,
104    ) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String>;
105
106    /// Direction-contracted second-order ψ terms for the profiled θ-HVP (#740).
107    ///
108    /// Given the ψ-block weights `alpha_psi` (length `psi_dim`, the ψ slice of
109    /// one applied outer direction α), return the `α`-contraction of every
110    /// `(ψ_i, ψ_j)` second-order term against the combined ψ-direction
111    /// `ψ(α) = Σ_j alpha_psi[j] · ψ_j`, as
112    /// [`ExactNewtonJointPsiSecondOrderContracted`]. A family that can stream
113    /// its rows once over `ψ(α)` overrides this so the profiled outer-Hessian
114    /// operator applies one combined-direction n-pass per matvec instead of the
115    /// dense path's `K²` per-pair [`Self::second_order_terms`] passes.
116    ///
117    /// Default returns `None`: the profiled θ-HVP operator is then not built and
118    /// the evaluator keeps the exact per-pair assembly (dense
119    /// `compute_outer_hessian` / `build_outer_hessian_operator`). Overriding
120    /// this method is purely a representation/cost choice — it must produce the
121    /// exact same contraction the per-pair terms would, which the
122    /// `profiled_theta_hvp_outer_hessian_fd` finite-difference cross-check
123    /// guards.
124    fn second_order_terms_contracted(
125        &self,
126        _: &[f64],
127    ) -> Result<Option<ExactNewtonJointPsiSecondOrderContracted>, String> {
128        // Default implementation ignores this parameter.
129        Ok(None)
130    }
131
132    fn hessian_directional_derivative(
133        &self,
134        psi_index: usize,
135        d_beta_flat: &Array1<f64>,
136    ) -> Result<Option<DriftDerivResult>, String>;
137}