use ndarray::{Array2, ArrayView1};
#[derive(Debug, Clone)]
pub struct JointPenaltySpec {
pub label: Option<String>,
pub matrix: Array2<f64>,
pub initial_log_lambda: f64,
pub nullspace_dim: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub enum JointPenaltyError {
NotSquare {
nrows: usize,
ncols: usize,
},
NonFiniteEntry {
row: usize,
col: usize,
value: f64,
},
NonFiniteInitialLogLambda {
value: f64,
},
NotSymmetric {
row: usize,
col: usize,
asymmetry: f64,
},
NullspaceTooLarge {
total: usize,
nullspace_dim: usize,
},
}
impl std::fmt::Display for JointPenaltyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotSquare { nrows, ncols } => {
write!(f, "joint penalty matrix is not square: {nrows}x{ncols}")
}
Self::NonFiniteEntry { row, col, value } => write!(
f,
"joint penalty matrix has non-finite entry at ({row},{col}): {value}"
),
Self::NonFiniteInitialLogLambda { value } => {
write!(f, "joint penalty initial_log_lambda is non-finite: {value}")
}
Self::NotSymmetric {
row,
col,
asymmetry,
} => write!(
f,
"joint penalty matrix is not symmetric at ({row},{col}): |S - Sᵀ|={asymmetry:.3e}"
),
Self::NullspaceTooLarge {
total,
nullspace_dim,
} => write!(
f,
"joint penalty nullspace_dim={nullspace_dim} exceeds dim={total}"
),
}
}
}
impl std::error::Error for JointPenaltyError {}
impl JointPenaltySpec {
const SYMMETRY_TOL: f64 = 1e-10;
#[inline]
pub fn dim(&self) -> usize {
self.matrix.nrows()
}
pub fn trace(&self) -> f64 {
self.matrix.diag().iter().copied().sum()
}
#[inline]
pub fn pseudo_rank(&self) -> usize {
self.dim().saturating_sub(self.nullspace_dim)
}
pub fn quadratic_form(&self, beta: ArrayView1<'_, f64>) -> f64 {
assert_eq!(
beta.len(),
self.dim(),
"joint penalty quadratic form: beta length {} != dim {}",
beta.len(),
self.dim()
);
beta.dot(&self.matrix.dot(&beta))
}
pub fn validate(&self) -> Result<(), JointPenaltyError> {
let (nrows, ncols) = self.matrix.dim();
if nrows != ncols {
return Err(JointPenaltyError::NotSquare { nrows, ncols });
}
if !self.initial_log_lambda.is_finite() {
return Err(JointPenaltyError::NonFiniteInitialLogLambda {
value: self.initial_log_lambda,
});
}
if self.nullspace_dim > nrows {
return Err(JointPenaltyError::NullspaceTooLarge {
total: nrows,
nullspace_dim: self.nullspace_dim,
});
}
for ((row, col), &value) in self.matrix.indexed_iter() {
if !value.is_finite() {
return Err(JointPenaltyError::NonFiniteEntry { row, col, value });
}
}
for row in 0..nrows {
for col in (row + 1)..ncols {
let asymmetry = (self.matrix[[row, col]] - self.matrix[[col, row]]).abs();
if asymmetry > Self::SYMMETRY_TOL {
return Err(JointPenaltyError::NotSymmetric {
row,
col,
asymmetry,
});
}
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct JointPenaltyBundle {
pub specs: std::sync::Arc<Vec<JointPenaltySpec>>,
pub log_lambdas: Vec<f64>,
}
impl JointPenaltyBundle {
pub fn new(
specs: std::sync::Arc<Vec<JointPenaltySpec>>,
log_lambdas: Vec<f64>,
total_compiled: usize,
) -> Result<Self, String> {
if specs.len() != log_lambdas.len() {
return Err(format!(
"joint penalty bundle: {} specs vs {} log_lambdas",
specs.len(),
log_lambdas.len(),
));
}
for (i, spec) in specs.iter().enumerate() {
if spec.dim() != total_compiled {
return Err(format!(
"joint penalty {i}: dim {} != total_compiled {}",
spec.dim(),
total_compiled,
));
}
}
Ok(Self { specs, log_lambdas })
}
#[inline]
pub fn len(&self) -> usize {
self.specs.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.specs.is_empty()
}
pub fn quadratic(&self, beta: ArrayView1<'_, f64>) -> f64 {
let mut total = 0.0;
for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
let lam = log_lambda.exp();
total += 0.5 * lam * spec.quadratic_form(beta);
}
total
}
pub fn add_apply_into(&self, vector: ArrayView1<'_, f64>, out: &mut ndarray::Array1<f64>) {
assert_eq!(out.len(), vector.len());
for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
let lam = log_lambda.exp();
let sv = spec.matrix.dot(&vector);
out.scaled_add(lam, &sv);
}
}
pub fn add_diag(&self, diag: &mut ndarray::Array1<f64>) {
for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
let lam = log_lambda.exp();
for (i, value) in spec.matrix.diag().iter().enumerate() {
diag[i] += lam * *value;
}
}
}
pub fn add_to_matrix(&self, matrix: &mut Array2<f64>) {
assert_eq!(matrix.nrows(), matrix.ncols());
for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
let lam = log_lambda.exp();
matrix.scaled_add(lam, &spec.matrix);
}
}
pub fn rho_objective_gradient(&self, beta: ArrayView1<'_, f64>, out: &mut [f64]) {
assert_eq!(out.len(), self.specs.len());
for (i, (spec, &log_lambda)) in self.specs.iter().zip(self.log_lambdas.iter()).enumerate() {
let lam = log_lambda.exp();
out[i] = 0.5 * lam * spec.quadratic_form(beta);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2, array};
fn cross_block_spec() -> JointPenaltySpec {
let v: Array1<f64> = array![1.0, 0.0, -1.0, 0.0];
let w: Array1<f64> = array![0.0, 1.0, 0.0, -1.0];
let mut matrix: Array2<f64> = Array2::zeros((4, 4));
for i in 0..4 {
for j in 0..4 {
matrix[[i, j]] = v[i] * v[j] + w[i] * w[j];
}
}
JointPenaltySpec {
label: Some("cross_block_pullback".to_string()),
matrix,
initial_log_lambda: -1.5,
nullspace_dim: 2,
}
}
#[test]
fn cross_block_dense_validates() {
let result = cross_block_spec().validate();
assert!(
result.is_ok(),
"valid cross-block spec rejected: {result:?}"
);
}
#[test]
fn trace_matches_diagonal_sum() {
let spec = cross_block_spec();
assert!((spec.trace() - 4.0).abs() < 1e-12);
}
#[test]
fn pseudo_rank_uses_declared_nullspace() {
let spec = cross_block_spec();
assert_eq!(spec.dim(), 4);
assert_eq!(spec.pseudo_rank(), 2);
}
#[test]
fn quadratic_form_matches_explicit_mat_vec() {
let spec = cross_block_spec();
let beta: Array1<f64> = array![0.5, -0.25, 1.0, 0.75];
let q = spec.quadratic_form(beta.view());
assert!((q - 1.25).abs() < 1e-12, "got {q}");
}
#[test]
fn determinant_zero_for_rank_deficient_matches_nullspace() {
use crate::faer_ndarray::FaerEigh;
let spec = cross_block_spec();
let (eigvals, _) =
FaerEigh::eigh(&spec.matrix, faer::Side::Lower).expect("symmetric eigh succeeds");
let mut sorted: Vec<f64> = eigvals.iter().copied().collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let zeros = sorted.iter().take_while(|&&v| v.abs() < 1e-10).count();
assert_eq!(
zeros, spec.nullspace_dim,
"spectrum {sorted:?} should have {} near-zeros",
spec.nullspace_dim
);
let det: f64 = sorted.iter().product();
assert!(det.abs() < 1e-10, "expected ~0 determinant, got {det}");
}
#[test]
fn validate_rejects_non_square() {
let spec = JointPenaltySpec {
label: None,
matrix: Array2::zeros((3, 4)),
initial_log_lambda: 0.0,
nullspace_dim: 0,
};
assert!(matches!(
spec.validate(),
Err(JointPenaltyError::NotSquare { nrows: 3, ncols: 4 })
));
}
#[test]
fn validate_rejects_non_symmetric() {
let mut matrix = Array2::<f64>::zeros((3, 3));
matrix[[0, 1]] = 1.0;
matrix[[1, 0]] = -1.0;
let spec = JointPenaltySpec {
label: None,
matrix,
initial_log_lambda: 0.0,
nullspace_dim: 0,
};
assert!(matches!(
spec.validate(),
Err(JointPenaltyError::NotSymmetric { .. })
));
}
#[test]
fn validate_rejects_oversized_nullspace() {
let spec = JointPenaltySpec {
label: None,
matrix: Array2::zeros((3, 3)),
initial_log_lambda: 0.0,
nullspace_dim: 4,
};
assert!(matches!(
spec.validate(),
Err(JointPenaltyError::NullspaceTooLarge {
total: 3,
nullspace_dim: 4
})
));
}
#[test]
fn validate_rejects_non_finite_initial_log_lambda() {
let spec = JointPenaltySpec {
label: None,
matrix: Array2::zeros((2, 2)),
initial_log_lambda: f64::NAN,
nullspace_dim: 0,
};
assert!(matches!(
spec.validate(),
Err(JointPenaltyError::NonFiniteInitialLogLambda { .. })
));
}
#[test]
fn bundle_two_block_minimiser_matches_analytic_solution() {
use crate::faer_ndarray::FaerCholesky;
use ndarray::Array2;
let spec = JointPenaltySpec {
label: Some("toy_cross_block".to_string()),
matrix: array![[2.0_f64, 1.0], [1.0, 2.0]],
initial_log_lambda: 0.0,
nullspace_dim: 0,
};
let log_lambda = -0.4_f64;
let lam = log_lambda.exp();
let bundle = JointPenaltyBundle::new(std::sync::Arc::new(vec![spec]), vec![log_lambda], 2)
.expect("valid bundle");
let mut lhs = Array2::<f64>::eye(2);
bundle.add_to_matrix(&mut lhs);
let expected_lhs = array![[1.0 + lam * 2.0, lam], [lam, 1.0 + lam * 2.0]];
for r in 0..2 {
for c in 0..2 {
assert!(
(lhs[[r, c]] - expected_lhs[[r, c]]).abs() < 1e-12,
"lhs[{r}, {c}] = {} expected {}",
lhs[[r, c]],
expected_lhs[[r, c]]
);
}
}
let b: Array1<f64> = array![1.0, -0.5];
let chol = lhs.cholesky(faer::Side::Lower).expect("SPD");
let mut rhs_mat = Array2::<f64>::zeros((2, 1));
rhs_mat[[0, 0]] = b[0];
rhs_mat[[1, 0]] = b[1];
let mut beta_mat = rhs_mat.clone();
chol.solve_mat_in_place(&mut beta_mat);
let beta_hat: Array1<f64> = array![beta_mat[[0, 0]], beta_mat[[1, 0]]];
let mut grad = &beta_hat - &b;
bundle.add_apply_into(beta_hat.view(), &mut grad);
let grad_inf = grad.iter().map(|v: &f64| v.abs()).fold(0.0_f64, f64::max);
assert!(
grad_inf < 1e-12,
"penalised gradient at analytic minimiser must vanish: {grad_inf:.3e}"
);
let resid = &beta_hat - &b;
let unpen = 0.5 * resid.dot(&resid);
let pen = bundle.quadratic(beta_hat.view());
let expected_obj = 0.5 * resid.dot(&resid)
+ 0.5 * lam * beta_hat.dot(&array![[2.0, 1.0], [1.0, 2.0]].dot(&beta_hat));
assert!(
(unpen + pen - expected_obj).abs() < 1e-12,
"objective sum {} mismatched expected {}",
unpen + pen,
expected_obj
);
let mut diag = ndarray::Array1::<f64>::from_elem(2, 1.0);
bundle.add_diag(&mut diag);
assert!((diag[0] - (1.0 + lam * 2.0)).abs() < 1e-12);
assert!((diag[1] - (1.0 + lam * 2.0)).abs() < 1e-12);
let mut rho_grad = vec![0.0_f64];
bundle.rho_objective_gradient(beta_hat.view(), &mut rho_grad);
let expected_rho_grad =
0.5 * lam * beta_hat.dot(&array![[2.0, 1.0], [1.0, 2.0]].dot(&beta_hat));
assert!(
(rho_grad[0] - expected_rho_grad).abs() < 1e-12,
"rho-grad {} expected {}",
rho_grad[0],
expected_rho_grad
);
}
#[test]
fn bundle_rejects_dim_mismatch() {
let spec = JointPenaltySpec {
label: None,
matrix: Array2::<f64>::eye(3),
initial_log_lambda: 0.0,
nullspace_dim: 0,
};
let err = JointPenaltyBundle::new(std::sync::Arc::new(vec![spec]), vec![0.0], 4)
.expect_err("dim mismatch must reject");
assert!(err.contains("total_compiled"));
}
#[test]
fn bundle_rejects_lambda_count_mismatch() {
let spec = JointPenaltySpec {
label: None,
matrix: Array2::<f64>::eye(2),
initial_log_lambda: 0.0,
nullspace_dim: 0,
};
let err = JointPenaltyBundle::new(std::sync::Arc::new(vec![spec]), vec![], 2)
.expect_err("count mismatch must reject");
assert!(err.contains("specs vs"));
}
}