1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
use super::*;
/// Loss breakdown for diagnostics and evidence ranking.
#[derive(Debug, Clone, Copy)]
pub struct SaeManifoldLoss {
pub data_fit: f64,
pub assignment_sparsity: f64,
pub smoothness: f64,
pub ard: f64,
pub evidence_gauge_deflated_directions: usize,
}
impl SaeManifoldLoss {
pub const fn total(&self) -> f64 {
self.data_fit + self.assignment_sparsity + self.smoothness + self.ard
}
/// Laplace/REML wrappers rank larger evidence higher. This local score is
/// the negative penalized objective, used when a full `RemlState` is not
/// driving the term yet.
pub const fn evidence_proxy(&self) -> f64 {
-self.total()
}
}
/// Componentized analytic derivative of the SAE REML criterion with respect to
/// the flat [`SaeManifoldRho`] layout.
///
/// This is intentionally only a value object for tests and derivation gates. It
/// is not wired into [`SaeManifoldOuterObjective`] capability planning until the
/// third-order logdet correction is available behind its own oracle.
#[derive(Debug, Clone)]
pub struct SaeOuterRhoGradientComponents {
/// Direct derivative of `loss.total() + extra_penalty_energy` with respect to
/// log-strength coordinates, excluding the Hessian logdet and Occam terms.
pub explicit: Array1<f64>,
/// `0.5 * tr(H^{-1} dH/d rho_j)` for the currently available penalty blocks.
pub logdet_trace: Array1<f64>,
/// Derivative contribution of `-occam`.
pub occam: Array1<f64>,
/// Reserved channel for `0.5 * tr(H^{-1} (dH/dtheta * dtheta_hat/d rho_j))`.
pub third_order_correction: Array1<f64>,
/// Whether `third_order_correction` is populated from analytic channels.
pub third_order_correction_available: bool,
}
impl SaeOuterRhoGradientComponents {
#[must_use]
pub fn gradient_excluding_unavailable_correction(&self) -> Array1<f64> {
&(&self.explicit + &self.logdet_trace) + &self.occam
}
#[must_use]
pub fn gradient_with_available_correction(&self) -> Array1<f64> {
// The name is a contract: callers asking for the corrected gradient
// must not silently receive the uncorrected one. Zeros-by-omission in
// the correction channel are exactly the objective↔gradient desync
// class; fail loudly instead.
assert!(
self.third_order_correction_available,
"gradient_with_available_correction: third-order correction channel \
is not populated for this fit; use \
gradient_excluding_unavailable_correction() and account for the \
missing term explicitly"
);
&self.gradient_excluding_unavailable_correction() + &self.third_order_correction
}
}