1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
use crate::faer_ndarray::FaerCholesky;
use faer::Side;
use ndarray::{Array2, ArrayView2};
/// Inputs to the closed-form Gaussian REML/Laplace score of a trained
/// skip-transcoder.
///
/// Every feature of the effective design is a rank-1 outer product of a
/// per-observation activation column with an output-space loading vector,
/// flattened over the `(observation, output)` pair:
///
/// * sparse atom `a`: `D_a = vec(z[:, a] · W_dec[a, :]^T)`,
/// * skip component `r`: `D_r = vec((XV)[:, r] · U[:, r]^T)`,
///
/// where `XV = x_in · skip_V` is the skip's data-dependent projection onto its
/// `rank_skip`-dimensional input subspace and `U = skip_U` is its output
/// loading. Because the skip map enters the prediction only through the
/// products `XV` and `U` (the prediction is `(x_in · V) · U^T`), passing those
/// two products — rather than `V` and `U` separately — makes the score
/// invariant to the unidentifiable balancing gauge `U -> cU, V -> V/c`, which
/// leaves the represented function unchanged.
pub struct SkipTranscoderRemlInputs<'a> {
pub y_out: ArrayView2<'a, f64>,
pub y_hat: ArrayView2<'a, f64>,
pub z: ArrayView2<'a, f64>,
pub w_dec: ArrayView2<'a, f64>,
pub lambda_sparse: f64,
/// Skip input projection `XV = x_in · skip_V`, shape `(n_obs, rank_skip)`.
pub skip_proj: Option<ArrayView2<'a, f64>>,
/// Skip output loading `U = skip_U`, shape `(out_dim, rank_skip)`.
pub skip_u: Option<ArrayView2<'a, f64>>,
}
pub struct SkipTranscoderRemlMetrics {
pub reml_score: f64,
pub mse: f64,
pub sparsity: f64,
pub explained_variance: f64,
pub active_atoms: usize,
pub effective_rank: usize,
}
pub fn skip_transcoder_reml_metrics(
inputs: SkipTranscoderRemlInputs<'_>,
) -> Result<SkipTranscoderRemlMetrics, String> {
let y_out = inputs.y_out;
let y_hat = inputs.y_hat;
let z = inputs.z;
let w_dec = inputs.w_dec;
let lambda_sparse = inputs.lambda_sparse;
let skip_proj = inputs.skip_proj;
let skip_u = inputs.skip_u;
let mut active_atoms = Vec::new();
let mut nonzero_entries = 0_usize;
for atom in 0..z.ncols() {
let mut active = false;
for row in 0..z.nrows() {
if z[[row, atom]].abs() > 1.0e-8 {
active = true;
nonzero_entries += 1;
}
}
if active {
active_atoms.push(atom);
}
}
let skip_rank = skip_u.as_ref().map_or(0, |value| value.ncols());
let feature_count = active_atoms.len() + skip_rank;
// The effective design column of every feature is the flattened outer
// product of a per-observation activation column with an output-space
// loading vector. Hence the Gram entry between two features is the
// elementwise (Hadamard) product of their activation inner product and
// their loading inner product:
//
// G_{pq} = (act_p^T act_q) · (load_p^T load_q).
//
// For atoms the activation is z[:, a] and the loading is W_dec[a, :]; for
// skip components the activation is (XV)[:, r] and the loading is U[:, r].
// The features are stacked in the ordering [active atoms .. | skip ranks ..]
// so the sparse circuit and the bypass are scored on equal footing.
let mut gram = Array2::<f64>::zeros((feature_count, feature_count));
let n_active = active_atoms.len();
let activation_inner = |feature: usize, other: usize| -> f64 {
let act_col = |idx: usize| {
if idx < n_active {
z.column(active_atoms[idx])
} else {
skip_proj
.as_ref()
.expect("skip_proj present whenever skip ranks exist")
.column(idx - n_active)
}
};
act_col(feature).dot(&act_col(other))
};
let loading_inner = |feature: usize, other: usize| -> f64 {
let load_vec = |idx: usize| {
if idx < n_active {
w_dec.row(active_atoms[idx])
} else {
skip_u
.as_ref()
.expect("skip_u present whenever skip ranks exist")
.column(idx - n_active)
}
};
load_vec(feature).dot(&load_vec(other))
};
for i in 0..feature_count {
for j in 0..=i {
let value = activation_inner(i, j) * loading_inner(i, j);
gram[[i, j]] = value;
gram[[j, i]] = value;
}
}
for diag in 0..feature_count {
gram[[diag, diag]] += lambda_sparse;
}
let logdet = if feature_count == 0 {
0.0
} else {
let sym = (&gram + &gram.t()) * 0.5;
let chol = sym
.cholesky(Side::Lower)
.map_err(|err| format!("skip_transcoder_reml_metrics logdet failed: {err}"))?;
let value = 2.0 * chol.diag().iter().map(|diag| diag.ln()).sum::<f64>();
if !value.is_finite() {
return Err(format!(
"skip_transcoder_reml_metrics logdet is not finite: {value}"
));
}
value
};
let (n_rows, out_dim) = y_out.dim();
let n_total = y_out.len() as f64;
let mut sse = 0.0_f64;
let mut y_sum = 0.0_f64;
for row in 0..n_rows {
for col in 0..out_dim {
let resid = y_out[[row, col]] - y_hat[[row, col]];
sse += resid * resid;
y_sum += y_out[[row, col]];
}
}
let mse = sse / n_total;
let sigma2 = mse.max(1.0e-12);
let y_mean = y_sum / n_total;
let mut sst = 0.0_f64;
for value in y_out.iter() {
let centered = value - y_mean;
sst += centered * centered;
}
let explained_variance = if sst > 0.0 {
1.0 - sse / sst
} else if sse == 0.0 {
1.0
} else {
0.0
};
let sparsity = if z.is_empty() {
0.0
} else {
nonzero_entries as f64 / z.len() as f64
};
let reml_score = 0.5 * (n_total * sigma2.ln() + logdet);
Ok(SkipTranscoderRemlMetrics {
reml_score,
mse,
sparsity,
explained_variance,
active_atoms: active_atoms.len(),
effective_rank: feature_count,
})
}