Skip to main content

gam_terms/decoders/
skip_transcoder.rs

1use faer::Side;
2use gam_linalg::faer_ndarray::FaerCholesky;
3use ndarray::{Array2, ArrayView2};
4
5/// Inputs to the closed-form Gaussian REML/Laplace score of a trained
6/// skip-transcoder.
7///
8/// Every feature of the effective design is a rank-1 outer product of a
9/// per-observation activation column with an output-space loading vector,
10/// flattened over the `(observation, output)` pair:
11///
12/// * sparse atom `a`: `D_a = vec(z[:, a] · W_dec[a, :]^T)`,
13/// * skip component `r`: `D_r = vec((XV)[:, r] · U[:, r]^T)`,
14///
15/// where `XV = x_in · skip_V` is the skip's data-dependent projection onto its
16/// `rank_skip`-dimensional input subspace and `U = skip_U` is its output
17/// loading. Because the skip map enters the prediction only through the
18/// products `XV` and `U` (the prediction is `(x_in · V) · U^T`), passing those
19/// two products — rather than `V` and `U` separately — makes the score
20/// invariant to the unidentifiable balancing gauge `U -> cU, V -> V/c`, which
21/// leaves the represented function unchanged.
22pub struct SkipTranscoderRemlInputs<'a> {
23    pub y_out: ArrayView2<'a, f64>,
24    pub y_hat: ArrayView2<'a, f64>,
25    pub z: ArrayView2<'a, f64>,
26    pub w_dec: ArrayView2<'a, f64>,
27    pub lambda_sparse: f64,
28    /// Skip input projection `XV = x_in · skip_V`, shape `(n_obs, rank_skip)`.
29    pub skip_proj: Option<ArrayView2<'a, f64>>,
30    /// Skip output loading `U = skip_U`, shape `(out_dim, rank_skip)`.
31    pub skip_u: Option<ArrayView2<'a, f64>>,
32}
33
34pub struct SkipTranscoderRemlMetrics {
35    pub reml_score: f64,
36    pub mse: f64,
37    pub sparsity: f64,
38    pub explained_variance: f64,
39    pub active_atoms: usize,
40    pub effective_rank: usize,
41}
42
43pub fn skip_transcoder_reml_metrics(
44    inputs: SkipTranscoderRemlInputs<'_>,
45) -> Result<SkipTranscoderRemlMetrics, String> {
46    let y_out = inputs.y_out;
47    let y_hat = inputs.y_hat;
48    let z = inputs.z;
49    let w_dec = inputs.w_dec;
50    let lambda_sparse = inputs.lambda_sparse;
51    let skip_proj = inputs.skip_proj;
52    let skip_u = inputs.skip_u;
53
54    let mut active_atoms = Vec::new();
55    let mut nonzero_entries = 0_usize;
56    for atom in 0..z.ncols() {
57        let mut active = false;
58        for row in 0..z.nrows() {
59            if z[[row, atom]].abs() > 1.0e-8 {
60                active = true;
61                nonzero_entries += 1;
62            }
63        }
64        if active {
65            active_atoms.push(atom);
66        }
67    }
68
69    let skip_rank = skip_u.as_ref().map_or(0, |value| value.ncols());
70    let feature_count = active_atoms.len() + skip_rank;
71
72    // The effective design column of every feature is the flattened outer
73    // product of a per-observation activation column with an output-space
74    // loading vector. Hence the Gram entry between two features is the
75    // elementwise (Hadamard) product of their activation inner product and
76    // their loading inner product:
77    //
78    //     G_{pq} = (act_p^T act_q) · (load_p^T load_q).
79    //
80    // For atoms the activation is z[:, a] and the loading is W_dec[a, :]; for
81    // skip components the activation is (XV)[:, r] and the loading is U[:, r].
82    // The features are stacked in the ordering [active atoms .. | skip ranks ..]
83    // so the sparse circuit and the bypass are scored on equal footing.
84    let mut gram = Array2::<f64>::zeros((feature_count, feature_count));
85    let n_active = active_atoms.len();
86    let activation_inner = |feature: usize, other: usize| -> f64 {
87        let act_col = |idx: usize| {
88            if idx < n_active {
89                z.column(active_atoms[idx])
90            } else {
91                skip_proj
92                    .as_ref()
93                    .expect("skip_proj present whenever skip ranks exist")
94                    .column(idx - n_active)
95            }
96        };
97        act_col(feature).dot(&act_col(other))
98    };
99    let loading_inner = |feature: usize, other: usize| -> f64 {
100        let load_vec = |idx: usize| {
101            if idx < n_active {
102                w_dec.row(active_atoms[idx])
103            } else {
104                skip_u
105                    .as_ref()
106                    .expect("skip_u present whenever skip ranks exist")
107                    .column(idx - n_active)
108            }
109        };
110        load_vec(feature).dot(&load_vec(other))
111    };
112
113    for i in 0..feature_count {
114        for j in 0..=i {
115            let value = activation_inner(i, j) * loading_inner(i, j);
116            gram[[i, j]] = value;
117            gram[[j, i]] = value;
118        }
119    }
120    for diag in 0..feature_count {
121        gram[[diag, diag]] += lambda_sparse;
122    }
123
124    let logdet = if feature_count == 0 {
125        0.0
126    } else {
127        let sym = (&gram + &gram.t()) * 0.5;
128        let chol = sym
129            .cholesky(Side::Lower)
130            .map_err(|err| format!("skip_transcoder_reml_metrics logdet failed: {err}"))?;
131        let value = 2.0 * chol.diag().iter().map(|diag| diag.ln()).sum::<f64>();
132        if !value.is_finite() {
133            return Err(format!(
134                "skip_transcoder_reml_metrics logdet is not finite: {value}"
135            ));
136        }
137        value
138    };
139
140    let (n_rows, out_dim) = y_out.dim();
141    let n_total = y_out.len() as f64;
142    let mut sse = 0.0_f64;
143    let mut y_sum = 0.0_f64;
144    for row in 0..n_rows {
145        for col in 0..out_dim {
146            let resid = y_out[[row, col]] - y_hat[[row, col]];
147            sse += resid * resid;
148            y_sum += y_out[[row, col]];
149        }
150    }
151    let mse = sse / n_total;
152    let sigma2 = mse.max(1.0e-12);
153    let y_mean = y_sum / n_total;
154    let mut sst = 0.0_f64;
155    for value in y_out.iter() {
156        let centered = value - y_mean;
157        sst += centered * centered;
158    }
159    let explained_variance = if sst > 0.0 {
160        1.0 - sse / sst
161    } else if sse == 0.0 {
162        1.0
163    } else {
164        0.0
165    };
166    let sparsity = if z.is_empty() {
167        0.0
168    } else {
169        nonzero_entries as f64 / z.len() as f64
170    };
171    let reml_score = 0.5 * (n_total * sigma2.ln() + logdet);
172
173    Ok(SkipTranscoderRemlMetrics {
174        reml_score,
175        mse,
176        sparsity,
177        explained_variance,
178        active_atoms: active_atoms.len(),
179        effective_rank: feature_count,
180    })
181}