use ferrolearn_core::error::FerroError;
use ferrolearn_core::introspection::HasClasses;
use ferrolearn_core::traits::{Fit, Predict};
use ndarray::{Array1, Array2, ScalarOperand};
use num_traits::Float;
#[derive(Debug, Clone)]
pub struct QDA<F> {
pub reg_param: F,
}
impl<F: Float> QDA<F> {
#[must_use]
pub fn new() -> Self {
Self {
reg_param: F::zero(),
}
}
#[must_use]
pub fn with_reg_param(mut self, reg_param: F) -> Self {
self.reg_param = reg_param;
self
}
}
impl<F: Float> Default for QDA<F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct QDAClass<F> {
mean: Array1<F>,
cov_inv: Array2<F>,
log_det: F,
log_prior: F,
}
#[derive(Debug, Clone)]
pub struct FittedQDA<F> {
class_models: Vec<QDAClass<F>>,
classes: Vec<usize>,
n_features: usize,
}
impl<F: Float> FittedQDA<F> {
#[must_use]
pub fn means(&self) -> Vec<&Array1<F>> {
self.class_models.iter().map(|m| &m.mean).collect()
}
}
impl<F: Float + ndarray::ScalarOperand + Send + Sync + 'static> FittedQDA<F> {
pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_features = x.ncols();
if n_features != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![n_features],
context: "number of features must match fitted model".into(),
});
}
let n_samples = x.nrows();
let n_classes = self.classes.len();
let half = F::from(0.5).unwrap();
let mut proba = Array2::<F>::zeros((n_samples, n_classes));
for i in 0..n_samples {
let xi = x.row(i);
let mut logits = vec![F::neg_infinity(); n_classes];
for (c, model) in self.class_models.iter().enumerate() {
let diff: Array1<F> = xi.to_owned() - &model.mean;
let mahal = diff.dot(&model.cov_inv.dot(&diff));
logits[c] = -half * model.log_det - half * mahal + model.log_prior;
}
let max_l = logits
.iter()
.copied()
.fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
let mut sum_exp = F::zero();
for c in 0..n_classes {
let e = (logits[c] - max_l).exp();
proba[[i, c]] = e;
sum_exp = sum_exp + e;
}
for c in 0..n_classes {
proba[[i, c]] = proba[[i, c]] / sum_exp;
}
}
Ok(proba)
}
pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let proba = self.predict_proba(x)?;
Ok(crate::log_proba(&proba))
}
pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_features = x.ncols();
if n_features != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![n_features],
context: "number of features must match fitted model".into(),
});
}
let n_samples = x.nrows();
let n_classes = self.classes.len();
let half = F::from(0.5).unwrap();
let mut out = Array2::<F>::zeros((n_samples, n_classes));
for i in 0..n_samples {
let xi = x.row(i);
for (c, model) in self.class_models.iter().enumerate() {
let diff: Array1<F> = xi.to_owned() - &model.mean;
let mahal = diff.dot(&model.cov_inv.dot(&diff));
out[[i, c]] = -half * model.log_det - half * mahal + model.log_prior;
}
}
Ok(out)
}
}
fn cholesky_inv_and_logdet<F: Float + 'static>(
a: &Array2<F>,
) -> Result<(Array2<F>, F), FerroError> {
let n = a.nrows();
let mut l = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut s = a[[i, j]];
for k in 0..j {
s = s - l[[i, k]] * l[[j, k]];
}
if i == j {
if s <= F::zero() {
return Err(FerroError::NumericalInstability {
message: "covariance matrix is not positive definite".into(),
});
}
l[[i, j]] = s.sqrt();
} else {
l[[i, j]] = s / l[[j, j]];
}
}
}
let two = F::from(2.0).unwrap();
let log_det = (0..n)
.map(|i| l[[i, i]].ln())
.fold(F::zero(), |a, b| a + b)
* two;
let mut l_inv = Array2::<F>::zeros((n, n));
for col in 0..n {
l_inv[[col, col]] = F::one() / l[[col, col]];
for i in (col + 1)..n {
let mut s = F::zero();
for k in col..i {
s = s + l[[i, k]] * l_inv[[k, col]];
}
l_inv[[i, col]] = -s / l[[i, i]];
}
}
let a_inv = l_inv.t().dot(&l_inv);
Ok((a_inv, log_det))
}
impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
for QDA<F>
{
type Fitted = FittedQDA<F>;
type Error = FerroError;
fn fit(
&self,
x: &Array2<F>,
y: &Array1<usize>,
) -> Result<FittedQDA<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if n_samples != y.len() {
return Err(FerroError::ShapeMismatch {
expected: vec![n_samples],
actual: vec![y.len()],
context: "y length must match number of samples in X".into(),
});
}
if self.reg_param < F::zero() || self.reg_param > F::one() {
return Err(FerroError::InvalidParameter {
name: "reg_param".into(),
reason: "must be in [0, 1]".into(),
});
}
let mut classes: Vec<usize> = y.to_vec();
classes.sort_unstable();
classes.dedup();
if classes.len() < 2 {
return Err(FerroError::InsufficientSamples {
required: 2,
actual: classes.len(),
context: "QDA requires at least 2 distinct classes".into(),
});
}
let n_f = F::from(n_samples).unwrap();
let mut class_models = Vec::with_capacity(classes.len());
for &cls in &classes {
let indices: Vec<usize> = y
.iter()
.enumerate()
.filter(|&(_, label)| *label == cls)
.map(|(i, _)| i)
.collect();
let n_k = indices.len();
if n_k < 2 {
return Err(FerroError::InsufficientSamples {
required: 2,
actual: n_k,
context: format!("class {cls} needs at least 2 samples for QDA"),
});
}
let n_k_f = F::from(n_k).unwrap();
let mut mean = Array1::<F>::zeros(n_features);
for &i in &indices {
for j in 0..n_features {
mean[j] = mean[j] + x[[i, j]];
}
}
mean.mapv_inplace(|v| v / n_k_f);
let mut cov = Array2::<F>::zeros((n_features, n_features));
for &i in &indices {
let diff: Array1<F> = x.row(i).to_owned() - &mean;
for r in 0..n_features {
for c in 0..n_features {
cov[[r, c]] = cov[[r, c]] + diff[r] * diff[c];
}
}
}
let divisor = F::from(n_k - 1).unwrap();
cov.mapv_inplace(|v| v / divisor);
if self.reg_param > F::zero() {
let one_minus = F::one() - self.reg_param;
for r in 0..n_features {
for c in 0..n_features {
cov[[r, c]] = cov[[r, c]] * one_minus;
}
cov[[r, r]] = cov[[r, r]] + self.reg_param;
}
}
let (cov_inv, log_det) = cholesky_inv_and_logdet(&cov)?;
let log_prior = (n_k_f / n_f).ln();
class_models.push(QDAClass {
mean,
cov_inv,
log_det,
log_prior,
});
}
Ok(FittedQDA {
class_models,
classes,
n_features,
})
}
}
impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
for FittedQDA<F>
{
type Output = Array1<usize>;
type Error = FerroError;
fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
let n_features = x.ncols();
if n_features != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![n_features],
context: "number of features must match fitted model".into(),
});
}
let n_samples = x.nrows();
let mut predictions = Array1::<usize>::zeros(n_samples);
let half = F::from(0.5).unwrap();
for i in 0..n_samples {
let xi = x.row(i);
let mut best_class = 0;
let mut best_score = F::neg_infinity();
for (c, model) in self.class_models.iter().enumerate() {
let diff: Array1<F> = xi.to_owned() - &model.mean;
let mahal = diff.dot(&model.cov_inv.dot(&diff));
let score = -half * model.log_det - half * mahal + model.log_prior;
if score > best_score {
best_score = score;
best_class = c;
}
}
predictions[i] = self.classes[best_class];
}
Ok(predictions)
}
}
impl<F: Float + Send + Sync + ScalarOperand + 'static> HasClasses for FittedQDA<F> {
fn classes(&self) -> &[usize] {
&self.classes
}
fn n_classes(&self) -> usize {
self.classes.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_default_constructor() {
let m = QDA::<f64>::new();
assert!(m.reg_param == 0.0);
}
#[test]
fn test_builder() {
let m = QDA::<f64>::new().with_reg_param(0.5);
assert!(m.reg_param == 0.5);
}
#[test]
fn test_binary_classification() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1];
let model = QDA::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
assert!(correct >= 6, "expected at least 6 correct, got {correct}");
}
#[test]
fn test_multiclass_classification() {
let x = Array2::from_shape_vec(
(12, 2),
vec![
0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.5,
10.0, 0.0, 10.5, 0.0, 10.0, 0.5, 10.5, 0.5,
0.0, 10.0, 0.5, 10.0, 0.0, 10.5, 0.5, 10.5,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
let model = QDA::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
assert_eq!(fitted.n_classes(), 3);
assert_eq!(fitted.classes(), &[0, 1, 2]);
let preds = fitted.predict(&x).unwrap();
let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
assert!(correct >= 10, "expected at least 10 correct, got {correct}");
}
#[test]
fn test_regularization() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1];
let model = QDA::<f64>::new().with_reg_param(0.5);
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 8);
}
#[test]
fn test_shape_mismatch() {
let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
let y = array![0, 1];
let model = QDA::<f64>::new();
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_single_class_error() {
let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
let y = array![0, 0, 0];
let model = QDA::<f64>::new();
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_invalid_reg_param() {
let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let y = array![0, 0, 1, 1];
let model = QDA::<f64>::new().with_reg_param(-0.1);
assert!(model.fit(&x, &y).is_err());
let model2 = QDA::<f64>::new().with_reg_param(1.5);
assert!(model2.fit(&x, &y).is_err());
}
#[test]
fn test_predict_feature_mismatch() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1];
let fitted = QDA::<f64>::new().fit(&x, &y).unwrap();
let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
assert!(fitted.predict(&x_bad).is_err());
}
#[test]
fn test_has_classes() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0,
8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1];
let fitted = QDA::<f64>::new().fit(&x, &y).unwrap();
assert_eq!(fitted.classes(), &[0, 1]);
assert_eq!(fitted.n_classes(), 2);
}
#[test]
fn test_means() {
let x = Array2::from_shape_vec(
(4, 1),
vec![1.0, 2.0, 5.0, 6.0],
)
.unwrap();
let y = array![0, 0, 1, 1];
let fitted = QDA::<f64>::new().with_reg_param(0.1).fit(&x, &y).unwrap();
let means = fitted.means();
assert_eq!(means.len(), 2);
}
#[test]
fn test_class_with_too_few_samples() {
let x = Array2::from_shape_vec(
(3, 1),
vec![1.0, 5.0, 6.0],
)
.unwrap();
let y = array![0, 1, 1];
let model = QDA::<f64>::new();
assert!(model.fit(&x, &y).is_err());
}
}