Skip to main content

ferrolearn_linear/
qda.rs

1//! Quadratic Discriminant Analysis (QDA).
2//!
3//! This module provides [`QDA`], a classifier that models each class with its
4//! own covariance matrix, yielding quadratic decision boundaries. Unlike
5//! [`LDA`](crate::lda::LDA), which assumes a shared covariance matrix, QDA
6//! fits a separate covariance per class.
7//!
8//! # Algorithm
9//!
10//! For each class `k`:
11//! 1. Compute the class mean `mu_k` and covariance `Sigma_k`.
12//! 2. Optionally regularize: `Sigma_k = (1 - reg) * Sigma_k + reg * I`.
13//! 3. Compute the log-posterior:
14//!    `delta_k(x) = -0.5 * log|Sigma_k| - 0.5 * (x - mu_k)^T Sigma_k^{-1} (x - mu_k) + log(prior_k)`.
15//! 4. Predict the class with the largest `delta_k`.
16//!
17//! # Examples
18//!
19//! ```
20//! use ferrolearn_linear::qda::QDA;
21//! use ferrolearn_core::{Fit, Predict};
22//! use ndarray::{array, Array2};
23//!
24//! let x = Array2::from_shape_vec(
25//!     (6, 2),
26//!     vec![1.0, 1.0, 1.5, 1.2, 1.2, 0.8, 5.0, 5.0, 5.5, 4.8, 4.8, 5.2],
27//! ).unwrap();
28//! let y = array![0usize, 0, 0, 1, 1, 1];
29//!
30//! let qda = QDA::<f64>::new();
31//! let fitted = qda.fit(&x, &y).unwrap();
32//! let preds = fitted.predict(&x).unwrap();
33//! assert_eq!(preds.len(), 6);
34//! ```
35
36use ferrolearn_core::error::FerroError;
37use ferrolearn_core::introspection::HasClasses;
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2, ScalarOperand};
40use num_traits::Float;
41
42/// Quadratic Discriminant Analysis configuration.
43///
44/// Holds hyperparameters. Calling [`Fit::fit`] computes per-class means
45/// and covariance matrices and returns a [`FittedQDA`].
46///
47/// # Type Parameters
48///
49/// - `F`: The floating-point scalar type (`f32` or `f64`).
50#[derive(Debug, Clone)]
51pub struct QDA<F> {
52    /// Regularization parameter for covariance matrices.
53    ///
54    /// Blends each class covariance toward the identity:
55    /// `Sigma_k = (1 - reg) * Sigma_k + reg * I`.
56    /// Must be in `[0, 1]`. Default: `0.0`.
57    pub reg_param: F,
58}
59
60impl<F: Float> QDA<F> {
61    /// Create a new `QDA` with default settings.
62    ///
63    /// Default: `reg_param = 0.0`.
64    #[must_use]
65    pub fn new() -> Self {
66        Self {
67            reg_param: F::zero(),
68        }
69    }
70
71    /// Set the regularization parameter.
72    #[must_use]
73    pub fn with_reg_param(mut self, reg_param: F) -> Self {
74        self.reg_param = reg_param;
75        self
76    }
77}
78
79impl<F: Float> Default for QDA<F> {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85/// Per-class model component for QDA.
86#[derive(Debug, Clone)]
87struct QDAClass<F> {
88    /// Class mean, shape `(n_features,)`.
89    mean: Array1<F>,
90    /// Inverse of the (regularized) covariance matrix, shape `(n_features, n_features)`.
91    cov_inv: Array2<F>,
92    /// Log-determinant of the covariance matrix.
93    log_det: F,
94    /// Log-prior probability for this class.
95    log_prior: F,
96}
97
98/// Fitted Quadratic Discriminant Analysis model.
99///
100/// Stores per-class means, covariance inverses, and priors. Implements
101/// [`Predict`] to produce class labels.
102#[derive(Debug, Clone)]
103pub struct FittedQDA<F> {
104    /// Per-class QDA models.
105    class_models: Vec<QDAClass<F>>,
106    /// Sorted unique class labels.
107    classes: Vec<usize>,
108    /// Number of features seen during fitting.
109    n_features: usize,
110}
111
112impl<F: Float> FittedQDA<F> {
113    /// Returns the class means, one per class.
114    #[must_use]
115    pub fn means(&self) -> Vec<&Array1<F>> {
116        self.class_models.iter().map(|m| &m.mean).collect()
117    }
118}
119
120impl<F: Float + ndarray::ScalarOperand + Send + Sync + 'static> FittedQDA<F> {
121    /// Predict per-class probabilities. Mirrors sklearn
122    /// `QuadraticDiscriminantAnalysis.predict_proba`.
123    ///
124    /// Computes softmax over the per-class quadratic discriminants
125    /// `δ_c(x) = -½ log|Σ_c| - ½ (x-μ_c)ᵀ Σ_c⁻¹ (x-μ_c) + log π_c`.
126    /// Returns shape `(n_samples, n_classes)`; rows sum to 1.
127    ///
128    /// # Errors
129    ///
130    /// Returns [`FerroError::ShapeMismatch`] if the number of features
131    /// does not match the fitted model.
132    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
133        let n_features = x.ncols();
134        if n_features != self.n_features {
135            return Err(FerroError::ShapeMismatch {
136                expected: vec![self.n_features],
137                actual: vec![n_features],
138                context: "number of features must match fitted model".into(),
139            });
140        }
141        let n_samples = x.nrows();
142        let n_classes = self.classes.len();
143        let half = F::from(0.5).unwrap();
144        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
145        for i in 0..n_samples {
146            let xi = x.row(i);
147            let mut logits = vec![F::neg_infinity(); n_classes];
148            for (c, model) in self.class_models.iter().enumerate() {
149                let diff: Array1<F> = xi.to_owned() - &model.mean;
150                let mahal = diff.dot(&model.cov_inv.dot(&diff));
151                logits[c] = -half * model.log_det - half * mahal + model.log_prior;
152            }
153            let max_l = logits
154                .iter()
155                .copied()
156                .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
157            let mut sum_exp = F::zero();
158            for c in 0..n_classes {
159                let e = (logits[c] - max_l).exp();
160                proba[[i, c]] = e;
161                sum_exp = sum_exp + e;
162            }
163            for c in 0..n_classes {
164                proba[[i, c]] = proba[[i, c]] / sum_exp;
165            }
166        }
167        Ok(proba)
168    }
169
170    /// Element-wise log of [`predict_proba`](Self::predict_proba).
171    ///
172    /// # Errors
173    ///
174    /// Forwards any error from [`predict_proba`](Self::predict_proba).
175    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
176        let proba = self.predict_proba(x)?;
177        Ok(crate::log_proba(&proba))
178    }
179
180    /// Per-class quadratic discriminant scores. Mirrors sklearn
181    /// `QuadraticDiscriminantAnalysis.decision_function`. Returns shape
182    /// `(n_samples, n_classes)` with `δ_c(x) = -½ log|Σ_c| - ½ (x-μ_c)ᵀ
183    /// Σ_c⁻¹ (x-μ_c) + log π_c`. argmax of each row agrees with [`Predict`].
184    ///
185    /// # Errors
186    ///
187    /// Returns [`FerroError::ShapeMismatch`] if the number of features
188    /// does not match the fitted model.
189    pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
190        let n_features = x.ncols();
191        if n_features != self.n_features {
192            return Err(FerroError::ShapeMismatch {
193                expected: vec![self.n_features],
194                actual: vec![n_features],
195                context: "number of features must match fitted model".into(),
196            });
197        }
198        let n_samples = x.nrows();
199        let n_classes = self.classes.len();
200        let half = F::from(0.5).unwrap();
201        let mut out = Array2::<F>::zeros((n_samples, n_classes));
202        for i in 0..n_samples {
203            let xi = x.row(i);
204            for (c, model) in self.class_models.iter().enumerate() {
205                let diff: Array1<F> = xi.to_owned() - &model.mean;
206                let mahal = diff.dot(&model.cov_inv.dot(&diff));
207                out[[i, c]] = -half * model.log_det - half * mahal + model.log_prior;
208            }
209        }
210        Ok(out)
211    }
212}
213
214/// Compute the inverse and log-determinant of a symmetric positive-definite
215/// matrix via Cholesky decomposition.
216fn cholesky_inv_and_logdet<F: Float + 'static>(
217    a: &Array2<F>,
218) -> Result<(Array2<F>, F), FerroError> {
219    let n = a.nrows();
220    let mut l = Array2::<F>::zeros((n, n));
221
222    // Cholesky decomposition: A = L L^T.
223    for i in 0..n {
224        for j in 0..=i {
225            let mut s = a[[i, j]];
226            for k in 0..j {
227                s = s - l[[i, k]] * l[[j, k]];
228            }
229            if i == j {
230                if s <= F::zero() {
231                    return Err(FerroError::NumericalInstability {
232                        message: "covariance matrix is not positive definite".into(),
233                    });
234                }
235                l[[i, j]] = s.sqrt();
236            } else {
237                l[[i, j]] = s / l[[j, j]];
238            }
239        }
240    }
241
242    // Log-determinant: log|A| = 2 * sum(log(diag(L))).
243    let two = F::from(2.0).unwrap();
244    let log_det = (0..n)
245        .map(|i| l[[i, i]].ln())
246        .fold(F::zero(), |a, b| a + b)
247        * two;
248
249    // Compute L^{-1} by forward substitution on each column of I.
250    let mut l_inv = Array2::<F>::zeros((n, n));
251    for col in 0..n {
252        l_inv[[col, col]] = F::one() / l[[col, col]];
253        for i in (col + 1)..n {
254            let mut s = F::zero();
255            for k in col..i {
256                s = s + l[[i, k]] * l_inv[[k, col]];
257            }
258            l_inv[[i, col]] = -s / l[[i, i]];
259        }
260    }
261
262    // A^{-1} = L^{-T} L^{-1}.
263    let a_inv = l_inv.t().dot(&l_inv);
264
265    Ok((a_inv, log_det))
266}
267
268impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
269    for QDA<F>
270{
271    type Fitted = FittedQDA<F>;
272    type Error = FerroError;
273
274    /// Fit the QDA model by computing per-class means and covariances.
275    ///
276    /// # Errors
277    ///
278    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
279    /// - [`FerroError::InsufficientSamples`] — fewer than 2 distinct classes
280    ///   or a class has too few samples.
281    /// - [`FerroError::InvalidParameter`] — `reg_param` not in `[0, 1]`.
282    /// - [`FerroError::NumericalInstability`] — covariance matrix is singular.
283    fn fit(
284        &self,
285        x: &Array2<F>,
286        y: &Array1<usize>,
287    ) -> Result<FittedQDA<F>, FerroError> {
288        let (n_samples, n_features) = x.dim();
289
290        if n_samples != y.len() {
291            return Err(FerroError::ShapeMismatch {
292                expected: vec![n_samples],
293                actual: vec![y.len()],
294                context: "y length must match number of samples in X".into(),
295            });
296        }
297
298        if self.reg_param < F::zero() || self.reg_param > F::one() {
299            return Err(FerroError::InvalidParameter {
300                name: "reg_param".into(),
301                reason: "must be in [0, 1]".into(),
302            });
303        }
304
305        let mut classes: Vec<usize> = y.to_vec();
306        classes.sort_unstable();
307        classes.dedup();
308
309        if classes.len() < 2 {
310            return Err(FerroError::InsufficientSamples {
311                required: 2,
312                actual: classes.len(),
313                context: "QDA requires at least 2 distinct classes".into(),
314            });
315        }
316
317        let n_f = F::from(n_samples).unwrap();
318        let mut class_models = Vec::with_capacity(classes.len());
319
320        for &cls in &classes {
321            // Extract samples for this class.
322            let indices: Vec<usize> = y
323                .iter()
324                .enumerate()
325                .filter(|&(_, label)| *label == cls)
326                .map(|(i, _)| i)
327                .collect();
328
329            let n_k = indices.len();
330            if n_k < 2 {
331                return Err(FerroError::InsufficientSamples {
332                    required: 2,
333                    actual: n_k,
334                    context: format!("class {cls} needs at least 2 samples for QDA"),
335                });
336            }
337
338            let n_k_f = F::from(n_k).unwrap();
339
340            // Compute class mean.
341            let mut mean = Array1::<F>::zeros(n_features);
342            for &i in &indices {
343                for j in 0..n_features {
344                    mean[j] = mean[j] + x[[i, j]];
345                }
346            }
347            mean.mapv_inplace(|v| v / n_k_f);
348
349            // Compute class covariance.
350            let mut cov = Array2::<F>::zeros((n_features, n_features));
351            for &i in &indices {
352                let diff: Array1<F> = x.row(i).to_owned() - &mean;
353                for r in 0..n_features {
354                    for c in 0..n_features {
355                        cov[[r, c]] = cov[[r, c]] + diff[r] * diff[c];
356                    }
357                }
358            }
359            // Use unbiased estimator: divide by (n_k - 1).
360            let divisor = F::from(n_k - 1).unwrap();
361            cov.mapv_inplace(|v| v / divisor);
362
363            // Regularize: Sigma_k = (1 - reg) * Sigma_k + reg * I.
364            if self.reg_param > F::zero() {
365                let one_minus = F::one() - self.reg_param;
366                for r in 0..n_features {
367                    for c in 0..n_features {
368                        cov[[r, c]] = cov[[r, c]] * one_minus;
369                    }
370                    cov[[r, r]] = cov[[r, r]] + self.reg_param;
371                }
372            }
373
374            // Compute inverse and log-determinant.
375            let (cov_inv, log_det) = cholesky_inv_and_logdet(&cov)?;
376
377            let log_prior = (n_k_f / n_f).ln();
378
379            class_models.push(QDAClass {
380                mean,
381                cov_inv,
382                log_det,
383                log_prior,
384            });
385        }
386
387        Ok(FittedQDA {
388            class_models,
389            classes,
390            n_features,
391        })
392    }
393}
394
395impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
396    for FittedQDA<F>
397{
398    type Output = Array1<usize>;
399    type Error = FerroError;
400
401    /// Predict class labels for the given feature matrix.
402    ///
403    /// Selects the class with the largest log-posterior for each sample.
404    ///
405    /// # Errors
406    ///
407    /// Returns [`FerroError::ShapeMismatch`] if the number of features
408    /// does not match the fitted model.
409    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
410        let n_features = x.ncols();
411        if n_features != self.n_features {
412            return Err(FerroError::ShapeMismatch {
413                expected: vec![self.n_features],
414                actual: vec![n_features],
415                context: "number of features must match fitted model".into(),
416            });
417        }
418
419        let n_samples = x.nrows();
420        let mut predictions = Array1::<usize>::zeros(n_samples);
421        let half = F::from(0.5).unwrap();
422
423        for i in 0..n_samples {
424            let xi = x.row(i);
425            let mut best_class = 0;
426            let mut best_score = F::neg_infinity();
427
428            for (c, model) in self.class_models.iter().enumerate() {
429                let diff: Array1<F> = xi.to_owned() - &model.mean;
430                // Mahalanobis: diff^T * cov_inv * diff
431                let mahal = diff.dot(&model.cov_inv.dot(&diff));
432                let score = -half * model.log_det - half * mahal + model.log_prior;
433
434                if score > best_score {
435                    best_score = score;
436                    best_class = c;
437                }
438            }
439
440            predictions[i] = self.classes[best_class];
441        }
442
443        Ok(predictions)
444    }
445}
446
447impl<F: Float + Send + Sync + ScalarOperand + 'static> HasClasses for FittedQDA<F> {
448    fn classes(&self) -> &[usize] {
449        &self.classes
450    }
451
452    fn n_classes(&self) -> usize {
453        self.classes.len()
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use ndarray::array;
461
462    #[test]
463    fn test_default_constructor() {
464        let m = QDA::<f64>::new();
465        assert!(m.reg_param == 0.0);
466    }
467
468    #[test]
469    fn test_builder() {
470        let m = QDA::<f64>::new().with_reg_param(0.5);
471        assert!(m.reg_param == 0.5);
472    }
473
474    #[test]
475    fn test_binary_classification() {
476        let x = Array2::from_shape_vec(
477            (8, 2),
478            vec![
479                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
480                8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
481            ],
482        )
483        .unwrap();
484        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
485
486        let model = QDA::<f64>::new();
487        let fitted = model.fit(&x, &y).unwrap();
488        let preds = fitted.predict(&x).unwrap();
489
490        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
491        assert!(correct >= 6, "expected at least 6 correct, got {correct}");
492    }
493
494    #[test]
495    fn test_multiclass_classification() {
496        let x = Array2::from_shape_vec(
497            (12, 2),
498            vec![
499                0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.5,
500                10.0, 0.0, 10.5, 0.0, 10.0, 0.5, 10.5, 0.5,
501                0.0, 10.0, 0.5, 10.0, 0.0, 10.5, 0.5, 10.5,
502            ],
503        )
504        .unwrap();
505        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
506
507        let model = QDA::<f64>::new();
508        let fitted = model.fit(&x, &y).unwrap();
509
510        assert_eq!(fitted.n_classes(), 3);
511        assert_eq!(fitted.classes(), &[0, 1, 2]);
512
513        let preds = fitted.predict(&x).unwrap();
514        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
515        assert!(correct >= 10, "expected at least 10 correct, got {correct}");
516    }
517
518    #[test]
519    fn test_regularization() {
520        let x = Array2::from_shape_vec(
521            (8, 2),
522            vec![
523                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
524                8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
525            ],
526        )
527        .unwrap();
528        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
529
530        // With regularization should still work.
531        let model = QDA::<f64>::new().with_reg_param(0.5);
532        let fitted = model.fit(&x, &y).unwrap();
533        let preds = fitted.predict(&x).unwrap();
534        assert_eq!(preds.len(), 8);
535    }
536
537    #[test]
538    fn test_shape_mismatch() {
539        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
540        let y = array![0, 1]; // Wrong length
541
542        let model = QDA::<f64>::new();
543        assert!(model.fit(&x, &y).is_err());
544    }
545
546    #[test]
547    fn test_single_class_error() {
548        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
549        let y = array![0, 0, 0];
550
551        let model = QDA::<f64>::new();
552        assert!(model.fit(&x, &y).is_err());
553    }
554
555    #[test]
556    fn test_invalid_reg_param() {
557        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
558        let y = array![0, 0, 1, 1];
559
560        let model = QDA::<f64>::new().with_reg_param(-0.1);
561        assert!(model.fit(&x, &y).is_err());
562
563        let model2 = QDA::<f64>::new().with_reg_param(1.5);
564        assert!(model2.fit(&x, &y).is_err());
565    }
566
567    #[test]
568    fn test_predict_feature_mismatch() {
569        let x = Array2::from_shape_vec(
570            (8, 2),
571            vec![
572                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
573                8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
574            ],
575        )
576        .unwrap();
577        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
578
579        let fitted = QDA::<f64>::new().fit(&x, &y).unwrap();
580
581        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
582        assert!(fitted.predict(&x_bad).is_err());
583    }
584
585    #[test]
586    fn test_has_classes() {
587        let x = Array2::from_shape_vec(
588            (8, 2),
589            vec![
590                1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
591                8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
592            ],
593        )
594        .unwrap();
595        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
596
597        let fitted = QDA::<f64>::new().fit(&x, &y).unwrap();
598        assert_eq!(fitted.classes(), &[0, 1]);
599        assert_eq!(fitted.n_classes(), 2);
600    }
601
602    #[test]
603    fn test_means() {
604        let x = Array2::from_shape_vec(
605            (4, 1),
606            vec![1.0, 2.0, 5.0, 6.0],
607        )
608        .unwrap();
609        let y = array![0, 0, 1, 1];
610
611        let fitted = QDA::<f64>::new().with_reg_param(0.1).fit(&x, &y).unwrap();
612        let means = fitted.means();
613        assert_eq!(means.len(), 2);
614    }
615
616    #[test]
617    fn test_class_with_too_few_samples() {
618        let x = Array2::from_shape_vec(
619            (3, 1),
620            vec![1.0, 5.0, 6.0],
621        )
622        .unwrap();
623        let y = array![0, 1, 1]; // class 0 has only 1 sample
624
625        let model = QDA::<f64>::new();
626        assert!(model.fit(&x, &y).is_err());
627    }
628}