Skip to main content

gam_problem/
psi_design_contract.rs

1//! The neutral ψ (hyperparameter) design-derivative contract carriers and
2//! operator traits shared by the `CustomFamily` trait layer (`gam-model-api`)
3//! and the solver: the per-block ψ-derivative carrier, the matrix-free
4//! `CustomFamilyPsiDerivativeOperator` trait (+ its dense-materialization
5//! extension), and the joint-Hessian source-preference / materialization-intent
6//! enums.
7//!
8//! These carry no dependency on the `CustomFamily` trait itself, so they live
9//! in the neutral `gam-problem` crate and are re-exported upward, keeping a
10//! single definition shared across crates.
11
12use crate::{BasisError, PenaltyMatrix};
13use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
14use std::any::Any;
15use std::ops::Range;
16use std::sync::Arc;
17
18#[derive(Clone)]
19pub struct CustomFamilyBlockPsiDerivative {
20    pub penalty_index: Option<usize>,
21    pub x_psi: Array2<f64>,
22    pub s_psi: Array2<f64>,
23    pub s_psi_components: Option<Vec<(usize, Array2<f64>)>>,
24    pub s_psi_penalty_components: Option<Vec<(usize, PenaltyMatrix)>>,
25    pub x_psi_psi: Option<Vec<Array2<f64>>>,
26    pub s_psi_psi: Option<Vec<Array2<f64>>>,
27    pub s_psi_psi_components: Option<Vec<Vec<(usize, Array2<f64>)>>>,
28    pub s_psi_psi_penalty_components: Option<Vec<Vec<(usize, PenaltyMatrix)>>>,
29    pub implicit_operator: Option<Arc<dyn CustomFamilyPsiDerivativeOperator>>,
30    pub implicit_axis: usize,
31    pub implicit_group_id: Option<usize>,
32}
33
34pub type SharedDerivativeBlocks = Arc<Vec<Vec<CustomFamilyBlockPsiDerivative>>>;
35
36impl CustomFamilyBlockPsiDerivative {
37    /// Public constructor for use in tests and external consumers.
38    /// Sets `implicit_operator` to `None`.
39    pub fn new(
40        penalty_index: Option<usize>,
41        x_psi: Array2<f64>,
42        s_psi: Array2<f64>,
43        s_psi_components: Option<Vec<(usize, Array2<f64>)>>,
44        x_psi_psi: Option<Vec<Array2<f64>>>,
45        s_psi_psi: Option<Vec<Array2<f64>>>,
46        s_psi_psi_components: Option<Vec<Vec<(usize, Array2<f64>)>>>,
47    ) -> Self {
48        Self {
49            penalty_index,
50            x_psi,
51            s_psi,
52            s_psi_components,
53            s_psi_penalty_components: None,
54            x_psi_psi,
55            s_psi_psi,
56            s_psi_psi_components,
57            s_psi_psi_penalty_components: None,
58            implicit_operator: None,
59            implicit_axis: 0,
60            implicit_group_id: None,
61        }
62    }
63}
64
65pub trait CustomFamilyPsiDerivativeOperator: Send + Sync + Any {
66    fn as_any(&self) -> &dyn Any;
67    fn n_data(&self) -> usize;
68    fn p_out(&self) -> usize;
69    fn transpose_mul(
70        &self,
71        axis: usize,
72        v: &ArrayView1<'_, f64>,
73    ) -> Result<Array1<f64>, BasisError>;
74    fn forward_mul(&self, axis: usize, u: &ArrayView1<'_, f64>) -> Result<Array1<f64>, BasisError>;
75    fn transpose_mul_second_diag(
76        &self,
77        axis: usize,
78        v: &ArrayView1<'_, f64>,
79    ) -> Result<Array1<f64>, BasisError>;
80    fn transpose_mul_second_cross(
81        &self,
82        axis_d: usize,
83        axis_e: usize,
84        v: &ArrayView1<'_, f64>,
85    ) -> Result<Array1<f64>, BasisError>;
86    fn forward_mul_second_diag(
87        &self,
88        axis: usize,
89        u: &ArrayView1<'_, f64>,
90    ) -> Result<Array1<f64>, BasisError>;
91    fn forward_mul_second_cross(
92        &self,
93        axis_d: usize,
94        axis_e: usize,
95        u: &ArrayView1<'_, f64>,
96    ) -> Result<Array1<f64>, BasisError>;
97    fn row_chunk_first(&self, axis: usize, rows: Range<usize>) -> Result<Array2<f64>, BasisError>;
98    /// Single-row specialization of `row_chunk_first`. Default implementation
99    /// delegates to `row_chunk_first(axis, row..row+1)` and copies the
100    /// resulting row into the output buffer; implementations that can avoid
101    /// the temporary matrix allocation should override this method.
102    fn row_vector_first_into(
103        &self,
104        axis: usize,
105        row: usize,
106        mut out: ArrayViewMut1<'_, f64>,
107    ) -> Result<(), BasisError> {
108        let chunk = self.row_chunk_first(axis, row..row + 1)?;
109        out.assign(&chunk.row(0));
110        Ok(())
111    }
112    fn row_chunk_second_diag(
113        &self,
114        axis: usize,
115        rows: Range<usize>,
116    ) -> Result<Array2<f64>, BasisError>;
117    fn row_chunk_second_cross(
118        &self,
119        axis_d: usize,
120        axis_e: usize,
121        rows: Range<usize>,
122    ) -> Result<Array2<f64>, BasisError>;
123
124    /// Optional upcast to the dense materialization surface. Production exact
125    /// paths should prefer the analytic matvec / row-chunk methods above and
126    /// avoid forming the full derivative matrix; implementations that *do*
127    /// support dense materialization (used by diagnostics, tests, and
128    /// small-data fallbacks) should override this to return `Some(self)`.
129    fn as_materializable(&self) -> Option<&dyn MaterializablePsiDerivativeOperator> {
130        None
131    }
132}
133
134/// Diagnostic / small-data extension that exposes dense materialization of
135/// `\partial X / \partial \psi`. Production exact-Hessian code MUST NOT depend
136/// on dense second-derivative materialization; second-order paths use the
137/// row-chunk and matvec methods on [`CustomFamilyPsiDerivativeOperator`].
138pub trait MaterializablePsiDerivativeOperator: CustomFamilyPsiDerivativeOperator {
139    fn materialize_first(&self, axis: usize) -> Result<Array2<f64>, BasisError>;
140}
141
142#[derive(Clone, Copy, Debug, PartialEq, Eq)]
143pub enum JointHessianSourcePreference {
144    Dense,
145    Operator,
146}
147
148/// What the consumer is going to *do* with the joint Hessian. This is the
149/// intent half of #738's capability-vs-representation split: the call site
150/// states what it needs, and the workspace picks the cheapest representation
151/// that serves that need (rather than a single per-workspace preference being
152/// applied uniformly regardless of how the result is consumed).
153///
154/// The distinction matters because the same workspace serves several
155/// consumers with opposite ideal representations:
156/// - the inner Newton/PCG solve only ever applies `H · v`, so a matrix-free
157///   HVP (`Operator`) is ideal and a dense build is pure waste;
158/// - the REML logdet term factorizes `H + S_λ` (Cholesky / eigendecomposition),
159///   so it must hold a dense matrix anyway — handing it an `Operator` only
160///   forces an immediate column-basis (or `dense_forced`) re-materialization,
161///   so a workspace with a structural direct-dense build should answer `Dense`
162///   here and skip the operator wrapper entirely.
163///
164/// Workspaces refine their representation choice per intent via
165/// [`ExactNewtonJointHessianWorkspace::hessian_source_preference_for_intent`];
166/// the default keeps the legacy single-preference behaviour so existing
167/// workspaces are unchanged.
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub enum MaterializationIntent {
170    /// Inner Newton / PCG solve — only applies `H · v`. Matrix-free is ideal.
171    InnerSolve,
172    /// REML/LAML logdet term — factorizes `H + S_λ`, needs a dense matrix.
173    LogdetFactorization,
174    /// Outer-Hessian / EFS evaluation — builds the joint hyper terms; today
175    /// these route through the same source as the gradient path.
176    OuterEvaluation,
177    /// Outer-gradient / IFT term assembly.
178    OuterGradient,
179}