1use ferrolearn_core::error::FerroError;
37use ferrolearn_core::introspection::HasClasses;
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2, ScalarOperand};
40use num_traits::Float;
41
42#[derive(Debug, Clone)]
51pub struct QDA<F> {
52 pub reg_param: F,
58}
59
60impl<F: Float> QDA<F> {
61 #[must_use]
65 pub fn new() -> Self {
66 Self {
67 reg_param: F::zero(),
68 }
69 }
70
71 #[must_use]
73 pub fn with_reg_param(mut self, reg_param: F) -> Self {
74 self.reg_param = reg_param;
75 self
76 }
77}
78
79impl<F: Float> Default for QDA<F> {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85#[derive(Debug, Clone)]
87struct QDAClass<F> {
88 mean: Array1<F>,
90 cov_inv: Array2<F>,
92 log_det: F,
94 log_prior: F,
96}
97
98#[derive(Debug, Clone)]
103pub struct FittedQDA<F> {
104 class_models: Vec<QDAClass<F>>,
106 classes: Vec<usize>,
108 n_features: usize,
110}
111
112impl<F: Float> FittedQDA<F> {
113 #[must_use]
115 pub fn means(&self) -> Vec<&Array1<F>> {
116 self.class_models.iter().map(|m| &m.mean).collect()
117 }
118}
119
120impl<F: Float + ndarray::ScalarOperand + Send + Sync + 'static> FittedQDA<F> {
121 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
133 let n_features = x.ncols();
134 if n_features != self.n_features {
135 return Err(FerroError::ShapeMismatch {
136 expected: vec![self.n_features],
137 actual: vec![n_features],
138 context: "number of features must match fitted model".into(),
139 });
140 }
141 let n_samples = x.nrows();
142 let n_classes = self.classes.len();
143 let half = F::from(0.5).unwrap();
144 let mut proba = Array2::<F>::zeros((n_samples, n_classes));
145 for i in 0..n_samples {
146 let xi = x.row(i);
147 let mut logits = vec![F::neg_infinity(); n_classes];
148 for (c, model) in self.class_models.iter().enumerate() {
149 let diff: Array1<F> = xi.to_owned() - &model.mean;
150 let mahal = diff.dot(&model.cov_inv.dot(&diff));
151 logits[c] = -half * model.log_det - half * mahal + model.log_prior;
152 }
153 let max_l = logits
154 .iter()
155 .copied()
156 .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
157 let mut sum_exp = F::zero();
158 for c in 0..n_classes {
159 let e = (logits[c] - max_l).exp();
160 proba[[i, c]] = e;
161 sum_exp = sum_exp + e;
162 }
163 for c in 0..n_classes {
164 proba[[i, c]] = proba[[i, c]] / sum_exp;
165 }
166 }
167 Ok(proba)
168 }
169
170 pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
176 let proba = self.predict_proba(x)?;
177 Ok(crate::log_proba(&proba))
178 }
179
180 pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
190 let n_features = x.ncols();
191 if n_features != self.n_features {
192 return Err(FerroError::ShapeMismatch {
193 expected: vec![self.n_features],
194 actual: vec![n_features],
195 context: "number of features must match fitted model".into(),
196 });
197 }
198 let n_samples = x.nrows();
199 let n_classes = self.classes.len();
200 let half = F::from(0.5).unwrap();
201 let mut out = Array2::<F>::zeros((n_samples, n_classes));
202 for i in 0..n_samples {
203 let xi = x.row(i);
204 for (c, model) in self.class_models.iter().enumerate() {
205 let diff: Array1<F> = xi.to_owned() - &model.mean;
206 let mahal = diff.dot(&model.cov_inv.dot(&diff));
207 out[[i, c]] = -half * model.log_det - half * mahal + model.log_prior;
208 }
209 }
210 Ok(out)
211 }
212}
213
214fn cholesky_inv_and_logdet<F: Float + 'static>(
217 a: &Array2<F>,
218) -> Result<(Array2<F>, F), FerroError> {
219 let n = a.nrows();
220 let mut l = Array2::<F>::zeros((n, n));
221
222 for i in 0..n {
224 for j in 0..=i {
225 let mut s = a[[i, j]];
226 for k in 0..j {
227 s = s - l[[i, k]] * l[[j, k]];
228 }
229 if i == j {
230 if s <= F::zero() {
231 return Err(FerroError::NumericalInstability {
232 message: "covariance matrix is not positive definite".into(),
233 });
234 }
235 l[[i, j]] = s.sqrt();
236 } else {
237 l[[i, j]] = s / l[[j, j]];
238 }
239 }
240 }
241
242 let two = F::from(2.0).unwrap();
244 let log_det = (0..n)
245 .map(|i| l[[i, i]].ln())
246 .fold(F::zero(), |a, b| a + b)
247 * two;
248
249 let mut l_inv = Array2::<F>::zeros((n, n));
251 for col in 0..n {
252 l_inv[[col, col]] = F::one() / l[[col, col]];
253 for i in (col + 1)..n {
254 let mut s = F::zero();
255 for k in col..i {
256 s = s + l[[i, k]] * l_inv[[k, col]];
257 }
258 l_inv[[i, col]] = -s / l[[i, i]];
259 }
260 }
261
262 let a_inv = l_inv.t().dot(&l_inv);
264
265 Ok((a_inv, log_det))
266}
267
268impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
269 for QDA<F>
270{
271 type Fitted = FittedQDA<F>;
272 type Error = FerroError;
273
274 fn fit(
284 &self,
285 x: &Array2<F>,
286 y: &Array1<usize>,
287 ) -> Result<FittedQDA<F>, FerroError> {
288 let (n_samples, n_features) = x.dim();
289
290 if n_samples != y.len() {
291 return Err(FerroError::ShapeMismatch {
292 expected: vec![n_samples],
293 actual: vec![y.len()],
294 context: "y length must match number of samples in X".into(),
295 });
296 }
297
298 if self.reg_param < F::zero() || self.reg_param > F::one() {
299 return Err(FerroError::InvalidParameter {
300 name: "reg_param".into(),
301 reason: "must be in [0, 1]".into(),
302 });
303 }
304
305 let mut classes: Vec<usize> = y.to_vec();
306 classes.sort_unstable();
307 classes.dedup();
308
309 if classes.len() < 2 {
310 return Err(FerroError::InsufficientSamples {
311 required: 2,
312 actual: classes.len(),
313 context: "QDA requires at least 2 distinct classes".into(),
314 });
315 }
316
317 let n_f = F::from(n_samples).unwrap();
318 let mut class_models = Vec::with_capacity(classes.len());
319
320 for &cls in &classes {
321 let indices: Vec<usize> = y
323 .iter()
324 .enumerate()
325 .filter(|&(_, label)| *label == cls)
326 .map(|(i, _)| i)
327 .collect();
328
329 let n_k = indices.len();
330 if n_k < 2 {
331 return Err(FerroError::InsufficientSamples {
332 required: 2,
333 actual: n_k,
334 context: format!("class {cls} needs at least 2 samples for QDA"),
335 });
336 }
337
338 let n_k_f = F::from(n_k).unwrap();
339
340 let mut mean = Array1::<F>::zeros(n_features);
342 for &i in &indices {
343 for j in 0..n_features {
344 mean[j] = mean[j] + x[[i, j]];
345 }
346 }
347 mean.mapv_inplace(|v| v / n_k_f);
348
349 let mut cov = Array2::<F>::zeros((n_features, n_features));
351 for &i in &indices {
352 let diff: Array1<F> = x.row(i).to_owned() - &mean;
353 for r in 0..n_features {
354 for c in 0..n_features {
355 cov[[r, c]] = cov[[r, c]] + diff[r] * diff[c];
356 }
357 }
358 }
359 let divisor = F::from(n_k - 1).unwrap();
361 cov.mapv_inplace(|v| v / divisor);
362
363 if self.reg_param > F::zero() {
365 let one_minus = F::one() - self.reg_param;
366 for r in 0..n_features {
367 for c in 0..n_features {
368 cov[[r, c]] = cov[[r, c]] * one_minus;
369 }
370 cov[[r, r]] = cov[[r, r]] + self.reg_param;
371 }
372 }
373
374 let (cov_inv, log_det) = cholesky_inv_and_logdet(&cov)?;
376
377 let log_prior = (n_k_f / n_f).ln();
378
379 class_models.push(QDAClass {
380 mean,
381 cov_inv,
382 log_det,
383 log_prior,
384 });
385 }
386
387 Ok(FittedQDA {
388 class_models,
389 classes,
390 n_features,
391 })
392 }
393}
394
395impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
396 for FittedQDA<F>
397{
398 type Output = Array1<usize>;
399 type Error = FerroError;
400
401 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
410 let n_features = x.ncols();
411 if n_features != self.n_features {
412 return Err(FerroError::ShapeMismatch {
413 expected: vec![self.n_features],
414 actual: vec![n_features],
415 context: "number of features must match fitted model".into(),
416 });
417 }
418
419 let n_samples = x.nrows();
420 let mut predictions = Array1::<usize>::zeros(n_samples);
421 let half = F::from(0.5).unwrap();
422
423 for i in 0..n_samples {
424 let xi = x.row(i);
425 let mut best_class = 0;
426 let mut best_score = F::neg_infinity();
427
428 for (c, model) in self.class_models.iter().enumerate() {
429 let diff: Array1<F> = xi.to_owned() - &model.mean;
430 let mahal = diff.dot(&model.cov_inv.dot(&diff));
432 let score = -half * model.log_det - half * mahal + model.log_prior;
433
434 if score > best_score {
435 best_score = score;
436 best_class = c;
437 }
438 }
439
440 predictions[i] = self.classes[best_class];
441 }
442
443 Ok(predictions)
444 }
445}
446
447impl<F: Float + Send + Sync + ScalarOperand + 'static> HasClasses for FittedQDA<F> {
448 fn classes(&self) -> &[usize] {
449 &self.classes
450 }
451
452 fn n_classes(&self) -> usize {
453 self.classes.len()
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use ndarray::array;
461
462 #[test]
463 fn test_default_constructor() {
464 let m = QDA::<f64>::new();
465 assert!(m.reg_param == 0.0);
466 }
467
468 #[test]
469 fn test_builder() {
470 let m = QDA::<f64>::new().with_reg_param(0.5);
471 assert!(m.reg_param == 0.5);
472 }
473
474 #[test]
475 fn test_binary_classification() {
476 let x = Array2::from_shape_vec(
477 (8, 2),
478 vec![
479 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
480 8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
481 ],
482 )
483 .unwrap();
484 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
485
486 let model = QDA::<f64>::new();
487 let fitted = model.fit(&x, &y).unwrap();
488 let preds = fitted.predict(&x).unwrap();
489
490 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
491 assert!(correct >= 6, "expected at least 6 correct, got {correct}");
492 }
493
494 #[test]
495 fn test_multiclass_classification() {
496 let x = Array2::from_shape_vec(
497 (12, 2),
498 vec![
499 0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.5,
500 10.0, 0.0, 10.5, 0.0, 10.0, 0.5, 10.5, 0.5,
501 0.0, 10.0, 0.5, 10.0, 0.0, 10.5, 0.5, 10.5,
502 ],
503 )
504 .unwrap();
505 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
506
507 let model = QDA::<f64>::new();
508 let fitted = model.fit(&x, &y).unwrap();
509
510 assert_eq!(fitted.n_classes(), 3);
511 assert_eq!(fitted.classes(), &[0, 1, 2]);
512
513 let preds = fitted.predict(&x).unwrap();
514 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
515 assert!(correct >= 10, "expected at least 10 correct, got {correct}");
516 }
517
518 #[test]
519 fn test_regularization() {
520 let x = Array2::from_shape_vec(
521 (8, 2),
522 vec![
523 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
524 8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
525 ],
526 )
527 .unwrap();
528 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
529
530 let model = QDA::<f64>::new().with_reg_param(0.5);
532 let fitted = model.fit(&x, &y).unwrap();
533 let preds = fitted.predict(&x).unwrap();
534 assert_eq!(preds.len(), 8);
535 }
536
537 #[test]
538 fn test_shape_mismatch() {
539 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
540 let y = array![0, 1]; let model = QDA::<f64>::new();
543 assert!(model.fit(&x, &y).is_err());
544 }
545
546 #[test]
547 fn test_single_class_error() {
548 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
549 let y = array![0, 0, 0];
550
551 let model = QDA::<f64>::new();
552 assert!(model.fit(&x, &y).is_err());
553 }
554
555 #[test]
556 fn test_invalid_reg_param() {
557 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
558 let y = array![0, 0, 1, 1];
559
560 let model = QDA::<f64>::new().with_reg_param(-0.1);
561 assert!(model.fit(&x, &y).is_err());
562
563 let model2 = QDA::<f64>::new().with_reg_param(1.5);
564 assert!(model2.fit(&x, &y).is_err());
565 }
566
567 #[test]
568 fn test_predict_feature_mismatch() {
569 let x = Array2::from_shape_vec(
570 (8, 2),
571 vec![
572 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
573 8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
574 ],
575 )
576 .unwrap();
577 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
578
579 let fitted = QDA::<f64>::new().fit(&x, &y).unwrap();
580
581 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
582 assert!(fitted.predict(&x_bad).is_err());
583 }
584
585 #[test]
586 fn test_has_classes() {
587 let x = Array2::from_shape_vec(
588 (8, 2),
589 vec![
590 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
591 8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
592 ],
593 )
594 .unwrap();
595 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
596
597 let fitted = QDA::<f64>::new().fit(&x, &y).unwrap();
598 assert_eq!(fitted.classes(), &[0, 1]);
599 assert_eq!(fitted.n_classes(), 2);
600 }
601
602 #[test]
603 fn test_means() {
604 let x = Array2::from_shape_vec(
605 (4, 1),
606 vec![1.0, 2.0, 5.0, 6.0],
607 )
608 .unwrap();
609 let y = array![0, 0, 1, 1];
610
611 let fitted = QDA::<f64>::new().with_reg_param(0.1).fit(&x, &y).unwrap();
612 let means = fitted.means();
613 assert_eq!(means.len(), 2);
614 }
615
616 #[test]
617 fn test_class_with_too_few_samples() {
618 let x = Array2::from_shape_vec(
619 (3, 1),
620 vec![1.0, 5.0, 6.0],
621 )
622 .unwrap();
623 let y = array![0, 1, 1]; let model = QDA::<f64>::new();
626 assert!(model.fit(&x, &y).is_err());
627 }
628}