1use ndarray::{Array2, ArrayView1};
42
43#[derive(Debug, Clone)]
53pub struct JointPenaltySpec {
54 pub label: Option<String>,
58 pub matrix: Array2<f64>,
60 pub initial_log_lambda: f64,
62 pub nullspace_dim: usize,
64}
65
66#[derive(Debug, Clone, PartialEq)]
68pub enum JointPenaltyError {
69 NotSquare {
70 nrows: usize,
71 ncols: usize,
72 },
73 NonFiniteEntry {
74 row: usize,
75 col: usize,
76 value: f64,
77 },
78 NonFiniteInitialLogLambda {
79 value: f64,
80 },
81 NotSymmetric {
82 row: usize,
83 col: usize,
84 asymmetry: f64,
85 },
86 NullspaceTooLarge {
87 total: usize,
88 nullspace_dim: usize,
89 },
90}
91
92impl std::fmt::Display for JointPenaltyError {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 Self::NotSquare { nrows, ncols } => {
96 write!(f, "joint penalty matrix is not square: {nrows}x{ncols}")
97 }
98 Self::NonFiniteEntry { row, col, value } => write!(
99 f,
100 "joint penalty matrix has non-finite entry at ({row},{col}): {value}"
101 ),
102 Self::NonFiniteInitialLogLambda { value } => {
103 write!(f, "joint penalty initial_log_lambda is non-finite: {value}")
104 }
105 Self::NotSymmetric {
106 row,
107 col,
108 asymmetry,
109 } => write!(
110 f,
111 "joint penalty matrix is not symmetric at ({row},{col}): |S - Sᵀ|={asymmetry:.3e}"
112 ),
113 Self::NullspaceTooLarge {
114 total,
115 nullspace_dim,
116 } => write!(
117 f,
118 "joint penalty nullspace_dim={nullspace_dim} exceeds dim={total}"
119 ),
120 }
121 }
122}
123
124impl std::error::Error for JointPenaltyError {}
125
126impl JointPenaltySpec {
127 const SYMMETRY_TOL: f64 = 1e-10;
131
132 #[inline]
134 pub fn dim(&self) -> usize {
135 self.matrix.nrows()
136 }
137
138 pub fn trace(&self) -> f64 {
140 self.matrix.diag().iter().copied().sum()
141 }
142
143 #[inline]
147 pub fn pseudo_rank(&self) -> usize {
148 self.dim().saturating_sub(self.nullspace_dim)
149 }
150
151 pub fn quadratic_form(&self, beta: ArrayView1<'_, f64>) -> f64 {
155 assert_eq!(
156 beta.len(),
157 self.dim(),
158 "joint penalty quadratic form: beta length {} != dim {}",
159 beta.len(),
160 self.dim()
161 );
162 beta.dot(&self.matrix.dot(&beta))
163 }
164
165 pub fn validate(&self) -> Result<(), JointPenaltyError> {
167 let (nrows, ncols) = self.matrix.dim();
168 if nrows != ncols {
169 return Err(JointPenaltyError::NotSquare { nrows, ncols });
170 }
171 if !self.initial_log_lambda.is_finite() {
172 return Err(JointPenaltyError::NonFiniteInitialLogLambda {
173 value: self.initial_log_lambda,
174 });
175 }
176 if self.nullspace_dim > nrows {
177 return Err(JointPenaltyError::NullspaceTooLarge {
178 total: nrows,
179 nullspace_dim: self.nullspace_dim,
180 });
181 }
182 for ((row, col), &value) in self.matrix.indexed_iter() {
183 if !value.is_finite() {
184 return Err(JointPenaltyError::NonFiniteEntry { row, col, value });
185 }
186 }
187 for row in 0..nrows {
188 for col in (row + 1)..ncols {
189 let asymmetry = (self.matrix[[row, col]] - self.matrix[[col, row]]).abs();
190 if asymmetry > Self::SYMMETRY_TOL {
191 return Err(JointPenaltyError::NotSymmetric {
192 row,
193 col,
194 asymmetry,
195 });
196 }
197 }
198 }
199 Ok(())
200 }
201}
202
203#[derive(Clone, Debug)]
212pub struct JointPenaltyBundle {
213 pub specs: std::sync::Arc<Vec<JointPenaltySpec>>,
214 pub log_lambdas: Vec<f64>,
215}
216
217impl JointPenaltyBundle {
218 pub fn new(
221 specs: std::sync::Arc<Vec<JointPenaltySpec>>,
222 log_lambdas: Vec<f64>,
223 total_compiled: usize,
224 ) -> Result<Self, String> {
225 if specs.len() != log_lambdas.len() {
226 return Err(format!(
227 "joint penalty bundle: {} specs vs {} log_lambdas",
228 specs.len(),
229 log_lambdas.len(),
230 ));
231 }
232 for (i, spec) in specs.iter().enumerate() {
233 if spec.dim() != total_compiled {
234 return Err(format!(
235 "joint penalty {i}: dim {} != total_compiled {}",
236 spec.dim(),
237 total_compiled,
238 ));
239 }
240 }
241 Ok(Self { specs, log_lambdas })
242 }
243
244 #[inline]
245 pub fn len(&self) -> usize {
246 self.specs.len()
247 }
248
249 #[inline]
250 pub fn is_empty(&self) -> bool {
251 self.specs.is_empty()
252 }
253
254 pub fn quadratic(&self, beta: ArrayView1<'_, f64>) -> f64 {
257 let mut total = 0.0;
258 for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
259 let lam = log_lambda.exp();
260 total += 0.5 * lam * spec.quadratic_form(beta);
261 }
262 total
263 }
264
265 pub fn add_apply_into(&self, vector: ArrayView1<'_, f64>, out: &mut ndarray::Array1<f64>) {
267 assert_eq!(out.len(), vector.len());
268 for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
269 let lam = log_lambda.exp();
270 let sv = spec.matrix.dot(&vector);
271 out.scaled_add(lam, &sv);
272 }
273 }
274
275 pub fn add_diag(&self, diag: &mut ndarray::Array1<f64>) {
277 for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
278 let lam = log_lambda.exp();
279 for (i, value) in spec.matrix.diag().iter().enumerate() {
280 diag[i] += lam * *value;
281 }
282 }
283 }
284
285 pub fn add_to_matrix(&self, matrix: &mut Array2<f64>) {
287 assert_eq!(matrix.nrows(), matrix.ncols());
288 for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
289 let lam = log_lambda.exp();
290 matrix.scaled_add(lam, &spec.matrix);
291 }
292 }
293
294 pub fn rho_objective_gradient(&self, beta: ArrayView1<'_, f64>, out: &mut [f64]) {
297 assert_eq!(out.len(), self.specs.len());
298 for (i, (spec, &log_lambda)) in self.specs.iter().zip(self.log_lambdas.iter()).enumerate() {
299 let lam = log_lambda.exp();
300 out[i] = 0.5 * lam * spec.quadratic_form(beta);
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use ndarray::{Array1, Array2, array};
309
310 fn cross_block_spec() -> JointPenaltySpec {
314 let v: Array1<f64> = array![1.0, 0.0, -1.0, 0.0];
316 let w: Array1<f64> = array![0.0, 1.0, 0.0, -1.0];
317 let mut matrix: Array2<f64> = Array2::zeros((4, 4));
318 for i in 0..4 {
319 for j in 0..4 {
320 matrix[[i, j]] = v[i] * v[j] + w[i] * w[j];
321 }
322 }
323 JointPenaltySpec {
324 label: Some("cross_block_pullback".to_string()),
325 matrix,
326 initial_log_lambda: -1.5,
327 nullspace_dim: 2,
328 }
329 }
330
331 #[test]
332 fn cross_block_dense_validates() {
333 let result = cross_block_spec().validate();
334 assert!(
335 result.is_ok(),
336 "valid cross-block spec rejected: {result:?}"
337 );
338 }
339
340 #[test]
341 fn trace_matches_diagonal_sum() {
342 let spec = cross_block_spec();
343 assert!((spec.trace() - 4.0).abs() < 1e-12);
345 }
346
347 #[test]
348 fn pseudo_rank_uses_declared_nullspace() {
349 let spec = cross_block_spec();
350 assert_eq!(spec.dim(), 4);
351 assert_eq!(spec.pseudo_rank(), 2);
352 }
353
354 #[test]
355 fn quadratic_form_matches_explicit_mat_vec() {
356 let spec = cross_block_spec();
357 let beta: Array1<f64> = array![0.5, -0.25, 1.0, 0.75];
359 let q = spec.quadratic_form(beta.view());
362 assert!((q - 1.25).abs() < 1e-12, "got {q}");
363 }
364
365 #[test]
366 fn determinant_zero_for_rank_deficient_matches_nullspace() {
367 use gam_linalg::faer_ndarray::FaerEigh;
368 let spec = cross_block_spec();
369 let (eigvals, _) =
372 FaerEigh::eigh(&spec.matrix, faer::Side::Lower).expect("symmetric eigh succeeds");
373 let mut sorted: Vec<f64> = eigvals.iter().copied().collect();
374 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
375 let zeros = sorted.iter().take_while(|&&v| v.abs() < 1e-10).count();
376 assert_eq!(
377 zeros, spec.nullspace_dim,
378 "spectrum {sorted:?} should have {} near-zeros",
379 spec.nullspace_dim
380 );
381 let det: f64 = sorted.iter().product();
384 assert!(det.abs() < 1e-10, "expected ~0 determinant, got {det}");
385 }
386
387 #[test]
388 fn validate_rejects_non_square() {
389 let spec = JointPenaltySpec {
390 label: None,
391 matrix: Array2::zeros((3, 4)),
392 initial_log_lambda: 0.0,
393 nullspace_dim: 0,
394 };
395 assert!(matches!(
396 spec.validate(),
397 Err(JointPenaltyError::NotSquare { nrows: 3, ncols: 4 })
398 ));
399 }
400
401 #[test]
402 fn validate_rejects_non_symmetric() {
403 let mut matrix = Array2::<f64>::zeros((3, 3));
404 matrix[[0, 1]] = 1.0;
405 matrix[[1, 0]] = -1.0;
406 let spec = JointPenaltySpec {
407 label: None,
408 matrix,
409 initial_log_lambda: 0.0,
410 nullspace_dim: 0,
411 };
412 assert!(matches!(
413 spec.validate(),
414 Err(JointPenaltyError::NotSymmetric { .. })
415 ));
416 }
417
418 #[test]
419 fn validate_rejects_oversized_nullspace() {
420 let spec = JointPenaltySpec {
421 label: None,
422 matrix: Array2::zeros((3, 3)),
423 initial_log_lambda: 0.0,
424 nullspace_dim: 4,
425 };
426 assert!(matches!(
427 spec.validate(),
428 Err(JointPenaltyError::NullspaceTooLarge {
429 total: 3,
430 nullspace_dim: 4
431 })
432 ));
433 }
434
435 #[test]
436 fn validate_rejects_non_finite_initial_log_lambda() {
437 let spec = JointPenaltySpec {
438 label: None,
439 matrix: Array2::zeros((2, 2)),
440 initial_log_lambda: f64::NAN,
441 nullspace_dim: 0,
442 };
443 assert!(matches!(
444 spec.validate(),
445 Err(JointPenaltyError::NonFiniteInitialLogLambda { .. })
446 ));
447 }
448
449 #[test]
463 fn bundle_two_block_minimiser_matches_analytic_solution() {
464 use gam_linalg::faer_ndarray::FaerCholesky;
465 use ndarray::Array2;
466
467 let spec = JointPenaltySpec {
468 label: Some("toy_cross_block".to_string()),
469 matrix: array![[2.0_f64, 1.0], [1.0, 2.0]],
470 initial_log_lambda: 0.0,
471 nullspace_dim: 0,
472 };
473 let log_lambda = -0.4_f64;
474 let lam = log_lambda.exp();
475 let bundle = JointPenaltyBundle::new(std::sync::Arc::new(vec![spec]), vec![log_lambda], 2)
476 .expect("valid bundle");
477
478 let mut lhs = Array2::<f64>::eye(2);
481 bundle.add_to_matrix(&mut lhs);
482 let expected_lhs = array![[1.0 + lam * 2.0, lam], [lam, 1.0 + lam * 2.0]];
484 for r in 0..2 {
485 for c in 0..2 {
486 assert!(
487 (lhs[[r, c]] - expected_lhs[[r, c]]).abs() < 1e-12,
488 "lhs[{r}, {c}] = {} expected {}",
489 lhs[[r, c]],
490 expected_lhs[[r, c]]
491 );
492 }
493 }
494
495 let b: Array1<f64> = array![1.0, -0.5];
497 let chol = lhs.cholesky(faer::Side::Lower).expect("SPD");
498 let mut rhs_mat = Array2::<f64>::zeros((2, 1));
499 rhs_mat[[0, 0]] = b[0];
500 rhs_mat[[1, 0]] = b[1];
501 let mut beta_mat = rhs_mat.clone();
502 chol.solve_mat_in_place(&mut beta_mat);
503 let beta_hat: Array1<f64> = array![beta_mat[[0, 0]], beta_mat[[1, 0]]];
504
505 let mut grad = &beta_hat - &b;
507 bundle.add_apply_into(beta_hat.view(), &mut grad);
508 let grad_inf = grad.iter().map(|v: &f64| v.abs()).fold(0.0_f64, f64::max);
509 assert!(
510 grad_inf < 1e-12,
511 "penalised gradient at analytic minimiser must vanish: {grad_inf:.3e}"
512 );
513
514 let resid = &beta_hat - &b;
517 let unpen = 0.5 * resid.dot(&resid);
518 let pen = bundle.quadratic(beta_hat.view());
519 let expected_obj = 0.5 * resid.dot(&resid)
520 + 0.5 * lam * beta_hat.dot(&array![[2.0, 1.0], [1.0, 2.0]].dot(&beta_hat));
521 assert!(
522 (unpen + pen - expected_obj).abs() < 1e-12,
523 "objective sum {} mismatched expected {}",
524 unpen + pen,
525 expected_obj
526 );
527
528 let mut diag = ndarray::Array1::<f64>::from_elem(2, 1.0);
530 bundle.add_diag(&mut diag);
531 assert!((diag[0] - (1.0 + lam * 2.0)).abs() < 1e-12);
532 assert!((diag[1] - (1.0 + lam * 2.0)).abs() < 1e-12);
533
534 let mut rho_grad = vec![0.0_f64];
536 bundle.rho_objective_gradient(beta_hat.view(), &mut rho_grad);
537 let expected_rho_grad =
538 0.5 * lam * beta_hat.dot(&array![[2.0, 1.0], [1.0, 2.0]].dot(&beta_hat));
539 assert!(
540 (rho_grad[0] - expected_rho_grad).abs() < 1e-12,
541 "rho-grad {} expected {}",
542 rho_grad[0],
543 expected_rho_grad
544 );
545 }
546
547 #[test]
548 fn bundle_rejects_dim_mismatch() {
549 let spec = JointPenaltySpec {
550 label: None,
551 matrix: Array2::<f64>::eye(3),
552 initial_log_lambda: 0.0,
553 nullspace_dim: 0,
554 };
555 let err = JointPenaltyBundle::new(std::sync::Arc::new(vec![spec]), vec![0.0], 4)
556 .expect_err("dim mismatch must reject");
557 assert!(err.contains("total_compiled"));
558 }
559
560 #[test]
561 fn bundle_rejects_lambda_count_mismatch() {
562 let spec = JointPenaltySpec {
563 label: None,
564 matrix: Array2::<f64>::eye(2),
565 initial_log_lambda: 0.0,
566 nullspace_dim: 0,
567 };
568 let err = JointPenaltyBundle::new(std::sync::Arc::new(vec![spec]), vec![], 2)
569 .expect_err("count mismatch must reject");
570 assert!(err.contains("specs vs"));
571 }
572}