1use ferrolearn_core::error::FerroError;
35use ferrolearn_core::traits::{Fit, Transform};
36use ndarray::Array2;
37use rand::SeedableRng;
38use rand_distr::{Distribution, Normal};
39use rand_xoshiro::Xoshiro256PlusPlus;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum DictFitAlgorithm {
48 CoordinateDescent,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum DictTransformAlgorithm {
55 Omp,
57 LassoCd,
59}
60
61#[derive(Debug, Clone)]
70pub struct DictionaryLearning {
71 n_components: usize,
73 alpha: f64,
75 max_iter: usize,
77 tol: f64,
79 fit_algorithm: DictFitAlgorithm,
81 transform_algorithm: DictTransformAlgorithm,
83 transform_n_nonzero_coefs: Option<usize>,
85 random_state: Option<u64>,
87}
88
89impl DictionaryLearning {
90 #[must_use]
95 pub fn new(n_components: usize) -> Self {
96 Self {
97 n_components,
98 alpha: 1.0,
99 max_iter: 1000,
100 tol: 1e-8,
101 fit_algorithm: DictFitAlgorithm::CoordinateDescent,
102 transform_algorithm: DictTransformAlgorithm::Omp,
103 transform_n_nonzero_coefs: None,
104 random_state: None,
105 }
106 }
107
108 #[must_use]
110 pub fn with_alpha(mut self, alpha: f64) -> Self {
111 self.alpha = alpha;
112 self
113 }
114
115 #[must_use]
117 pub fn with_max_iter(mut self, n: usize) -> Self {
118 self.max_iter = n;
119 self
120 }
121
122 #[must_use]
124 pub fn with_tol(mut self, tol: f64) -> Self {
125 self.tol = tol;
126 self
127 }
128
129 #[must_use]
131 pub fn with_fit_algorithm(mut self, algo: DictFitAlgorithm) -> Self {
132 self.fit_algorithm = algo;
133 self
134 }
135
136 #[must_use]
138 pub fn with_transform_algorithm(mut self, algo: DictTransformAlgorithm) -> Self {
139 self.transform_algorithm = algo;
140 self
141 }
142
143 #[must_use]
145 pub fn with_transform_n_nonzero_coefs(mut self, n: usize) -> Self {
146 self.transform_n_nonzero_coefs = Some(n);
147 self
148 }
149
150 #[must_use]
152 pub fn with_random_state(mut self, seed: u64) -> Self {
153 self.random_state = Some(seed);
154 self
155 }
156
157 #[must_use]
159 pub fn n_components(&self) -> usize {
160 self.n_components
161 }
162
163 #[must_use]
165 pub fn alpha(&self) -> f64 {
166 self.alpha
167 }
168
169 #[must_use]
171 pub fn max_iter(&self) -> usize {
172 self.max_iter
173 }
174
175 #[must_use]
177 pub fn tol(&self) -> f64 {
178 self.tol
179 }
180
181 #[must_use]
183 pub fn fit_algorithm(&self) -> DictFitAlgorithm {
184 self.fit_algorithm
185 }
186
187 #[must_use]
189 pub fn transform_algorithm(&self) -> DictTransformAlgorithm {
190 self.transform_algorithm
191 }
192
193 #[must_use]
195 pub fn random_state(&self) -> Option<u64> {
196 self.random_state
197 }
198}
199
200#[derive(Debug, Clone)]
210pub struct FittedDictionaryLearning {
211 components_: Array2<f64>,
214 alpha_: f64,
216 n_iter_: usize,
218 reconstruction_err_: f64,
220 transform_algorithm_: DictTransformAlgorithm,
222 transform_n_nonzero_coefs_: usize,
224}
225
226impl FittedDictionaryLearning {
227 #[must_use]
229 pub fn components(&self) -> &Array2<f64> {
230 &self.components_
231 }
232
233 #[must_use]
235 pub fn n_iter(&self) -> usize {
236 self.n_iter_
237 }
238
239 #[must_use]
241 pub fn reconstruction_err(&self) -> f64 {
242 self.reconstruction_err_
243 }
244}
245
246fn normalise_dictionary(d: &mut Array2<f64>) {
252 let n_components = d.nrows();
253 let n_features = d.ncols();
254 for k in 0..n_components {
255 let mut norm = 0.0;
256 for j in 0..n_features {
257 norm += d[[k, j]] * d[[k, j]];
258 }
259 let norm = norm.sqrt();
260 if norm > 1e-16 {
261 for j in 0..n_features {
262 d[[k, j]] /= norm;
263 }
264 }
265 }
266}
267
268fn lasso_cd_single(x_row: &[f64], d: &Array2<f64>, alpha: f64, max_iter: usize) -> Vec<f64> {
272 let n_components = d.nrows();
273 let n_features = d.ncols();
274 let mut a = vec![0.0; n_components];
275
276 let mut gram = vec![vec![0.0; n_components]; n_components];
278 let mut dx = vec![0.0; n_components];
279 for k in 0..n_components {
280 for j in 0..n_features {
281 dx[k] += d[[k, j]] * x_row[j];
282 }
283 for l in k..n_components {
284 let mut val = 0.0;
285 for j in 0..n_features {
286 val += d[[k, j]] * d[[l, j]];
287 }
288 gram[k][l] = val;
289 gram[l][k] = val;
290 }
291 }
292
293 for _iter in 0..max_iter {
294 let mut max_change = 0.0;
295 for k in 0..n_components {
296 let mut rho = dx[k];
298 for l in 0..n_components {
299 if l != k {
300 rho -= gram[k][l] * a[l];
301 }
302 }
303
304 let gram_kk = gram[k][k];
306 let new_a = if gram_kk.abs() < 1e-16 {
307 0.0
308 } else {
309 soft_threshold(rho, alpha) / gram_kk
310 };
311
312 let change = (new_a - a[k]).abs();
313 if change > max_change {
314 max_change = change;
315 }
316 a[k] = new_a;
317 }
318 if max_change < 1e-6 {
319 break;
320 }
321 }
322
323 a
324}
325
326fn omp_single(x_row: &[f64], d: &Array2<f64>, max_nonzero: usize) -> Vec<f64> {
328 let n_components = d.nrows();
329 let n_features = d.ncols();
330 let mut a = vec![0.0; n_components];
331 let mut residual: Vec<f64> = x_row.to_vec();
332 let mut selected: Vec<usize> = Vec::new();
333 let max_k = max_nonzero.min(n_components).min(n_features);
334
335 for _step in 0..max_k {
336 let mut best_idx = 0;
338 let mut best_corr = 0.0;
339 for k in 0..n_components {
340 if selected.contains(&k) {
341 continue;
342 }
343 let mut corr = 0.0;
344 for j in 0..n_features {
345 corr += d[[k, j]] * residual[j];
346 }
347 if corr.abs() > best_corr {
348 best_corr = corr.abs();
349 best_idx = k;
350 }
351 }
352
353 if best_corr < 1e-12 {
354 break;
355 }
356
357 selected.push(best_idx);
358
359 let m = selected.len();
362 let mut gram = vec![vec![0.0; m]; m];
363 let mut rhs = vec![0.0; m];
364 for (ii, &ki) in selected.iter().enumerate() {
365 for j in 0..n_features {
366 rhs[ii] += d[[ki, j]] * x_row[j];
367 }
368 for (jj, &kj) in selected.iter().enumerate() {
369 let mut val = 0.0;
370 for f in 0..n_features {
371 val += d[[ki, f]] * d[[kj, f]];
372 }
373 gram[ii][jj] = val;
374 }
375 }
376
377 if let Some(coefs) = solve_symmetric(&gram, &rhs) {
379 residual = x_row.to_vec();
381 for (ii, &ki) in selected.iter().enumerate() {
382 a[ki] = coefs[ii];
383 for j in 0..n_features {
384 residual[j] -= coefs[ii] * d[[ki, j]];
385 }
386 }
387 } else {
388 break;
389 }
390
391 let res_norm: f64 = residual.iter().map(|v| v * v).sum::<f64>().sqrt();
393 if res_norm < 1e-10 {
394 break;
395 }
396 }
397
398 a
399}
400
401#[allow(clippy::needless_range_loop)]
404fn solve_symmetric(a: &[Vec<f64>], b: &[f64]) -> Option<Vec<f64>> {
405 let n = b.len();
406 if n == 0 {
407 return Some(vec![]);
408 }
409
410 let mut aug: Vec<Vec<f64>> = Vec::with_capacity(n);
412 for (i, row) in a.iter().enumerate().take(n) {
413 let mut r = row.clone();
414 r.push(b[i]);
415 aug.push(r);
416 }
417
418 for col in 0..n {
420 let mut max_val = aug[col][col].abs();
422 let mut max_row = col;
423 for row in (col + 1)..n {
424 if aug[row][col].abs() > max_val {
425 max_val = aug[row][col].abs();
426 max_row = row;
427 }
428 }
429 if max_val < 1e-14 {
430 return None;
431 }
432 aug.swap(col, max_row);
433
434 let pivot = aug[col][col];
435 for row in (col + 1)..n {
436 let factor = aug[row][col] / pivot;
437 for j in col..=n {
438 let val = aug[col][j];
439 aug[row][j] -= factor * val;
440 }
441 }
442 }
443
444 let mut x = vec![0.0; n];
446 for i in (0..n).rev() {
447 let mut sum = aug[i][n];
448 for j in (i + 1)..n {
449 sum -= aug[i][j] * x[j];
450 }
451 x[i] = sum / aug[i][i];
452 }
453
454 Some(x)
455}
456
457fn soft_threshold(x: f64, lambda: f64) -> f64 {
459 if x > lambda {
460 x - lambda
461 } else if x < -lambda {
462 x + lambda
463 } else {
464 0.0
465 }
466}
467
468fn reconstruction_error(x: &Array2<f64>, a: &Array2<f64>, d: &Array2<f64>) -> f64 {
470 let ad = a.dot(d);
471 let mut err = 0.0;
472 for (xi, adi) in x.iter().zip(ad.iter()) {
473 let diff = xi - adi;
474 err += diff * diff;
475 }
476 err.sqrt()
477}
478
479impl Fit<Array2<f64>, ()> for DictionaryLearning {
484 type Fitted = FittedDictionaryLearning;
485 type Error = FerroError;
486
487 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedDictionaryLearning, FerroError> {
496 let (n_samples, n_features) = x.dim();
497
498 if self.n_components == 0 {
500 return Err(FerroError::InvalidParameter {
501 name: "n_components".into(),
502 reason: "must be at least 1".into(),
503 });
504 }
505 if n_samples == 0 {
506 return Err(FerroError::InsufficientSamples {
507 required: 1,
508 actual: 0,
509 context: "DictionaryLearning::fit".into(),
510 });
511 }
512 if n_features == 0 {
513 return Err(FerroError::InvalidParameter {
514 name: "X".into(),
515 reason: "must have at least 1 feature".into(),
516 });
517 }
518 if self.alpha < 0.0 {
519 return Err(FerroError::InvalidParameter {
520 name: "alpha".into(),
521 reason: "must be non-negative".into(),
522 });
523 }
524
525 let n_components = self.n_components;
526 let seed = self.random_state.unwrap_or(0);
527 let transform_n_nonzero = self.transform_n_nonzero_coefs.unwrap_or(n_components);
528
529 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
531 let normal = Normal::new(0.0, 1.0).unwrap();
532 let mut d = Array2::<f64>::zeros((n_components, n_features));
533 for elem in d.iter_mut() {
534 *elem = normal.sample(&mut rng);
535 }
536 normalise_dictionary(&mut d);
537
538 let mut prev_err = f64::MAX;
539 let mut n_iter = 0;
540
541 for iteration in 0..self.max_iter {
542 n_iter = iteration + 1;
543
544 let mut a = Array2::<f64>::zeros((n_samples, n_components));
546 for i in 0..n_samples {
547 let x_row: Vec<f64> = (0..n_features).map(|j| x[[i, j]]).collect();
548 let codes = lasso_cd_single(&x_row, &d, self.alpha, 200);
549 for k in 0..n_components {
550 a[[i, k]] = codes[k];
551 }
552 }
553
554 let ata = a.t().dot(&a);
557 let atx = a.t().dot(x);
558
559 let gram: Vec<Vec<f64>> = (0..n_components)
562 .map(|i| (0..n_components).map(|j| ata[[i, j]]).collect())
563 .collect();
564
565 let mut gram_reg = gram.clone();
567 for (k, row) in gram_reg.iter_mut().enumerate() {
568 row[k] += 1e-10;
569 }
570
571 for j in 0..n_features {
572 let rhs: Vec<f64> = (0..n_components).map(|k| atx[[k, j]]).collect();
573 if let Some(sol) = solve_symmetric(&gram_reg, &rhs) {
574 for k in 0..n_components {
575 d[[k, j]] = sol[k];
576 }
577 }
578 }
579
580 normalise_dictionary(&mut d);
581
582 let err = reconstruction_error(x, &a, &d);
584 if (prev_err - err).abs() < self.tol {
585 break;
586 }
587 prev_err = err;
588 }
589
590 let mut a_final = Array2::<f64>::zeros((n_samples, n_components));
592 for i in 0..n_samples {
593 let x_row: Vec<f64> = (0..n_features).map(|j| x[[i, j]]).collect();
594 let codes = lasso_cd_single(&x_row, &d, self.alpha, 200);
595 for k in 0..n_components {
596 a_final[[i, k]] = codes[k];
597 }
598 }
599 let final_err = reconstruction_error(x, &a_final, &d);
600
601 Ok(FittedDictionaryLearning {
602 components_: d,
603 alpha_: self.alpha,
604 n_iter_: n_iter,
605 reconstruction_err_: final_err,
606 transform_algorithm_: self.transform_algorithm,
607 transform_n_nonzero_coefs_: transform_n_nonzero,
608 })
609 }
610}
611
612impl Transform<Array2<f64>> for FittedDictionaryLearning {
613 type Output = Array2<f64>;
614 type Error = FerroError;
615
616 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
623 let n_features = self.components_.ncols();
624 if x.ncols() != n_features {
625 return Err(FerroError::ShapeMismatch {
626 expected: vec![x.nrows(), n_features],
627 actual: vec![x.nrows(), x.ncols()],
628 context: "FittedDictionaryLearning::transform".into(),
629 });
630 }
631
632 let n_samples = x.nrows();
633 let n_components = self.components_.nrows();
634 let mut codes = Array2::<f64>::zeros((n_samples, n_components));
635
636 for i in 0..n_samples {
637 let x_row: Vec<f64> = (0..n_features).map(|j| x[[i, j]]).collect();
638 let a = match self.transform_algorithm_ {
639 DictTransformAlgorithm::Omp => {
640 omp_single(&x_row, &self.components_, self.transform_n_nonzero_coefs_)
641 }
642 DictTransformAlgorithm::LassoCd => {
643 lasso_cd_single(&x_row, &self.components_, self.alpha_, 200)
644 }
645 };
646 for k in 0..n_components {
647 codes[[i, k]] = a[k];
648 }
649 }
650
651 Ok(codes)
652 }
653}
654
655#[cfg(test)]
660mod tests {
661 use super::*;
662 use ndarray::Array2;
663
664 fn test_data() -> Array2<f64> {
666 Array2::<f64>::from_shape_fn((20, 10), |(i, j)| ((i * 7 + j * 3) % 11) as f64)
667 }
668
669 #[test]
670 fn test_dictlearn_basic_shape() {
671 let x = test_data();
672 let dl = DictionaryLearning::new(5)
673 .with_max_iter(20)
674 .with_random_state(42);
675 let fitted = dl.fit(&x, &()).unwrap();
676 assert_eq!(fitted.components().dim(), (5, 10));
677 }
678
679 #[test]
680 fn test_dictlearn_transform_shape() {
681 let x = test_data();
682 let dl = DictionaryLearning::new(5)
683 .with_max_iter(20)
684 .with_random_state(42);
685 let fitted = dl.fit(&x, &()).unwrap();
686 let codes = fitted.transform(&x).unwrap();
687 assert_eq!(codes.dim(), (20, 5));
688 }
689
690 #[test]
691 fn test_dictlearn_reconstruction_error_decreases() {
692 let x = test_data();
693 let dl_few = DictionaryLearning::new(5)
694 .with_max_iter(5)
695 .with_random_state(42);
696 let dl_many = DictionaryLearning::new(5)
697 .with_max_iter(50)
698 .with_random_state(42);
699 let fitted_few = dl_few.fit(&x, &()).unwrap();
700 let fitted_many = dl_many.fit(&x, &()).unwrap();
701 assert!(
702 fitted_many.reconstruction_err() <= fitted_few.reconstruction_err() + 1.0,
703 "more iterations should reduce error: few={}, many={}",
704 fitted_few.reconstruction_err(),
705 fitted_many.reconstruction_err()
706 );
707 }
708
709 #[test]
710 fn test_dictlearn_dictionary_atoms_normalised() {
711 let x = test_data();
712 let dl = DictionaryLearning::new(5)
713 .with_max_iter(20)
714 .with_random_state(42);
715 let fitted = dl.fit(&x, &()).unwrap();
716 let d = fitted.components();
717 for k in 0..d.nrows() {
718 let norm: f64 = d.row(k).iter().map(|v| v * v).sum::<f64>().sqrt();
719 assert!(
720 (norm - 1.0).abs() < 1e-6,
721 "atom {k} should be unit norm, got {norm}"
722 );
723 }
724 }
725
726 #[test]
727 fn test_dictlearn_sparsity_of_codes() {
728 let x = test_data();
729 let dl = DictionaryLearning::new(8)
730 .with_alpha(2.0) .with_max_iter(20)
732 .with_random_state(42);
733 let fitted = dl.fit(&x, &()).unwrap();
734 let codes = fitted.transform(&x).unwrap();
735 let total = codes.len();
737 let zeros = codes.iter().filter(|&&v| v.abs() < 1e-10).count();
738 let sparsity = zeros as f64 / total as f64;
739 assert!(
740 sparsity > 0.1,
741 "codes should have some sparsity, got {:.1}%",
742 sparsity * 100.0
743 );
744 }
745
746 #[test]
747 fn test_dictlearn_omp_transform() {
748 let x = test_data();
749 let dl = DictionaryLearning::new(5)
750 .with_max_iter(20)
751 .with_transform_algorithm(DictTransformAlgorithm::Omp)
752 .with_random_state(42);
753 let fitted = dl.fit(&x, &()).unwrap();
754 let codes = fitted.transform(&x).unwrap();
755 assert_eq!(codes.dim(), (20, 5));
756 }
757
758 #[test]
759 fn test_dictlearn_lasso_cd_transform() {
760 let x = test_data();
761 let dl = DictionaryLearning::new(5)
762 .with_max_iter(20)
763 .with_transform_algorithm(DictTransformAlgorithm::LassoCd)
764 .with_random_state(42);
765 let fitted = dl.fit(&x, &()).unwrap();
766 let codes = fitted.transform(&x).unwrap();
767 assert_eq!(codes.dim(), (20, 5));
768 }
769
770 #[test]
771 fn test_dictlearn_transform_shape_mismatch() {
772 let x = test_data();
773 let dl = DictionaryLearning::new(5)
774 .with_max_iter(10)
775 .with_random_state(42);
776 let fitted = dl.fit(&x, &()).unwrap();
777 let x_bad = Array2::<f64>::zeros((5, 3)); assert!(fitted.transform(&x_bad).is_err());
779 }
780
781 #[test]
782 fn test_dictlearn_invalid_n_components_zero() {
783 let x = test_data();
784 let dl = DictionaryLearning::new(0);
785 assert!(dl.fit(&x, &()).is_err());
786 }
787
788 #[test]
789 fn test_dictlearn_invalid_alpha_negative() {
790 let x = test_data();
791 let dl = DictionaryLearning::new(5).with_alpha(-1.0);
792 assert!(dl.fit(&x, &()).is_err());
793 }
794
795 #[test]
796 fn test_dictlearn_empty_data() {
797 let x = Array2::<f64>::zeros((0, 5));
798 let dl = DictionaryLearning::new(2);
799 assert!(dl.fit(&x, &()).is_err());
800 }
801
802 #[test]
803 fn test_dictlearn_zero_features() {
804 let x = Array2::<f64>::zeros((10, 0));
805 let dl = DictionaryLearning::new(2);
806 assert!(dl.fit(&x, &()).is_err());
807 }
808
809 #[test]
810 fn test_dictlearn_getters() {
811 let dl = DictionaryLearning::new(5)
812 .with_alpha(0.5)
813 .with_max_iter(100)
814 .with_tol(1e-6)
815 .with_fit_algorithm(DictFitAlgorithm::CoordinateDescent)
816 .with_transform_algorithm(DictTransformAlgorithm::LassoCd)
817 .with_random_state(99);
818 assert_eq!(dl.n_components(), 5);
819 assert!((dl.alpha() - 0.5).abs() < 1e-10);
820 assert_eq!(dl.max_iter(), 100);
821 assert!((dl.tol() - 1e-6).abs() < 1e-12);
822 assert_eq!(dl.fit_algorithm(), DictFitAlgorithm::CoordinateDescent);
823 assert_eq!(dl.transform_algorithm(), DictTransformAlgorithm::LassoCd);
824 assert_eq!(dl.random_state(), Some(99));
825 }
826
827 #[test]
828 fn test_dictlearn_fitted_accessors() {
829 let x = test_data();
830 let dl = DictionaryLearning::new(5)
831 .with_max_iter(10)
832 .with_random_state(42);
833 let fitted = dl.fit(&x, &()).unwrap();
834 assert!(fitted.n_iter() > 0);
835 assert!(fitted.reconstruction_err() >= 0.0);
836 }
837
838 #[test]
839 fn test_dictlearn_single_component() {
840 let x = test_data();
841 let dl = DictionaryLearning::new(1)
842 .with_max_iter(20)
843 .with_random_state(42);
844 let fitted = dl.fit(&x, &()).unwrap();
845 assert_eq!(fitted.components().nrows(), 1);
846 let codes = fitted.transform(&x).unwrap();
847 assert_eq!(codes.ncols(), 1);
848 }
849
850 #[test]
851 fn test_dictlearn_omp_nonzero_coefs() {
852 let x = test_data();
853 let dl = DictionaryLearning::new(5)
854 .with_max_iter(20)
855 .with_transform_algorithm(DictTransformAlgorithm::Omp)
856 .with_transform_n_nonzero_coefs(2)
857 .with_random_state(42);
858 let fitted = dl.fit(&x, &()).unwrap();
859 let codes = fitted.transform(&x).unwrap();
860 for i in 0..codes.nrows() {
862 let nnz = codes.row(i).iter().filter(|&&v| v.abs() > 1e-10).count();
863 assert!(nnz <= 2, "row {i} has {nnz} non-zeros, expected at most 2");
864 }
865 }
866
867 #[test]
868 fn test_soft_threshold() {
869 assert!((soft_threshold(5.0, 2.0) - 3.0).abs() < 1e-10);
870 assert!((soft_threshold(-5.0, 2.0) - (-3.0)).abs() < 1e-10);
871 assert!((soft_threshold(1.0, 2.0)).abs() < 1e-10);
872 assert!((soft_threshold(0.0, 2.0)).abs() < 1e-10);
873 }
874}