gam_problem/psi_design_contract.rs
1//! The neutral ψ (hyperparameter) design-derivative contract carriers and
2//! operator traits shared by the `CustomFamily` trait layer (`gam-model-api`)
3//! and the solver: the per-block ψ-derivative carrier, the matrix-free
4//! `CustomFamilyPsiDerivativeOperator` trait (+ its dense-materialization
5//! extension), and the joint-Hessian source-preference / materialization-intent
6//! enums.
7//!
8//! These carry no dependency on the `CustomFamily` trait itself, so they live
9//! in the neutral `gam-problem` crate and are re-exported upward, keeping a
10//! single definition shared across crates.
11
12use crate::{BasisError, PenaltyMatrix};
13use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
14use std::any::Any;
15use std::ops::Range;
16use std::sync::Arc;
17
18#[derive(Clone)]
19pub struct CustomFamilyBlockPsiDerivative {
20 pub penalty_index: Option<usize>,
21 pub x_psi: Array2<f64>,
22 pub s_psi: Array2<f64>,
23 pub s_psi_components: Option<Vec<(usize, Array2<f64>)>>,
24 pub s_psi_penalty_components: Option<Vec<(usize, PenaltyMatrix)>>,
25 pub x_psi_psi: Option<Vec<Array2<f64>>>,
26 pub s_psi_psi: Option<Vec<Array2<f64>>>,
27 pub s_psi_psi_components: Option<Vec<Vec<(usize, Array2<f64>)>>>,
28 pub s_psi_psi_penalty_components: Option<Vec<Vec<(usize, PenaltyMatrix)>>>,
29 pub implicit_operator: Option<Arc<dyn CustomFamilyPsiDerivativeOperator>>,
30 pub implicit_axis: usize,
31 pub implicit_group_id: Option<usize>,
32}
33
34pub type SharedDerivativeBlocks = Arc<Vec<Vec<CustomFamilyBlockPsiDerivative>>>;
35
36impl CustomFamilyBlockPsiDerivative {
37 /// Public constructor for use in tests and external consumers.
38 /// Sets `implicit_operator` to `None`.
39 pub fn new(
40 penalty_index: Option<usize>,
41 x_psi: Array2<f64>,
42 s_psi: Array2<f64>,
43 s_psi_components: Option<Vec<(usize, Array2<f64>)>>,
44 x_psi_psi: Option<Vec<Array2<f64>>>,
45 s_psi_psi: Option<Vec<Array2<f64>>>,
46 s_psi_psi_components: Option<Vec<Vec<(usize, Array2<f64>)>>>,
47 ) -> Self {
48 Self {
49 penalty_index,
50 x_psi,
51 s_psi,
52 s_psi_components,
53 s_psi_penalty_components: None,
54 x_psi_psi,
55 s_psi_psi,
56 s_psi_psi_components,
57 s_psi_psi_penalty_components: None,
58 implicit_operator: None,
59 implicit_axis: 0,
60 implicit_group_id: None,
61 }
62 }
63}
64
65pub trait CustomFamilyPsiDerivativeOperator: Send + Sync + Any {
66 fn as_any(&self) -> &dyn Any;
67 fn n_data(&self) -> usize;
68 fn p_out(&self) -> usize;
69 fn transpose_mul(
70 &self,
71 axis: usize,
72 v: &ArrayView1<'_, f64>,
73 ) -> Result<Array1<f64>, BasisError>;
74 fn forward_mul(&self, axis: usize, u: &ArrayView1<'_, f64>) -> Result<Array1<f64>, BasisError>;
75 fn transpose_mul_second_diag(
76 &self,
77 axis: usize,
78 v: &ArrayView1<'_, f64>,
79 ) -> Result<Array1<f64>, BasisError>;
80 fn transpose_mul_second_cross(
81 &self,
82 axis_d: usize,
83 axis_e: usize,
84 v: &ArrayView1<'_, f64>,
85 ) -> Result<Array1<f64>, BasisError>;
86 fn forward_mul_second_diag(
87 &self,
88 axis: usize,
89 u: &ArrayView1<'_, f64>,
90 ) -> Result<Array1<f64>, BasisError>;
91 fn forward_mul_second_cross(
92 &self,
93 axis_d: usize,
94 axis_e: usize,
95 u: &ArrayView1<'_, f64>,
96 ) -> Result<Array1<f64>, BasisError>;
97 fn row_chunk_first(&self, axis: usize, rows: Range<usize>) -> Result<Array2<f64>, BasisError>;
98 /// Single-row specialization of `row_chunk_first`. Default implementation
99 /// delegates to `row_chunk_first(axis, row..row+1)` and copies the
100 /// resulting row into the output buffer; implementations that can avoid
101 /// the temporary matrix allocation should override this method.
102 fn row_vector_first_into(
103 &self,
104 axis: usize,
105 row: usize,
106 mut out: ArrayViewMut1<'_, f64>,
107 ) -> Result<(), BasisError> {
108 let chunk = self.row_chunk_first(axis, row..row + 1)?;
109 out.assign(&chunk.row(0));
110 Ok(())
111 }
112 fn row_chunk_second_diag(
113 &self,
114 axis: usize,
115 rows: Range<usize>,
116 ) -> Result<Array2<f64>, BasisError>;
117 fn row_chunk_second_cross(
118 &self,
119 axis_d: usize,
120 axis_e: usize,
121 rows: Range<usize>,
122 ) -> Result<Array2<f64>, BasisError>;
123
124 /// Optional upcast to the dense materialization surface. Production exact
125 /// paths should prefer the analytic matvec / row-chunk methods above and
126 /// avoid forming the full derivative matrix; implementations that *do*
127 /// support dense materialization (used by diagnostics, tests, and
128 /// small-data fallbacks) should override this to return `Some(self)`.
129 fn as_materializable(&self) -> Option<&dyn MaterializablePsiDerivativeOperator> {
130 None
131 }
132}
133
134/// Diagnostic / small-data extension that exposes dense materialization of
135/// `\partial X / \partial \psi`. Production exact-Hessian code MUST NOT depend
136/// on dense second-derivative materialization; second-order paths use the
137/// row-chunk and matvec methods on [`CustomFamilyPsiDerivativeOperator`].
138pub trait MaterializablePsiDerivativeOperator: CustomFamilyPsiDerivativeOperator {
139 fn materialize_first(&self, axis: usize) -> Result<Array2<f64>, BasisError>;
140}
141
142#[derive(Clone, Copy, Debug, PartialEq, Eq)]
143pub enum JointHessianSourcePreference {
144 Dense,
145 Operator,
146}
147
148/// What the consumer is going to *do* with the joint Hessian. This is the
149/// intent half of #738's capability-vs-representation split: the call site
150/// states what it needs, and the workspace picks the cheapest representation
151/// that serves that need (rather than a single per-workspace preference being
152/// applied uniformly regardless of how the result is consumed).
153///
154/// The distinction matters because the same workspace serves several
155/// consumers with opposite ideal representations:
156/// - the inner Newton/PCG solve only ever applies `H · v`, so a matrix-free
157/// HVP (`Operator`) is ideal and a dense build is pure waste;
158/// - the REML logdet term factorizes `H + S_λ` (Cholesky / eigendecomposition),
159/// so it must hold a dense matrix anyway — handing it an `Operator` only
160/// forces an immediate column-basis (or `dense_forced`) re-materialization,
161/// so a workspace with a structural direct-dense build should answer `Dense`
162/// here and skip the operator wrapper entirely.
163///
164/// Workspaces refine their representation choice per intent via
165/// [`ExactNewtonJointHessianWorkspace::hessian_source_preference_for_intent`];
166/// the default keeps the legacy single-preference behaviour so existing
167/// workspaces are unchanged.
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub enum MaterializationIntent {
170 /// Inner Newton / PCG solve — only applies `H · v`. Matrix-free is ideal.
171 InnerSolve,
172 /// REML/LAML logdet term — factorizes `H + S_λ`, needs a dense matrix.
173 LogdetFactorization,
174 /// Outer-Hessian / EFS evaluation — builds the joint hyper terms; today
175 /// these route through the same source as the gradient path.
176 OuterEvaluation,
177 /// Outer-gradient / IFT term assembly.
178 OuterGradient,
179}