1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use super::*;
/// Loss breakdown for diagnostics and evidence ranking.
#[derive(Debug, Clone, Copy)]
pub struct SaeManifoldLoss {
pub data_fit: f64,
pub assignment_sparsity: f64,
pub smoothness: f64,
pub ard: f64,
pub evidence_gauge_deflated_directions: usize,
}
impl SaeManifoldLoss {
pub const fn total(&self) -> f64 {
self.data_fit + self.assignment_sparsity + self.smoothness + self.ard
}
/// Negative penalized loss `−(data_fit + assignment_sparsity + smoothness +
/// ard)`. Larger is "less penalized loss", so Laplace/REML wrappers that rank
/// larger-is-better can sort on it — but this is **not** a REML / marginal
/// likelihood: it omits the Hessian log-determinant, the Occam log-λ term,
/// any extra analytic penalties, the co-training fold, the top-k projection
/// effect, and hybrid-collapse effects (#1231). Callers must surface it under
/// an honest name (`penalized_loss_score`, or `oos_penalized_loss` on the
/// fixed-decoder OOS path), never `reml_score`.
pub const fn penalized_loss_score(&self) -> f64 {
-self.total()
}
/// Honest component breakdown of [`Self::total`] — the four penalized-loss
/// terms this struct actually carries — so a consumer can see exactly what
/// the score is (and what it is *not*: it is missing the evidence pieces
/// listed on [`Self::penalized_loss_score`]). The values are the raw
/// (positive) loss contributions; `penalized_loss_score == −Σ` of the first
/// four.
pub const fn breakdown(&self) -> SaeManifoldLossBreakdown {
SaeManifoldLossBreakdown {
data_fit: self.data_fit,
assignment_sparsity: self.assignment_sparsity,
smoothness: self.smoothness,
ard: self.ard,
total_penalized_loss: self.total(),
penalized_loss_score: self.penalized_loss_score(),
evidence_gauge_deflated_directions: self.evidence_gauge_deflated_directions,
}
}
}
/// Honest, fully-itemized view of [`SaeManifoldLoss`] for the model output. It
/// reports the penalized-loss components that the score is actually built from,
/// and is deliberately NOT named or shaped like a REML / evidence breakdown:
/// the Hessian log-determinant, Occam log-λ, extra penalties, co-training fold,
/// and top-k / hybrid-collapse effects are not part of this object (#1231).
#[derive(Debug, Clone, Copy)]
pub struct SaeManifoldLossBreakdown {
pub data_fit: f64,
pub assignment_sparsity: f64,
pub smoothness: f64,
pub ard: f64,
/// `data_fit + assignment_sparsity + smoothness + ard`.
pub total_penalized_loss: f64,
/// `−total_penalized_loss` (larger = less penalized loss).
pub penalized_loss_score: f64,
/// Count of evidence-gauge-deflated directions recorded on the loss.
pub evidence_gauge_deflated_directions: usize,
}
/// Componentized analytic derivative of the SAE REML criterion with respect to
/// the flat [`SaeManifoldRho`] layout.
///
/// Production objective and certificate paths consume this value object so the
/// criterion value and gradient are assembled from the same converged cache.
#[derive(Debug, Clone)]
pub struct SaeOuterRhoGradientComponents {
/// Direct derivative of `loss.total() + extra_penalty_energy` with respect to
/// log-strength coordinates, excluding the Hessian logdet and Occam terms.
pub explicit: Array1<f64>,
/// `0.5 * tr(H^{-1} dH/d rho_j)` for the currently available penalty blocks.
pub logdet_trace: Array1<f64>,
/// Derivative contribution of `-occam`.
pub occam: Array1<f64>,
/// `0.5 * tr(H^{-1} (dH/dtheta * dtheta_hat/d rho_j))`.
pub third_order_correction: Array1<f64>,
}
impl SaeOuterRhoGradientComponents {
#[must_use]
pub fn gradient(&self) -> Array1<f64> {
&(&(&self.explicit + &self.logdet_trace) + &self.occam) + &self.third_order_correction
}
}
#[cfg(test)]
mod tests {
use super::*;
/// #1231 — the public score is the NEGATIVE penalized loss of the four loss
/// components, and the breakdown itemizes exactly those components. It is not
/// (and must not be presented as) a REML criterion.
#[test]
fn penalized_loss_score_is_negative_total_with_breakdown() {
let loss = SaeManifoldLoss {
data_fit: 1.5,
assignment_sparsity: 0.25,
smoothness: 0.5,
ard: 0.75,
evidence_gauge_deflated_directions: 3,
};
let total = 1.5 + 0.25 + 0.5 + 0.75;
assert!((loss.total() - total).abs() < 1e-12);
assert!((loss.penalized_loss_score() - (-total)).abs() < 1e-12);
let b = loss.breakdown();
assert!((b.data_fit - 1.5).abs() < 1e-12);
assert!((b.assignment_sparsity - 0.25).abs() < 1e-12);
assert!((b.smoothness - 0.5).abs() < 1e-12);
assert!((b.ard - 0.75).abs() < 1e-12);
assert!((b.total_penalized_loss - total).abs() < 1e-12);
assert!((b.penalized_loss_score - (-total)).abs() < 1e-12);
// The breakdown's four components must sum to the reported total — the
// score is fully explained by what the breakdown lists, with no hidden
// evidence pieces folded into it.
let summed = b.data_fit + b.assignment_sparsity + b.smoothness + b.ard;
assert!((summed - b.total_penalized_loss).abs() < 1e-12);
assert_eq!(b.evidence_gauge_deflated_directions, 3);
}
}