1use crate::basis::BasisError;
2use faer::{Mat, MatRef, Side};
3use gam_linalg::faer_ndarray::FaerLinalgError;
4use ndarray::{Array1, Array2, Axis};
5use rayon::prelude::*;
6use std::sync::Arc;
7
8fn array_to_faer(array: &Array2<f64>) -> Mat<f64> {
9 let (rows, cols) = array.dim();
10 Mat::from_fn(rows, cols, |i, j| array[[i, j]])
11}
12
13fn mat_to_array(mat: &Mat<f64>) -> Array2<f64> {
14 let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
15 for i in 0..mat.nrows() {
16 for j in 0..mat.ncols() {
17 out[[i, j]] = mat[(i, j)];
18 }
19 }
20 out
21}
22
23fn mat_max_abs_element(matrix: MatRef<'_, f64>) -> f64 {
24 let (rows, cols) = matrix.shape();
25 let mut maxval = 0.0_f64;
26 for i in 0..rows {
27 for j in 0..cols {
28 let val = matrix[(i, j)];
29 if val.is_finite() {
30 maxval = maxval.max(val.abs());
31 }
32 }
33 }
34 maxval
35}
36
37fn sanitize_symmetric_faer(matrix: &Mat<f64>) -> Mat<f64> {
38 let (rows, cols) = matrix.as_ref().shape();
39 assert_eq!(rows, cols, "Matrix must be square for sanitization");
40
41 let mut sanitized = matrix.clone();
42
43 for i in 0..rows {
44 let diag = sanitized[(i, i)];
45 if !diag.is_finite() {
46 sanitized[(i, i)] = 0.0;
47 }
48 for j in (i + 1)..cols {
49 let mut upper = sanitized[(i, j)];
50 let mut lower = sanitized[(j, i)];
51 if !upper.is_finite() {
52 upper = 0.0;
53 }
54 if !lower.is_finite() {
55 lower = 0.0;
56 }
57 let avg = 0.5 * (upper + lower);
58 sanitized[(i, j)] = avg;
59 sanitized[(j, i)] = avg;
60 }
61 }
62
63 let scale = mat_max_abs_element(sanitized.as_ref());
64 let tiny = (scale * 1e-14).max(1e-30);
65 for i in 0..rows {
66 for j in 0..cols {
67 let val = sanitized[(i, j)];
68 if !val.is_finite() {
69 sanitized[(i, j)] = 0.0;
70 } else if val.abs() < tiny {
71 sanitized[(i, j)] = 0.0;
72 }
73 }
74 }
75
76 sanitized
77}
78
79fn classify_eigenvalues_strict(eigenvalues: &mut [f64], context: &str) -> Result<(), BasisError> {
93 const C_EPS_P_FACTOR: f64 = 64.0;
94 let p = eigenvalues.len();
95
96 let mut scale = 0.0_f64;
97 for (idx, &val) in eigenvalues.iter().enumerate() {
98 if !val.is_finite() {
99 return Err(BasisError::Other(format!(
100 "Penalty spectrum check failed in '{context}': non-finite eigenvalue {value:?} at index {index}",
101 value = val,
102 index = idx
103 )));
104 }
105 scale = scale.max(val.abs());
106 }
107
108 let tolerance =
113 (C_EPS_P_FACTOR * f64::EPSILON * (p.max(1) as f64) * scale).max(f64::MIN_POSITIVE);
114
115 for (idx, val) in eigenvalues.iter_mut().enumerate() {
116 if val.abs() <= tolerance {
117 *val = 0.0;
118 } else if *val < 0.0 {
119 return Err(BasisError::Other(format!(
120 "Penalty spectrum check failed in '{context}': indefinite eigenvalue {value:.3e} at index {index} (tolerance {tolerance:.3e}, scale {scale:.3e})",
121 value = *val,
122 index = idx
123 )));
124 }
125 }
126 Ok(())
127}
128
129fn robust_eighwith_policy<M, V, E, Validate, Sanitize, EigCall, MapErr>(
130 matrix: &M,
131 context: &str,
132 validate_input: Validate,
133 sanitize: Sanitize,
134 mut eig_call: EigCall,
135 map_error: MapErr,
136) -> Result<(Vec<f64>, V), BasisError>
137where
138 Validate: Fn(&M, &str) -> Result<(), BasisError>,
139 Sanitize: Fn(&M) -> M,
140 EigCall: FnMut(&M) -> Result<(Vec<f64>, V), E>,
141 MapErr: Fn(E, &str) -> BasisError,
142{
143 validate_input(matrix, context)?;
144
145 let candidate = sanitize(matrix);
151 match eig_call(&candidate) {
152 Ok((mut eigenvalues, eigenvectors)) => {
153 classify_eigenvalues_strict(&mut eigenvalues, context)?;
154 Ok((eigenvalues, eigenvectors))
155 }
156 Err(err) => Err(map_error(err, context)),
157 }
158}
159
160fn robust_eigh_faer(
161 matrix: &Mat<f64>,
162 side: Side,
163 context: &str,
164) -> Result<(Vec<f64>, Mat<f64>), BasisError> {
165 robust_eighwith_policy(
166 matrix,
167 context,
168 |mat, ctx| {
169 let (rows, cols) = mat.as_ref().shape();
170 for i in 0..rows {
171 for j in 0..cols {
172 let val = mat[(i, j)];
173 if !val.is_finite() {
174 let max_abs = mat_max_abs_element(mat.as_ref());
175 return Err(BasisError::Other(format!(
176 "{} contains non-finite entries (max finite magnitude {:.3e})",
177 ctx, max_abs
178 )));
179 }
180 }
181 }
182 Ok(())
183 },
184 sanitize_symmetric_faer,
185 |candidate| {
186 let eig = candidate.as_ref().self_adjoint_eigen(side)?;
187 let diag = eig.S();
188 let mut eigenvalues = Vec::with_capacity(diag.dim());
189 for idx in 0..diag.dim() {
190 eigenvalues.push(diag[idx]);
191 }
192
193 let vectors_ref = eig.U();
194 let mut eigenvectors = Mat::<f64>::zeros(vectors_ref.nrows(), vectors_ref.ncols());
195 for i in 0..vectors_ref.nrows() {
196 for j in 0..vectors_ref.ncols() {
197 eigenvectors[(i, j)] = vectors_ref[(i, j)];
198 }
199 }
200 Ok((eigenvalues, eigenvectors))
201 },
202 |err, _ctx| {
203 BasisError::Other(format!(
204 "Eigendecomposition failed: {}",
205 FaerLinalgError::SelfAdjointEigen(err)
206 ))
207 },
208 )
209}
210
211fn robust_eigh(
212 matrix: &Array2<f64>,
213 side: Side,
214 context: &str,
215) -> Result<(Array1<f64>, Array2<f64>), BasisError> {
216 let matrix_faer = array_to_faer(matrix);
217 let (eigenvalues, eigenvectors) = robust_eigh_faer(&matrix_faer, side, context)?;
218 Ok((Array1::from_vec(eigenvalues), mat_to_array(&eigenvectors)))
219}
220
221fn kronecker_marginal_eigensystems(
222 marginal_penalties: &[Array2<f64>],
223 context: &str,
224) -> Result<Vec<(Array1<f64>, Array2<f64>)>, BasisError> {
225 let mut eigensystems = Vec::with_capacity(marginal_penalties.len());
226 for (k, penalty) in marginal_penalties.iter().enumerate() {
227 eigensystems.push(robust_eigh(
228 penalty,
229 Side::Lower,
230 &format!("{context} marginal {k}"),
231 )?);
232 }
233 Ok(eigensystems)
234}
235
236pub fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
240 let (arows, a_cols) = a.dim();
241 let (brows, b_cols) = b.dim();
242 if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
243 return Array2::zeros((arows * brows, a_cols * b_cols));
244 }
245 let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
246
247 result
248 .axis_chunks_iter_mut(Axis(0), brows)
249 .into_par_iter()
250 .enumerate()
251 .for_each(|(i, mut row_block)| {
252 let arow = a.row(i);
253 let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
254 for (j, mut block) in col_chunks.into_iter().enumerate() {
255 let aval = arow[j];
256 if aval == 0.0 {
257 continue;
258 }
259 for (dest, &src) in block.iter_mut().zip(b.iter()) {
260 *dest = aval * src;
261 }
262 }
263 });
264
265 result
266}
267
268#[inline]
271fn kronecker_multi_index_advance(multi_idx: &mut [usize], dims: &[usize]) -> bool {
272 let mut carry = true;
273 for dim in (0..dims.len()).rev() {
274 if carry {
275 multi_idx[dim] += 1;
276 if multi_idx[dim] < dims[dim] {
277 carry = false;
278 } else {
279 multi_idx[dim] = 0;
280 }
281 }
282 }
283 carry
284}
285
286#[derive(Clone, Debug)]
297pub struct KroneckerInvariantStructure {
298 pub marginal_eigenvalues: Arc<Vec<Array1<f64>>>,
303 pub marginal_qs: Arc<Vec<Array2<f64>>>,
305 pub reparameterized_marginals: Arc<Vec<Array2<f64>>>,
307 pub max_balanced_eigenvalue: f64,
310}
311
312impl KroneckerInvariantStructure {
313 pub fn compute(
315 marginal_designs: &[Array2<f64>],
316 marginal_penalties: &[Array2<f64>],
317 marginal_dims: &[usize],
318 ) -> Result<Self, BasisError> {
319 let d = marginal_dims.len();
320 let mut marginal_eigenvalues = Vec::with_capacity(d);
324 let mut marginal_qs = Vec::with_capacity(d);
325 for (evals, evecs) in kronecker_marginal_eigensystems(
326 marginal_penalties,
327 "kronecker_reparameterization_engine",
328 )? {
329 marginal_eigenvalues.push(evals);
330 marginal_qs.push(evecs);
331 }
332
333 let reparameterized_marginals: Vec<Array2<f64>> = marginal_designs
335 .iter()
336 .zip(marginal_qs.iter())
337 .map(|(b_k, u_k)| gam_linalg::faer_ndarray::fast_ab(b_k, u_k))
338 .collect();
339
340 let mut max_balanced_eigenvalue = 0.0_f64;
343 let mut multi_idx = vec![0usize; d];
344 let frob_norms: Vec<f64> = marginal_penalties
345 .iter()
346 .map(|s| s.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-12))
347 .collect();
348 loop {
349 let mut sigma = 0.0;
350 for k in 0..d {
351 sigma += marginal_eigenvalues[k][multi_idx[k]] / frob_norms[k];
352 }
353 max_balanced_eigenvalue = max_balanced_eigenvalue.max(sigma);
354
355 if kronecker_multi_index_advance(&mut multi_idx, marginal_dims) {
356 break;
357 }
358 }
359
360 Ok(Self {
361 marginal_eigenvalues: Arc::new(marginal_eigenvalues),
362 marginal_qs: Arc::new(marginal_qs),
363 reparameterized_marginals: Arc::new(reparameterized_marginals),
364 max_balanced_eigenvalue,
365 })
366 }
367}