Skip to main content

gam_terms/
dictionary.rs

1use faer::Side;
2use gam_linalg::faer_ndarray::{FaerCholesky, FaerEigh};
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, s};
4
5const DEFAULT_MAX_ITER: usize = 30;
6const DEFAULT_TOP_K: usize = 1;
7const DEFAULT_TEMPERATURE: f64 = 0.25;
8const DEFAULT_CODE_RIDGE: f64 = 1.0e-8;
9const DEFAULT_TOLERANCE: f64 = 1.0e-7;
10const INACTIVE_LAMBDA: f64 = 1.0e30;
11const MIN_NORM2: f64 = 1.0e-24;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub enum LinearDictionaryAssignment {
15    TopK,
16    Softmax,
17}
18
19impl LinearDictionaryAssignment {
20    pub fn parse(value: &str) -> Result<Self, String> {
21        match value.trim().to_ascii_lowercase().as_str() {
22            "top_k" | "topk" | "hard" => Ok(Self::TopK),
23            "softmax" | "soft" => Ok(Self::Softmax),
24            other => Err(format!(
25                "linear dictionary assignment must be 'top_k' or 'softmax'; got {other:?}"
26            )),
27        }
28    }
29
30    pub const fn as_str(self) -> &'static str {
31        match self {
32            Self::TopK => "top_k",
33            Self::Softmax => "softmax",
34        }
35    }
36}
37
38#[derive(Clone, Debug)]
39pub struct LinearDictionaryConfig {
40    pub n_atoms: usize,
41    pub max_iter: usize,
42    pub top_k: usize,
43    pub assignment: LinearDictionaryAssignment,
44    pub temperature: f64,
45    pub code_ridge: f64,
46    pub tolerance: f64,
47    /// K=1 lane only. When `false` (default) the rank-one lane takes the leading
48    /// eigenvector of the UNCENTERED second-moment matrix `XᵀX` (byte-identical to
49    /// historical behavior), which is only a true centered-PCA ceiling when `x` is
50    /// already mean-centered. When `true` the lane subtracts the column mean, takes
51    /// the leading eigenvector of the CENTERED second-moment matrix, fits the
52    /// rank-1 code on the centered data, and adds the mean back — so the reported
53    /// EV (measured against the crate's centered denominator) is a genuine
54    /// centered-PCA ceiling even on uncentered input. Because the reconstruction is
55    /// then affine (mean + rank-1), the returned `fitted` INCLUDES the mean and is
56    /// NOT equal to `assignments.dot(atoms)` in this mode.
57    pub center_rank_one: bool,
58}
59
60impl LinearDictionaryConfig {
61    pub fn new(n_atoms: usize) -> Self {
62        Self {
63            n_atoms,
64            ..Self::default()
65        }
66    }
67}
68
69impl Default for LinearDictionaryConfig {
70    fn default() -> Self {
71        Self {
72            n_atoms: 1,
73            max_iter: DEFAULT_MAX_ITER,
74            top_k: DEFAULT_TOP_K,
75            assignment: LinearDictionaryAssignment::TopK,
76            temperature: DEFAULT_TEMPERATURE,
77            code_ridge: DEFAULT_CODE_RIDGE,
78            tolerance: DEFAULT_TOLERANCE,
79            center_rank_one: false,
80        }
81    }
82}
83
84#[derive(Clone, Debug)]
85pub struct LinearDictionaryFit {
86    pub atoms: Array2<f64>,
87    pub assignments: Array2<f64>,
88    pub fitted: Array2<f64>,
89    pub lambdas: Array1<f64>,
90    pub reml_scores: Array1<f64>,
91    pub explained_variance: f64,
92    pub iterations: usize,
93    pub converged: bool,
94    pub assignment: LinearDictionaryAssignment,
95    pub top_k: usize,
96}
97
98/// Fit a linear (flat) dictionary by block coordinate descent: each sweep
99/// re-routes rows to atoms (the assignment step) and then refines every atom and
100/// its assignment column by a penalized least-squares update against the residual.
101///
102/// CONTRACT: this is a heuristic coordinate-descent dictionary learner, not a
103/// globally-optimal linear SAE. The coordinate-descent sweep leaves `assignments`
104/// as the per-atom-refined routing from the final sweep (each atom's column is the
105/// LS solve against the residual of the *then-current* dictionary), which is NOT a
106/// fresh global routing against the FINAL atoms. After the loop we therefore run a
107/// FINAL REROUTE (see [`reroute_against_atoms`]): a single fresh global assignment
108/// of every row against the final atoms using the configured rule. We ADOPT that
109/// rerouted routing only when it does not lower EV, so the returned model is the
110/// better of {coordinate-descent routing, fresh global reroute} and is never worse
111/// than before this step. `fitted` and `explained_variance` are always recomputed
112/// from the adopted `assignments`, so the reported EV is exactly the EV of the
113/// model that is returned (honest, and now the better of the two cheap routings).
114pub fn fit_linear_dictionary(
115    x: ArrayView2<'_, f64>,
116    config: &LinearDictionaryConfig,
117) -> Result<LinearDictionaryFit, String> {
118    validate_inputs(x, config)?;
119    if config.n_atoms == 1 {
120        return fit_rank_one_pca_lane(x, config);
121    }
122    Ok(fit_multi_atom_dictionary(x, config)?.fit)
123}
124
125/// Diagnostics returned by the internal multi-atom solver: the fitted model plus
126/// the EV of the coordinate-descent routing as it stood *before* the final reroute
127/// adoption decision. The reroute-never-regresses invariant is exactly
128/// `fit.explained_variance >= pre_reroute_ev`; exposing both lets the unit tests
129/// assert it without re-running the private routing logic.
130struct MultiAtomDictionaryFit {
131    fit: LinearDictionaryFit,
132    // Read only by the reroute-never-regresses unit test; the production caller
133    // takes `.fit` and discards the diagnostic.
134    #[cfg_attr(not(test), allow(dead_code))]
135    pre_reroute_ev: f64,
136}
137
138fn fit_multi_atom_dictionary(
139    x: ArrayView2<'_, f64>,
140    config: &LinearDictionaryConfig,
141) -> Result<MultiAtomDictionaryFit, String> {
142    let top_k = config.top_k.min(config.n_atoms).max(1);
143    let mut atoms = initialize_atoms(x, config.n_atoms);
144    let mut assignments = Array2::<f64>::zeros((x.nrows(), config.n_atoms));
145    let mut fitted = Array2::<f64>::zeros(x.dim());
146    let mut lambdas = Array1::<f64>::from_elem(config.n_atoms, INACTIVE_LAMBDA);
147    let mut reml_scores = Array1::<f64>::zeros(config.n_atoms);
148    let mut previous_ev = f64::NEG_INFINITY;
149    let mut converged = false;
150    let mut completed_iterations = 0usize;
151
152    for iter in 0..config.max_iter {
153        assignments = reroute_against_atoms(x, atoms.view(), top_k, config)?;
154
155        fitted = assignments.dot(&atoms);
156        let mut any_reseeded = false;
157        for atom_idx in 0..config.n_atoms {
158            any_reseeded |= fit_one_atom_penalized_ls(
159                x,
160                &mut atoms,
161                &mut assignments,
162                &mut fitted,
163                &mut lambdas,
164                &mut reml_scores,
165                atom_idx,
166                config.code_ridge,
167            )?;
168        }
169
170        completed_iterations = iter + 1;
171        let ev = explained_variance(x, fitted.view());
172        // #1500: never declare convergence on an iteration that re-seeded a dead
173        // atom — its revived direction carries no code yet, so EV is momentarily
174        // flat; one more sweep lets the assignment step route rows to it.
175        if !any_reseeded && (ev - previous_ev).abs() <= config.tolerance.max(0.0) {
176            converged = true;
177            break;
178        }
179        previous_ev = ev;
180    }
181
182    // FINAL REROUTE: the loop's last assignment step routed rows against the atoms
183    // as they were BEFORE that sweep's per-atom refinement, and the atoms have
184    // since moved. Recompute a fresh global routing of every row against the FINAL
185    // atoms with the configured rule, and ADOPT it only when it does not lower EV —
186    // guaranteeing no regression and keeping assignments / fitted / EV consistent.
187    let pre_reroute_ev = explained_variance(x, fitted.view());
188    let rerouted = reroute_against_atoms(x, atoms.view(), top_k, config)?;
189    let rerouted_fitted = rerouted.dot(&atoms);
190    let rerouted_ev = explained_variance(x, rerouted_fitted.view());
191    let (assignments, fitted, final_ev) = if rerouted_ev >= pre_reroute_ev {
192        (rerouted, rerouted_fitted, rerouted_ev)
193    } else {
194        (assignments, fitted, pre_reroute_ev)
195    };
196
197    Ok(MultiAtomDictionaryFit {
198        fit: LinearDictionaryFit {
199            atoms,
200            assignments,
201            fitted,
202            lambdas,
203            reml_scores,
204            explained_variance: final_ev,
205            iterations: completed_iterations,
206            converged,
207            assignment: config.assignment,
208            top_k,
209        },
210        pre_reroute_ev,
211    })
212}
213
214/// Fresh global routing of every row against `atoms` using the configured
215/// assignment rule. This is the single source of truth shared by the
216/// coordinate-descent assignment step and the post-loop final reroute, so both
217/// route identically and the reroute is a true global re-assignment against the
218/// final atoms.
219fn reroute_against_atoms(
220    x: ArrayView2<'_, f64>,
221    atoms: ArrayView2<'_, f64>,
222    top_k: usize,
223    config: &LinearDictionaryConfig,
224) -> Result<Array2<f64>, String> {
225    match config.assignment {
226        LinearDictionaryAssignment::TopK => top_k_assignments(x, atoms, top_k, config.code_ridge),
227        LinearDictionaryAssignment::Softmax => {
228            softmax_assignments(x, atoms, top_k, config.temperature, config.code_ridge)
229        }
230    }
231}
232
233fn validate_inputs(x: ArrayView2<'_, f64>, config: &LinearDictionaryConfig) -> Result<(), String> {
234    if x.nrows() == 0 || x.ncols() == 0 {
235        return Err("linear_dictionary_fit requires a non-empty 2-D matrix".to_string());
236    }
237    if !x.iter().all(|value| value.is_finite()) {
238        return Err("linear_dictionary_fit input must be finite".to_string());
239    }
240    if config.n_atoms == 0 {
241        return Err("linear_dictionary_fit requires K >= 1".to_string());
242    }
243    if config.max_iter == 0 {
244        return Err("linear_dictionary_fit requires max_iter >= 1".to_string());
245    }
246    if config.top_k == 0 || config.top_k > config.n_atoms {
247        return Err(format!(
248            "linear_dictionary_fit top_k must be in [1, K={}]; got {}",
249            config.n_atoms, config.top_k
250        ));
251    }
252    if !(config.temperature.is_finite() && config.temperature > 0.0) {
253        return Err(format!(
254            "linear_dictionary_fit temperature must be finite and positive; got {}",
255            config.temperature
256        ));
257    }
258    if !(config.code_ridge.is_finite() && config.code_ridge > 0.0) {
259        return Err(format!(
260            "linear_dictionary_fit code_ridge must be finite and positive; got {}",
261            config.code_ridge
262        ));
263    }
264    if !config.tolerance.is_finite() {
265        return Err("linear_dictionary_fit tolerance must be finite".to_string());
266    }
267    Ok(())
268}
269
270/// K=1 closed-form lane.
271///
272/// Default (`config.center_rank_one == false`): the leading eigenvector of the
273/// UNCENTERED second-moment matrix `XᵀX`. This is only a true centered-PCA ceiling
274/// when `x` is already mean-centered upstream; the `explained_variance` denominator
275/// IS centered, so on uncentered input the leading `XᵀX` eigenvector can absorb the
276/// mean direction and this lane is a second-moment rank-1 fit rather than the
277/// centered principal component. This branch is byte-identical to historical
278/// behavior.
279///
280/// Centered (`config.center_rank_one == true`): delegates to
281/// [`fit_rank_one_centered_lane`], which subtracts the column mean, takes the
282/// leading eigenvector of the CENTERED second-moment matrix, and adds the mean
283/// back, so the reported EV is a genuine centered-PCA ceiling even on uncentered
284/// input. See that function and [`rank_one_centered_pca_ceiling`] for details.
285fn fit_rank_one_pca_lane(
286    x: ArrayView2<'_, f64>,
287    config: &LinearDictionaryConfig,
288) -> Result<LinearDictionaryFit, String> {
289    if config.center_rank_one {
290        return fit_rank_one_centered_lane(x, config);
291    }
292    let covariance = x.t().dot(&x);
293    let (evals, evecs) = covariance
294        .eigh(Side::Lower)
295        .map_err(|err| format!("linear_dictionary_fit PCA eigensolve failed: {err}"))?;
296    let last = evals.len() - 1;
297    let mut atom = evecs.column(last).to_owned();
298    orient_vector(&mut atom);
299    let mut assignments = Array2::<f64>::zeros((x.nrows(), 1));
300    for row in 0..x.nrows() {
301        assignments[[row, 0]] = x.row(row).dot(&atom) / (1.0 + config.code_ridge);
302    }
303    let mut atoms = atom.insert_axis(Axis(0)).to_owned();
304    normalize_atom_and_assignments(&mut atoms, &mut assignments, 0);
305    let fitted = assignments.dot(&atoms);
306    let score = penalized_reconstruction_loss(x, fitted.view(), config.code_ridge, atoms.view());
307    Ok(LinearDictionaryFit {
308        atoms,
309        assignments,
310        fitted: fitted.clone(),
311        lambdas: Array1::from_elem(1, config.code_ridge),
312        reml_scores: Array1::from_elem(1, score),
313        explained_variance: explained_variance(x, fitted.view()),
314        iterations: 1.min(config.max_iter),
315        converged: true,
316        assignment: config.assignment,
317        top_k: 1,
318    })
319}
320
321/// Centered K=1 lane (`config.center_rank_one == true`): a genuine centered-PCA
322/// ceiling. Builds a full [`LinearDictionaryFit`] from the shared centered
323/// components — `atoms` is the unit-norm centered principal direction,
324/// `assignments` are the centered rank-1 codes, and `fitted` is the AFFINE
325/// reconstruction `mean + code·atom`, so `explained_variance` (centered
326/// denominator) is a true ceiling. Because the reconstruction is affine, `fitted`
327/// INCLUDES the mean and is NOT `assignments.dot(atoms)` in this mode.
328fn fit_rank_one_centered_lane(
329    x: ArrayView2<'_, f64>,
330    config: &LinearDictionaryConfig,
331) -> Result<LinearDictionaryFit, String> {
332    let CenteredRankOne {
333        atom,
334        codes,
335        fitted,
336        explained_variance: ev,
337    } = centered_rank_one_components(x, config.code_ridge)?;
338    let atoms = atom.insert_axis(Axis(0)).to_owned();
339    let assignments = codes.insert_axis(Axis(1)).to_owned();
340    let score = penalized_reconstruction_loss(x, fitted.view(), config.code_ridge, atoms.view());
341    Ok(LinearDictionaryFit {
342        atoms,
343        assignments,
344        fitted,
345        lambdas: Array1::from_elem(1, config.code_ridge),
346        reml_scores: Array1::from_elem(1, score),
347        explained_variance: ev,
348        iterations: 1.min(config.max_iter),
349        converged: true,
350        assignment: config.assignment,
351        top_k: 1,
352    })
353}
354
355/// Shared components of the centered rank-1 fit, so the public ceiling helper and
356/// the centered K=1 lane compute exactly the same principal direction / codes.
357struct CenteredRankOne {
358    /// Unit-norm centered principal direction (length `p`).
359    atom: Array1<f64>,
360    /// Centered rank-1 codes with the ridge shrink applied (length `n`).
361    codes: Array1<f64>,
362    /// Affine reconstruction `mean + code·atom` (shape `n × p`).
363    fitted: Array2<f64>,
364    /// EV of `fitted` against the crate's centered denominator.
365    explained_variance: f64,
366}
367
368fn centered_rank_one_components(
369    x: ArrayView2<'_, f64>,
370    code_ridge: f64,
371) -> Result<CenteredRankOne, String> {
372    if x.nrows() == 0 || x.ncols() == 0 {
373        return Err("rank_one_centered_pca_ceiling requires a non-empty 2-D matrix".to_string());
374    }
375    if !(code_ridge.is_finite() && code_ridge > 0.0) {
376        return Err(format!(
377            "rank_one_centered_pca_ceiling code_ridge must be finite and positive; got {code_ridge}"
378        ));
379    }
380    let means = x.mean_axis(Axis(0)).expect("non-empty input has means");
381    let centered = &x.to_owned() - &means;
382    let covariance = centered.t().dot(&centered);
383    let (evals, evecs) = covariance
384        .eigh(Side::Lower)
385        .map_err(|err| format!("rank_one_centered_pca_ceiling eigensolve failed: {err}"))?;
386    let last = evals.len() - 1;
387    let mut atom = evecs.column(last).to_owned();
388    orient_vector(&mut atom);
389    let shrink = 1.0 / (1.0 + code_ridge);
390    let mut codes = Array1::<f64>::zeros(x.nrows());
391    let mut fitted = Array2::<f64>::zeros(x.dim());
392    for row in 0..x.nrows() {
393        let code = centered.row(row).dot(&atom) * shrink;
394        codes[row] = code;
395        for col in 0..x.ncols() {
396            fitted[[row, col]] = means[col] + code * atom[col];
397        }
398    }
399    let ev = explained_variance(x, fitted.view());
400    Ok(CenteredRankOne {
401        atom,
402        codes,
403        fitted,
404        explained_variance: ev,
405    })
406}
407
408/// Centered rank-1 PCA ceiling for the K=1 lane, exposed for callers that want the
409/// ceiling reconstruction/EV directly. Subtracts the column means, takes the
410/// leading eigenvector of the CENTERED second-moment matrix, fits the rank-1 code
411/// on the centered data with the same ridge shrink the uncentered lane uses, then
412/// adds the mean back so the reconstruction lives in the original space. Returns
413/// `(fitted, explained_variance)`; the EV is measured against the same centered
414/// denominator as the rest of the crate, so it is directly comparable to (and an
415/// upper bound on) the uncentered lane's EV. Prefer setting
416/// `LinearDictionaryConfig::center_rank_one = true` to route the K=1 lane through
417/// this computation as part of a full [`LinearDictionaryFit`].
418pub fn rank_one_centered_pca_ceiling(
419    x: ArrayView2<'_, f64>,
420    code_ridge: f64,
421) -> Result<(Array2<f64>, f64), String> {
422    let components = centered_rank_one_components(x, code_ridge)?;
423    Ok((components.fitted, components.explained_variance))
424}
425
426fn initialize_atoms(x: ArrayView2<'_, f64>, n_atoms: usize) -> Array2<f64> {
427    let mut atoms = Array2::<f64>::zeros((n_atoms, x.ncols()));
428    let first = max_norm_row(x);
429    atoms.row_mut(0).assign(&x.row(first));
430    normalize_row(atoms.slice_mut(s![0, ..]));
431    let mut min_dist2 = Array1::<f64>::from_elem(x.nrows(), f64::INFINITY);
432
433    for atom_idx in 1..n_atoms {
434        let prev = atoms.row(atom_idx - 1);
435        for row in 0..x.nrows() {
436            let dist2 = squared_distance(x.row(row), prev);
437            if dist2 < min_dist2[row] {
438                min_dist2[row] = dist2;
439            }
440        }
441        let chosen = if atom_idx < x.nrows() {
442            max_index(min_dist2.view())
443        } else {
444            atom_idx % x.nrows()
445        };
446        atoms.row_mut(atom_idx).assign(&x.row(chosen));
447        normalize_row(atoms.slice_mut(s![atom_idx, ..]));
448    }
449    atoms
450}
451
452fn fit_one_atom_penalized_ls(
453    x: ArrayView2<'_, f64>,
454    atoms: &mut Array2<f64>,
455    assignments: &mut Array2<f64>,
456    fitted: &mut Array2<f64>,
457    lambdas: &mut Array1<f64>,
458    reml_scores: &mut Array1<f64>,
459    atom_idx: usize,
460    atom_ridge: f64,
461) -> Result<bool, String> {
462    let code = assignments.column(atom_idx).to_owned();
463    let code_norm2 = code.dot(&code);
464    if code_norm2 <= MIN_NORM2 {
465        // #1500: this atom's cluster is EMPTY (no rows routed to it by the
466        // assignment step). Zeroing it here made the atom permanently DEAD — a
467        // zero atom has zero similarity to every row, so `top_k_assignments`
468        // never routes anything back to it, the dictionary collapses to < K live
469        // atoms, and it under-explains variance even when the data is exactly K
470        // rank-1 atoms a K-atom dictionary could reconstruct perfectly. Instead
471        // RE-SEED the atom into the worst-currently-reconstructed direction (the
472        // standard k-means empty-cluster cure): point it at the largest-residual
473        // row's UNEXPLAINED component so the next assignment sweep can route that
474        // row's cluster to it and revive it. Returns `true` so the outer loop
475        // suppresses convergence this iteration (the revived atom has no code
476        // yet, so EV is momentarily flat — converging now would strand it).
477        let mut worst_row = 0usize;
478        let mut worst_res2 = -1.0_f64;
479        for row in 0..x.nrows() {
480            let mut res2 = 0.0_f64;
481            for col in 0..x.ncols() {
482                let d = x[[row, col]] - fitted[[row, col]];
483                res2 += d * d;
484            }
485            if res2 > worst_res2 {
486                worst_res2 = res2;
487                worst_row = row;
488            }
489        }
490        if worst_res2 <= MIN_NORM2 {
491            // Every row is already fully reconstructed by the other atoms: there
492            // is no unexplained direction to seed, so this atom is genuinely
493            // redundant capacity. Leave it inactive (this is not the bug).
494            atoms.row_mut(atom_idx).fill(0.0);
495            lambdas[atom_idx] = INACTIVE_LAMBDA;
496            reml_scores[atom_idx] = 0.0;
497            return Ok(false);
498        }
499        for col in 0..x.ncols() {
500            atoms[[atom_idx, col]] = x[[worst_row, col]] - fitted[[worst_row, col]];
501        }
502        normalize_row(atoms.slice_mut(s![atom_idx, ..]));
503        lambdas[atom_idx] = atom_ridge;
504        reml_scores[atom_idx] =
505            penalized_reconstruction_loss(x, fitted.view(), atom_ridge, atoms.view());
506        return Ok(true);
507    }
508
509    let old_atom = atoms.row(atom_idx).to_owned();
510    let mut residual = x.to_owned() - fitted.view();
511    residual += &code
512        .view()
513        .insert_axis(Axis(1))
514        .dot(&old_atom.view().insert_axis(Axis(0)));
515
516    let denominator = code_norm2 + atom_ridge;
517    for col in 0..x.ncols() {
518        atoms[[atom_idx, col]] = code.dot(&residual.column(col)) / denominator;
519    }
520    lambdas[atom_idx] = atom_ridge;
521    normalize_atom_and_assignments(atoms, assignments, atom_idx);
522    let updated_code = assignments.column(atom_idx).to_owned();
523    fitted.assign(&x);
524    *fitted -= &residual;
525    *fitted += &updated_code
526        .view()
527        .insert_axis(Axis(1))
528        .dot(&atoms.row(atom_idx).insert_axis(Axis(0)));
529    reml_scores[atom_idx] =
530        penalized_reconstruction_loss(x, fitted.view(), atom_ridge, atoms.view());
531    Ok(false)
532}
533
534fn top_k_assignments(
535    x: ArrayView2<'_, f64>,
536    atoms: ArrayView2<'_, f64>,
537    top_k: usize,
538    code_ridge: f64,
539) -> Result<Array2<f64>, String> {
540    let cross = x.dot(&atoms.t());
541    let mut assignments = Array2::<f64>::zeros((x.nrows(), atoms.nrows()));
542    for row in 0..x.nrows() {
543        let active = top_indices_by_abs(cross.row(row), top_k);
544        let coeffs = solve_active_coefficients(atoms, cross.row(row), &active, code_ridge)?;
545        for pos in 0..active.len() {
546            assignments[[row, active[pos]]] = coeffs[pos];
547        }
548    }
549    Ok(assignments)
550}
551
552/// Encode held-out rows `x` (`M x P`) against a frozen dictionary `atoms`
553/// (`K x P`) using the same top-`top_k` ridge least-squares routing the fit
554/// uses against its final atoms. Returns the `(M, K)` sparse code matrix.
555///
556/// This is the out-of-sample `transform`/encode step for a fitted linear
557/// dictionary; the math (top-k selection + active-set ridge solve) lives in
558/// the Rust core so the Python facade stays a thin wrapper.
559pub fn linear_dictionary_transform(
560    x: ArrayView2<'_, f64>,
561    atoms: ArrayView2<'_, f64>,
562    top_k: usize,
563    code_ridge: f64,
564) -> Result<Array2<f64>, String> {
565    let k = atoms.nrows();
566    if k == 0 {
567        return Err("linear_dictionary_transform: dictionary has no atoms".to_string());
568    }
569    if x.ncols() != atoms.ncols() {
570        return Err(format!(
571            "linear_dictionary_transform: X has P={} columns but atoms have P={}",
572            x.ncols(),
573            atoms.ncols()
574        ));
575    }
576    let effective_k = top_k.min(k).max(1);
577    top_k_assignments(x, atoms, effective_k, code_ridge)
578}
579
580fn softmax_assignments(
581    x: ArrayView2<'_, f64>,
582    atoms: ArrayView2<'_, f64>,
583    top_k: usize,
584    temperature: f64,
585    code_ridge: f64,
586) -> Result<Array2<f64>, String> {
587    let cross = x.dot(&atoms.t());
588    let atom_norm2 = atoms.map_axis(Axis(1), |row| row.dot(&row).max(MIN_NORM2));
589    let mut assignments = Array2::<f64>::zeros((x.nrows(), atoms.nrows()));
590    for row in 0..x.nrows() {
591        let active = top_indices_by_abs(cross.row(row), top_k);
592        let mut max_score = f64::NEG_INFINITY;
593        for &atom_idx in &active {
594            let score = cross[[row, atom_idx]].abs() / (atom_norm2[atom_idx].sqrt() * temperature);
595            if score > max_score {
596                max_score = score;
597            }
598        }
599        let mut denom = 0.0;
600        for &atom_idx in &active {
601            let score = cross[[row, atom_idx]].abs() / (atom_norm2[atom_idx].sqrt() * temperature);
602            let mass = (score - max_score).exp();
603            assignments[[row, atom_idx]] = mass;
604            denom += mass;
605        }
606        if denom <= 0.0 || !denom.is_finite() {
607            return Err("linear_dictionary_fit softmax assignment underflowed".to_string());
608        }
609        for &atom_idx in &active {
610            let projection = cross[[row, atom_idx]] / (atom_norm2[atom_idx] + code_ridge);
611            assignments[[row, atom_idx]] = assignments[[row, atom_idx]] * projection / denom;
612        }
613    }
614    Ok(assignments)
615}
616
617fn solve_active_coefficients(
618    atoms: ArrayView2<'_, f64>,
619    cross_row: ArrayView1<'_, f64>,
620    active: &[usize],
621    code_ridge: f64,
622) -> Result<Array1<f64>, String> {
623    let m = active.len();
624    let mut system = Array2::<f64>::zeros((m, m));
625    let mut rhs = Array2::<f64>::zeros((m, 1));
626    for i in 0..m {
627        rhs[[i, 0]] = cross_row[active[i]];
628        for j in 0..m {
629            system[[i, j]] = atoms.row(active[i]).dot(&atoms.row(active[j]));
630        }
631        system[[i, i]] += code_ridge;
632    }
633    let factor = system
634        .cholesky(Side::Lower)
635        .map_err(|err| format!("linear_dictionary_fit sparse-code solve failed: {err}"))?;
636    let mut solution = rhs;
637    factor.solve_mat_in_place(&mut solution);
638    Ok(solution.column(0).to_owned())
639}
640
641fn top_indices_by_abs(row: ArrayView1<'_, f64>, top_k: usize) -> Vec<usize> {
642    let mut selected: Vec<(usize, f64)> = Vec::with_capacity(top_k);
643    for idx in 0..row.len() {
644        let score = row[idx].abs();
645        if selected.len() < top_k {
646            selected.push((idx, score));
647            continue;
648        }
649        let mut worst_pos = 0usize;
650        for pos in 1..selected.len() {
651            if selected[pos].1 < selected[worst_pos].1
652                || (selected[pos].1 == selected[worst_pos].1
653                    && selected[pos].0 > selected[worst_pos].0)
654            {
655                worst_pos = pos;
656            }
657        }
658        let worst = selected[worst_pos];
659        if score > worst.1 || (score == worst.1 && idx < worst.0) {
660            selected[worst_pos] = (idx, score);
661        }
662    }
663    selected.sort_by(|a, b| {
664        b.1.partial_cmp(&a.1)
665            .unwrap_or(std::cmp::Ordering::Equal)
666            .then_with(|| a.0.cmp(&b.0))
667    });
668    selected.into_iter().map(|(idx, _)| idx).collect()
669}
670
671fn normalize_atom_and_assignments(
672    atoms: &mut Array2<f64>,
673    assignments: &mut Array2<f64>,
674    atom_idx: usize,
675) {
676    let norm = atoms.row(atom_idx).dot(&atoms.row(atom_idx)).sqrt();
677    if norm > MIN_NORM2.sqrt() {
678        atoms.row_mut(atom_idx).mapv_inplace(|value| value / norm);
679        assignments
680            .column_mut(atom_idx)
681            .mapv_inplace(|value| value * norm);
682    }
683    orient_atom_and_code(atoms, assignments, atom_idx);
684}
685
686fn orient_atom_and_code(atoms: &mut Array2<f64>, assignments: &mut Array2<f64>, atom_idx: usize) {
687    let sign = first_nonzero_sign(atoms.row(atom_idx));
688    if sign < 0.0 {
689        atoms.row_mut(atom_idx).mapv_inplace(|value| -value);
690        assignments
691            .column_mut(atom_idx)
692            .mapv_inplace(|value| -value);
693    }
694}
695
696fn orient_vector(vector: &mut Array1<f64>) {
697    if first_nonzero_sign(vector.view()) < 0.0 {
698        vector.mapv_inplace(|value| -value);
699    }
700}
701
702fn first_nonzero_sign(row: ndarray::ArrayView1<'_, f64>) -> f64 {
703    for &value in row {
704        if value.abs() > 1.0e-12 {
705            return value.signum();
706        }
707    }
708    1.0
709}
710
711fn normalize_row(mut row: ndarray::ArrayViewMut1<'_, f64>) {
712    let norm = row.dot(&row).sqrt();
713    if norm > MIN_NORM2.sqrt() {
714        row.mapv_inplace(|value| value / norm);
715    }
716}
717
718fn max_norm_row(x: ArrayView2<'_, f64>) -> usize {
719    let mut best = 0usize;
720    let mut best_norm = f64::NEG_INFINITY;
721    for row in 0..x.nrows() {
722        let norm = x.row(row).dot(&x.row(row));
723        if norm > best_norm {
724            best = row;
725            best_norm = norm;
726        }
727    }
728    best
729}
730
731fn max_index(values: ndarray::ArrayView1<'_, f64>) -> usize {
732    let mut best = 0usize;
733    let mut best_value = f64::NEG_INFINITY;
734    for idx in 0..values.len() {
735        if values[idx] > best_value {
736            best = idx;
737            best_value = values[idx];
738        }
739    }
740    best
741}
742
743fn squared_distance(a: ndarray::ArrayView1<'_, f64>, b: ndarray::ArrayView1<'_, f64>) -> f64 {
744    a.iter()
745        .zip(b.iter())
746        .map(|(av, bv)| {
747            let diff = av - bv;
748            diff * diff
749        })
750        .sum()
751}
752
753fn explained_variance(x: ArrayView2<'_, f64>, fitted: ArrayView2<'_, f64>) -> f64 {
754    let mut rss = 0.0;
755    for row in 0..x.nrows() {
756        for col in 0..x.ncols() {
757            let residual = x[[row, col]] - fitted[[row, col]];
758            rss += residual * residual;
759        }
760    }
761    let means = x.mean_axis(Axis(0)).expect("non-empty input has means");
762    let mut tss = 0.0;
763    for row in 0..x.nrows() {
764        for col in 0..x.ncols() {
765            let centered = x[[row, col]] - means[col];
766            tss += centered * centered;
767        }
768    }
769    if tss <= MIN_NORM2 {
770        if rss <= MIN_NORM2 { 1.0 } else { 0.0 }
771    } else {
772        1.0 - rss / tss
773    }
774}
775
776fn penalized_reconstruction_loss(
777    x: ArrayView2<'_, f64>,
778    fitted: ArrayView2<'_, f64>,
779    ridge: f64,
780    atoms: ArrayView2<'_, f64>,
781) -> f64 {
782    let mut loss = 0.0;
783    for row in 0..x.nrows() {
784        for col in 0..x.ncols() {
785            let residual = x[[row, col]] - fitted[[row, col]];
786            loss += residual * residual;
787        }
788    }
789    loss + ridge * atoms.iter().map(|value| value * value).sum::<f64>()
790}
791
792#[cfg(test)]
793mod tests {
794    use super::*;
795    use approx::assert_abs_diff_eq;
796    use ndarray::{Array2, array};
797
798    #[test]
799    fn planted_sparse_linear_dictionary_reaches_high_explained_variance() {
800        let truth = array![
801            [1.0, 0.0, 0.0, 0.0],
802            [0.0, 1.0, 0.0, 0.0],
803            [0.0, 0.0, 1.0, 0.0],
804            [0.0, 0.0, 0.0, 1.0],
805        ];
806        let mut assignments = Array2::<f64>::zeros((160, 4));
807        for row in 0..160 {
808            let atom = row % 4;
809            assignments[[row, atom]] = 0.7 + 0.01 * ((row / 4) as f64);
810            assignments[[row, (atom + 1) % 4]] = 0.2;
811        }
812        let x = assignments.dot(&truth);
813        let config = LinearDictionaryConfig {
814            n_atoms: 4,
815            max_iter: 40,
816            top_k: 2,
817            assignment: LinearDictionaryAssignment::TopK,
818            temperature: DEFAULT_TEMPERATURE,
819            code_ridge: DEFAULT_CODE_RIDGE,
820            tolerance: 1.0e-9,
821            center_rank_one: false,
822        };
823
824        let fit = fit_linear_dictionary(x.view(), &config).expect("linear dictionary fit");
825
826        assert!(
827            fit.explained_variance > 0.95,
828            "expected EV > 0.95, got {}",
829            fit.explained_variance
830        );
831    }
832
833    #[test]
834    fn single_atom_matches_penalized_pca_oracle() {
835        let mut x = Array2::<f64>::zeros((80, 3));
836        for row in 0..80 {
837            let t = (row as f64 - 39.5) / 20.0;
838            x[[row, 0]] = 2.0 * t;
839            x[[row, 1]] = -t;
840            x[[row, 2]] = 0.05 * (row as f64).sin();
841        }
842        let config = LinearDictionaryConfig {
843            n_atoms: 1,
844            max_iter: 5,
845            top_k: 1,
846            assignment: LinearDictionaryAssignment::TopK,
847            temperature: DEFAULT_TEMPERATURE,
848            code_ridge: DEFAULT_CODE_RIDGE,
849            tolerance: DEFAULT_TOLERANCE,
850            center_rank_one: false,
851        };
852
853        let fit = fit_linear_dictionary(x.view(), &config).expect("rank-one fit");
854        let covariance = x.t().dot(&x);
855        let (evals, _) = covariance.eigh(Side::Lower).expect("PCA eigensolve");
856        let shrink = 1.0 / (1.0 + DEFAULT_CODE_RIDGE);
857        let oracle_ev = 1.0
858            - ((1.0 - shrink) * (1.0 - shrink) * evals[evals.len() - 1]
859                + evals.slice(s![..evals.len() - 1]).sum())
860                / evals.sum();
861
862        assert!(fit.explained_variance > 0.99);
863        assert_abs_diff_eq!(fit.explained_variance, oracle_ev, epsilon = 2.0e-4);
864    }
865
866    #[test]
867    fn orthonormal_rank_one_atoms_all_revived_no_dead_collapse_1500() {
868        // #1500: rows lie on K mutually ORTHONORMAL rank-1 directions, so a
869        // K-atom top_k=1 dictionary that recovers them reconstructs every row
870        // exactly (EV → 1). The dead-atom bug emptied a cluster, zeroed that atom
871        // permanently, and returned < K live atoms with badly under-explained
872        // variance. With empty-cluster re-seeding every atom stays live.
873        let (k, p, n) = (4usize, 8usize, 400usize);
874        // Deterministic orthonormal directions: eigenvectors of a fixed symmetric
875        // matrix are orthonormal, so no RNG is needed for a stable regression.
876        let mut a = Array2::<f64>::zeros((p, p));
877        for i in 0..p {
878            for j in 0..p {
879                a[[i, j]] = ((i * 7 + j * 3 + 1) % 11) as f64 - 5.0;
880            }
881        }
882        let sym = &a + &a.t();
883        let (_evals, evecs) = sym.eigh(Side::Lower).expect("orthonormal directions");
884        let dirs = evecs.slice(s![.., ..k]).t().to_owned(); // k×p, orthonormal rows
885        let mut x = Array2::<f64>::zeros((n, p));
886        for row in 0..n {
887            let atom = row % k;
888            let scale = if row % 2 == 0 { 2.0 } else { -1.5 } + 0.01 * (row / k) as f64;
889            for col in 0..p {
890                let noise = 1.0e-3 * (((row * p + col) % 13) as f64 - 6.0);
891                x[[row, col]] = scale * dirs[[atom, col]] + noise;
892            }
893        }
894        let config = LinearDictionaryConfig {
895            n_atoms: k,
896            max_iter: 40,
897            top_k: 1,
898            assignment: LinearDictionaryAssignment::TopK,
899            temperature: DEFAULT_TEMPERATURE,
900            code_ridge: DEFAULT_CODE_RIDGE,
901            tolerance: 1.0e-9,
902            center_rank_one: false,
903        };
904        let fit = fit_linear_dictionary(x.view(), &config).expect("orthonormal dictionary fit");
905        let live = fit
906            .atoms
907            .axis_iter(Axis(0))
908            .filter(|atom| atom.iter().any(|value| value.abs() > 1.0e-12))
909            .count();
910        assert_eq!(
911            live, k,
912            "all {k} atoms must stay live (no dead-atom collapse); got {live} live"
913        );
914        assert!(
915            fit.explained_variance > 0.99,
916            "K orthonormal rank-1 atoms must be reconstructed at EV > 0.99; got {}",
917            fit.explained_variance
918        );
919    }
920
921    #[test]
922    fn final_reroute_never_regresses_and_stays_consistent() {
923        // Planted sparse problem where the coordinate-descent routing and a fresh
924        // global reroute against the final atoms generally differ.
925        let truth = array![
926            [1.0, 0.0, 0.0, 0.0],
927            [0.0, 1.0, 0.0, 0.0],
928            [0.0, 0.0, 1.0, 0.0],
929            [0.0, 0.0, 0.0, 1.0],
930        ];
931        let mut assignments = Array2::<f64>::zeros((160, 4));
932        for row in 0..160 {
933            let atom = row % 4;
934            assignments[[row, atom]] = 0.7 + 0.01 * ((row / 4) as f64);
935            assignments[[row, (atom + 1) % 4]] = 0.2;
936        }
937        let x = assignments.dot(&truth);
938        let config = LinearDictionaryConfig {
939            n_atoms: 4,
940            max_iter: 40,
941            top_k: 2,
942            assignment: LinearDictionaryAssignment::TopK,
943            temperature: DEFAULT_TEMPERATURE,
944            code_ridge: DEFAULT_CODE_RIDGE,
945            tolerance: 1.0e-9,
946            center_rank_one: false,
947        };
948
949        // Internal solver exposes the pre-reroute (coordinate-descent) EV so we can
950        // assert the no-regression invariant directly.
951        let diag =
952            fit_multi_atom_dictionary(x.view(), &config).expect("multi-atom dictionary fit");
953        assert!(
954            diag.fit.explained_variance >= diag.pre_reroute_ev - 1.0e-12,
955            "final reroute regressed EV: pre={}, returned={}",
956            diag.pre_reroute_ev,
957            diag.fit.explained_variance
958        );
959
960        // Returned fitted must be exactly assignments.dot(atoms) for the adopted
961        // routing, and the reported EV must match that fitted.
962        let recomputed_fitted = diag.fit.assignments.dot(&diag.fit.atoms);
963        for (a, b) in diag.fit.fitted.iter().zip(recomputed_fitted.iter()) {
964            assert_abs_diff_eq!(*a, *b, epsilon = 1.0e-10);
965        }
966        assert_abs_diff_eq!(
967            diag.fit.explained_variance,
968            explained_variance(x.view(), diag.fit.fitted.view()),
969            epsilon = 1.0e-10
970        );
971
972        // Public entry point returns the adopted result and is also self-consistent.
973        let public = fit_linear_dictionary(x.view(), &config).expect("public fit");
974        let public_fitted = public.assignments.dot(&public.atoms);
975        for (a, b) in public.fitted.iter().zip(public_fitted.iter()) {
976            assert_abs_diff_eq!(*a, *b, epsilon = 1.0e-10);
977        }
978    }
979
980    #[test]
981    fn centered_rank_one_ceiling_agrees_when_data_already_centered() {
982        // Build correlated data, then explicitly mean-center it. On centered input
983        // the uncentered XᵀX lane and the centered helper see the same second-moment
984        // matrix, so their EVs must agree.
985        let mut x = Array2::<f64>::zeros((90, 3));
986        for row in 0..90 {
987            let t = (row as f64 - 44.5) / 25.0;
988            x[[row, 0]] = 1.5 * t;
989            x[[row, 1]] = -0.8 * t + 0.02 * (row as f64).cos();
990            x[[row, 2]] = 0.6 * t;
991        }
992        let means = x.mean_axis(Axis(0)).unwrap();
993        let centered = &x - &means;
994
995        let config = LinearDictionaryConfig::new(1);
996        let uncentered = fit_linear_dictionary(centered.view(), &config).expect("rank-one fit");
997        let (_fitted, centered_ev) =
998            rank_one_centered_pca_ceiling(centered.view(), DEFAULT_CODE_RIDGE)
999                .expect("centered ceiling");
1000
1001        assert_abs_diff_eq!(
1002            uncentered.explained_variance,
1003            centered_ev,
1004            epsilon = 1.0e-9
1005        );
1006    }
1007
1008    #[test]
1009    fn centered_rank_one_ceiling_beats_uncentered_with_strong_mean() {
1010        // Strong column mean (offset) plus a low-variance signal direction: the
1011        // uncentered XᵀX lane wastes its single rank on the mean direction and
1012        // under-explains the CENTERED variance, while the centered helper recovers
1013        // the true principal component and is a genuine, higher centered-PCA ceiling.
1014        let mut x = Array2::<f64>::zeros((120, 2));
1015        for row in 0..120 {
1016            let t = (row as f64 - 59.5) / 60.0; // small spread around the offset
1017            x[[row, 0]] = 50.0 + 0.3 * t;
1018            x[[row, 1]] = 50.0 - 0.3 * t;
1019        }
1020        let config = LinearDictionaryConfig::new(1);
1021        let uncentered = fit_linear_dictionary(x.view(), &config).expect("rank-one fit");
1022        let (fitted, centered_ev) =
1023            rank_one_centered_pca_ceiling(x.view(), DEFAULT_CODE_RIDGE).expect("centered ceiling");
1024
1025        assert!(
1026            centered_ev > uncentered.explained_variance + 1.0e-6,
1027            "centered ceiling ({centered_ev}) should beat uncentered lane ({}) on strong-mean data",
1028            uncentered.explained_variance
1029        );
1030        // The centered helper's reported EV is consistent with its returned fitted.
1031        assert_abs_diff_eq!(
1032            centered_ev,
1033            explained_variance(x.view(), fitted.view()),
1034            epsilon = 1.0e-10
1035        );
1036    }
1037
1038    #[test]
1039    fn center_rank_one_config_flag_routes_k1_lane_to_centered_ceiling() {
1040        // Strong-mean, low-variance-signal data: the default (uncentered) K=1 lane
1041        // wastes its single rank on the mean, so setting `center_rank_one = true`
1042        // must route the lane through the centered computation and report the
1043        // genuine (higher) centered-PCA ceiling — matching the standalone helper.
1044        let mut x = Array2::<f64>::zeros((100, 3));
1045        for row in 0..100 {
1046            let t = (row as f64 - 49.5) / 50.0;
1047            x[[row, 0]] = 30.0 + 0.2 * t;
1048            x[[row, 1]] = 30.0 - 0.2 * t;
1049            x[[row, 2]] = 30.0 + 0.05 * t;
1050        }
1051
1052        let default_config = LinearDictionaryConfig::new(1);
1053        assert!(!default_config.center_rank_one, "flag must default to false");
1054        let uncentered = fit_linear_dictionary(x.view(), &default_config).expect("uncentered lane");
1055
1056        let mut centered_config = LinearDictionaryConfig::new(1);
1057        centered_config.center_rank_one = true;
1058        let centered = fit_linear_dictionary(x.view(), &centered_config).expect("centered lane");
1059
1060        // The flag actually routes to the centered lane: its EV equals the helper's
1061        // centered ceiling and strictly beats the default uncentered lane.
1062        let (_fitted, helper_ev) =
1063            rank_one_centered_pca_ceiling(x.view(), DEFAULT_CODE_RIDGE).expect("helper ceiling");
1064        assert_abs_diff_eq!(centered.explained_variance, helper_ev, epsilon = 1.0e-10);
1065        assert!(
1066            centered.explained_variance > uncentered.explained_variance + 1.0e-6,
1067            "center_rank_one=true ({}) must beat default ({}) on strong-mean data",
1068            centered.explained_variance,
1069            uncentered.explained_variance
1070        );
1071        // Centered lane reports the affine reconstruction directly, so its EV is
1072        // consistent with the returned `fitted` (which INCLUDES the mean and is not
1073        // assignments.dot(atoms) in this mode).
1074        assert_abs_diff_eq!(
1075            centered.explained_variance,
1076            explained_variance(x.view(), centered.fitted.view()),
1077            epsilon = 1.0e-10
1078        );
1079    }
1080
1081    #[test]
1082    fn sparse_assignment_scales_to_thousand_atom_dictionary() {
1083        let active_atoms = array![
1084            [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1085            [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1086            [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1087            [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
1088            [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
1089            [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
1090            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
1091            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
1092        ];
1093        let mut x = Array2::<f64>::zeros((256, 8));
1094        for row in 0..x.nrows() {
1095            let atom = row % active_atoms.nrows();
1096            let scale = 0.7 + 0.003 * row as f64;
1097            x.row_mut(row).assign(&(&active_atoms.row(atom) * scale));
1098        }
1099        let config = LinearDictionaryConfig {
1100            n_atoms: 1024,
1101            max_iter: 8,
1102            top_k: 1,
1103            assignment: LinearDictionaryAssignment::TopK,
1104            temperature: DEFAULT_TEMPERATURE,
1105            code_ridge: DEFAULT_CODE_RIDGE,
1106            tolerance: 1.0e-9,
1107            center_rank_one: false,
1108        };
1109
1110        let fit = fit_linear_dictionary(x.view(), &config).expect("large-K linear dictionary fit");
1111        let max_active = fit
1112            .assignments
1113            .axis_iter(Axis(0))
1114            .map(|row| row.iter().filter(|value| value.abs() > 1.0e-10).count())
1115            .max()
1116            .unwrap();
1117
1118        assert_eq!(max_active, 1);
1119        assert!(
1120            fit.explained_variance > 0.95,
1121            "expected EV > 0.95 at K=1024, got {}",
1122            fit.explained_variance
1123        );
1124    }
1125}