gam_terms/decoders/
skip_transcoder.rs1use faer::Side;
2use gam_linalg::faer_ndarray::FaerCholesky;
3use ndarray::{Array2, ArrayView2};
4
5pub struct SkipTranscoderRemlInputs<'a> {
23 pub y_out: ArrayView2<'a, f64>,
24 pub y_hat: ArrayView2<'a, f64>,
25 pub z: ArrayView2<'a, f64>,
26 pub w_dec: ArrayView2<'a, f64>,
27 pub lambda_sparse: f64,
28 pub skip_proj: Option<ArrayView2<'a, f64>>,
30 pub skip_u: Option<ArrayView2<'a, f64>>,
32}
33
34pub struct SkipTranscoderRemlMetrics {
35 pub reml_score: f64,
36 pub mse: f64,
37 pub sparsity: f64,
38 pub explained_variance: f64,
39 pub active_atoms: usize,
40 pub effective_rank: usize,
41}
42
43pub fn skip_transcoder_reml_metrics(
44 inputs: SkipTranscoderRemlInputs<'_>,
45) -> Result<SkipTranscoderRemlMetrics, String> {
46 let y_out = inputs.y_out;
47 let y_hat = inputs.y_hat;
48 let z = inputs.z;
49 let w_dec = inputs.w_dec;
50 let lambda_sparse = inputs.lambda_sparse;
51 let skip_proj = inputs.skip_proj;
52 let skip_u = inputs.skip_u;
53
54 let mut active_atoms = Vec::new();
55 let mut nonzero_entries = 0_usize;
56 for atom in 0..z.ncols() {
57 let mut active = false;
58 for row in 0..z.nrows() {
59 if z[[row, atom]].abs() > 1.0e-8 {
60 active = true;
61 nonzero_entries += 1;
62 }
63 }
64 if active {
65 active_atoms.push(atom);
66 }
67 }
68
69 let skip_rank = skip_u.as_ref().map_or(0, |value| value.ncols());
70 let feature_count = active_atoms.len() + skip_rank;
71
72 let mut gram = Array2::<f64>::zeros((feature_count, feature_count));
85 let n_active = active_atoms.len();
86 let activation_inner = |feature: usize, other: usize| -> f64 {
87 let act_col = |idx: usize| {
88 if idx < n_active {
89 z.column(active_atoms[idx])
90 } else {
91 skip_proj
92 .as_ref()
93 .expect("skip_proj present whenever skip ranks exist")
94 .column(idx - n_active)
95 }
96 };
97 act_col(feature).dot(&act_col(other))
98 };
99 let loading_inner = |feature: usize, other: usize| -> f64 {
100 let load_vec = |idx: usize| {
101 if idx < n_active {
102 w_dec.row(active_atoms[idx])
103 } else {
104 skip_u
105 .as_ref()
106 .expect("skip_u present whenever skip ranks exist")
107 .column(idx - n_active)
108 }
109 };
110 load_vec(feature).dot(&load_vec(other))
111 };
112
113 for i in 0..feature_count {
114 for j in 0..=i {
115 let value = activation_inner(i, j) * loading_inner(i, j);
116 gram[[i, j]] = value;
117 gram[[j, i]] = value;
118 }
119 }
120 for diag in 0..feature_count {
121 gram[[diag, diag]] += lambda_sparse;
122 }
123
124 let logdet = if feature_count == 0 {
125 0.0
126 } else {
127 let sym = (&gram + &gram.t()) * 0.5;
128 let chol = sym
129 .cholesky(Side::Lower)
130 .map_err(|err| format!("skip_transcoder_reml_metrics logdet failed: {err}"))?;
131 let value = 2.0 * chol.diag().iter().map(|diag| diag.ln()).sum::<f64>();
132 if !value.is_finite() {
133 return Err(format!(
134 "skip_transcoder_reml_metrics logdet is not finite: {value}"
135 ));
136 }
137 value
138 };
139
140 let (n_rows, out_dim) = y_out.dim();
141 let n_total = y_out.len() as f64;
142 let mut sse = 0.0_f64;
143 let mut y_sum = 0.0_f64;
144 for row in 0..n_rows {
145 for col in 0..out_dim {
146 let resid = y_out[[row, col]] - y_hat[[row, col]];
147 sse += resid * resid;
148 y_sum += y_out[[row, col]];
149 }
150 }
151 let mse = sse / n_total;
152 let sigma2 = mse.max(1.0e-12);
153 let y_mean = y_sum / n_total;
154 let mut sst = 0.0_f64;
155 for value in y_out.iter() {
156 let centered = value - y_mean;
157 sst += centered * centered;
158 }
159 let explained_variance = if sst > 0.0 {
160 1.0 - sse / sst
161 } else if sse == 0.0 {
162 1.0
163 } else {
164 0.0
165 };
166 let sparsity = if z.is_empty() {
167 0.0
168 } else {
169 nonzero_entries as f64 / z.len() as f64
170 };
171 let reml_score = 0.5 * (n_total * sigma2.ln() + logdet);
172
173 Ok(SkipTranscoderRemlMetrics {
174 reml_score,
175 mse,
176 sparsity,
177 explained_variance,
178 active_atoms: active_atoms.len(),
179 effective_rank: feature_count,
180 })
181}