use ndarray::Array1;
#[derive(Debug, Clone)]
pub enum SaeCriterionAtom {
DataFitPriors {
value: f64,
grad: Array1<f64>,
},
LaplaceLogdet {
value: f64,
grad: Array1<f64>,
},
Occam {
value: f64,
grad: Array1<f64>,
},
ImplicitStationarityCorrection {
grad: Array1<f64>,
},
}
impl SaeCriterionAtom {
#[must_use]
pub fn value(&self) -> f64 {
match self {
Self::DataFitPriors { value, .. }
| Self::LaplaceLogdet { value, .. }
| Self::Occam { value, .. } => *value,
Self::ImplicitStationarityCorrection { .. } => 0.0,
}
}
#[must_use]
pub fn grad(&self) -> &Array1<f64> {
match self {
Self::DataFitPriors { grad, .. }
| Self::LaplaceLogdet { grad, .. }
| Self::Occam { grad, .. }
| Self::ImplicitStationarityCorrection { grad } => grad,
}
}
#[must_use]
pub fn label(&self) -> &'static str {
match self {
Self::DataFitPriors { .. } => "data_fit_priors",
Self::LaplaceLogdet { .. } => "laplace_logdet",
Self::Occam { .. } => "occam",
Self::ImplicitStationarityCorrection { .. } => "implicit_stationarity_correction",
}
}
}
#[derive(Debug, Clone)]
pub struct SaeCriterion {
atoms: Vec<SaeCriterionAtom>,
n_rho: usize,
}
impl SaeCriterion {
#[must_use]
pub fn assemble(
data_fit_priors_value: f64,
log_det: f64,
occam: f64,
explicit: Array1<f64>,
logdet_trace: Array1<f64>,
occam_grad: Array1<f64>,
implicit_correction: Array1<f64>,
) -> Self {
let n_rho = explicit.len();
let atoms = vec![
SaeCriterionAtom::DataFitPriors {
value: data_fit_priors_value,
grad: explicit,
},
SaeCriterionAtom::LaplaceLogdet {
value: 0.5 * log_det,
grad: logdet_trace,
},
SaeCriterionAtom::Occam {
value: -occam,
grad: occam_grad,
},
SaeCriterionAtom::ImplicitStationarityCorrection {
grad: implicit_correction,
},
];
Self { atoms, n_rho }
}
#[must_use]
pub fn value(&self) -> f64 {
self.atoms.iter().map(SaeCriterionAtom::value).sum()
}
#[must_use]
pub fn gradient(&self) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.n_rho);
for atom in &self.atoms {
out += atom.grad();
}
out
}
#[must_use]
pub fn atoms(&self) -> &[SaeCriterionAtom] {
&self.atoms
}
#[must_use]
pub fn n_rho(&self) -> usize {
self.n_rho
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn sample_criterion() -> SaeCriterion {
SaeCriterion::assemble(
3.0, 2.0, 0.5, array![0.10, -0.20, 0.05], array![0.01, 0.02, -0.03], array![-0.04, 0.00, 0.06], array![0.07, -0.01, 0.00], )
}
#[test]
fn value_is_atom_sum() {
let crit = sample_criterion();
let expected = 3.0 + 0.5 * 2.0 - 0.5;
assert!((crit.value() - expected).abs() < 1e-12);
let by_atom: f64 = crit.atoms().iter().map(SaeCriterionAtom::value).sum();
assert!((by_atom - expected).abs() < 1e-12);
}
#[test]
fn gradient_is_channel_sum_including_correction() {
let crit = sample_criterion();
let g = crit.gradient();
let expected = array![
0.10 + 0.01 - 0.04 + 0.07,
-0.20 + 0.02 + 0.00 - 0.01,
0.05 - 0.03 + 0.06 + 0.00
];
for i in 0..3 {
assert!(
(g[i] - expected[i]).abs() < 1e-12,
"coord {i}: {} vs {}",
g[i],
expected[i]
);
}
}
#[test]
fn implicit_correction_atom_is_gradient_only() {
let atom = SaeCriterionAtom::ImplicitStationarityCorrection {
grad: array![1.0, 2.0, 3.0],
};
assert_eq!(atom.value(), 0.0);
assert_eq!(atom.grad().sum(), 6.0);
assert_eq!(atom.label(), "implicit_stationarity_correction");
}
#[test]
fn atoms_have_distinct_labels() {
let crit = sample_criterion();
let labels: Vec<&str> = crit.atoms().iter().map(SaeCriterionAtom::label).collect();
let mut sorted = labels.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(sorted.len(), labels.len(), "labels must be distinct");
}
}