Skip to main content

gam_terms/
kronecker.rs

1use crate::basis::BasisError;
2use faer::{Mat, MatRef, Side};
3use gam_linalg::faer_ndarray::FaerLinalgError;
4use ndarray::{Array1, Array2, Axis};
5use rayon::prelude::*;
6use std::sync::Arc;
7
8fn array_to_faer(array: &Array2<f64>) -> Mat<f64> {
9    let (rows, cols) = array.dim();
10    Mat::from_fn(rows, cols, |i, j| array[[i, j]])
11}
12
13fn mat_to_array(mat: &Mat<f64>) -> Array2<f64> {
14    let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
15    for i in 0..mat.nrows() {
16        for j in 0..mat.ncols() {
17            out[[i, j]] = mat[(i, j)];
18        }
19    }
20    out
21}
22
23fn mat_max_abs_element(matrix: MatRef<'_, f64>) -> f64 {
24    let (rows, cols) = matrix.shape();
25    let mut maxval = 0.0_f64;
26    for i in 0..rows {
27        for j in 0..cols {
28            let val = matrix[(i, j)];
29            if val.is_finite() {
30                maxval = maxval.max(val.abs());
31            }
32        }
33    }
34    maxval
35}
36
37fn sanitize_symmetric_faer(matrix: &Mat<f64>) -> Mat<f64> {
38    let (rows, cols) = matrix.as_ref().shape();
39    assert_eq!(rows, cols, "Matrix must be square for sanitization");
40
41    let mut sanitized = matrix.clone();
42
43    for i in 0..rows {
44        let diag = sanitized[(i, i)];
45        if !diag.is_finite() {
46            sanitized[(i, i)] = 0.0;
47        }
48        for j in (i + 1)..cols {
49            let mut upper = sanitized[(i, j)];
50            let mut lower = sanitized[(j, i)];
51            if !upper.is_finite() {
52                upper = 0.0;
53            }
54            if !lower.is_finite() {
55                lower = 0.0;
56            }
57            let avg = 0.5 * (upper + lower);
58            sanitized[(i, j)] = avg;
59            sanitized[(j, i)] = avg;
60        }
61    }
62
63    let scale = mat_max_abs_element(sanitized.as_ref());
64    let tiny = (scale * 1e-14).max(1e-30);
65    for i in 0..rows {
66        for j in 0..cols {
67            let val = sanitized[(i, j)];
68            if !val.is_finite() {
69                sanitized[(i, j)] = 0.0;
70            } else if val.abs() < tiny {
71                sanitized[(i, j)] = 0.0;
72            }
73        }
74    }
75
76    sanitized
77}
78
79/// Strict spectral classifier used as a final guard on penalty eigendecompositions.
80///
81/// Penalty matrices fed to the GAM solver are required to be PSD by construction.
82/// This routine snaps roundoff-zero eigenvalues to exact zero, accepts strictly
83/// positive eigenvalues, and rejects materially-indefinite or non-finite spectra
84/// with a hard error rather than silently rewriting them. The previous behaviour
85/// (mass-zeroing negative or non-finite eigenvalues) hid construction bugs and
86/// changed the optimisation objective downstream.
87///
88/// `C_EPS_P_FACTOR = 64` chooses the multiplier `c` in
89/// `tol = c * eps_machine * p * scale`: 64 absorbs the rounding accumulated in a
90/// symmetric eigendecomposition of a moderate-dimension matrix while still
91/// rejecting the 1e-12 * scale magnitudes that previously slipped through.
92fn classify_eigenvalues_strict(eigenvalues: &mut [f64], context: &str) -> Result<(), BasisError> {
93    const C_EPS_P_FACTOR: f64 = 64.0;
94    let p = eigenvalues.len();
95
96    let mut scale = 0.0_f64;
97    for (idx, &val) in eigenvalues.iter().enumerate() {
98        if !val.is_finite() {
99            return Err(BasisError::Other(format!(
100                "Penalty spectrum check failed in '{context}': non-finite eigenvalue {value:?} at index {index}",
101                value = val,
102                index = idx
103            )));
104        }
105        scale = scale.max(val.abs());
106    }
107
108    // p * eps captures the rounding floor of a symmetric eigendecomposition of a
109    // p-dimensional matrix; multiplying by `scale` lifts the floor to the actual
110    // magnitude of the spectrum. The constant `C_EPS_P_FACTOR` provides headroom
111    // for the residual rounding in subsequent matmuls.
112    let tolerance =
113        (C_EPS_P_FACTOR * f64::EPSILON * (p.max(1) as f64) * scale).max(f64::MIN_POSITIVE);
114
115    for (idx, val) in eigenvalues.iter_mut().enumerate() {
116        if val.abs() <= tolerance {
117            *val = 0.0;
118        } else if *val < 0.0 {
119            return Err(BasisError::Other(format!(
120                "Penalty spectrum check failed in '{context}': indefinite eigenvalue {value:.3e} at index {index} (tolerance {tolerance:.3e}, scale {scale:.3e})",
121                value = *val,
122                index = idx
123            )));
124        }
125    }
126    Ok(())
127}
128
129fn robust_eighwith_policy<M, V, E, Validate, Sanitize, EigCall, MapErr>(
130    matrix: &M,
131    context: &str,
132    validate_input: Validate,
133    sanitize: Sanitize,
134    mut eig_call: EigCall,
135    map_error: MapErr,
136) -> Result<(Vec<f64>, V), BasisError>
137where
138    Validate: Fn(&M, &str) -> Result<(), BasisError>,
139    Sanitize: Fn(&M) -> M,
140    EigCall: FnMut(&M) -> Result<(Vec<f64>, V), E>,
141    MapErr: Fn(E, &str) -> BasisError,
142{
143    validate_input(matrix, context)?;
144
145    // The sanitize step only enforces exact symmetry by averaging M and M^T and
146    // zeros sub-eps noise; it never adds a diagonal ridge. Adding ridge changes
147    // the matrix being decomposed, which silently changes the optimisation
148    // objective downstream. If eigh genuinely fails on a finite symmetric input,
149    // surface the error instead of mutating the spectrum.
150    let candidate = sanitize(matrix);
151    match eig_call(&candidate) {
152        Ok((mut eigenvalues, eigenvectors)) => {
153            classify_eigenvalues_strict(&mut eigenvalues, context)?;
154            Ok((eigenvalues, eigenvectors))
155        }
156        Err(err) => Err(map_error(err, context)),
157    }
158}
159
160fn robust_eigh_faer(
161    matrix: &Mat<f64>,
162    side: Side,
163    context: &str,
164) -> Result<(Vec<f64>, Mat<f64>), BasisError> {
165    robust_eighwith_policy(
166        matrix,
167        context,
168        |mat, ctx| {
169            let (rows, cols) = mat.as_ref().shape();
170            for i in 0..rows {
171                for j in 0..cols {
172                    let val = mat[(i, j)];
173                    if !val.is_finite() {
174                        let max_abs = mat_max_abs_element(mat.as_ref());
175                        return Err(BasisError::Other(format!(
176                            "{} contains non-finite entries (max finite magnitude {:.3e})",
177                            ctx, max_abs
178                        )));
179                    }
180                }
181            }
182            Ok(())
183        },
184        sanitize_symmetric_faer,
185        |candidate| {
186            let eig = candidate.as_ref().self_adjoint_eigen(side)?;
187            let diag = eig.S();
188            let mut eigenvalues = Vec::with_capacity(diag.dim());
189            for idx in 0..diag.dim() {
190                eigenvalues.push(diag[idx]);
191            }
192
193            let vectors_ref = eig.U();
194            let mut eigenvectors = Mat::<f64>::zeros(vectors_ref.nrows(), vectors_ref.ncols());
195            for i in 0..vectors_ref.nrows() {
196                for j in 0..vectors_ref.ncols() {
197                    eigenvectors[(i, j)] = vectors_ref[(i, j)];
198                }
199            }
200            Ok((eigenvalues, eigenvectors))
201        },
202        |err, _ctx| {
203            BasisError::Other(format!(
204                "Eigendecomposition failed: {}",
205                FaerLinalgError::SelfAdjointEigen(err)
206            ))
207        },
208    )
209}
210
211fn robust_eigh(
212    matrix: &Array2<f64>,
213    side: Side,
214    context: &str,
215) -> Result<(Array1<f64>, Array2<f64>), BasisError> {
216    let matrix_faer = array_to_faer(matrix);
217    let (eigenvalues, eigenvectors) = robust_eigh_faer(&matrix_faer, side, context)?;
218    Ok((Array1::from_vec(eigenvalues), mat_to_array(&eigenvectors)))
219}
220
221fn kronecker_marginal_eigensystems(
222    marginal_penalties: &[Array2<f64>],
223    context: &str,
224) -> Result<Vec<(Array1<f64>, Array2<f64>)>, BasisError> {
225    let mut eigensystems = Vec::with_capacity(marginal_penalties.len());
226    for (k, penalty) in marginal_penalties.iter().enumerate() {
227        eigensystems.push(robust_eigh(
228            penalty,
229            Side::Lower,
230            &format!("{context} marginal {k}"),
231        )?);
232    }
233    Ok(eigensystems)
234}
235
236/// Computes the Kronecker product A ⊗ B for penalty matrix construction.
237/// This is used to create tensor product penalties that enforce smoothness
238/// in multiple dimensions for interaction terms.
239pub fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
240    let (arows, a_cols) = a.dim();
241    let (brows, b_cols) = b.dim();
242    if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
243        return Array2::zeros((arows * brows, a_cols * b_cols));
244    }
245    let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
246
247    result
248        .axis_chunks_iter_mut(Axis(0), brows)
249        .into_par_iter()
250        .enumerate()
251        .for_each(|(i, mut row_block)| {
252            let arow = a.row(i);
253            let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
254            for (j, mut block) in col_chunks.into_iter().enumerate() {
255                let aval = arow[j];
256                if aval == 0.0 {
257                    continue;
258                }
259                for (dest, &src) in block.iter_mut().zip(b.iter()) {
260                    *dest = aval * src;
261                }
262            }
263        });
264
265    result
266}
267
268/// Advance a row-major multi-index over the `dims` grid in place.
269/// Returns `true` when the grid is exhausted (the index wrapped back to all-zero).
270#[inline]
271fn kronecker_multi_index_advance(multi_idx: &mut [usize], dims: &[usize]) -> bool {
272    let mut carry = true;
273    for dim in (0..dims.len()).rev() {
274        if carry {
275            multi_idx[dim] += 1;
276            if multi_idx[dim] < dims[dim] {
277                carry = false;
278            } else {
279                multi_idx[dim] = 0;
280            }
281        }
282    }
283    carry
284}
285
286/// λ-invariant Kronecker tensor structure: everything in a tensor-product fit
287/// that depends ONLY on the marginal designs/penalties (which are fixed for the
288/// whole fit) and NOT on the smoothing parameters λ = exp(ρ).
289///
290/// The marginal eigendecomposition (`O(Σ q_k³)`), the reparameterized marginals
291/// `B_k · U_k`, and the balanced-penalty shrinkage scale `max_bal` are all
292/// functions of the fixed marginal data alone. Caching them once per fit lets
293/// every outer REML iterate (50+ per fit on the #1082 tensor cases) skip the
294/// repeated `eigh()` calls and `B_k U_k` GEMMs; only the cheap
295/// `kronecker_logdet_and_derivatives` λ-grid sweep is redone per iterate.
296#[derive(Clone, Debug)]
297pub struct KroneckerInvariantStructure {
298    /// Marginal eigenvalues from each marginal penalty eigendecomposition.
299    ///
300    /// `Arc`-shared so handing this structure to the per-iterate memoized
301    /// engine is an O(1) refcount bump, not a deep array copy.
302    pub marginal_eigenvalues: Arc<Vec<Array1<f64>>>,
303    /// Marginal eigenvector matrices U_k.
304    pub marginal_qs: Arc<Vec<Array2<f64>>>,
305    /// Reparameterized marginal designs: `B_k · U_k` for each marginal k.
306    pub reparameterized_marginals: Arc<Vec<Array2<f64>>>,
307    /// Max balanced-penalty eigenvalue scale `max_k-grid Σ_k μ_{k,j_k}/||S_k||_F`,
308    /// used to form the shrinkage ridge `floor * max_bal`. λ-independent.
309    pub max_balanced_eigenvalue: f64,
310}
311
312impl KroneckerInvariantStructure {
313    /// Compute the λ-invariant tensor structure once from the fixed marginal data.
314    pub fn compute(
315        marginal_designs: &[Array2<f64>],
316        marginal_penalties: &[Array2<f64>],
317        marginal_dims: &[usize],
318    ) -> Result<Self, BasisError> {
319        let d = marginal_dims.len();
320        // Eigendecompose each marginal penalty once through the same robust path
321        // used by KroneckerPenaltySystem so every Kronecker caller sees the same
322        // eigensystem and pseudo-logdet surface.
323        let mut marginal_eigenvalues = Vec::with_capacity(d);
324        let mut marginal_qs = Vec::with_capacity(d);
325        for (evals, evecs) in kronecker_marginal_eigensystems(
326            marginal_penalties,
327            "kronecker_reparameterization_engine",
328        )? {
329            marginal_eigenvalues.push(evals);
330            marginal_qs.push(evecs);
331        }
332
333        // Reparameterized marginals: B_k · U_k.
334        let reparameterized_marginals: Vec<Array2<f64>> = marginal_designs
335            .iter()
336            .zip(marginal_qs.iter())
337            .map(|(b_k, u_k)| gam_linalg::faer_ndarray::fast_ab(b_k, u_k))
338            .collect();
339
340        // Max balanced eigenvalue: for Kronecker, the balanced penalty's max
341        // eigenvalue is the max over multi-indices of Σ_k (1/||S_k||_F) μ_{k,j_k}.
342        let mut max_balanced_eigenvalue = 0.0_f64;
343        let mut multi_idx = vec![0usize; d];
344        let frob_norms: Vec<f64> = marginal_penalties
345            .iter()
346            .map(|s| s.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-12))
347            .collect();
348        loop {
349            let mut sigma = 0.0;
350            for k in 0..d {
351                sigma += marginal_eigenvalues[k][multi_idx[k]] / frob_norms[k];
352            }
353            max_balanced_eigenvalue = max_balanced_eigenvalue.max(sigma);
354
355            if kronecker_multi_index_advance(&mut multi_idx, marginal_dims) {
356                break;
357            }
358        }
359
360        Ok(Self {
361            marginal_eigenvalues: Arc::new(marginal_eigenvalues),
362            marginal_qs: Arc::new(marginal_qs),
363            reparameterized_marginals: Arc::new(reparameterized_marginals),
364            max_balanced_eigenvalue,
365        })
366    }
367}