gam_sae/manifold/loss.rs
1use super::*;
2
3/// Loss breakdown for diagnostics and evidence ranking.
4#[derive(Debug, Clone, Copy)]
5pub struct SaeManifoldLoss {
6 pub data_fit: f64,
7 pub assignment_sparsity: f64,
8 pub smoothness: f64,
9 pub ard: f64,
10 pub evidence_gauge_deflated_directions: usize,
11}
12
13impl SaeManifoldLoss {
14 pub const fn total(&self) -> f64 {
15 self.data_fit + self.assignment_sparsity + self.smoothness + self.ard
16 }
17
18 /// Negative penalized loss `−(data_fit + assignment_sparsity + smoothness +
19 /// ard)`. Larger is "less penalized loss", so Laplace/REML wrappers that rank
20 /// larger-is-better can sort on it — but this is **not** a REML / marginal
21 /// likelihood: it omits the Hessian log-determinant, the Occam log-λ term,
22 /// any extra analytic penalties, the co-training fold, the top-k projection
23 /// effect, and hybrid-collapse effects (#1231). Callers must surface it under
24 /// an honest name (`penalized_loss_score`, or `oos_penalized_loss` on the
25 /// fixed-decoder OOS path), never `reml_score`.
26 pub const fn penalized_loss_score(&self) -> f64 {
27 -self.total()
28 }
29
30 /// Honest component breakdown of [`Self::total`] — the four penalized-loss
31 /// terms this struct actually carries — so a consumer can see exactly what
32 /// the score is (and what it is *not*: it is missing the evidence pieces
33 /// listed on [`Self::penalized_loss_score`]). The values are the raw
34 /// (positive) loss contributions; `penalized_loss_score == −Σ` of the first
35 /// four.
36 pub const fn breakdown(&self) -> SaeManifoldLossBreakdown {
37 SaeManifoldLossBreakdown {
38 data_fit: self.data_fit,
39 assignment_sparsity: self.assignment_sparsity,
40 smoothness: self.smoothness,
41 ard: self.ard,
42 total_penalized_loss: self.total(),
43 penalized_loss_score: self.penalized_loss_score(),
44 evidence_gauge_deflated_directions: self.evidence_gauge_deflated_directions,
45 }
46 }
47}
48
49/// Honest, fully-itemized view of [`SaeManifoldLoss`] for the model output. It
50/// reports the penalized-loss components that the score is actually built from,
51/// and is deliberately NOT named or shaped like a REML / evidence breakdown:
52/// the Hessian log-determinant, Occam log-λ, extra penalties, co-training fold,
53/// and top-k / hybrid-collapse effects are not part of this object (#1231).
54#[derive(Debug, Clone, Copy)]
55pub struct SaeManifoldLossBreakdown {
56 pub data_fit: f64,
57 pub assignment_sparsity: f64,
58 pub smoothness: f64,
59 pub ard: f64,
60 /// `data_fit + assignment_sparsity + smoothness + ard`.
61 pub total_penalized_loss: f64,
62 /// `−total_penalized_loss` (larger = less penalized loss).
63 pub penalized_loss_score: f64,
64 /// Count of evidence-gauge-deflated directions recorded on the loss.
65 pub evidence_gauge_deflated_directions: usize,
66}
67
68/// Componentized analytic derivative of the SAE REML criterion with respect to
69/// the flat [`SaeManifoldRho`] layout.
70///
71/// Production objective and certificate paths consume this value object so the
72/// criterion value and gradient are assembled from the same converged cache.
73#[derive(Debug, Clone)]
74pub struct SaeOuterRhoGradientComponents {
75 /// Direct derivative of `loss.total() + extra_penalty_energy` with respect to
76 /// log-strength coordinates, excluding the Hessian logdet and Occam terms.
77 pub explicit: Array1<f64>,
78 /// `0.5 * tr(H^{-1} dH/d rho_j)` for the currently available penalty blocks.
79 pub logdet_trace: Array1<f64>,
80 /// Derivative contribution of `-occam`.
81 pub occam: Array1<f64>,
82 /// `0.5 * tr(H^{-1} (dH/dtheta * dtheta_hat/d rho_j))`.
83 pub third_order_correction: Array1<f64>,
84}
85
86impl SaeOuterRhoGradientComponents {
87 #[must_use]
88 pub fn gradient(&self) -> Array1<f64> {
89 &(&(&self.explicit + &self.logdet_trace) + &self.occam) + &self.third_order_correction
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 /// #1231 — the public score is the NEGATIVE penalized loss of the four loss
98 /// components, and the breakdown itemizes exactly those components. It is not
99 /// (and must not be presented as) a REML criterion.
100 #[test]
101 fn penalized_loss_score_is_negative_total_with_breakdown() {
102 let loss = SaeManifoldLoss {
103 data_fit: 1.5,
104 assignment_sparsity: 0.25,
105 smoothness: 0.5,
106 ard: 0.75,
107 evidence_gauge_deflated_directions: 3,
108 };
109 let total = 1.5 + 0.25 + 0.5 + 0.75;
110 assert!((loss.total() - total).abs() < 1e-12);
111 assert!((loss.penalized_loss_score() - (-total)).abs() < 1e-12);
112
113 let b = loss.breakdown();
114 assert!((b.data_fit - 1.5).abs() < 1e-12);
115 assert!((b.assignment_sparsity - 0.25).abs() < 1e-12);
116 assert!((b.smoothness - 0.5).abs() < 1e-12);
117 assert!((b.ard - 0.75).abs() < 1e-12);
118 assert!((b.total_penalized_loss - total).abs() < 1e-12);
119 assert!((b.penalized_loss_score - (-total)).abs() < 1e-12);
120 // The breakdown's four components must sum to the reported total — the
121 // score is fully explained by what the breakdown lists, with no hidden
122 // evidence pieces folded into it.
123 let summed = b.data_fit + b.assignment_sparsity + b.smoothness + b.ard;
124 assert!((summed - b.total_penalized_loss).abs() < 1e-12);
125 assert_eq!(b.evidence_gauge_deflated_directions, 3);
126 }
127}