Skip to main content

gam_terms/
dictionary.rs

1use faer::Side;
2use gam_linalg::faer_ndarray::{FaerCholesky, FaerEigh};
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, s};
4
5const DEFAULT_MAX_ITER: usize = 30;
6const DEFAULT_TOP_K: usize = 1;
7const DEFAULT_TEMPERATURE: f64 = 0.25;
8const DEFAULT_CODE_RIDGE: f64 = 1.0e-8;
9const DEFAULT_TOLERANCE: f64 = 1.0e-7;
10const INACTIVE_LAMBDA: f64 = 1.0e30;
11const MIN_NORM2: f64 = 1.0e-24;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub enum LinearDictionaryAssignment {
15    TopK,
16    Softmax,
17}
18
19impl LinearDictionaryAssignment {
20    pub fn parse(value: &str) -> Result<Self, String> {
21        match value.trim().to_ascii_lowercase().as_str() {
22            "top_k" | "topk" | "hard" => Ok(Self::TopK),
23            "softmax" | "soft" => Ok(Self::Softmax),
24            other => Err(format!(
25                "linear dictionary assignment must be 'top_k' or 'softmax'; got {other:?}"
26            )),
27        }
28    }
29
30    pub const fn as_str(self) -> &'static str {
31        match self {
32            Self::TopK => "top_k",
33            Self::Softmax => "softmax",
34        }
35    }
36}
37
38#[derive(Clone, Debug)]
39pub struct LinearDictionaryConfig {
40    pub n_atoms: usize,
41    pub max_iter: usize,
42    pub top_k: usize,
43    pub assignment: LinearDictionaryAssignment,
44    pub temperature: f64,
45    pub code_ridge: f64,
46    pub tolerance: f64,
47}
48
49impl LinearDictionaryConfig {
50    pub fn new(n_atoms: usize) -> Self {
51        Self {
52            n_atoms,
53            ..Self::default()
54        }
55    }
56}
57
58impl Default for LinearDictionaryConfig {
59    fn default() -> Self {
60        Self {
61            n_atoms: 1,
62            max_iter: DEFAULT_MAX_ITER,
63            top_k: DEFAULT_TOP_K,
64            assignment: LinearDictionaryAssignment::TopK,
65            temperature: DEFAULT_TEMPERATURE,
66            code_ridge: DEFAULT_CODE_RIDGE,
67            tolerance: DEFAULT_TOLERANCE,
68        }
69    }
70}
71
72#[derive(Clone, Debug)]
73pub struct LinearDictionaryFit {
74    pub atoms: Array2<f64>,
75    pub assignments: Array2<f64>,
76    pub fitted: Array2<f64>,
77    pub lambdas: Array1<f64>,
78    pub reml_scores: Array1<f64>,
79    pub explained_variance: f64,
80    pub iterations: usize,
81    pub converged: bool,
82    pub assignment: LinearDictionaryAssignment,
83    pub top_k: usize,
84}
85
86pub fn fit_linear_dictionary(
87    x: ArrayView2<'_, f64>,
88    config: &LinearDictionaryConfig,
89) -> Result<LinearDictionaryFit, String> {
90    validate_inputs(x, config)?;
91    if config.n_atoms == 1 {
92        return fit_rank_one_pca_lane(x, config);
93    }
94
95    let top_k = config.top_k.min(config.n_atoms).max(1);
96    let mut atoms = initialize_atoms(x, config.n_atoms);
97    let mut assignments = Array2::<f64>::zeros((x.nrows(), config.n_atoms));
98    let mut fitted = Array2::<f64>::zeros(x.dim());
99    let mut lambdas = Array1::<f64>::from_elem(config.n_atoms, INACTIVE_LAMBDA);
100    let mut reml_scores = Array1::<f64>::zeros(config.n_atoms);
101    let mut previous_ev = f64::NEG_INFINITY;
102    let mut converged = false;
103    let mut completed_iterations = 0usize;
104
105    for iter in 0..config.max_iter {
106        assignments = match config.assignment {
107            LinearDictionaryAssignment::TopK => {
108                top_k_assignments(x, atoms.view(), top_k, config.code_ridge)?
109            }
110            LinearDictionaryAssignment::Softmax => softmax_assignments(
111                x,
112                atoms.view(),
113                top_k,
114                config.temperature,
115                config.code_ridge,
116            )?,
117        };
118
119        fitted = assignments.dot(&atoms);
120        let mut any_reseeded = false;
121        for atom_idx in 0..config.n_atoms {
122            any_reseeded |= fit_one_atom_penalized_ls(
123                x,
124                &mut atoms,
125                &mut assignments,
126                &mut fitted,
127                &mut lambdas,
128                &mut reml_scores,
129                atom_idx,
130                config.code_ridge,
131            )?;
132        }
133
134        completed_iterations = iter + 1;
135        let ev = explained_variance(x, fitted.view());
136        // #1500: never declare convergence on an iteration that re-seeded a dead
137        // atom — its revived direction carries no code yet, so EV is momentarily
138        // flat; one more sweep lets the assignment step route rows to it.
139        if !any_reseeded && (ev - previous_ev).abs() <= config.tolerance.max(0.0) {
140            converged = true;
141            break;
142        }
143        previous_ev = ev;
144    }
145
146    let final_ev = explained_variance(x, fitted.view());
147    Ok(LinearDictionaryFit {
148        atoms,
149        assignments,
150        fitted,
151        lambdas,
152        reml_scores,
153        explained_variance: final_ev,
154        iterations: completed_iterations,
155        converged,
156        assignment: config.assignment,
157        top_k,
158    })
159}
160
161fn validate_inputs(x: ArrayView2<'_, f64>, config: &LinearDictionaryConfig) -> Result<(), String> {
162    if x.nrows() == 0 || x.ncols() == 0 {
163        return Err("linear_dictionary_fit requires a non-empty 2-D matrix".to_string());
164    }
165    if !x.iter().all(|value| value.is_finite()) {
166        return Err("linear_dictionary_fit input must be finite".to_string());
167    }
168    if config.n_atoms == 0 {
169        return Err("linear_dictionary_fit requires K >= 1".to_string());
170    }
171    if config.max_iter == 0 {
172        return Err("linear_dictionary_fit requires max_iter >= 1".to_string());
173    }
174    if config.top_k == 0 || config.top_k > config.n_atoms {
175        return Err(format!(
176            "linear_dictionary_fit top_k must be in [1, K={}]; got {}",
177            config.n_atoms, config.top_k
178        ));
179    }
180    if !(config.temperature.is_finite() && config.temperature > 0.0) {
181        return Err(format!(
182            "linear_dictionary_fit temperature must be finite and positive; got {}",
183            config.temperature
184        ));
185    }
186    if !(config.code_ridge.is_finite() && config.code_ridge > 0.0) {
187        return Err(format!(
188            "linear_dictionary_fit code_ridge must be finite and positive; got {}",
189            config.code_ridge
190        ));
191    }
192    if !config.tolerance.is_finite() {
193        return Err("linear_dictionary_fit tolerance must be finite".to_string());
194    }
195    Ok(())
196}
197
198fn fit_rank_one_pca_lane(
199    x: ArrayView2<'_, f64>,
200    config: &LinearDictionaryConfig,
201) -> Result<LinearDictionaryFit, String> {
202    let covariance = x.t().dot(&x);
203    let (evals, evecs) = covariance
204        .eigh(Side::Lower)
205        .map_err(|err| format!("linear_dictionary_fit PCA eigensolve failed: {err}"))?;
206    let last = evals.len() - 1;
207    let mut atom = evecs.column(last).to_owned();
208    orient_vector(&mut atom);
209    let mut assignments = Array2::<f64>::zeros((x.nrows(), 1));
210    for row in 0..x.nrows() {
211        assignments[[row, 0]] = x.row(row).dot(&atom) / (1.0 + config.code_ridge);
212    }
213    let mut atoms = atom.insert_axis(Axis(0)).to_owned();
214    normalize_atom_and_assignments(&mut atoms, &mut assignments, 0);
215    let fitted = assignments.dot(&atoms);
216    let score = penalized_reconstruction_loss(x, fitted.view(), config.code_ridge, atoms.view());
217    Ok(LinearDictionaryFit {
218        atoms,
219        assignments,
220        fitted: fitted.clone(),
221        lambdas: Array1::from_elem(1, config.code_ridge),
222        reml_scores: Array1::from_elem(1, score),
223        explained_variance: explained_variance(x, fitted.view()),
224        iterations: 1.min(config.max_iter),
225        converged: true,
226        assignment: config.assignment,
227        top_k: 1,
228    })
229}
230
231fn initialize_atoms(x: ArrayView2<'_, f64>, n_atoms: usize) -> Array2<f64> {
232    let mut atoms = Array2::<f64>::zeros((n_atoms, x.ncols()));
233    let first = max_norm_row(x);
234    atoms.row_mut(0).assign(&x.row(first));
235    normalize_row(atoms.slice_mut(s![0, ..]));
236    let mut min_dist2 = Array1::<f64>::from_elem(x.nrows(), f64::INFINITY);
237
238    for atom_idx in 1..n_atoms {
239        let prev = atoms.row(atom_idx - 1);
240        for row in 0..x.nrows() {
241            let dist2 = squared_distance(x.row(row), prev);
242            if dist2 < min_dist2[row] {
243                min_dist2[row] = dist2;
244            }
245        }
246        let chosen = if atom_idx < x.nrows() {
247            max_index(min_dist2.view())
248        } else {
249            atom_idx % x.nrows()
250        };
251        atoms.row_mut(atom_idx).assign(&x.row(chosen));
252        normalize_row(atoms.slice_mut(s![atom_idx, ..]));
253    }
254    atoms
255}
256
257fn fit_one_atom_penalized_ls(
258    x: ArrayView2<'_, f64>,
259    atoms: &mut Array2<f64>,
260    assignments: &mut Array2<f64>,
261    fitted: &mut Array2<f64>,
262    lambdas: &mut Array1<f64>,
263    reml_scores: &mut Array1<f64>,
264    atom_idx: usize,
265    atom_ridge: f64,
266) -> Result<bool, String> {
267    let code = assignments.column(atom_idx).to_owned();
268    let code_norm2 = code.dot(&code);
269    if code_norm2 <= MIN_NORM2 {
270        // #1500: this atom's cluster is EMPTY (no rows routed to it by the
271        // assignment step). Zeroing it here made the atom permanently DEAD — a
272        // zero atom has zero similarity to every row, so `top_k_assignments`
273        // never routes anything back to it, the dictionary collapses to < K live
274        // atoms, and it under-explains variance even when the data is exactly K
275        // rank-1 atoms a K-atom dictionary could reconstruct perfectly. Instead
276        // RE-SEED the atom into the worst-currently-reconstructed direction (the
277        // standard k-means empty-cluster cure): point it at the largest-residual
278        // row's UNEXPLAINED component so the next assignment sweep can route that
279        // row's cluster to it and revive it. Returns `true` so the outer loop
280        // suppresses convergence this iteration (the revived atom has no code
281        // yet, so EV is momentarily flat — converging now would strand it).
282        let mut worst_row = 0usize;
283        let mut worst_res2 = -1.0_f64;
284        for row in 0..x.nrows() {
285            let mut res2 = 0.0_f64;
286            for col in 0..x.ncols() {
287                let d = x[[row, col]] - fitted[[row, col]];
288                res2 += d * d;
289            }
290            if res2 > worst_res2 {
291                worst_res2 = res2;
292                worst_row = row;
293            }
294        }
295        if worst_res2 <= MIN_NORM2 {
296            // Every row is already fully reconstructed by the other atoms: there
297            // is no unexplained direction to seed, so this atom is genuinely
298            // redundant capacity. Leave it inactive (this is not the bug).
299            atoms.row_mut(atom_idx).fill(0.0);
300            lambdas[atom_idx] = INACTIVE_LAMBDA;
301            reml_scores[atom_idx] = 0.0;
302            return Ok(false);
303        }
304        for col in 0..x.ncols() {
305            atoms[[atom_idx, col]] = x[[worst_row, col]] - fitted[[worst_row, col]];
306        }
307        normalize_row(atoms.slice_mut(s![atom_idx, ..]));
308        lambdas[atom_idx] = atom_ridge;
309        reml_scores[atom_idx] =
310            penalized_reconstruction_loss(x, fitted.view(), atom_ridge, atoms.view());
311        return Ok(true);
312    }
313
314    let old_atom = atoms.row(atom_idx).to_owned();
315    let mut residual = x.to_owned() - fitted.view();
316    residual += &code
317        .view()
318        .insert_axis(Axis(1))
319        .dot(&old_atom.view().insert_axis(Axis(0)));
320
321    let denominator = code_norm2 + atom_ridge;
322    for col in 0..x.ncols() {
323        atoms[[atom_idx, col]] = code.dot(&residual.column(col)) / denominator;
324    }
325    lambdas[atom_idx] = atom_ridge;
326    normalize_atom_and_assignments(atoms, assignments, atom_idx);
327    let updated_code = assignments.column(atom_idx).to_owned();
328    fitted.assign(&x);
329    *fitted -= &residual;
330    *fitted += &updated_code
331        .view()
332        .insert_axis(Axis(1))
333        .dot(&atoms.row(atom_idx).insert_axis(Axis(0)));
334    reml_scores[atom_idx] =
335        penalized_reconstruction_loss(x, fitted.view(), atom_ridge, atoms.view());
336    Ok(false)
337}
338
339fn top_k_assignments(
340    x: ArrayView2<'_, f64>,
341    atoms: ArrayView2<'_, f64>,
342    top_k: usize,
343    code_ridge: f64,
344) -> Result<Array2<f64>, String> {
345    let cross = x.dot(&atoms.t());
346    let mut assignments = Array2::<f64>::zeros((x.nrows(), atoms.nrows()));
347    for row in 0..x.nrows() {
348        let active = top_indices_by_abs(cross.row(row), top_k);
349        let coeffs = solve_active_coefficients(atoms, cross.row(row), &active, code_ridge)?;
350        for pos in 0..active.len() {
351            assignments[[row, active[pos]]] = coeffs[pos];
352        }
353    }
354    Ok(assignments)
355}
356
357fn softmax_assignments(
358    x: ArrayView2<'_, f64>,
359    atoms: ArrayView2<'_, f64>,
360    top_k: usize,
361    temperature: f64,
362    code_ridge: f64,
363) -> Result<Array2<f64>, String> {
364    let cross = x.dot(&atoms.t());
365    let atom_norm2 = atoms.map_axis(Axis(1), |row| row.dot(&row).max(MIN_NORM2));
366    let mut assignments = Array2::<f64>::zeros((x.nrows(), atoms.nrows()));
367    for row in 0..x.nrows() {
368        let active = top_indices_by_abs(cross.row(row), top_k);
369        let mut max_score = f64::NEG_INFINITY;
370        for &atom_idx in &active {
371            let score = cross[[row, atom_idx]].abs() / (atom_norm2[atom_idx].sqrt() * temperature);
372            if score > max_score {
373                max_score = score;
374            }
375        }
376        let mut denom = 0.0;
377        for &atom_idx in &active {
378            let score = cross[[row, atom_idx]].abs() / (atom_norm2[atom_idx].sqrt() * temperature);
379            let mass = (score - max_score).exp();
380            assignments[[row, atom_idx]] = mass;
381            denom += mass;
382        }
383        if denom <= 0.0 || !denom.is_finite() {
384            return Err("linear_dictionary_fit softmax assignment underflowed".to_string());
385        }
386        for &atom_idx in &active {
387            let projection = cross[[row, atom_idx]] / (atom_norm2[atom_idx] + code_ridge);
388            assignments[[row, atom_idx]] = assignments[[row, atom_idx]] * projection / denom;
389        }
390    }
391    Ok(assignments)
392}
393
394fn solve_active_coefficients(
395    atoms: ArrayView2<'_, f64>,
396    cross_row: ArrayView1<'_, f64>,
397    active: &[usize],
398    code_ridge: f64,
399) -> Result<Array1<f64>, String> {
400    let m = active.len();
401    let mut system = Array2::<f64>::zeros((m, m));
402    let mut rhs = Array2::<f64>::zeros((m, 1));
403    for i in 0..m {
404        rhs[[i, 0]] = cross_row[active[i]];
405        for j in 0..m {
406            system[[i, j]] = atoms.row(active[i]).dot(&atoms.row(active[j]));
407        }
408        system[[i, i]] += code_ridge;
409    }
410    let factor = system
411        .cholesky(Side::Lower)
412        .map_err(|err| format!("linear_dictionary_fit sparse-code solve failed: {err}"))?;
413    let mut solution = rhs;
414    factor.solve_mat_in_place(&mut solution);
415    Ok(solution.column(0).to_owned())
416}
417
418fn top_indices_by_abs(row: ArrayView1<'_, f64>, top_k: usize) -> Vec<usize> {
419    let mut selected: Vec<(usize, f64)> = Vec::with_capacity(top_k);
420    for idx in 0..row.len() {
421        let score = row[idx].abs();
422        if selected.len() < top_k {
423            selected.push((idx, score));
424            continue;
425        }
426        let mut worst_pos = 0usize;
427        for pos in 1..selected.len() {
428            if selected[pos].1 < selected[worst_pos].1
429                || (selected[pos].1 == selected[worst_pos].1
430                    && selected[pos].0 > selected[worst_pos].0)
431            {
432                worst_pos = pos;
433            }
434        }
435        let worst = selected[worst_pos];
436        if score > worst.1 || (score == worst.1 && idx < worst.0) {
437            selected[worst_pos] = (idx, score);
438        }
439    }
440    selected.sort_by(|a, b| {
441        b.1.partial_cmp(&a.1)
442            .unwrap_or(std::cmp::Ordering::Equal)
443            .then_with(|| a.0.cmp(&b.0))
444    });
445    selected.into_iter().map(|(idx, _)| idx).collect()
446}
447
448fn normalize_atom_and_assignments(
449    atoms: &mut Array2<f64>,
450    assignments: &mut Array2<f64>,
451    atom_idx: usize,
452) {
453    let norm = atoms.row(atom_idx).dot(&atoms.row(atom_idx)).sqrt();
454    if norm > MIN_NORM2.sqrt() {
455        atoms.row_mut(atom_idx).mapv_inplace(|value| value / norm);
456        assignments
457            .column_mut(atom_idx)
458            .mapv_inplace(|value| value * norm);
459    }
460    orient_atom_and_code(atoms, assignments, atom_idx);
461}
462
463fn orient_atom_and_code(atoms: &mut Array2<f64>, assignments: &mut Array2<f64>, atom_idx: usize) {
464    let sign = first_nonzero_sign(atoms.row(atom_idx));
465    if sign < 0.0 {
466        atoms.row_mut(atom_idx).mapv_inplace(|value| -value);
467        assignments
468            .column_mut(atom_idx)
469            .mapv_inplace(|value| -value);
470    }
471}
472
473fn orient_vector(vector: &mut Array1<f64>) {
474    if first_nonzero_sign(vector.view()) < 0.0 {
475        vector.mapv_inplace(|value| -value);
476    }
477}
478
479fn first_nonzero_sign(row: ndarray::ArrayView1<'_, f64>) -> f64 {
480    for &value in row {
481        if value.abs() > 1.0e-12 {
482            return value.signum();
483        }
484    }
485    1.0
486}
487
488fn normalize_row(mut row: ndarray::ArrayViewMut1<'_, f64>) {
489    let norm = row.dot(&row).sqrt();
490    if norm > MIN_NORM2.sqrt() {
491        row.mapv_inplace(|value| value / norm);
492    }
493}
494
495fn max_norm_row(x: ArrayView2<'_, f64>) -> usize {
496    let mut best = 0usize;
497    let mut best_norm = f64::NEG_INFINITY;
498    for row in 0..x.nrows() {
499        let norm = x.row(row).dot(&x.row(row));
500        if norm > best_norm {
501            best = row;
502            best_norm = norm;
503        }
504    }
505    best
506}
507
508fn max_index(values: ndarray::ArrayView1<'_, f64>) -> usize {
509    let mut best = 0usize;
510    let mut best_value = f64::NEG_INFINITY;
511    for idx in 0..values.len() {
512        if values[idx] > best_value {
513            best = idx;
514            best_value = values[idx];
515        }
516    }
517    best
518}
519
520fn squared_distance(a: ndarray::ArrayView1<'_, f64>, b: ndarray::ArrayView1<'_, f64>) -> f64 {
521    a.iter()
522        .zip(b.iter())
523        .map(|(av, bv)| {
524            let diff = av - bv;
525            diff * diff
526        })
527        .sum()
528}
529
530fn explained_variance(x: ArrayView2<'_, f64>, fitted: ArrayView2<'_, f64>) -> f64 {
531    let mut rss = 0.0;
532    for row in 0..x.nrows() {
533        for col in 0..x.ncols() {
534            let residual = x[[row, col]] - fitted[[row, col]];
535            rss += residual * residual;
536        }
537    }
538    let means = x.mean_axis(Axis(0)).expect("non-empty input has means");
539    let mut tss = 0.0;
540    for row in 0..x.nrows() {
541        for col in 0..x.ncols() {
542            let centered = x[[row, col]] - means[col];
543            tss += centered * centered;
544        }
545    }
546    if tss <= MIN_NORM2 {
547        if rss <= MIN_NORM2 { 1.0 } else { 0.0 }
548    } else {
549        1.0 - rss / tss
550    }
551}
552
553fn penalized_reconstruction_loss(
554    x: ArrayView2<'_, f64>,
555    fitted: ArrayView2<'_, f64>,
556    ridge: f64,
557    atoms: ArrayView2<'_, f64>,
558) -> f64 {
559    let mut loss = 0.0;
560    for row in 0..x.nrows() {
561        for col in 0..x.ncols() {
562            let residual = x[[row, col]] - fitted[[row, col]];
563            loss += residual * residual;
564        }
565    }
566    loss + ridge * atoms.iter().map(|value| value * value).sum::<f64>()
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use approx::assert_abs_diff_eq;
573    use ndarray::{Array2, array};
574
575    #[test]
576    fn planted_sparse_linear_dictionary_reaches_high_explained_variance() {
577        let truth = array![
578            [1.0, 0.0, 0.0, 0.0],
579            [0.0, 1.0, 0.0, 0.0],
580            [0.0, 0.0, 1.0, 0.0],
581            [0.0, 0.0, 0.0, 1.0],
582        ];
583        let mut assignments = Array2::<f64>::zeros((160, 4));
584        for row in 0..160 {
585            let atom = row % 4;
586            assignments[[row, atom]] = 0.7 + 0.01 * ((row / 4) as f64);
587            assignments[[row, (atom + 1) % 4]] = 0.2;
588        }
589        let x = assignments.dot(&truth);
590        let config = LinearDictionaryConfig {
591            n_atoms: 4,
592            max_iter: 40,
593            top_k: 2,
594            assignment: LinearDictionaryAssignment::TopK,
595            temperature: DEFAULT_TEMPERATURE,
596            code_ridge: DEFAULT_CODE_RIDGE,
597            tolerance: 1.0e-9,
598        };
599
600        let fit = fit_linear_dictionary(x.view(), &config).expect("linear dictionary fit");
601
602        assert!(
603            fit.explained_variance > 0.95,
604            "expected EV > 0.95, got {}",
605            fit.explained_variance
606        );
607    }
608
609    #[test]
610    fn single_atom_matches_penalized_pca_oracle() {
611        let mut x = Array2::<f64>::zeros((80, 3));
612        for row in 0..80 {
613            let t = (row as f64 - 39.5) / 20.0;
614            x[[row, 0]] = 2.0 * t;
615            x[[row, 1]] = -t;
616            x[[row, 2]] = 0.05 * (row as f64).sin();
617        }
618        let config = LinearDictionaryConfig {
619            n_atoms: 1,
620            max_iter: 5,
621            top_k: 1,
622            assignment: LinearDictionaryAssignment::TopK,
623            temperature: DEFAULT_TEMPERATURE,
624            code_ridge: DEFAULT_CODE_RIDGE,
625            tolerance: DEFAULT_TOLERANCE,
626        };
627
628        let fit = fit_linear_dictionary(x.view(), &config).expect("rank-one fit");
629        let covariance = x.t().dot(&x);
630        let (evals, _) = covariance.eigh(Side::Lower).expect("PCA eigensolve");
631        let shrink = 1.0 / (1.0 + DEFAULT_CODE_RIDGE);
632        let oracle_ev = 1.0
633            - ((1.0 - shrink) * (1.0 - shrink) * evals[evals.len() - 1]
634                + evals.slice(s![..evals.len() - 1]).sum())
635                / evals.sum();
636
637        assert!(fit.explained_variance > 0.99);
638        assert_abs_diff_eq!(fit.explained_variance, oracle_ev, epsilon = 2.0e-4);
639    }
640
641    #[test]
642    fn orthonormal_rank_one_atoms_all_revived_no_dead_collapse_1500() {
643        // #1500: rows lie on K mutually ORTHONORMAL rank-1 directions, so a
644        // K-atom top_k=1 dictionary that recovers them reconstructs every row
645        // exactly (EV → 1). The dead-atom bug emptied a cluster, zeroed that atom
646        // permanently, and returned < K live atoms with badly under-explained
647        // variance. With empty-cluster re-seeding every atom stays live.
648        let (k, p, n) = (4usize, 8usize, 400usize);
649        // Deterministic orthonormal directions: eigenvectors of a fixed symmetric
650        // matrix are orthonormal, so no RNG is needed for a stable regression.
651        let mut a = Array2::<f64>::zeros((p, p));
652        for i in 0..p {
653            for j in 0..p {
654                a[[i, j]] = ((i * 7 + j * 3 + 1) % 11) as f64 - 5.0;
655            }
656        }
657        let sym = &a + &a.t();
658        let (_evals, evecs) = sym.eigh(Side::Lower).expect("orthonormal directions");
659        let dirs = evecs.slice(s![.., ..k]).t().to_owned(); // k×p, orthonormal rows
660        let mut x = Array2::<f64>::zeros((n, p));
661        for row in 0..n {
662            let atom = row % k;
663            let scale = if row % 2 == 0 { 2.0 } else { -1.5 } + 0.01 * (row / k) as f64;
664            for col in 0..p {
665                let noise = 1.0e-3 * (((row * p + col) % 13) as f64 - 6.0);
666                x[[row, col]] = scale * dirs[[atom, col]] + noise;
667            }
668        }
669        let config = LinearDictionaryConfig {
670            n_atoms: k,
671            max_iter: 40,
672            top_k: 1,
673            assignment: LinearDictionaryAssignment::TopK,
674            temperature: DEFAULT_TEMPERATURE,
675            code_ridge: DEFAULT_CODE_RIDGE,
676            tolerance: 1.0e-9,
677        };
678        let fit = fit_linear_dictionary(x.view(), &config).expect("orthonormal dictionary fit");
679        let live = fit
680            .atoms
681            .axis_iter(Axis(0))
682            .filter(|atom| atom.iter().any(|value| value.abs() > 1.0e-12))
683            .count();
684        assert_eq!(
685            live, k,
686            "all {k} atoms must stay live (no dead-atom collapse); got {live} live"
687        );
688        assert!(
689            fit.explained_variance > 0.99,
690            "K orthonormal rank-1 atoms must be reconstructed at EV > 0.99; got {}",
691            fit.explained_variance
692        );
693    }
694
695    #[test]
696    fn sparse_assignment_scales_to_thousand_atom_dictionary() {
697        let active_atoms = array![
698            [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
699            [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
700            [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
701            [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
702            [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
703            [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
704            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
705            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
706        ];
707        let mut x = Array2::<f64>::zeros((256, 8));
708        for row in 0..x.nrows() {
709            let atom = row % active_atoms.nrows();
710            let scale = 0.7 + 0.003 * row as f64;
711            x.row_mut(row).assign(&(&active_atoms.row(atom) * scale));
712        }
713        let config = LinearDictionaryConfig {
714            n_atoms: 1024,
715            max_iter: 8,
716            top_k: 1,
717            assignment: LinearDictionaryAssignment::TopK,
718            temperature: DEFAULT_TEMPERATURE,
719            code_ridge: DEFAULT_CODE_RIDGE,
720            tolerance: 1.0e-9,
721        };
722
723        let fit = fit_linear_dictionary(x.view(), &config).expect("large-K linear dictionary fit");
724        let max_active = fit
725            .assignments
726            .axis_iter(Axis(0))
727            .map(|row| row.iter().filter(|value| value.abs() > 1.0e-10).count())
728            .max()
729            .unwrap();
730
731        assert_eq!(max_active, 1);
732        assert!(
733            fit.explained_variance > 0.95,
734            "expected EV > 0.95 at K=1024, got {}",
735            fit.explained_variance
736        );
737    }
738}