1use faer::Side;
2use gam_linalg::faer_ndarray::{FaerCholesky, FaerEigh};
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, s};
4
5const DEFAULT_MAX_ITER: usize = 30;
6const DEFAULT_TOP_K: usize = 1;
7const DEFAULT_TEMPERATURE: f64 = 0.25;
8const DEFAULT_CODE_RIDGE: f64 = 1.0e-8;
9const DEFAULT_TOLERANCE: f64 = 1.0e-7;
10const INACTIVE_LAMBDA: f64 = 1.0e30;
11const MIN_NORM2: f64 = 1.0e-24;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub enum LinearDictionaryAssignment {
15 TopK,
16 Softmax,
17}
18
19impl LinearDictionaryAssignment {
20 pub fn parse(value: &str) -> Result<Self, String> {
21 match value.trim().to_ascii_lowercase().as_str() {
22 "top_k" | "topk" | "hard" => Ok(Self::TopK),
23 "softmax" | "soft" => Ok(Self::Softmax),
24 other => Err(format!(
25 "linear dictionary assignment must be 'top_k' or 'softmax'; got {other:?}"
26 )),
27 }
28 }
29
30 pub const fn as_str(self) -> &'static str {
31 match self {
32 Self::TopK => "top_k",
33 Self::Softmax => "softmax",
34 }
35 }
36}
37
38#[derive(Clone, Debug)]
39pub struct LinearDictionaryConfig {
40 pub n_atoms: usize,
41 pub max_iter: usize,
42 pub top_k: usize,
43 pub assignment: LinearDictionaryAssignment,
44 pub temperature: f64,
45 pub code_ridge: f64,
46 pub tolerance: f64,
47}
48
49impl LinearDictionaryConfig {
50 pub fn new(n_atoms: usize) -> Self {
51 Self {
52 n_atoms,
53 ..Self::default()
54 }
55 }
56}
57
58impl Default for LinearDictionaryConfig {
59 fn default() -> Self {
60 Self {
61 n_atoms: 1,
62 max_iter: DEFAULT_MAX_ITER,
63 top_k: DEFAULT_TOP_K,
64 assignment: LinearDictionaryAssignment::TopK,
65 temperature: DEFAULT_TEMPERATURE,
66 code_ridge: DEFAULT_CODE_RIDGE,
67 tolerance: DEFAULT_TOLERANCE,
68 }
69 }
70}
71
72#[derive(Clone, Debug)]
73pub struct LinearDictionaryFit {
74 pub atoms: Array2<f64>,
75 pub assignments: Array2<f64>,
76 pub fitted: Array2<f64>,
77 pub lambdas: Array1<f64>,
78 pub reml_scores: Array1<f64>,
79 pub explained_variance: f64,
80 pub iterations: usize,
81 pub converged: bool,
82 pub assignment: LinearDictionaryAssignment,
83 pub top_k: usize,
84}
85
86pub fn fit_linear_dictionary(
87 x: ArrayView2<'_, f64>,
88 config: &LinearDictionaryConfig,
89) -> Result<LinearDictionaryFit, String> {
90 validate_inputs(x, config)?;
91 if config.n_atoms == 1 {
92 return fit_rank_one_pca_lane(x, config);
93 }
94
95 let top_k = config.top_k.min(config.n_atoms).max(1);
96 let mut atoms = initialize_atoms(x, config.n_atoms);
97 let mut assignments = Array2::<f64>::zeros((x.nrows(), config.n_atoms));
98 let mut fitted = Array2::<f64>::zeros(x.dim());
99 let mut lambdas = Array1::<f64>::from_elem(config.n_atoms, INACTIVE_LAMBDA);
100 let mut reml_scores = Array1::<f64>::zeros(config.n_atoms);
101 let mut previous_ev = f64::NEG_INFINITY;
102 let mut converged = false;
103 let mut completed_iterations = 0usize;
104
105 for iter in 0..config.max_iter {
106 assignments = match config.assignment {
107 LinearDictionaryAssignment::TopK => {
108 top_k_assignments(x, atoms.view(), top_k, config.code_ridge)?
109 }
110 LinearDictionaryAssignment::Softmax => softmax_assignments(
111 x,
112 atoms.view(),
113 top_k,
114 config.temperature,
115 config.code_ridge,
116 )?,
117 };
118
119 fitted = assignments.dot(&atoms);
120 let mut any_reseeded = false;
121 for atom_idx in 0..config.n_atoms {
122 any_reseeded |= fit_one_atom_penalized_ls(
123 x,
124 &mut atoms,
125 &mut assignments,
126 &mut fitted,
127 &mut lambdas,
128 &mut reml_scores,
129 atom_idx,
130 config.code_ridge,
131 )?;
132 }
133
134 completed_iterations = iter + 1;
135 let ev = explained_variance(x, fitted.view());
136 if !any_reseeded && (ev - previous_ev).abs() <= config.tolerance.max(0.0) {
140 converged = true;
141 break;
142 }
143 previous_ev = ev;
144 }
145
146 let final_ev = explained_variance(x, fitted.view());
147 Ok(LinearDictionaryFit {
148 atoms,
149 assignments,
150 fitted,
151 lambdas,
152 reml_scores,
153 explained_variance: final_ev,
154 iterations: completed_iterations,
155 converged,
156 assignment: config.assignment,
157 top_k,
158 })
159}
160
161fn validate_inputs(x: ArrayView2<'_, f64>, config: &LinearDictionaryConfig) -> Result<(), String> {
162 if x.nrows() == 0 || x.ncols() == 0 {
163 return Err("linear_dictionary_fit requires a non-empty 2-D matrix".to_string());
164 }
165 if !x.iter().all(|value| value.is_finite()) {
166 return Err("linear_dictionary_fit input must be finite".to_string());
167 }
168 if config.n_atoms == 0 {
169 return Err("linear_dictionary_fit requires K >= 1".to_string());
170 }
171 if config.max_iter == 0 {
172 return Err("linear_dictionary_fit requires max_iter >= 1".to_string());
173 }
174 if config.top_k == 0 || config.top_k > config.n_atoms {
175 return Err(format!(
176 "linear_dictionary_fit top_k must be in [1, K={}]; got {}",
177 config.n_atoms, config.top_k
178 ));
179 }
180 if !(config.temperature.is_finite() && config.temperature > 0.0) {
181 return Err(format!(
182 "linear_dictionary_fit temperature must be finite and positive; got {}",
183 config.temperature
184 ));
185 }
186 if !(config.code_ridge.is_finite() && config.code_ridge > 0.0) {
187 return Err(format!(
188 "linear_dictionary_fit code_ridge must be finite and positive; got {}",
189 config.code_ridge
190 ));
191 }
192 if !config.tolerance.is_finite() {
193 return Err("linear_dictionary_fit tolerance must be finite".to_string());
194 }
195 Ok(())
196}
197
198fn fit_rank_one_pca_lane(
199 x: ArrayView2<'_, f64>,
200 config: &LinearDictionaryConfig,
201) -> Result<LinearDictionaryFit, String> {
202 let covariance = x.t().dot(&x);
203 let (evals, evecs) = covariance
204 .eigh(Side::Lower)
205 .map_err(|err| format!("linear_dictionary_fit PCA eigensolve failed: {err}"))?;
206 let last = evals.len() - 1;
207 let mut atom = evecs.column(last).to_owned();
208 orient_vector(&mut atom);
209 let mut assignments = Array2::<f64>::zeros((x.nrows(), 1));
210 for row in 0..x.nrows() {
211 assignments[[row, 0]] = x.row(row).dot(&atom) / (1.0 + config.code_ridge);
212 }
213 let mut atoms = atom.insert_axis(Axis(0)).to_owned();
214 normalize_atom_and_assignments(&mut atoms, &mut assignments, 0);
215 let fitted = assignments.dot(&atoms);
216 let score = penalized_reconstruction_loss(x, fitted.view(), config.code_ridge, atoms.view());
217 Ok(LinearDictionaryFit {
218 atoms,
219 assignments,
220 fitted: fitted.clone(),
221 lambdas: Array1::from_elem(1, config.code_ridge),
222 reml_scores: Array1::from_elem(1, score),
223 explained_variance: explained_variance(x, fitted.view()),
224 iterations: 1.min(config.max_iter),
225 converged: true,
226 assignment: config.assignment,
227 top_k: 1,
228 })
229}
230
231fn initialize_atoms(x: ArrayView2<'_, f64>, n_atoms: usize) -> Array2<f64> {
232 let mut atoms = Array2::<f64>::zeros((n_atoms, x.ncols()));
233 let first = max_norm_row(x);
234 atoms.row_mut(0).assign(&x.row(first));
235 normalize_row(atoms.slice_mut(s![0, ..]));
236 let mut min_dist2 = Array1::<f64>::from_elem(x.nrows(), f64::INFINITY);
237
238 for atom_idx in 1..n_atoms {
239 let prev = atoms.row(atom_idx - 1);
240 for row in 0..x.nrows() {
241 let dist2 = squared_distance(x.row(row), prev);
242 if dist2 < min_dist2[row] {
243 min_dist2[row] = dist2;
244 }
245 }
246 let chosen = if atom_idx < x.nrows() {
247 max_index(min_dist2.view())
248 } else {
249 atom_idx % x.nrows()
250 };
251 atoms.row_mut(atom_idx).assign(&x.row(chosen));
252 normalize_row(atoms.slice_mut(s![atom_idx, ..]));
253 }
254 atoms
255}
256
257fn fit_one_atom_penalized_ls(
258 x: ArrayView2<'_, f64>,
259 atoms: &mut Array2<f64>,
260 assignments: &mut Array2<f64>,
261 fitted: &mut Array2<f64>,
262 lambdas: &mut Array1<f64>,
263 reml_scores: &mut Array1<f64>,
264 atom_idx: usize,
265 atom_ridge: f64,
266) -> Result<bool, String> {
267 let code = assignments.column(atom_idx).to_owned();
268 let code_norm2 = code.dot(&code);
269 if code_norm2 <= MIN_NORM2 {
270 let mut worst_row = 0usize;
283 let mut worst_res2 = -1.0_f64;
284 for row in 0..x.nrows() {
285 let mut res2 = 0.0_f64;
286 for col in 0..x.ncols() {
287 let d = x[[row, col]] - fitted[[row, col]];
288 res2 += d * d;
289 }
290 if res2 > worst_res2 {
291 worst_res2 = res2;
292 worst_row = row;
293 }
294 }
295 if worst_res2 <= MIN_NORM2 {
296 atoms.row_mut(atom_idx).fill(0.0);
300 lambdas[atom_idx] = INACTIVE_LAMBDA;
301 reml_scores[atom_idx] = 0.0;
302 return Ok(false);
303 }
304 for col in 0..x.ncols() {
305 atoms[[atom_idx, col]] = x[[worst_row, col]] - fitted[[worst_row, col]];
306 }
307 normalize_row(atoms.slice_mut(s![atom_idx, ..]));
308 lambdas[atom_idx] = atom_ridge;
309 reml_scores[atom_idx] =
310 penalized_reconstruction_loss(x, fitted.view(), atom_ridge, atoms.view());
311 return Ok(true);
312 }
313
314 let old_atom = atoms.row(atom_idx).to_owned();
315 let mut residual = x.to_owned() - fitted.view();
316 residual += &code
317 .view()
318 .insert_axis(Axis(1))
319 .dot(&old_atom.view().insert_axis(Axis(0)));
320
321 let denominator = code_norm2 + atom_ridge;
322 for col in 0..x.ncols() {
323 atoms[[atom_idx, col]] = code.dot(&residual.column(col)) / denominator;
324 }
325 lambdas[atom_idx] = atom_ridge;
326 normalize_atom_and_assignments(atoms, assignments, atom_idx);
327 let updated_code = assignments.column(atom_idx).to_owned();
328 fitted.assign(&x);
329 *fitted -= &residual;
330 *fitted += &updated_code
331 .view()
332 .insert_axis(Axis(1))
333 .dot(&atoms.row(atom_idx).insert_axis(Axis(0)));
334 reml_scores[atom_idx] =
335 penalized_reconstruction_loss(x, fitted.view(), atom_ridge, atoms.view());
336 Ok(false)
337}
338
339fn top_k_assignments(
340 x: ArrayView2<'_, f64>,
341 atoms: ArrayView2<'_, f64>,
342 top_k: usize,
343 code_ridge: f64,
344) -> Result<Array2<f64>, String> {
345 let cross = x.dot(&atoms.t());
346 let mut assignments = Array2::<f64>::zeros((x.nrows(), atoms.nrows()));
347 for row in 0..x.nrows() {
348 let active = top_indices_by_abs(cross.row(row), top_k);
349 let coeffs = solve_active_coefficients(atoms, cross.row(row), &active, code_ridge)?;
350 for pos in 0..active.len() {
351 assignments[[row, active[pos]]] = coeffs[pos];
352 }
353 }
354 Ok(assignments)
355}
356
357fn softmax_assignments(
358 x: ArrayView2<'_, f64>,
359 atoms: ArrayView2<'_, f64>,
360 top_k: usize,
361 temperature: f64,
362 code_ridge: f64,
363) -> Result<Array2<f64>, String> {
364 let cross = x.dot(&atoms.t());
365 let atom_norm2 = atoms.map_axis(Axis(1), |row| row.dot(&row).max(MIN_NORM2));
366 let mut assignments = Array2::<f64>::zeros((x.nrows(), atoms.nrows()));
367 for row in 0..x.nrows() {
368 let active = top_indices_by_abs(cross.row(row), top_k);
369 let mut max_score = f64::NEG_INFINITY;
370 for &atom_idx in &active {
371 let score = cross[[row, atom_idx]].abs() / (atom_norm2[atom_idx].sqrt() * temperature);
372 if score > max_score {
373 max_score = score;
374 }
375 }
376 let mut denom = 0.0;
377 for &atom_idx in &active {
378 let score = cross[[row, atom_idx]].abs() / (atom_norm2[atom_idx].sqrt() * temperature);
379 let mass = (score - max_score).exp();
380 assignments[[row, atom_idx]] = mass;
381 denom += mass;
382 }
383 if denom <= 0.0 || !denom.is_finite() {
384 return Err("linear_dictionary_fit softmax assignment underflowed".to_string());
385 }
386 for &atom_idx in &active {
387 let projection = cross[[row, atom_idx]] / (atom_norm2[atom_idx] + code_ridge);
388 assignments[[row, atom_idx]] = assignments[[row, atom_idx]] * projection / denom;
389 }
390 }
391 Ok(assignments)
392}
393
394fn solve_active_coefficients(
395 atoms: ArrayView2<'_, f64>,
396 cross_row: ArrayView1<'_, f64>,
397 active: &[usize],
398 code_ridge: f64,
399) -> Result<Array1<f64>, String> {
400 let m = active.len();
401 let mut system = Array2::<f64>::zeros((m, m));
402 let mut rhs = Array2::<f64>::zeros((m, 1));
403 for i in 0..m {
404 rhs[[i, 0]] = cross_row[active[i]];
405 for j in 0..m {
406 system[[i, j]] = atoms.row(active[i]).dot(&atoms.row(active[j]));
407 }
408 system[[i, i]] += code_ridge;
409 }
410 let factor = system
411 .cholesky(Side::Lower)
412 .map_err(|err| format!("linear_dictionary_fit sparse-code solve failed: {err}"))?;
413 let mut solution = rhs;
414 factor.solve_mat_in_place(&mut solution);
415 Ok(solution.column(0).to_owned())
416}
417
418fn top_indices_by_abs(row: ArrayView1<'_, f64>, top_k: usize) -> Vec<usize> {
419 let mut selected: Vec<(usize, f64)> = Vec::with_capacity(top_k);
420 for idx in 0..row.len() {
421 let score = row[idx].abs();
422 if selected.len() < top_k {
423 selected.push((idx, score));
424 continue;
425 }
426 let mut worst_pos = 0usize;
427 for pos in 1..selected.len() {
428 if selected[pos].1 < selected[worst_pos].1
429 || (selected[pos].1 == selected[worst_pos].1
430 && selected[pos].0 > selected[worst_pos].0)
431 {
432 worst_pos = pos;
433 }
434 }
435 let worst = selected[worst_pos];
436 if score > worst.1 || (score == worst.1 && idx < worst.0) {
437 selected[worst_pos] = (idx, score);
438 }
439 }
440 selected.sort_by(|a, b| {
441 b.1.partial_cmp(&a.1)
442 .unwrap_or(std::cmp::Ordering::Equal)
443 .then_with(|| a.0.cmp(&b.0))
444 });
445 selected.into_iter().map(|(idx, _)| idx).collect()
446}
447
448fn normalize_atom_and_assignments(
449 atoms: &mut Array2<f64>,
450 assignments: &mut Array2<f64>,
451 atom_idx: usize,
452) {
453 let norm = atoms.row(atom_idx).dot(&atoms.row(atom_idx)).sqrt();
454 if norm > MIN_NORM2.sqrt() {
455 atoms.row_mut(atom_idx).mapv_inplace(|value| value / norm);
456 assignments
457 .column_mut(atom_idx)
458 .mapv_inplace(|value| value * norm);
459 }
460 orient_atom_and_code(atoms, assignments, atom_idx);
461}
462
463fn orient_atom_and_code(atoms: &mut Array2<f64>, assignments: &mut Array2<f64>, atom_idx: usize) {
464 let sign = first_nonzero_sign(atoms.row(atom_idx));
465 if sign < 0.0 {
466 atoms.row_mut(atom_idx).mapv_inplace(|value| -value);
467 assignments
468 .column_mut(atom_idx)
469 .mapv_inplace(|value| -value);
470 }
471}
472
473fn orient_vector(vector: &mut Array1<f64>) {
474 if first_nonzero_sign(vector.view()) < 0.0 {
475 vector.mapv_inplace(|value| -value);
476 }
477}
478
479fn first_nonzero_sign(row: ndarray::ArrayView1<'_, f64>) -> f64 {
480 for &value in row {
481 if value.abs() > 1.0e-12 {
482 return value.signum();
483 }
484 }
485 1.0
486}
487
488fn normalize_row(mut row: ndarray::ArrayViewMut1<'_, f64>) {
489 let norm = row.dot(&row).sqrt();
490 if norm > MIN_NORM2.sqrt() {
491 row.mapv_inplace(|value| value / norm);
492 }
493}
494
495fn max_norm_row(x: ArrayView2<'_, f64>) -> usize {
496 let mut best = 0usize;
497 let mut best_norm = f64::NEG_INFINITY;
498 for row in 0..x.nrows() {
499 let norm = x.row(row).dot(&x.row(row));
500 if norm > best_norm {
501 best = row;
502 best_norm = norm;
503 }
504 }
505 best
506}
507
508fn max_index(values: ndarray::ArrayView1<'_, f64>) -> usize {
509 let mut best = 0usize;
510 let mut best_value = f64::NEG_INFINITY;
511 for idx in 0..values.len() {
512 if values[idx] > best_value {
513 best = idx;
514 best_value = values[idx];
515 }
516 }
517 best
518}
519
520fn squared_distance(a: ndarray::ArrayView1<'_, f64>, b: ndarray::ArrayView1<'_, f64>) -> f64 {
521 a.iter()
522 .zip(b.iter())
523 .map(|(av, bv)| {
524 let diff = av - bv;
525 diff * diff
526 })
527 .sum()
528}
529
530fn explained_variance(x: ArrayView2<'_, f64>, fitted: ArrayView2<'_, f64>) -> f64 {
531 let mut rss = 0.0;
532 for row in 0..x.nrows() {
533 for col in 0..x.ncols() {
534 let residual = x[[row, col]] - fitted[[row, col]];
535 rss += residual * residual;
536 }
537 }
538 let means = x.mean_axis(Axis(0)).expect("non-empty input has means");
539 let mut tss = 0.0;
540 for row in 0..x.nrows() {
541 for col in 0..x.ncols() {
542 let centered = x[[row, col]] - means[col];
543 tss += centered * centered;
544 }
545 }
546 if tss <= MIN_NORM2 {
547 if rss <= MIN_NORM2 { 1.0 } else { 0.0 }
548 } else {
549 1.0 - rss / tss
550 }
551}
552
553fn penalized_reconstruction_loss(
554 x: ArrayView2<'_, f64>,
555 fitted: ArrayView2<'_, f64>,
556 ridge: f64,
557 atoms: ArrayView2<'_, f64>,
558) -> f64 {
559 let mut loss = 0.0;
560 for row in 0..x.nrows() {
561 for col in 0..x.ncols() {
562 let residual = x[[row, col]] - fitted[[row, col]];
563 loss += residual * residual;
564 }
565 }
566 loss + ridge * atoms.iter().map(|value| value * value).sum::<f64>()
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use approx::assert_abs_diff_eq;
573 use ndarray::{Array2, array};
574
575 #[test]
576 fn planted_sparse_linear_dictionary_reaches_high_explained_variance() {
577 let truth = array![
578 [1.0, 0.0, 0.0, 0.0],
579 [0.0, 1.0, 0.0, 0.0],
580 [0.0, 0.0, 1.0, 0.0],
581 [0.0, 0.0, 0.0, 1.0],
582 ];
583 let mut assignments = Array2::<f64>::zeros((160, 4));
584 for row in 0..160 {
585 let atom = row % 4;
586 assignments[[row, atom]] = 0.7 + 0.01 * ((row / 4) as f64);
587 assignments[[row, (atom + 1) % 4]] = 0.2;
588 }
589 let x = assignments.dot(&truth);
590 let config = LinearDictionaryConfig {
591 n_atoms: 4,
592 max_iter: 40,
593 top_k: 2,
594 assignment: LinearDictionaryAssignment::TopK,
595 temperature: DEFAULT_TEMPERATURE,
596 code_ridge: DEFAULT_CODE_RIDGE,
597 tolerance: 1.0e-9,
598 };
599
600 let fit = fit_linear_dictionary(x.view(), &config).expect("linear dictionary fit");
601
602 assert!(
603 fit.explained_variance > 0.95,
604 "expected EV > 0.95, got {}",
605 fit.explained_variance
606 );
607 }
608
609 #[test]
610 fn single_atom_matches_penalized_pca_oracle() {
611 let mut x = Array2::<f64>::zeros((80, 3));
612 for row in 0..80 {
613 let t = (row as f64 - 39.5) / 20.0;
614 x[[row, 0]] = 2.0 * t;
615 x[[row, 1]] = -t;
616 x[[row, 2]] = 0.05 * (row as f64).sin();
617 }
618 let config = LinearDictionaryConfig {
619 n_atoms: 1,
620 max_iter: 5,
621 top_k: 1,
622 assignment: LinearDictionaryAssignment::TopK,
623 temperature: DEFAULT_TEMPERATURE,
624 code_ridge: DEFAULT_CODE_RIDGE,
625 tolerance: DEFAULT_TOLERANCE,
626 };
627
628 let fit = fit_linear_dictionary(x.view(), &config).expect("rank-one fit");
629 let covariance = x.t().dot(&x);
630 let (evals, _) = covariance.eigh(Side::Lower).expect("PCA eigensolve");
631 let shrink = 1.0 / (1.0 + DEFAULT_CODE_RIDGE);
632 let oracle_ev = 1.0
633 - ((1.0 - shrink) * (1.0 - shrink) * evals[evals.len() - 1]
634 + evals.slice(s![..evals.len() - 1]).sum())
635 / evals.sum();
636
637 assert!(fit.explained_variance > 0.99);
638 assert_abs_diff_eq!(fit.explained_variance, oracle_ev, epsilon = 2.0e-4);
639 }
640
641 #[test]
642 fn orthonormal_rank_one_atoms_all_revived_no_dead_collapse_1500() {
643 let (k, p, n) = (4usize, 8usize, 400usize);
649 let mut a = Array2::<f64>::zeros((p, p));
652 for i in 0..p {
653 for j in 0..p {
654 a[[i, j]] = ((i * 7 + j * 3 + 1) % 11) as f64 - 5.0;
655 }
656 }
657 let sym = &a + &a.t();
658 let (_evals, evecs) = sym.eigh(Side::Lower).expect("orthonormal directions");
659 let dirs = evecs.slice(s![.., ..k]).t().to_owned(); let mut x = Array2::<f64>::zeros((n, p));
661 for row in 0..n {
662 let atom = row % k;
663 let scale = if row % 2 == 0 { 2.0 } else { -1.5 } + 0.01 * (row / k) as f64;
664 for col in 0..p {
665 let noise = 1.0e-3 * (((row * p + col) % 13) as f64 - 6.0);
666 x[[row, col]] = scale * dirs[[atom, col]] + noise;
667 }
668 }
669 let config = LinearDictionaryConfig {
670 n_atoms: k,
671 max_iter: 40,
672 top_k: 1,
673 assignment: LinearDictionaryAssignment::TopK,
674 temperature: DEFAULT_TEMPERATURE,
675 code_ridge: DEFAULT_CODE_RIDGE,
676 tolerance: 1.0e-9,
677 };
678 let fit = fit_linear_dictionary(x.view(), &config).expect("orthonormal dictionary fit");
679 let live = fit
680 .atoms
681 .axis_iter(Axis(0))
682 .filter(|atom| atom.iter().any(|value| value.abs() > 1.0e-12))
683 .count();
684 assert_eq!(
685 live, k,
686 "all {k} atoms must stay live (no dead-atom collapse); got {live} live"
687 );
688 assert!(
689 fit.explained_variance > 0.99,
690 "K orthonormal rank-1 atoms must be reconstructed at EV > 0.99; got {}",
691 fit.explained_variance
692 );
693 }
694
695 #[test]
696 fn sparse_assignment_scales_to_thousand_atom_dictionary() {
697 let active_atoms = array![
698 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
699 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
700 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
701 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
702 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
703 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
704 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
705 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
706 ];
707 let mut x = Array2::<f64>::zeros((256, 8));
708 for row in 0..x.nrows() {
709 let atom = row % active_atoms.nrows();
710 let scale = 0.7 + 0.003 * row as f64;
711 x.row_mut(row).assign(&(&active_atoms.row(atom) * scale));
712 }
713 let config = LinearDictionaryConfig {
714 n_atoms: 1024,
715 max_iter: 8,
716 top_k: 1,
717 assignment: LinearDictionaryAssignment::TopK,
718 temperature: DEFAULT_TEMPERATURE,
719 code_ridge: DEFAULT_CODE_RIDGE,
720 tolerance: 1.0e-9,
721 };
722
723 let fit = fit_linear_dictionary(x.view(), &config).expect("large-K linear dictionary fit");
724 let max_active = fit
725 .assignments
726 .axis_iter(Axis(0))
727 .map(|row| row.iter().filter(|value| value.abs() > 1.0e-10).count())
728 .max()
729 .unwrap();
730
731 assert_eq!(max_active, 1);
732 assert!(
733 fit.explained_variance > 0.95,
734 "expected EV > 0.95 at K=1024, got {}",
735 fit.explained_variance
736 );
737 }
738}