use std::sync::Arc;
use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
use crate::linalg::faer_ndarray::FaerEigh;
use crate::terms::closed_form_operator::ClosedFormPenaltyOperator;
pub trait PenaltyOp: Send + Sync {
fn dim(&self) -> usize;
fn matvec(&self, w: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>);
fn diag(&self) -> Array1<f64>;
fn trace(&self) -> f64 {
self.diag().sum()
}
fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String>;
fn eigendecompose(&self) -> Result<(Array1<f64>, Array2<f64>), String> {
let dense = self.as_dense();
dense
.eigh(Side::Lower)
.map_err(|e| format!("PenaltyOp::eigendecompose: {e}"))
}
fn as_dense(&self) -> Array2<f64>;
}
impl PenaltyOp for Array2<f64> {
fn dim(&self) -> usize {
debug_assert_eq!(
self.nrows(),
self.ncols(),
"PenaltyOp matrix must be square"
);
self.nrows()
}
fn matvec(&self, w: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
out.assign(&self.dot(&w));
}
fn diag(&self) -> Array1<f64> {
let n = self.nrows();
let mut d = Array1::<f64>::zeros(n);
for i in 0..n {
d[i] = self[[i, i]];
}
d
}
fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
assert!(lambda > 0.0, "log_det_plus_lambda_i requires λ > 0");
let n = <Self as PenaltyOp>::dim(self);
let mut regularized = self.clone();
for i in 0..n {
regularized[[i, i]] += lambda;
}
let (evals, _) = regularized.eigh(Side::Lower).map_err(|e| {
format!("PenaltyOp::log_det_plus_lambda_i eigendecomposition failed: {e}")
})?;
let mut logdet = 0.0;
for (idx, &ev) in evals.iter().enumerate() {
if !ev.is_finite() || ev <= 0.0 {
return Err(format!(
"PenaltyOp::log_det_plus_lambda_i expected SPD S+λI, \
eigenvalue {idx} is {ev:.3e}"
));
}
logdet += ev.ln();
}
Ok(logdet)
}
fn as_dense(&self) -> Array2<f64> {
self.clone()
}
}
impl PenaltyOp for ClosedFormPenaltyOperator {
fn dim(&self) -> usize {
self.dim()
}
fn matvec(&self, w: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
self.matvec(w, out)
}
fn diag(&self) -> Array1<f64> {
self.diag()
}
fn trace(&self) -> f64 {
self.trace()
}
fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
self.log_det_plus_lambda_i(lambda)
}
fn as_dense(&self) -> Array2<f64> {
self.dense_form()
}
}
pub struct ScaledPenaltyOp {
inner: Arc<dyn PenaltyOp>,
scale: f64,
}
impl ScaledPenaltyOp {
pub fn new(inner: Arc<dyn PenaltyOp>, scale: f64) -> Self {
Self { inner, scale }
}
}
impl PenaltyOp for ScaledPenaltyOp {
fn dim(&self) -> usize {
self.inner.dim()
}
fn matvec(&self, w: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
self.inner.matvec(w, out.view_mut());
out.mapv_inplace(|v| v * self.scale);
}
fn diag(&self) -> Array1<f64> {
let mut d = self.inner.diag();
d.mapv_inplace(|v| v * self.scale);
d
}
fn trace(&self) -> f64 {
self.inner.trace() * self.scale
}
fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
let dense = self.as_dense();
<Array2<f64> as PenaltyOp>::log_det_plus_lambda_i(&dense, lambda)
}
fn as_dense(&self) -> Array2<f64> {
let mut m = self.inner.as_dense();
m.mapv_inplace(|v| v * self.scale);
m
}
}
#[derive(Clone)]
pub enum PenaltyForm {
Dense(Array2<f64>),
Operator(Arc<dyn PenaltyOp>),
}
impl PenaltyForm {
pub fn dim(&self) -> usize {
match self {
PenaltyForm::Dense(m) => {
debug_assert_eq!(m.nrows(), m.ncols());
m.nrows()
}
PenaltyForm::Operator(op) => op.dim(),
}
}
pub fn to_dense(&self) -> Array2<f64> {
match self {
PenaltyForm::Dense(m) => m.clone(),
PenaltyForm::Operator(op) => op.as_dense(),
}
}
pub fn as_op_arc(&self) -> Arc<dyn PenaltyOp> {
match self {
PenaltyForm::Dense(m) => Arc::new(m.clone()),
PenaltyForm::Operator(op) => op.clone(),
}
}
}
impl std::fmt::Debug for PenaltyForm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PenaltyForm::Dense(m) => f
.debug_tuple("PenaltyForm::Dense")
.field(&format_args!("{}×{}", m.nrows(), m.ncols()))
.finish(),
PenaltyForm::Operator(op) => f
.debug_tuple("PenaltyForm::Operator")
.field(&format_args!("dim={}", op.dim()))
.finish(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::Array;
fn psd_fixture() -> Array2<f64> {
let b = Array::from_shape_vec(
(3, 4),
vec![
1.0, -0.3, 0.7, 0.1, 0.2, 1.1, -0.5, 0.4, -0.1, 0.6, 0.9, -0.2,
],
)
.unwrap();
b.t().dot(&b)
}
#[test]
fn array2_impl_matvec_matches_dot() {
let s = psd_fixture();
let v = Array1::from_vec(vec![0.7, -0.4, 0.2, 1.1]);
let mut out = Array1::<f64>::zeros(4);
s.matvec(v.view(), out.view_mut());
let want = s.dot(&v);
for i in 0..4 {
assert_abs_diff_eq!(out[i], want[i], epsilon = 1e-12);
}
}
#[test]
fn array2_impl_diag_and_trace() {
let s = psd_fixture();
let d = <Array2<f64> as PenaltyOp>::diag(&s);
for i in 0..4 {
assert_abs_diff_eq!(d[i], s[[i, i]], epsilon = 0.0);
}
let tr = <Array2<f64> as PenaltyOp>::trace(&s);
assert_abs_diff_eq!(tr, s.diag().sum(), epsilon = 0.0);
}
#[test]
fn array2_impl_eigendecompose_matches_eigh() {
let s = psd_fixture();
let (evals_op, evecs_op) = <Array2<f64> as PenaltyOp>::eigendecompose(&s).expect("eigh");
let (evals_ref, evecs_ref) = s.eigh(Side::Lower).expect("eigh ref");
for i in 0..evals_op.len() {
assert_abs_diff_eq!(evals_op[i], evals_ref[i], epsilon = 1e-12);
}
let p_op = evecs_op.dot(&evecs_op.t());
let p_ref = evecs_ref.dot(&evecs_ref.t());
for i in 0..p_op.nrows() {
for j in 0..p_op.ncols() {
assert_abs_diff_eq!(p_op[[i, j]], p_ref[[i, j]], epsilon = 1e-12);
}
}
}
#[test]
fn penalty_form_dim_and_to_dense_round_trip() {
let s = psd_fixture();
let form = PenaltyForm::Dense(s.clone());
assert_eq!(form.dim(), 4);
let m = form.to_dense();
for i in 0..4 {
for j in 0..4 {
assert_abs_diff_eq!(m[[i, j]], s[[i, j]], epsilon = 0.0);
}
}
let arc: Arc<dyn PenaltyOp> = Arc::new(s.clone());
let op_form = PenaltyForm::Operator(arc);
assert_eq!(op_form.dim(), 4);
let m2 = op_form.to_dense();
for i in 0..4 {
for j in 0..4 {
assert_abs_diff_eq!(m2[[i, j]], s[[i, j]], epsilon = 0.0);
}
}
}
}