Skip to main content

gam_sae/manifold/
loss.rs

1use super::*;
2
3/// Loss breakdown for diagnostics and evidence ranking.
4#[derive(Debug, Clone, Copy)]
5pub struct SaeManifoldLoss {
6    pub data_fit: f64,
7    pub assignment_sparsity: f64,
8    pub smoothness: f64,
9    pub ard: f64,
10    pub evidence_gauge_deflated_directions: usize,
11}
12
13impl SaeManifoldLoss {
14    pub const fn total(&self) -> f64 {
15        self.data_fit + self.assignment_sparsity + self.smoothness + self.ard
16    }
17
18    /// Negative penalized loss `−(data_fit + assignment_sparsity + smoothness +
19    /// ard)`. Larger is "less penalized loss", so Laplace/REML wrappers that rank
20    /// larger-is-better can sort on it — but this is **not** a REML / marginal
21    /// likelihood: it omits the Hessian log-determinant, the Occam log-λ term,
22    /// any extra analytic penalties, the co-training fold, the top-k projection
23    /// effect, and hybrid-collapse effects (#1231). Callers must surface it under
24    /// an honest name (`penalized_loss_score`, or `oos_penalized_loss` on the
25    /// fixed-decoder OOS path), never `reml_score`.
26    pub const fn penalized_loss_score(&self) -> f64 {
27        -self.total()
28    }
29
30    /// Honest component breakdown of [`Self::total`] — the four penalized-loss
31    /// terms this struct actually carries — so a consumer can see exactly what
32    /// the score is (and what it is *not*: it is missing the evidence pieces
33    /// listed on [`Self::penalized_loss_score`]). The values are the raw
34    /// (positive) loss contributions; `penalized_loss_score == −Σ` of the first
35    /// four.
36    pub const fn breakdown(&self) -> SaeManifoldLossBreakdown {
37        SaeManifoldLossBreakdown {
38            data_fit: self.data_fit,
39            assignment_sparsity: self.assignment_sparsity,
40            smoothness: self.smoothness,
41            ard: self.ard,
42            total_penalized_loss: self.total(),
43            penalized_loss_score: self.penalized_loss_score(),
44            evidence_gauge_deflated_directions: self.evidence_gauge_deflated_directions,
45        }
46    }
47}
48
49/// Honest, fully-itemized view of [`SaeManifoldLoss`] for the model output. It
50/// reports the penalized-loss components that the score is actually built from,
51/// and is deliberately NOT named or shaped like a REML / evidence breakdown:
52/// the Hessian log-determinant, Occam log-λ, extra penalties, co-training fold,
53/// and top-k / hybrid-collapse effects are not part of this object (#1231).
54#[derive(Debug, Clone, Copy)]
55pub struct SaeManifoldLossBreakdown {
56    pub data_fit: f64,
57    pub assignment_sparsity: f64,
58    pub smoothness: f64,
59    pub ard: f64,
60    /// `data_fit + assignment_sparsity + smoothness + ard`.
61    pub total_penalized_loss: f64,
62    /// `−total_penalized_loss` (larger = less penalized loss).
63    pub penalized_loss_score: f64,
64    /// Count of evidence-gauge-deflated directions recorded on the loss.
65    pub evidence_gauge_deflated_directions: usize,
66}
67
68/// Componentized analytic derivative of the SAE REML criterion with respect to
69/// the flat [`SaeManifoldRho`] layout.
70///
71/// Production objective and certificate paths consume this value object so the
72/// criterion value and gradient are assembled from the same converged cache.
73#[derive(Debug, Clone)]
74pub struct SaeOuterRhoGradientComponents {
75    /// Direct derivative of `loss.total() + extra_penalty_energy` with respect to
76    /// log-strength coordinates, excluding the Hessian logdet and Occam terms.
77    pub explicit: Array1<f64>,
78    /// `0.5 * tr(H^{-1} dH/d rho_j)` for the currently available penalty blocks.
79    pub logdet_trace: Array1<f64>,
80    /// Derivative contribution of `-occam`.
81    pub occam: Array1<f64>,
82    /// `0.5 * tr(H^{-1} (dH/dtheta * dtheta_hat/d rho_j))`.
83    pub third_order_correction: Array1<f64>,
84}
85
86impl SaeOuterRhoGradientComponents {
87    #[must_use]
88    pub fn gradient(&self) -> Array1<f64> {
89        &(&(&self.explicit + &self.logdet_trace) + &self.occam) + &self.third_order_correction
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    /// #1231 — the public score is the NEGATIVE penalized loss of the four loss
98    /// components, and the breakdown itemizes exactly those components. It is not
99    /// (and must not be presented as) a REML criterion.
100    #[test]
101    fn penalized_loss_score_is_negative_total_with_breakdown() {
102        let loss = SaeManifoldLoss {
103            data_fit: 1.5,
104            assignment_sparsity: 0.25,
105            smoothness: 0.5,
106            ard: 0.75,
107            evidence_gauge_deflated_directions: 3,
108        };
109        let total = 1.5 + 0.25 + 0.5 + 0.75;
110        assert!((loss.total() - total).abs() < 1e-12);
111        assert!((loss.penalized_loss_score() - (-total)).abs() < 1e-12);
112
113        let b = loss.breakdown();
114        assert!((b.data_fit - 1.5).abs() < 1e-12);
115        assert!((b.assignment_sparsity - 0.25).abs() < 1e-12);
116        assert!((b.smoothness - 0.5).abs() < 1e-12);
117        assert!((b.ard - 0.75).abs() < 1e-12);
118        assert!((b.total_penalized_loss - total).abs() < 1e-12);
119        assert!((b.penalized_loss_score - (-total)).abs() < 1e-12);
120        // The breakdown's four components must sum to the reported total — the
121        // score is fully explained by what the breakdown lists, with no hidden
122        // evidence pieces folded into it.
123        let summed = b.data_fit + b.assignment_sparsity + b.smoothness + b.ard;
124        assert!((summed - b.total_penalized_loss).abs() < 1e-12);
125        assert_eq!(b.evidence_gauge_deflated_directions, 3);
126    }
127}