use super::*;
#[derive(Debug, Clone, Copy)]
pub struct SaeManifoldLoss {
pub data_fit: f64,
pub assignment_sparsity: f64,
pub smoothness: f64,
pub ard: f64,
pub evidence_gauge_deflated_directions: usize,
}
impl SaeManifoldLoss {
pub const fn total(&self) -> f64 {
self.data_fit + self.assignment_sparsity + self.smoothness + self.ard
}
pub const fn penalized_loss_score(&self) -> f64 {
-self.total()
}
pub const fn breakdown(&self) -> SaeManifoldLossBreakdown {
SaeManifoldLossBreakdown {
data_fit: self.data_fit,
assignment_sparsity: self.assignment_sparsity,
smoothness: self.smoothness,
ard: self.ard,
total_penalized_loss: self.total(),
penalized_loss_score: self.penalized_loss_score(),
evidence_gauge_deflated_directions: self.evidence_gauge_deflated_directions,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SaeManifoldLossBreakdown {
pub data_fit: f64,
pub assignment_sparsity: f64,
pub smoothness: f64,
pub ard: f64,
pub total_penalized_loss: f64,
pub penalized_loss_score: f64,
pub evidence_gauge_deflated_directions: usize,
}
#[derive(Debug, Clone)]
pub struct SaeOuterRhoGradientComponents {
pub explicit: Array1<f64>,
pub logdet_trace: Array1<f64>,
pub occam: Array1<f64>,
pub third_order_correction: Array1<f64>,
}
impl SaeOuterRhoGradientComponents {
#[must_use]
pub fn gradient(&self) -> Array1<f64> {
&(&(&self.explicit + &self.logdet_trace) + &self.occam) + &self.third_order_correction
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn penalized_loss_score_is_negative_total_with_breakdown() {
let loss = SaeManifoldLoss {
data_fit: 1.5,
assignment_sparsity: 0.25,
smoothness: 0.5,
ard: 0.75,
evidence_gauge_deflated_directions: 3,
};
let total = 1.5 + 0.25 + 0.5 + 0.75;
assert!((loss.total() - total).abs() < 1e-12);
assert!((loss.penalized_loss_score() - (-total)).abs() < 1e-12);
let b = loss.breakdown();
assert!((b.data_fit - 1.5).abs() < 1e-12);
assert!((b.assignment_sparsity - 0.25).abs() < 1e-12);
assert!((b.smoothness - 0.5).abs() < 1e-12);
assert!((b.ard - 0.75).abs() < 1e-12);
assert!((b.total_penalized_loss - total).abs() < 1e-12);
assert!((b.penalized_loss_score - (-total)).abs() < 1e-12);
let summed = b.data_fit + b.assignment_sparsity + b.smoothness + b.ard;
assert!((summed - b.total_penalized_loss).abs() < 1e-12);
assert_eq!(b.evidence_gauge_deflated_directions, 3);
}
}