Skip to main content

ferrolearn_decomp/
dictionary_learning.rs

1//! Dictionary Learning.
2//!
3//! [`DictionaryLearning`] learns a dictionary `D` and sparse codes `A` such
4//! that `X ~ A * D`. The dictionary atoms form an overcomplete basis, and
5//! the codes are encouraged to be sparse via an L1 penalty.
6//!
7//! # Algorithm
8//!
9//! Alternating optimisation:
10//!
11//! 1. **Sparse coding step**: with `D` fixed, solve for `A` using coordinate
12//!    descent (lasso) or orthogonal matching pursuit (OMP).
13//! 2. **Dictionary update step**: with `A` fixed, update `D` by solving a
14//!    least-squares problem and normalising the atoms.
15//!
16//! # Examples
17//!
18//! ```
19//! use ferrolearn_decomp::DictionaryLearning;
20//! use ferrolearn_core::traits::{Fit, Transform};
21//! use ndarray::Array2;
22//!
23//! let x = Array2::<f64>::from_shape_fn((20, 10), |(i, j)| {
24//!     ((i * 7 + j * 3) % 11) as f64
25//! });
26//! let dl = DictionaryLearning::new(5)
27//!     .with_max_iter(50)
28//!     .with_random_state(42);
29//! let fitted = dl.fit(&x, &()).unwrap();
30//! let codes = fitted.transform(&x).unwrap();
31//! assert_eq!(codes.dim(), (20, 5));
32//! ```
33
34use ferrolearn_core::error::FerroError;
35use ferrolearn_core::traits::{Fit, Transform};
36use ndarray::Array2;
37use rand::SeedableRng;
38use rand_distr::{Distribution, Normal};
39use rand_xoshiro::Xoshiro256PlusPlus;
40
41// ---------------------------------------------------------------------------
42// Algorithm enums
43// ---------------------------------------------------------------------------
44
45/// The algorithm for the sparse coding step during fitting.
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum DictFitAlgorithm {
48    /// Coordinate descent (lasso).
49    CoordinateDescent,
50}
51
52/// The algorithm for the sparse coding step during transform.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum DictTransformAlgorithm {
55    /// Orthogonal Matching Pursuit.
56    Omp,
57    /// Coordinate descent (lasso).
58    LassoCd,
59}
60
61// ---------------------------------------------------------------------------
62// DictionaryLearning (unfitted)
63// ---------------------------------------------------------------------------
64
65/// Dictionary Learning configuration.
66///
67/// Holds hyperparameters for the dictionary learning algorithm. Calling
68/// [`Fit::fit`] learns a dictionary and returns a [`FittedDictionaryLearning`].
69#[derive(Debug, Clone)]
70pub struct DictionaryLearning {
71    /// Number of dictionary atoms (components).
72    n_components: usize,
73    /// Sparsity penalty (L1 coefficient). Default 1.0.
74    alpha: f64,
75    /// Maximum number of alternating optimisation iterations. Default 1000.
76    max_iter: usize,
77    /// Convergence tolerance. Default 1e-8.
78    tol: f64,
79    /// Algorithm for fitting. Default coordinate descent.
80    fit_algorithm: DictFitAlgorithm,
81    /// Algorithm for transform. Default OMP.
82    transform_algorithm: DictTransformAlgorithm,
83    /// Maximum atoms per sample for OMP. Default n_components.
84    transform_n_nonzero_coefs: Option<usize>,
85    /// Optional random seed.
86    random_state: Option<u64>,
87}
88
89impl DictionaryLearning {
90    /// Create a new `DictionaryLearning` with `n_components` atoms.
91    ///
92    /// Defaults: `alpha=1.0`, `max_iter=1000`, `tol=1e-8`,
93    /// `fit_algorithm=CoordinateDescent`, `transform_algorithm=Omp`.
94    #[must_use]
95    pub fn new(n_components: usize) -> Self {
96        Self {
97            n_components,
98            alpha: 1.0,
99            max_iter: 1000,
100            tol: 1e-8,
101            fit_algorithm: DictFitAlgorithm::CoordinateDescent,
102            transform_algorithm: DictTransformAlgorithm::Omp,
103            transform_n_nonzero_coefs: None,
104            random_state: None,
105        }
106    }
107
108    /// Set the sparsity penalty.
109    #[must_use]
110    pub fn with_alpha(mut self, alpha: f64) -> Self {
111        self.alpha = alpha;
112        self
113    }
114
115    /// Set the maximum number of iterations.
116    #[must_use]
117    pub fn with_max_iter(mut self, n: usize) -> Self {
118        self.max_iter = n;
119        self
120    }
121
122    /// Set the convergence tolerance.
123    #[must_use]
124    pub fn with_tol(mut self, tol: f64) -> Self {
125        self.tol = tol;
126        self
127    }
128
129    /// Set the fit algorithm.
130    #[must_use]
131    pub fn with_fit_algorithm(mut self, algo: DictFitAlgorithm) -> Self {
132        self.fit_algorithm = algo;
133        self
134    }
135
136    /// Set the transform algorithm.
137    #[must_use]
138    pub fn with_transform_algorithm(mut self, algo: DictTransformAlgorithm) -> Self {
139        self.transform_algorithm = algo;
140        self
141    }
142
143    /// Set the maximum number of non-zero coefficients for OMP transform.
144    #[must_use]
145    pub fn with_transform_n_nonzero_coefs(mut self, n: usize) -> Self {
146        self.transform_n_nonzero_coefs = Some(n);
147        self
148    }
149
150    /// Set the random seed.
151    #[must_use]
152    pub fn with_random_state(mut self, seed: u64) -> Self {
153        self.random_state = Some(seed);
154        self
155    }
156
157    /// Return the configured number of components.
158    #[must_use]
159    pub fn n_components(&self) -> usize {
160        self.n_components
161    }
162
163    /// Return the configured alpha.
164    #[must_use]
165    pub fn alpha(&self) -> f64 {
166        self.alpha
167    }
168
169    /// Return the configured maximum iterations.
170    #[must_use]
171    pub fn max_iter(&self) -> usize {
172        self.max_iter
173    }
174
175    /// Return the configured tolerance.
176    #[must_use]
177    pub fn tol(&self) -> f64 {
178        self.tol
179    }
180
181    /// Return the configured fit algorithm.
182    #[must_use]
183    pub fn fit_algorithm(&self) -> DictFitAlgorithm {
184        self.fit_algorithm
185    }
186
187    /// Return the configured transform algorithm.
188    #[must_use]
189    pub fn transform_algorithm(&self) -> DictTransformAlgorithm {
190        self.transform_algorithm
191    }
192
193    /// Return the configured random state, if any.
194    #[must_use]
195    pub fn random_state(&self) -> Option<u64> {
196        self.random_state
197    }
198}
199
200// ---------------------------------------------------------------------------
201// FittedDictionaryLearning
202// ---------------------------------------------------------------------------
203
204/// A fitted dictionary learning model.
205///
206/// Created by calling [`Fit::fit`] on a [`DictionaryLearning`]. The learned
207/// dictionary is accessible via [`FittedDictionaryLearning::components`].
208/// Implements [`Transform<Array2<f64>>`] to compute sparse codes for new data.
209#[derive(Debug, Clone)]
210pub struct FittedDictionaryLearning {
211    /// Learned dictionary, shape `(n_components, n_features)`.
212    /// Each row is a dictionary atom.
213    components_: Array2<f64>,
214    /// Sparsity penalty used during fitting.
215    alpha_: f64,
216    /// Number of iterations performed.
217    n_iter_: usize,
218    /// Final reconstruction error (Frobenius norm).
219    reconstruction_err_: f64,
220    /// Transform algorithm to use.
221    transform_algorithm_: DictTransformAlgorithm,
222    /// Max non-zero coefs for OMP.
223    transform_n_nonzero_coefs_: usize,
224}
225
226impl FittedDictionaryLearning {
227    /// The learned dictionary, shape `(n_components, n_features)`.
228    #[must_use]
229    pub fn components(&self) -> &Array2<f64> {
230        &self.components_
231    }
232
233    /// Number of iterations performed.
234    #[must_use]
235    pub fn n_iter(&self) -> usize {
236        self.n_iter_
237    }
238
239    /// The reconstruction error at convergence.
240    #[must_use]
241    pub fn reconstruction_err(&self) -> f64 {
242        self.reconstruction_err_
243    }
244}
245
246// ---------------------------------------------------------------------------
247// Internal helpers
248// ---------------------------------------------------------------------------
249
250/// Normalise dictionary rows to unit L2 norm.
251fn normalise_dictionary(d: &mut Array2<f64>) {
252    let n_components = d.nrows();
253    let n_features = d.ncols();
254    for k in 0..n_components {
255        let mut norm = 0.0;
256        for j in 0..n_features {
257            norm += d[[k, j]] * d[[k, j]];
258        }
259        let norm = norm.sqrt();
260        if norm > 1e-16 {
261            for j in 0..n_features {
262                d[[k, j]] /= norm;
263            }
264        }
265    }
266}
267
268/// Lasso coordinate descent for a single sample: solve
269///   min_a 0.5 * ||x - D^T a||^2 + alpha * ||a||_1
270/// where x is (n_features,), D is (n_components, n_features), a is (n_components,).
271fn lasso_cd_single(x_row: &[f64], d: &Array2<f64>, alpha: f64, max_iter: usize) -> Vec<f64> {
272    let n_components = d.nrows();
273    let n_features = d.ncols();
274    let mut a = vec![0.0; n_components];
275
276    // Pre-compute D * D^T (Gram matrix) and D * x.
277    let mut gram = vec![vec![0.0; n_components]; n_components];
278    let mut dx = vec![0.0; n_components];
279    for k in 0..n_components {
280        for j in 0..n_features {
281            dx[k] += d[[k, j]] * x_row[j];
282        }
283        for l in k..n_components {
284            let mut val = 0.0;
285            for j in 0..n_features {
286                val += d[[k, j]] * d[[l, j]];
287            }
288            gram[k][l] = val;
289            gram[l][k] = val;
290        }
291    }
292
293    for _iter in 0..max_iter {
294        let mut max_change = 0.0;
295        for k in 0..n_components {
296            // Compute partial residual: dx[k] - sum_{l!=k} gram[k][l] * a[l]
297            let mut rho = dx[k];
298            for l in 0..n_components {
299                if l != k {
300                    rho -= gram[k][l] * a[l];
301                }
302            }
303
304            // Soft threshold.
305            let gram_kk = gram[k][k];
306            let new_a = if gram_kk.abs() < 1e-16 {
307                0.0
308            } else {
309                soft_threshold(rho, alpha) / gram_kk
310            };
311
312            let change = (new_a - a[k]).abs();
313            if change > max_change {
314                max_change = change;
315            }
316            a[k] = new_a;
317        }
318        if max_change < 1e-6 {
319            break;
320        }
321    }
322
323    a
324}
325
326/// Orthogonal Matching Pursuit for a single sample.
327fn omp_single(x_row: &[f64], d: &Array2<f64>, max_nonzero: usize) -> Vec<f64> {
328    let n_components = d.nrows();
329    let n_features = d.ncols();
330    let mut a = vec![0.0; n_components];
331    let mut residual: Vec<f64> = x_row.to_vec();
332    let mut selected: Vec<usize> = Vec::new();
333    let max_k = max_nonzero.min(n_components).min(n_features);
334
335    for _step in 0..max_k {
336        // Find the atom most correlated with the residual.
337        let mut best_idx = 0;
338        let mut best_corr = 0.0;
339        for k in 0..n_components {
340            if selected.contains(&k) {
341                continue;
342            }
343            let mut corr = 0.0;
344            for j in 0..n_features {
345                corr += d[[k, j]] * residual[j];
346            }
347            if corr.abs() > best_corr {
348                best_corr = corr.abs();
349                best_idx = k;
350            }
351        }
352
353        if best_corr < 1e-12 {
354            break;
355        }
356
357        selected.push(best_idx);
358
359        // Solve least squares: x = D_selected^T * a_selected
360        // Use normal equations: (D_s D_s^T) a_s = D_s x
361        let m = selected.len();
362        let mut gram = vec![vec![0.0; m]; m];
363        let mut rhs = vec![0.0; m];
364        for (ii, &ki) in selected.iter().enumerate() {
365            for j in 0..n_features {
366                rhs[ii] += d[[ki, j]] * x_row[j];
367            }
368            for (jj, &kj) in selected.iter().enumerate() {
369                let mut val = 0.0;
370                for f in 0..n_features {
371                    val += d[[ki, f]] * d[[kj, f]];
372                }
373                gram[ii][jj] = val;
374            }
375        }
376
377        // Solve gram * coefs = rhs via Cholesky-like.
378        if let Some(coefs) = solve_symmetric(&gram, &rhs) {
379            // Update residual.
380            residual = x_row.to_vec();
381            for (ii, &ki) in selected.iter().enumerate() {
382                a[ki] = coefs[ii];
383                for j in 0..n_features {
384                    residual[j] -= coefs[ii] * d[[ki, j]];
385                }
386            }
387        } else {
388            break;
389        }
390
391        // Check if residual is small enough.
392        let res_norm: f64 = residual.iter().map(|v| v * v).sum::<f64>().sqrt();
393        if res_norm < 1e-10 {
394            break;
395        }
396    }
397
398    a
399}
400
401/// Solve a small symmetric positive definite system Ax = b using
402/// Gaussian elimination with partial pivoting. Returns None if singular.
403#[allow(clippy::needless_range_loop)]
404fn solve_symmetric(a: &[Vec<f64>], b: &[f64]) -> Option<Vec<f64>> {
405    let n = b.len();
406    if n == 0 {
407        return Some(vec![]);
408    }
409
410    // Augmented matrix.
411    let mut aug: Vec<Vec<f64>> = Vec::with_capacity(n);
412    for (i, row) in a.iter().enumerate().take(n) {
413        let mut r = row.clone();
414        r.push(b[i]);
415        aug.push(r);
416    }
417
418    // Forward elimination with partial pivoting.
419    for col in 0..n {
420        // Find pivot.
421        let mut max_val = aug[col][col].abs();
422        let mut max_row = col;
423        for row in (col + 1)..n {
424            if aug[row][col].abs() > max_val {
425                max_val = aug[row][col].abs();
426                max_row = row;
427            }
428        }
429        if max_val < 1e-14 {
430            return None;
431        }
432        aug.swap(col, max_row);
433
434        let pivot = aug[col][col];
435        for row in (col + 1)..n {
436            let factor = aug[row][col] / pivot;
437            for j in col..=n {
438                let val = aug[col][j];
439                aug[row][j] -= factor * val;
440            }
441        }
442    }
443
444    // Back substitution.
445    let mut x = vec![0.0; n];
446    for i in (0..n).rev() {
447        let mut sum = aug[i][n];
448        for j in (i + 1)..n {
449            sum -= aug[i][j] * x[j];
450        }
451        x[i] = sum / aug[i][i];
452    }
453
454    Some(x)
455}
456
457/// Soft thresholding: sign(x) * max(|x| - lambda, 0).
458fn soft_threshold(x: f64, lambda: f64) -> f64 {
459    if x > lambda {
460        x - lambda
461    } else if x < -lambda {
462        x + lambda
463    } else {
464        0.0
465    }
466}
467
468/// Compute Frobenius norm of X - A * D.
469fn reconstruction_error(x: &Array2<f64>, a: &Array2<f64>, d: &Array2<f64>) -> f64 {
470    let ad = a.dot(d);
471    let mut err = 0.0;
472    for (xi, adi) in x.iter().zip(ad.iter()) {
473        let diff = xi - adi;
474        err += diff * diff;
475    }
476    err.sqrt()
477}
478
479// ---------------------------------------------------------------------------
480// Trait implementations
481// ---------------------------------------------------------------------------
482
483impl Fit<Array2<f64>, ()> for DictionaryLearning {
484    type Fitted = FittedDictionaryLearning;
485    type Error = FerroError;
486
487    /// Fit the dictionary learning model.
488    ///
489    /// # Errors
490    ///
491    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or
492    ///   `alpha` is negative.
493    /// - [`FerroError::InsufficientSamples`] if there are zero samples or
494    ///   zero features.
495    fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedDictionaryLearning, FerroError> {
496        let (n_samples, n_features) = x.dim();
497
498        // Validate.
499        if self.n_components == 0 {
500            return Err(FerroError::InvalidParameter {
501                name: "n_components".into(),
502                reason: "must be at least 1".into(),
503            });
504        }
505        if n_samples == 0 {
506            return Err(FerroError::InsufficientSamples {
507                required: 1,
508                actual: 0,
509                context: "DictionaryLearning::fit".into(),
510            });
511        }
512        if n_features == 0 {
513            return Err(FerroError::InvalidParameter {
514                name: "X".into(),
515                reason: "must have at least 1 feature".into(),
516            });
517        }
518        if self.alpha < 0.0 {
519            return Err(FerroError::InvalidParameter {
520                name: "alpha".into(),
521                reason: "must be non-negative".into(),
522            });
523        }
524
525        let n_components = self.n_components;
526        let seed = self.random_state.unwrap_or(0);
527        let transform_n_nonzero = self.transform_n_nonzero_coefs.unwrap_or(n_components);
528
529        // Initialise dictionary from random Gaussian, then normalise.
530        let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
531        let normal = Normal::new(0.0, 1.0).unwrap();
532        let mut d = Array2::<f64>::zeros((n_components, n_features));
533        for elem in d.iter_mut() {
534            *elem = normal.sample(&mut rng);
535        }
536        normalise_dictionary(&mut d);
537
538        let mut prev_err = f64::MAX;
539        let mut n_iter = 0;
540
541        for iteration in 0..self.max_iter {
542            n_iter = iteration + 1;
543
544            // Sparse coding step: compute codes A.
545            let mut a = Array2::<f64>::zeros((n_samples, n_components));
546            for i in 0..n_samples {
547                let x_row: Vec<f64> = (0..n_features).map(|j| x[[i, j]]).collect();
548                let codes = lasso_cd_single(&x_row, &d, self.alpha, 200);
549                for k in 0..n_components {
550                    a[[i, k]] = codes[k];
551                }
552            }
553
554            // Dictionary update step: D = (A^T A)^{-1} A^T X
555            // We solve the normal equations for each atom.
556            let ata = a.t().dot(&a);
557            let atx = a.t().dot(x);
558
559            // Solve K x K system for each feature column of D.
560            // Build the Gram matrix as Vec<Vec<f64>>.
561            let gram: Vec<Vec<f64>> = (0..n_components)
562                .map(|i| (0..n_components).map(|j| ata[[i, j]]).collect())
563                .collect();
564
565            // Add small regularisation for stability.
566            let mut gram_reg = gram.clone();
567            for (k, row) in gram_reg.iter_mut().enumerate() {
568                row[k] += 1e-10;
569            }
570
571            for j in 0..n_features {
572                let rhs: Vec<f64> = (0..n_components).map(|k| atx[[k, j]]).collect();
573                if let Some(sol) = solve_symmetric(&gram_reg, &rhs) {
574                    for k in 0..n_components {
575                        d[[k, j]] = sol[k];
576                    }
577                }
578            }
579
580            normalise_dictionary(&mut d);
581
582            // Check convergence.
583            let err = reconstruction_error(x, &a, &d);
584            if (prev_err - err).abs() < self.tol {
585                break;
586            }
587            prev_err = err;
588        }
589
590        // Final sparse coding for reconstruction error.
591        let mut a_final = Array2::<f64>::zeros((n_samples, n_components));
592        for i in 0..n_samples {
593            let x_row: Vec<f64> = (0..n_features).map(|j| x[[i, j]]).collect();
594            let codes = lasso_cd_single(&x_row, &d, self.alpha, 200);
595            for k in 0..n_components {
596                a_final[[i, k]] = codes[k];
597            }
598        }
599        let final_err = reconstruction_error(x, &a_final, &d);
600
601        Ok(FittedDictionaryLearning {
602            components_: d,
603            alpha_: self.alpha,
604            n_iter_: n_iter,
605            reconstruction_err_: final_err,
606            transform_algorithm_: self.transform_algorithm,
607            transform_n_nonzero_coefs_: transform_n_nonzero,
608        })
609    }
610}
611
612impl Transform<Array2<f64>> for FittedDictionaryLearning {
613    type Output = Array2<f64>;
614    type Error = FerroError;
615
616    /// Compute sparse codes for new data using the learned dictionary.
617    ///
618    /// # Errors
619    ///
620    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
621    /// not match the dictionary.
622    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
623        let n_features = self.components_.ncols();
624        if x.ncols() != n_features {
625            return Err(FerroError::ShapeMismatch {
626                expected: vec![x.nrows(), n_features],
627                actual: vec![x.nrows(), x.ncols()],
628                context: "FittedDictionaryLearning::transform".into(),
629            });
630        }
631
632        let n_samples = x.nrows();
633        let n_components = self.components_.nrows();
634        let mut codes = Array2::<f64>::zeros((n_samples, n_components));
635
636        for i in 0..n_samples {
637            let x_row: Vec<f64> = (0..n_features).map(|j| x[[i, j]]).collect();
638            let a = match self.transform_algorithm_ {
639                DictTransformAlgorithm::Omp => {
640                    omp_single(&x_row, &self.components_, self.transform_n_nonzero_coefs_)
641                }
642                DictTransformAlgorithm::LassoCd => {
643                    lasso_cd_single(&x_row, &self.components_, self.alpha_, 200)
644                }
645            };
646            for k in 0..n_components {
647                codes[[i, k]] = a[k];
648            }
649        }
650
651        Ok(codes)
652    }
653}
654
655// ---------------------------------------------------------------------------
656// Tests
657// ---------------------------------------------------------------------------
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use ndarray::Array2;
663
664    /// Create a simple test dataset.
665    fn test_data() -> Array2<f64> {
666        Array2::<f64>::from_shape_fn((20, 10), |(i, j)| ((i * 7 + j * 3) % 11) as f64)
667    }
668
669    #[test]
670    fn test_dictlearn_basic_shape() {
671        let x = test_data();
672        let dl = DictionaryLearning::new(5)
673            .with_max_iter(20)
674            .with_random_state(42);
675        let fitted = dl.fit(&x, &()).unwrap();
676        assert_eq!(fitted.components().dim(), (5, 10));
677    }
678
679    #[test]
680    fn test_dictlearn_transform_shape() {
681        let x = test_data();
682        let dl = DictionaryLearning::new(5)
683            .with_max_iter(20)
684            .with_random_state(42);
685        let fitted = dl.fit(&x, &()).unwrap();
686        let codes = fitted.transform(&x).unwrap();
687        assert_eq!(codes.dim(), (20, 5));
688    }
689
690    #[test]
691    fn test_dictlearn_reconstruction_error_decreases() {
692        let x = test_data();
693        let dl_few = DictionaryLearning::new(5)
694            .with_max_iter(5)
695            .with_random_state(42);
696        let dl_many = DictionaryLearning::new(5)
697            .with_max_iter(50)
698            .with_random_state(42);
699        let fitted_few = dl_few.fit(&x, &()).unwrap();
700        let fitted_many = dl_many.fit(&x, &()).unwrap();
701        assert!(
702            fitted_many.reconstruction_err() <= fitted_few.reconstruction_err() + 1.0,
703            "more iterations should reduce error: few={}, many={}",
704            fitted_few.reconstruction_err(),
705            fitted_many.reconstruction_err()
706        );
707    }
708
709    #[test]
710    fn test_dictlearn_dictionary_atoms_normalised() {
711        let x = test_data();
712        let dl = DictionaryLearning::new(5)
713            .with_max_iter(20)
714            .with_random_state(42);
715        let fitted = dl.fit(&x, &()).unwrap();
716        let d = fitted.components();
717        for k in 0..d.nrows() {
718            let norm: f64 = d.row(k).iter().map(|v| v * v).sum::<f64>().sqrt();
719            assert!(
720                (norm - 1.0).abs() < 1e-6,
721                "atom {k} should be unit norm, got {norm}"
722            );
723        }
724    }
725
726    #[test]
727    fn test_dictlearn_sparsity_of_codes() {
728        let x = test_data();
729        let dl = DictionaryLearning::new(8)
730            .with_alpha(2.0) // Higher alpha = more sparsity.
731            .with_max_iter(20)
732            .with_random_state(42);
733        let fitted = dl.fit(&x, &()).unwrap();
734        let codes = fitted.transform(&x).unwrap();
735        // Count zero entries.
736        let total = codes.len();
737        let zeros = codes.iter().filter(|&&v| v.abs() < 1e-10).count();
738        let sparsity = zeros as f64 / total as f64;
739        assert!(
740            sparsity > 0.1,
741            "codes should have some sparsity, got {:.1}%",
742            sparsity * 100.0
743        );
744    }
745
746    #[test]
747    fn test_dictlearn_omp_transform() {
748        let x = test_data();
749        let dl = DictionaryLearning::new(5)
750            .with_max_iter(20)
751            .with_transform_algorithm(DictTransformAlgorithm::Omp)
752            .with_random_state(42);
753        let fitted = dl.fit(&x, &()).unwrap();
754        let codes = fitted.transform(&x).unwrap();
755        assert_eq!(codes.dim(), (20, 5));
756    }
757
758    #[test]
759    fn test_dictlearn_lasso_cd_transform() {
760        let x = test_data();
761        let dl = DictionaryLearning::new(5)
762            .with_max_iter(20)
763            .with_transform_algorithm(DictTransformAlgorithm::LassoCd)
764            .with_random_state(42);
765        let fitted = dl.fit(&x, &()).unwrap();
766        let codes = fitted.transform(&x).unwrap();
767        assert_eq!(codes.dim(), (20, 5));
768    }
769
770    #[test]
771    fn test_dictlearn_transform_shape_mismatch() {
772        let x = test_data();
773        let dl = DictionaryLearning::new(5)
774            .with_max_iter(10)
775            .with_random_state(42);
776        let fitted = dl.fit(&x, &()).unwrap();
777        let x_bad = Array2::<f64>::zeros((5, 3)); // wrong number of features
778        assert!(fitted.transform(&x_bad).is_err());
779    }
780
781    #[test]
782    fn test_dictlearn_invalid_n_components_zero() {
783        let x = test_data();
784        let dl = DictionaryLearning::new(0);
785        assert!(dl.fit(&x, &()).is_err());
786    }
787
788    #[test]
789    fn test_dictlearn_invalid_alpha_negative() {
790        let x = test_data();
791        let dl = DictionaryLearning::new(5).with_alpha(-1.0);
792        assert!(dl.fit(&x, &()).is_err());
793    }
794
795    #[test]
796    fn test_dictlearn_empty_data() {
797        let x = Array2::<f64>::zeros((0, 5));
798        let dl = DictionaryLearning::new(2);
799        assert!(dl.fit(&x, &()).is_err());
800    }
801
802    #[test]
803    fn test_dictlearn_zero_features() {
804        let x = Array2::<f64>::zeros((10, 0));
805        let dl = DictionaryLearning::new(2);
806        assert!(dl.fit(&x, &()).is_err());
807    }
808
809    #[test]
810    fn test_dictlearn_getters() {
811        let dl = DictionaryLearning::new(5)
812            .with_alpha(0.5)
813            .with_max_iter(100)
814            .with_tol(1e-6)
815            .with_fit_algorithm(DictFitAlgorithm::CoordinateDescent)
816            .with_transform_algorithm(DictTransformAlgorithm::LassoCd)
817            .with_random_state(99);
818        assert_eq!(dl.n_components(), 5);
819        assert!((dl.alpha() - 0.5).abs() < 1e-10);
820        assert_eq!(dl.max_iter(), 100);
821        assert!((dl.tol() - 1e-6).abs() < 1e-12);
822        assert_eq!(dl.fit_algorithm(), DictFitAlgorithm::CoordinateDescent);
823        assert_eq!(dl.transform_algorithm(), DictTransformAlgorithm::LassoCd);
824        assert_eq!(dl.random_state(), Some(99));
825    }
826
827    #[test]
828    fn test_dictlearn_fitted_accessors() {
829        let x = test_data();
830        let dl = DictionaryLearning::new(5)
831            .with_max_iter(10)
832            .with_random_state(42);
833        let fitted = dl.fit(&x, &()).unwrap();
834        assert!(fitted.n_iter() > 0);
835        assert!(fitted.reconstruction_err() >= 0.0);
836    }
837
838    #[test]
839    fn test_dictlearn_single_component() {
840        let x = test_data();
841        let dl = DictionaryLearning::new(1)
842            .with_max_iter(20)
843            .with_random_state(42);
844        let fitted = dl.fit(&x, &()).unwrap();
845        assert_eq!(fitted.components().nrows(), 1);
846        let codes = fitted.transform(&x).unwrap();
847        assert_eq!(codes.ncols(), 1);
848    }
849
850    #[test]
851    fn test_dictlearn_omp_nonzero_coefs() {
852        let x = test_data();
853        let dl = DictionaryLearning::new(5)
854            .with_max_iter(20)
855            .with_transform_algorithm(DictTransformAlgorithm::Omp)
856            .with_transform_n_nonzero_coefs(2)
857            .with_random_state(42);
858        let fitted = dl.fit(&x, &()).unwrap();
859        let codes = fitted.transform(&x).unwrap();
860        // Each row should have at most 2 non-zero entries.
861        for i in 0..codes.nrows() {
862            let nnz = codes.row(i).iter().filter(|&&v| v.abs() > 1e-10).count();
863            assert!(nnz <= 2, "row {i} has {nnz} non-zeros, expected at most 2");
864        }
865    }
866
867    #[test]
868    fn test_soft_threshold() {
869        assert!((soft_threshold(5.0, 2.0) - 3.0).abs() < 1e-10);
870        assert!((soft_threshold(-5.0, 2.0) - (-3.0)).abs() < 1e-10);
871        assert!((soft_threshold(1.0, 2.0)).abs() < 1e-10);
872        assert!((soft_threshold(0.0, 2.0)).abs() < 1e-10);
873    }
874}