1use faer::Side;
2use gam_linalg::faer_ndarray::{FaerCholesky, FaerEigh};
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, s};
4
5const DEFAULT_MAX_ITER: usize = 30;
6const DEFAULT_TOP_K: usize = 1;
7const DEFAULT_TEMPERATURE: f64 = 0.25;
8const DEFAULT_CODE_RIDGE: f64 = 1.0e-8;
9const DEFAULT_TOLERANCE: f64 = 1.0e-7;
10const INACTIVE_LAMBDA: f64 = 1.0e30;
11const MIN_NORM2: f64 = 1.0e-24;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub enum LinearDictionaryAssignment {
15 TopK,
16 Softmax,
17}
18
19impl LinearDictionaryAssignment {
20 pub fn parse(value: &str) -> Result<Self, String> {
21 match value.trim().to_ascii_lowercase().as_str() {
22 "top_k" | "topk" | "hard" => Ok(Self::TopK),
23 "softmax" | "soft" => Ok(Self::Softmax),
24 other => Err(format!(
25 "linear dictionary assignment must be 'top_k' or 'softmax'; got {other:?}"
26 )),
27 }
28 }
29
30 pub const fn as_str(self) -> &'static str {
31 match self {
32 Self::TopK => "top_k",
33 Self::Softmax => "softmax",
34 }
35 }
36}
37
38#[derive(Clone, Debug)]
39pub struct LinearDictionaryConfig {
40 pub n_atoms: usize,
41 pub max_iter: usize,
42 pub top_k: usize,
43 pub assignment: LinearDictionaryAssignment,
44 pub temperature: f64,
45 pub code_ridge: f64,
46 pub tolerance: f64,
47 pub center_rank_one: bool,
58}
59
60impl LinearDictionaryConfig {
61 pub fn new(n_atoms: usize) -> Self {
62 Self {
63 n_atoms,
64 ..Self::default()
65 }
66 }
67}
68
69impl Default for LinearDictionaryConfig {
70 fn default() -> Self {
71 Self {
72 n_atoms: 1,
73 max_iter: DEFAULT_MAX_ITER,
74 top_k: DEFAULT_TOP_K,
75 assignment: LinearDictionaryAssignment::TopK,
76 temperature: DEFAULT_TEMPERATURE,
77 code_ridge: DEFAULT_CODE_RIDGE,
78 tolerance: DEFAULT_TOLERANCE,
79 center_rank_one: false,
80 }
81 }
82}
83
84#[derive(Clone, Debug)]
85pub struct LinearDictionaryFit {
86 pub atoms: Array2<f64>,
87 pub assignments: Array2<f64>,
88 pub fitted: Array2<f64>,
89 pub lambdas: Array1<f64>,
90 pub reml_scores: Array1<f64>,
91 pub explained_variance: f64,
92 pub iterations: usize,
93 pub converged: bool,
94 pub assignment: LinearDictionaryAssignment,
95 pub top_k: usize,
96}
97
98pub fn fit_linear_dictionary(
115 x: ArrayView2<'_, f64>,
116 config: &LinearDictionaryConfig,
117) -> Result<LinearDictionaryFit, String> {
118 validate_inputs(x, config)?;
119 if config.n_atoms == 1 {
120 return fit_rank_one_pca_lane(x, config);
121 }
122 Ok(fit_multi_atom_dictionary(x, config)?.fit)
123}
124
125struct MultiAtomDictionaryFit {
131 fit: LinearDictionaryFit,
132 #[cfg_attr(not(test), allow(dead_code))]
135 pre_reroute_ev: f64,
136}
137
138fn fit_multi_atom_dictionary(
139 x: ArrayView2<'_, f64>,
140 config: &LinearDictionaryConfig,
141) -> Result<MultiAtomDictionaryFit, String> {
142 let top_k = config.top_k.min(config.n_atoms).max(1);
143 let mut atoms = initialize_atoms(x, config.n_atoms);
144 let mut assignments = Array2::<f64>::zeros((x.nrows(), config.n_atoms));
145 let mut fitted = Array2::<f64>::zeros(x.dim());
146 let mut lambdas = Array1::<f64>::from_elem(config.n_atoms, INACTIVE_LAMBDA);
147 let mut reml_scores = Array1::<f64>::zeros(config.n_atoms);
148 let mut previous_ev = f64::NEG_INFINITY;
149 let mut converged = false;
150 let mut completed_iterations = 0usize;
151
152 for iter in 0..config.max_iter {
153 assignments = reroute_against_atoms(x, atoms.view(), top_k, config)?;
154
155 fitted = assignments.dot(&atoms);
156 let mut any_reseeded = false;
157 for atom_idx in 0..config.n_atoms {
158 any_reseeded |= fit_one_atom_penalized_ls(
159 x,
160 &mut atoms,
161 &mut assignments,
162 &mut fitted,
163 &mut lambdas,
164 &mut reml_scores,
165 atom_idx,
166 config.code_ridge,
167 )?;
168 }
169
170 completed_iterations = iter + 1;
171 let ev = explained_variance(x, fitted.view());
172 if !any_reseeded && (ev - previous_ev).abs() <= config.tolerance.max(0.0) {
176 converged = true;
177 break;
178 }
179 previous_ev = ev;
180 }
181
182 let pre_reroute_ev = explained_variance(x, fitted.view());
188 let rerouted = reroute_against_atoms(x, atoms.view(), top_k, config)?;
189 let rerouted_fitted = rerouted.dot(&atoms);
190 let rerouted_ev = explained_variance(x, rerouted_fitted.view());
191 let (assignments, fitted, final_ev) = if rerouted_ev >= pre_reroute_ev {
192 (rerouted, rerouted_fitted, rerouted_ev)
193 } else {
194 (assignments, fitted, pre_reroute_ev)
195 };
196
197 Ok(MultiAtomDictionaryFit {
198 fit: LinearDictionaryFit {
199 atoms,
200 assignments,
201 fitted,
202 lambdas,
203 reml_scores,
204 explained_variance: final_ev,
205 iterations: completed_iterations,
206 converged,
207 assignment: config.assignment,
208 top_k,
209 },
210 pre_reroute_ev,
211 })
212}
213
214fn reroute_against_atoms(
220 x: ArrayView2<'_, f64>,
221 atoms: ArrayView2<'_, f64>,
222 top_k: usize,
223 config: &LinearDictionaryConfig,
224) -> Result<Array2<f64>, String> {
225 match config.assignment {
226 LinearDictionaryAssignment::TopK => top_k_assignments(x, atoms, top_k, config.code_ridge),
227 LinearDictionaryAssignment::Softmax => {
228 softmax_assignments(x, atoms, top_k, config.temperature, config.code_ridge)
229 }
230 }
231}
232
233fn validate_inputs(x: ArrayView2<'_, f64>, config: &LinearDictionaryConfig) -> Result<(), String> {
234 if x.nrows() == 0 || x.ncols() == 0 {
235 return Err("linear_dictionary_fit requires a non-empty 2-D matrix".to_string());
236 }
237 if !x.iter().all(|value| value.is_finite()) {
238 return Err("linear_dictionary_fit input must be finite".to_string());
239 }
240 if config.n_atoms == 0 {
241 return Err("linear_dictionary_fit requires K >= 1".to_string());
242 }
243 if config.max_iter == 0 {
244 return Err("linear_dictionary_fit requires max_iter >= 1".to_string());
245 }
246 if config.top_k == 0 || config.top_k > config.n_atoms {
247 return Err(format!(
248 "linear_dictionary_fit top_k must be in [1, K={}]; got {}",
249 config.n_atoms, config.top_k
250 ));
251 }
252 if !(config.temperature.is_finite() && config.temperature > 0.0) {
253 return Err(format!(
254 "linear_dictionary_fit temperature must be finite and positive; got {}",
255 config.temperature
256 ));
257 }
258 if !(config.code_ridge.is_finite() && config.code_ridge > 0.0) {
259 return Err(format!(
260 "linear_dictionary_fit code_ridge must be finite and positive; got {}",
261 config.code_ridge
262 ));
263 }
264 if !config.tolerance.is_finite() {
265 return Err("linear_dictionary_fit tolerance must be finite".to_string());
266 }
267 Ok(())
268}
269
270fn fit_rank_one_pca_lane(
286 x: ArrayView2<'_, f64>,
287 config: &LinearDictionaryConfig,
288) -> Result<LinearDictionaryFit, String> {
289 if config.center_rank_one {
290 return fit_rank_one_centered_lane(x, config);
291 }
292 let covariance = x.t().dot(&x);
293 let (evals, evecs) = covariance
294 .eigh(Side::Lower)
295 .map_err(|err| format!("linear_dictionary_fit PCA eigensolve failed: {err}"))?;
296 let last = evals.len() - 1;
297 let mut atom = evecs.column(last).to_owned();
298 orient_vector(&mut atom);
299 let mut assignments = Array2::<f64>::zeros((x.nrows(), 1));
300 for row in 0..x.nrows() {
301 assignments[[row, 0]] = x.row(row).dot(&atom) / (1.0 + config.code_ridge);
302 }
303 let mut atoms = atom.insert_axis(Axis(0)).to_owned();
304 normalize_atom_and_assignments(&mut atoms, &mut assignments, 0);
305 let fitted = assignments.dot(&atoms);
306 let score = penalized_reconstruction_loss(x, fitted.view(), config.code_ridge, atoms.view());
307 Ok(LinearDictionaryFit {
308 atoms,
309 assignments,
310 fitted: fitted.clone(),
311 lambdas: Array1::from_elem(1, config.code_ridge),
312 reml_scores: Array1::from_elem(1, score),
313 explained_variance: explained_variance(x, fitted.view()),
314 iterations: 1.min(config.max_iter),
315 converged: true,
316 assignment: config.assignment,
317 top_k: 1,
318 })
319}
320
321fn fit_rank_one_centered_lane(
329 x: ArrayView2<'_, f64>,
330 config: &LinearDictionaryConfig,
331) -> Result<LinearDictionaryFit, String> {
332 let CenteredRankOne {
333 atom,
334 codes,
335 fitted,
336 explained_variance: ev,
337 } = centered_rank_one_components(x, config.code_ridge)?;
338 let atoms = atom.insert_axis(Axis(0)).to_owned();
339 let assignments = codes.insert_axis(Axis(1)).to_owned();
340 let score = penalized_reconstruction_loss(x, fitted.view(), config.code_ridge, atoms.view());
341 Ok(LinearDictionaryFit {
342 atoms,
343 assignments,
344 fitted,
345 lambdas: Array1::from_elem(1, config.code_ridge),
346 reml_scores: Array1::from_elem(1, score),
347 explained_variance: ev,
348 iterations: 1.min(config.max_iter),
349 converged: true,
350 assignment: config.assignment,
351 top_k: 1,
352 })
353}
354
355struct CenteredRankOne {
358 atom: Array1<f64>,
360 codes: Array1<f64>,
362 fitted: Array2<f64>,
364 explained_variance: f64,
366}
367
368fn centered_rank_one_components(
369 x: ArrayView2<'_, f64>,
370 code_ridge: f64,
371) -> Result<CenteredRankOne, String> {
372 if x.nrows() == 0 || x.ncols() == 0 {
373 return Err("rank_one_centered_pca_ceiling requires a non-empty 2-D matrix".to_string());
374 }
375 if !(code_ridge.is_finite() && code_ridge > 0.0) {
376 return Err(format!(
377 "rank_one_centered_pca_ceiling code_ridge must be finite and positive; got {code_ridge}"
378 ));
379 }
380 let means = x.mean_axis(Axis(0)).expect("non-empty input has means");
381 let centered = &x.to_owned() - &means;
382 let covariance = centered.t().dot(¢ered);
383 let (evals, evecs) = covariance
384 .eigh(Side::Lower)
385 .map_err(|err| format!("rank_one_centered_pca_ceiling eigensolve failed: {err}"))?;
386 let last = evals.len() - 1;
387 let mut atom = evecs.column(last).to_owned();
388 orient_vector(&mut atom);
389 let shrink = 1.0 / (1.0 + code_ridge);
390 let mut codes = Array1::<f64>::zeros(x.nrows());
391 let mut fitted = Array2::<f64>::zeros(x.dim());
392 for row in 0..x.nrows() {
393 let code = centered.row(row).dot(&atom) * shrink;
394 codes[row] = code;
395 for col in 0..x.ncols() {
396 fitted[[row, col]] = means[col] + code * atom[col];
397 }
398 }
399 let ev = explained_variance(x, fitted.view());
400 Ok(CenteredRankOne {
401 atom,
402 codes,
403 fitted,
404 explained_variance: ev,
405 })
406}
407
408pub fn rank_one_centered_pca_ceiling(
419 x: ArrayView2<'_, f64>,
420 code_ridge: f64,
421) -> Result<(Array2<f64>, f64), String> {
422 let components = centered_rank_one_components(x, code_ridge)?;
423 Ok((components.fitted, components.explained_variance))
424}
425
426fn initialize_atoms(x: ArrayView2<'_, f64>, n_atoms: usize) -> Array2<f64> {
427 let mut atoms = Array2::<f64>::zeros((n_atoms, x.ncols()));
428 let first = max_norm_row(x);
429 atoms.row_mut(0).assign(&x.row(first));
430 normalize_row(atoms.slice_mut(s![0, ..]));
431 let mut min_dist2 = Array1::<f64>::from_elem(x.nrows(), f64::INFINITY);
432
433 for atom_idx in 1..n_atoms {
434 let prev = atoms.row(atom_idx - 1);
435 for row in 0..x.nrows() {
436 let dist2 = squared_distance(x.row(row), prev);
437 if dist2 < min_dist2[row] {
438 min_dist2[row] = dist2;
439 }
440 }
441 let chosen = if atom_idx < x.nrows() {
442 max_index(min_dist2.view())
443 } else {
444 atom_idx % x.nrows()
445 };
446 atoms.row_mut(atom_idx).assign(&x.row(chosen));
447 normalize_row(atoms.slice_mut(s![atom_idx, ..]));
448 }
449 atoms
450}
451
452fn fit_one_atom_penalized_ls(
453 x: ArrayView2<'_, f64>,
454 atoms: &mut Array2<f64>,
455 assignments: &mut Array2<f64>,
456 fitted: &mut Array2<f64>,
457 lambdas: &mut Array1<f64>,
458 reml_scores: &mut Array1<f64>,
459 atom_idx: usize,
460 atom_ridge: f64,
461) -> Result<bool, String> {
462 let code = assignments.column(atom_idx).to_owned();
463 let code_norm2 = code.dot(&code);
464 if code_norm2 <= MIN_NORM2 {
465 let mut worst_row = 0usize;
478 let mut worst_res2 = -1.0_f64;
479 for row in 0..x.nrows() {
480 let mut res2 = 0.0_f64;
481 for col in 0..x.ncols() {
482 let d = x[[row, col]] - fitted[[row, col]];
483 res2 += d * d;
484 }
485 if res2 > worst_res2 {
486 worst_res2 = res2;
487 worst_row = row;
488 }
489 }
490 if worst_res2 <= MIN_NORM2 {
491 atoms.row_mut(atom_idx).fill(0.0);
495 lambdas[atom_idx] = INACTIVE_LAMBDA;
496 reml_scores[atom_idx] = 0.0;
497 return Ok(false);
498 }
499 for col in 0..x.ncols() {
500 atoms[[atom_idx, col]] = x[[worst_row, col]] - fitted[[worst_row, col]];
501 }
502 normalize_row(atoms.slice_mut(s![atom_idx, ..]));
503 lambdas[atom_idx] = atom_ridge;
504 reml_scores[atom_idx] =
505 penalized_reconstruction_loss(x, fitted.view(), atom_ridge, atoms.view());
506 return Ok(true);
507 }
508
509 let old_atom = atoms.row(atom_idx).to_owned();
510 let mut residual = x.to_owned() - fitted.view();
511 residual += &code
512 .view()
513 .insert_axis(Axis(1))
514 .dot(&old_atom.view().insert_axis(Axis(0)));
515
516 let denominator = code_norm2 + atom_ridge;
517 for col in 0..x.ncols() {
518 atoms[[atom_idx, col]] = code.dot(&residual.column(col)) / denominator;
519 }
520 lambdas[atom_idx] = atom_ridge;
521 normalize_atom_and_assignments(atoms, assignments, atom_idx);
522 let updated_code = assignments.column(atom_idx).to_owned();
523 fitted.assign(&x);
524 *fitted -= &residual;
525 *fitted += &updated_code
526 .view()
527 .insert_axis(Axis(1))
528 .dot(&atoms.row(atom_idx).insert_axis(Axis(0)));
529 reml_scores[atom_idx] =
530 penalized_reconstruction_loss(x, fitted.view(), atom_ridge, atoms.view());
531 Ok(false)
532}
533
534fn top_k_assignments(
535 x: ArrayView2<'_, f64>,
536 atoms: ArrayView2<'_, f64>,
537 top_k: usize,
538 code_ridge: f64,
539) -> Result<Array2<f64>, String> {
540 let cross = x.dot(&atoms.t());
541 let mut assignments = Array2::<f64>::zeros((x.nrows(), atoms.nrows()));
542 for row in 0..x.nrows() {
543 let active = top_indices_by_abs(cross.row(row), top_k);
544 let coeffs = solve_active_coefficients(atoms, cross.row(row), &active, code_ridge)?;
545 for pos in 0..active.len() {
546 assignments[[row, active[pos]]] = coeffs[pos];
547 }
548 }
549 Ok(assignments)
550}
551
552pub fn linear_dictionary_transform(
560 x: ArrayView2<'_, f64>,
561 atoms: ArrayView2<'_, f64>,
562 top_k: usize,
563 code_ridge: f64,
564) -> Result<Array2<f64>, String> {
565 let k = atoms.nrows();
566 if k == 0 {
567 return Err("linear_dictionary_transform: dictionary has no atoms".to_string());
568 }
569 if x.ncols() != atoms.ncols() {
570 return Err(format!(
571 "linear_dictionary_transform: X has P={} columns but atoms have P={}",
572 x.ncols(),
573 atoms.ncols()
574 ));
575 }
576 let effective_k = top_k.min(k).max(1);
577 top_k_assignments(x, atoms, effective_k, code_ridge)
578}
579
580fn softmax_assignments(
581 x: ArrayView2<'_, f64>,
582 atoms: ArrayView2<'_, f64>,
583 top_k: usize,
584 temperature: f64,
585 code_ridge: f64,
586) -> Result<Array2<f64>, String> {
587 let cross = x.dot(&atoms.t());
588 let atom_norm2 = atoms.map_axis(Axis(1), |row| row.dot(&row).max(MIN_NORM2));
589 let mut assignments = Array2::<f64>::zeros((x.nrows(), atoms.nrows()));
590 for row in 0..x.nrows() {
591 let active = top_indices_by_abs(cross.row(row), top_k);
592 let mut max_score = f64::NEG_INFINITY;
593 for &atom_idx in &active {
594 let score = cross[[row, atom_idx]].abs() / (atom_norm2[atom_idx].sqrt() * temperature);
595 if score > max_score {
596 max_score = score;
597 }
598 }
599 let mut denom = 0.0;
600 for &atom_idx in &active {
601 let score = cross[[row, atom_idx]].abs() / (atom_norm2[atom_idx].sqrt() * temperature);
602 let mass = (score - max_score).exp();
603 assignments[[row, atom_idx]] = mass;
604 denom += mass;
605 }
606 if denom <= 0.0 || !denom.is_finite() {
607 return Err("linear_dictionary_fit softmax assignment underflowed".to_string());
608 }
609 for &atom_idx in &active {
610 let projection = cross[[row, atom_idx]] / (atom_norm2[atom_idx] + code_ridge);
611 assignments[[row, atom_idx]] = assignments[[row, atom_idx]] * projection / denom;
612 }
613 }
614 Ok(assignments)
615}
616
617fn solve_active_coefficients(
618 atoms: ArrayView2<'_, f64>,
619 cross_row: ArrayView1<'_, f64>,
620 active: &[usize],
621 code_ridge: f64,
622) -> Result<Array1<f64>, String> {
623 let m = active.len();
624 let mut system = Array2::<f64>::zeros((m, m));
625 let mut rhs = Array2::<f64>::zeros((m, 1));
626 for i in 0..m {
627 rhs[[i, 0]] = cross_row[active[i]];
628 for j in 0..m {
629 system[[i, j]] = atoms.row(active[i]).dot(&atoms.row(active[j]));
630 }
631 system[[i, i]] += code_ridge;
632 }
633 let factor = system
634 .cholesky(Side::Lower)
635 .map_err(|err| format!("linear_dictionary_fit sparse-code solve failed: {err}"))?;
636 let mut solution = rhs;
637 factor.solve_mat_in_place(&mut solution);
638 Ok(solution.column(0).to_owned())
639}
640
641fn top_indices_by_abs(row: ArrayView1<'_, f64>, top_k: usize) -> Vec<usize> {
642 let mut selected: Vec<(usize, f64)> = Vec::with_capacity(top_k);
643 for idx in 0..row.len() {
644 let score = row[idx].abs();
645 if selected.len() < top_k {
646 selected.push((idx, score));
647 continue;
648 }
649 let mut worst_pos = 0usize;
650 for pos in 1..selected.len() {
651 if selected[pos].1 < selected[worst_pos].1
652 || (selected[pos].1 == selected[worst_pos].1
653 && selected[pos].0 > selected[worst_pos].0)
654 {
655 worst_pos = pos;
656 }
657 }
658 let worst = selected[worst_pos];
659 if score > worst.1 || (score == worst.1 && idx < worst.0) {
660 selected[worst_pos] = (idx, score);
661 }
662 }
663 selected.sort_by(|a, b| {
664 b.1.partial_cmp(&a.1)
665 .unwrap_or(std::cmp::Ordering::Equal)
666 .then_with(|| a.0.cmp(&b.0))
667 });
668 selected.into_iter().map(|(idx, _)| idx).collect()
669}
670
671fn normalize_atom_and_assignments(
672 atoms: &mut Array2<f64>,
673 assignments: &mut Array2<f64>,
674 atom_idx: usize,
675) {
676 let norm = atoms.row(atom_idx).dot(&atoms.row(atom_idx)).sqrt();
677 if norm > MIN_NORM2.sqrt() {
678 atoms.row_mut(atom_idx).mapv_inplace(|value| value / norm);
679 assignments
680 .column_mut(atom_idx)
681 .mapv_inplace(|value| value * norm);
682 }
683 orient_atom_and_code(atoms, assignments, atom_idx);
684}
685
686fn orient_atom_and_code(atoms: &mut Array2<f64>, assignments: &mut Array2<f64>, atom_idx: usize) {
687 let sign = first_nonzero_sign(atoms.row(atom_idx));
688 if sign < 0.0 {
689 atoms.row_mut(atom_idx).mapv_inplace(|value| -value);
690 assignments
691 .column_mut(atom_idx)
692 .mapv_inplace(|value| -value);
693 }
694}
695
696fn orient_vector(vector: &mut Array1<f64>) {
697 if first_nonzero_sign(vector.view()) < 0.0 {
698 vector.mapv_inplace(|value| -value);
699 }
700}
701
702fn first_nonzero_sign(row: ndarray::ArrayView1<'_, f64>) -> f64 {
703 for &value in row {
704 if value.abs() > 1.0e-12 {
705 return value.signum();
706 }
707 }
708 1.0
709}
710
711fn normalize_row(mut row: ndarray::ArrayViewMut1<'_, f64>) {
712 let norm = row.dot(&row).sqrt();
713 if norm > MIN_NORM2.sqrt() {
714 row.mapv_inplace(|value| value / norm);
715 }
716}
717
718fn max_norm_row(x: ArrayView2<'_, f64>) -> usize {
719 let mut best = 0usize;
720 let mut best_norm = f64::NEG_INFINITY;
721 for row in 0..x.nrows() {
722 let norm = x.row(row).dot(&x.row(row));
723 if norm > best_norm {
724 best = row;
725 best_norm = norm;
726 }
727 }
728 best
729}
730
731fn max_index(values: ndarray::ArrayView1<'_, f64>) -> usize {
732 let mut best = 0usize;
733 let mut best_value = f64::NEG_INFINITY;
734 for idx in 0..values.len() {
735 if values[idx] > best_value {
736 best = idx;
737 best_value = values[idx];
738 }
739 }
740 best
741}
742
743fn squared_distance(a: ndarray::ArrayView1<'_, f64>, b: ndarray::ArrayView1<'_, f64>) -> f64 {
744 a.iter()
745 .zip(b.iter())
746 .map(|(av, bv)| {
747 let diff = av - bv;
748 diff * diff
749 })
750 .sum()
751}
752
753fn explained_variance(x: ArrayView2<'_, f64>, fitted: ArrayView2<'_, f64>) -> f64 {
754 let mut rss = 0.0;
755 for row in 0..x.nrows() {
756 for col in 0..x.ncols() {
757 let residual = x[[row, col]] - fitted[[row, col]];
758 rss += residual * residual;
759 }
760 }
761 let means = x.mean_axis(Axis(0)).expect("non-empty input has means");
762 let mut tss = 0.0;
763 for row in 0..x.nrows() {
764 for col in 0..x.ncols() {
765 let centered = x[[row, col]] - means[col];
766 tss += centered * centered;
767 }
768 }
769 if tss <= MIN_NORM2 {
770 if rss <= MIN_NORM2 { 1.0 } else { 0.0 }
771 } else {
772 1.0 - rss / tss
773 }
774}
775
776fn penalized_reconstruction_loss(
777 x: ArrayView2<'_, f64>,
778 fitted: ArrayView2<'_, f64>,
779 ridge: f64,
780 atoms: ArrayView2<'_, f64>,
781) -> f64 {
782 let mut loss = 0.0;
783 for row in 0..x.nrows() {
784 for col in 0..x.ncols() {
785 let residual = x[[row, col]] - fitted[[row, col]];
786 loss += residual * residual;
787 }
788 }
789 loss + ridge * atoms.iter().map(|value| value * value).sum::<f64>()
790}
791
792#[cfg(test)]
793mod tests {
794 use super::*;
795 use approx::assert_abs_diff_eq;
796 use ndarray::{Array2, array};
797
798 #[test]
799 fn planted_sparse_linear_dictionary_reaches_high_explained_variance() {
800 let truth = array![
801 [1.0, 0.0, 0.0, 0.0],
802 [0.0, 1.0, 0.0, 0.0],
803 [0.0, 0.0, 1.0, 0.0],
804 [0.0, 0.0, 0.0, 1.0],
805 ];
806 let mut assignments = Array2::<f64>::zeros((160, 4));
807 for row in 0..160 {
808 let atom = row % 4;
809 assignments[[row, atom]] = 0.7 + 0.01 * ((row / 4) as f64);
810 assignments[[row, (atom + 1) % 4]] = 0.2;
811 }
812 let x = assignments.dot(&truth);
813 let config = LinearDictionaryConfig {
814 n_atoms: 4,
815 max_iter: 40,
816 top_k: 2,
817 assignment: LinearDictionaryAssignment::TopK,
818 temperature: DEFAULT_TEMPERATURE,
819 code_ridge: DEFAULT_CODE_RIDGE,
820 tolerance: 1.0e-9,
821 center_rank_one: false,
822 };
823
824 let fit = fit_linear_dictionary(x.view(), &config).expect("linear dictionary fit");
825
826 assert!(
827 fit.explained_variance > 0.95,
828 "expected EV > 0.95, got {}",
829 fit.explained_variance
830 );
831 }
832
833 #[test]
834 fn single_atom_matches_penalized_pca_oracle() {
835 let mut x = Array2::<f64>::zeros((80, 3));
836 for row in 0..80 {
837 let t = (row as f64 - 39.5) / 20.0;
838 x[[row, 0]] = 2.0 * t;
839 x[[row, 1]] = -t;
840 x[[row, 2]] = 0.05 * (row as f64).sin();
841 }
842 let config = LinearDictionaryConfig {
843 n_atoms: 1,
844 max_iter: 5,
845 top_k: 1,
846 assignment: LinearDictionaryAssignment::TopK,
847 temperature: DEFAULT_TEMPERATURE,
848 code_ridge: DEFAULT_CODE_RIDGE,
849 tolerance: DEFAULT_TOLERANCE,
850 center_rank_one: false,
851 };
852
853 let fit = fit_linear_dictionary(x.view(), &config).expect("rank-one fit");
854 let covariance = x.t().dot(&x);
855 let (evals, _) = covariance.eigh(Side::Lower).expect("PCA eigensolve");
856 let shrink = 1.0 / (1.0 + DEFAULT_CODE_RIDGE);
857 let oracle_ev = 1.0
858 - ((1.0 - shrink) * (1.0 - shrink) * evals[evals.len() - 1]
859 + evals.slice(s![..evals.len() - 1]).sum())
860 / evals.sum();
861
862 assert!(fit.explained_variance > 0.99);
863 assert_abs_diff_eq!(fit.explained_variance, oracle_ev, epsilon = 2.0e-4);
864 }
865
866 #[test]
867 fn orthonormal_rank_one_atoms_all_revived_no_dead_collapse_1500() {
868 let (k, p, n) = (4usize, 8usize, 400usize);
874 let mut a = Array2::<f64>::zeros((p, p));
877 for i in 0..p {
878 for j in 0..p {
879 a[[i, j]] = ((i * 7 + j * 3 + 1) % 11) as f64 - 5.0;
880 }
881 }
882 let sym = &a + &a.t();
883 let (_evals, evecs) = sym.eigh(Side::Lower).expect("orthonormal directions");
884 let dirs = evecs.slice(s![.., ..k]).t().to_owned(); let mut x = Array2::<f64>::zeros((n, p));
886 for row in 0..n {
887 let atom = row % k;
888 let scale = if row % 2 == 0 { 2.0 } else { -1.5 } + 0.01 * (row / k) as f64;
889 for col in 0..p {
890 let noise = 1.0e-3 * (((row * p + col) % 13) as f64 - 6.0);
891 x[[row, col]] = scale * dirs[[atom, col]] + noise;
892 }
893 }
894 let config = LinearDictionaryConfig {
895 n_atoms: k,
896 max_iter: 40,
897 top_k: 1,
898 assignment: LinearDictionaryAssignment::TopK,
899 temperature: DEFAULT_TEMPERATURE,
900 code_ridge: DEFAULT_CODE_RIDGE,
901 tolerance: 1.0e-9,
902 center_rank_one: false,
903 };
904 let fit = fit_linear_dictionary(x.view(), &config).expect("orthonormal dictionary fit");
905 let live = fit
906 .atoms
907 .axis_iter(Axis(0))
908 .filter(|atom| atom.iter().any(|value| value.abs() > 1.0e-12))
909 .count();
910 assert_eq!(
911 live, k,
912 "all {k} atoms must stay live (no dead-atom collapse); got {live} live"
913 );
914 assert!(
915 fit.explained_variance > 0.99,
916 "K orthonormal rank-1 atoms must be reconstructed at EV > 0.99; got {}",
917 fit.explained_variance
918 );
919 }
920
921 #[test]
922 fn final_reroute_never_regresses_and_stays_consistent() {
923 let truth = array![
926 [1.0, 0.0, 0.0, 0.0],
927 [0.0, 1.0, 0.0, 0.0],
928 [0.0, 0.0, 1.0, 0.0],
929 [0.0, 0.0, 0.0, 1.0],
930 ];
931 let mut assignments = Array2::<f64>::zeros((160, 4));
932 for row in 0..160 {
933 let atom = row % 4;
934 assignments[[row, atom]] = 0.7 + 0.01 * ((row / 4) as f64);
935 assignments[[row, (atom + 1) % 4]] = 0.2;
936 }
937 let x = assignments.dot(&truth);
938 let config = LinearDictionaryConfig {
939 n_atoms: 4,
940 max_iter: 40,
941 top_k: 2,
942 assignment: LinearDictionaryAssignment::TopK,
943 temperature: DEFAULT_TEMPERATURE,
944 code_ridge: DEFAULT_CODE_RIDGE,
945 tolerance: 1.0e-9,
946 center_rank_one: false,
947 };
948
949 let diag =
952 fit_multi_atom_dictionary(x.view(), &config).expect("multi-atom dictionary fit");
953 assert!(
954 diag.fit.explained_variance >= diag.pre_reroute_ev - 1.0e-12,
955 "final reroute regressed EV: pre={}, returned={}",
956 diag.pre_reroute_ev,
957 diag.fit.explained_variance
958 );
959
960 let recomputed_fitted = diag.fit.assignments.dot(&diag.fit.atoms);
963 for (a, b) in diag.fit.fitted.iter().zip(recomputed_fitted.iter()) {
964 assert_abs_diff_eq!(*a, *b, epsilon = 1.0e-10);
965 }
966 assert_abs_diff_eq!(
967 diag.fit.explained_variance,
968 explained_variance(x.view(), diag.fit.fitted.view()),
969 epsilon = 1.0e-10
970 );
971
972 let public = fit_linear_dictionary(x.view(), &config).expect("public fit");
974 let public_fitted = public.assignments.dot(&public.atoms);
975 for (a, b) in public.fitted.iter().zip(public_fitted.iter()) {
976 assert_abs_diff_eq!(*a, *b, epsilon = 1.0e-10);
977 }
978 }
979
980 #[test]
981 fn centered_rank_one_ceiling_agrees_when_data_already_centered() {
982 let mut x = Array2::<f64>::zeros((90, 3));
986 for row in 0..90 {
987 let t = (row as f64 - 44.5) / 25.0;
988 x[[row, 0]] = 1.5 * t;
989 x[[row, 1]] = -0.8 * t + 0.02 * (row as f64).cos();
990 x[[row, 2]] = 0.6 * t;
991 }
992 let means = x.mean_axis(Axis(0)).unwrap();
993 let centered = &x - &means;
994
995 let config = LinearDictionaryConfig::new(1);
996 let uncentered = fit_linear_dictionary(centered.view(), &config).expect("rank-one fit");
997 let (_fitted, centered_ev) =
998 rank_one_centered_pca_ceiling(centered.view(), DEFAULT_CODE_RIDGE)
999 .expect("centered ceiling");
1000
1001 assert_abs_diff_eq!(
1002 uncentered.explained_variance,
1003 centered_ev,
1004 epsilon = 1.0e-9
1005 );
1006 }
1007
1008 #[test]
1009 fn centered_rank_one_ceiling_beats_uncentered_with_strong_mean() {
1010 let mut x = Array2::<f64>::zeros((120, 2));
1015 for row in 0..120 {
1016 let t = (row as f64 - 59.5) / 60.0; x[[row, 0]] = 50.0 + 0.3 * t;
1018 x[[row, 1]] = 50.0 - 0.3 * t;
1019 }
1020 let config = LinearDictionaryConfig::new(1);
1021 let uncentered = fit_linear_dictionary(x.view(), &config).expect("rank-one fit");
1022 let (fitted, centered_ev) =
1023 rank_one_centered_pca_ceiling(x.view(), DEFAULT_CODE_RIDGE).expect("centered ceiling");
1024
1025 assert!(
1026 centered_ev > uncentered.explained_variance + 1.0e-6,
1027 "centered ceiling ({centered_ev}) should beat uncentered lane ({}) on strong-mean data",
1028 uncentered.explained_variance
1029 );
1030 assert_abs_diff_eq!(
1032 centered_ev,
1033 explained_variance(x.view(), fitted.view()),
1034 epsilon = 1.0e-10
1035 );
1036 }
1037
1038 #[test]
1039 fn center_rank_one_config_flag_routes_k1_lane_to_centered_ceiling() {
1040 let mut x = Array2::<f64>::zeros((100, 3));
1045 for row in 0..100 {
1046 let t = (row as f64 - 49.5) / 50.0;
1047 x[[row, 0]] = 30.0 + 0.2 * t;
1048 x[[row, 1]] = 30.0 - 0.2 * t;
1049 x[[row, 2]] = 30.0 + 0.05 * t;
1050 }
1051
1052 let default_config = LinearDictionaryConfig::new(1);
1053 assert!(!default_config.center_rank_one, "flag must default to false");
1054 let uncentered = fit_linear_dictionary(x.view(), &default_config).expect("uncentered lane");
1055
1056 let mut centered_config = LinearDictionaryConfig::new(1);
1057 centered_config.center_rank_one = true;
1058 let centered = fit_linear_dictionary(x.view(), ¢ered_config).expect("centered lane");
1059
1060 let (_fitted, helper_ev) =
1063 rank_one_centered_pca_ceiling(x.view(), DEFAULT_CODE_RIDGE).expect("helper ceiling");
1064 assert_abs_diff_eq!(centered.explained_variance, helper_ev, epsilon = 1.0e-10);
1065 assert!(
1066 centered.explained_variance > uncentered.explained_variance + 1.0e-6,
1067 "center_rank_one=true ({}) must beat default ({}) on strong-mean data",
1068 centered.explained_variance,
1069 uncentered.explained_variance
1070 );
1071 assert_abs_diff_eq!(
1075 centered.explained_variance,
1076 explained_variance(x.view(), centered.fitted.view()),
1077 epsilon = 1.0e-10
1078 );
1079 }
1080
1081 #[test]
1082 fn sparse_assignment_scales_to_thousand_atom_dictionary() {
1083 let active_atoms = array![
1084 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1085 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1086 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1087 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
1088 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
1089 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
1090 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
1091 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
1092 ];
1093 let mut x = Array2::<f64>::zeros((256, 8));
1094 for row in 0..x.nrows() {
1095 let atom = row % active_atoms.nrows();
1096 let scale = 0.7 + 0.003 * row as f64;
1097 x.row_mut(row).assign(&(&active_atoms.row(atom) * scale));
1098 }
1099 let config = LinearDictionaryConfig {
1100 n_atoms: 1024,
1101 max_iter: 8,
1102 top_k: 1,
1103 assignment: LinearDictionaryAssignment::TopK,
1104 temperature: DEFAULT_TEMPERATURE,
1105 code_ridge: DEFAULT_CODE_RIDGE,
1106 tolerance: 1.0e-9,
1107 center_rank_one: false,
1108 };
1109
1110 let fit = fit_linear_dictionary(x.view(), &config).expect("large-K linear dictionary fit");
1111 let max_active = fit
1112 .assignments
1113 .axis_iter(Axis(0))
1114 .map(|row| row.iter().filter(|value| value.abs() > 1.0e-10).count())
1115 .max()
1116 .unwrap();
1117
1118 assert_eq!(max_active, 1);
1119 assert!(
1120 fit.explained_variance > 0.95,
1121 "expected EV > 0.95 at K=1024, got {}",
1122 fit.explained_variance
1123 );
1124 }
1125}