fdars_core/classification/
qda.rs1use crate::error::FdarError;
4use crate::linalg::{cholesky_d, log_det_from_cholesky, mahalanobis_sq};
5use crate::matrix::FdMatrix;
6
7use super::{
8 build_feature_matrix, class_means_and_priors, compute_accuracy, confusion_matrix, remap_labels,
9 ClassifResult,
10};
11
12fn accumulate_class_cov(
14 features: &FdMatrix,
15 members: &[usize],
16 mean: &[f64],
17 d: usize,
18) -> Vec<f64> {
19 let mut cov = vec![0.0; d * d];
20 for &i in members {
21 for r in 0..d {
22 let dr = features[(i, r)] - mean[r];
23 for s in r..d {
24 let val = dr * (features[(i, s)] - mean[s]);
25 cov[r * d + s] += val;
26 if r != s {
27 cov[s * d + r] += val;
28 }
29 }
30 }
31 }
32 cov
33}
34
35fn qda_class_covariances(
37 features: &FdMatrix,
38 labels: &[usize],
39 class_means: &[Vec<f64>],
40 g: usize,
41) -> Vec<Vec<f64>> {
42 let n = features.nrows();
43 let d = features.ncols();
44
45 (0..g)
46 .map(|c| {
47 let members: Vec<usize> = (0..n).filter(|&i| labels[i] == c).collect();
48 let nc = members.len();
49 let divisor = (nc.saturating_sub(1)).max(1) as f64;
50 let mut cov = accumulate_class_cov(features, &members, &class_means[c], d);
51 for v in &mut cov {
52 *v /= divisor;
53 }
54 for j in 0..d {
55 cov[j * d + j] += 1e-6;
56 }
57 cov
58 })
59 .collect()
60}
61
62pub(crate) fn build_qda_params(
64 features: &FdMatrix,
65 labels: &[usize],
66 g: usize,
67) -> Result<(Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<f64>, Vec<f64>), FdarError> {
68 let d = features.ncols();
69 let (class_means, _counts, priors) = class_means_and_priors(features, labels, g);
70 let class_covs = qda_class_covariances(features, labels, &class_means, g);
71 let mut class_chols = Vec::with_capacity(g);
72 let mut class_log_dets = Vec::with_capacity(g);
73 for cov in &class_covs {
74 let chol = cholesky_d(cov, d)?;
75 class_log_dets.push(log_det_from_cholesky(&chol, d));
76 class_chols.push(chol);
77 }
78 Ok((class_means, class_chols, class_log_dets, priors))
79}
80
81pub(crate) fn qda_predict(
83 features: &FdMatrix,
84 class_means: &[Vec<f64>],
85 class_chols: &[Vec<f64>],
86 class_log_dets: &[f64],
87 priors: &[f64],
88 g: usize,
89) -> Vec<usize> {
90 let n = features.nrows();
91 let d = features.ncols();
92
93 (0..n)
94 .map(|i| {
95 let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
96 let mut best_class = 0;
97 let mut best_score = f64::NEG_INFINITY;
98 for c in 0..g {
99 let maha = mahalanobis_sq(&xi, &class_means[c], &class_chols[c], d);
100 let score = priors[c].max(1e-15).ln() - 0.5 * (class_log_dets[c] + maha);
101 if score > best_score {
102 best_score = score;
103 best_class = c;
104 }
105 }
106 best_class
107 })
108 .collect()
109}
110
111#[must_use = "expensive computation whose result should not be discarded"]
121pub fn fclassif_qda(
122 data: &FdMatrix,
123 y: &[usize],
124 scalar_covariates: Option<&FdMatrix>,
125 ncomp: usize,
126) -> Result<ClassifResult, FdarError> {
127 let n = data.nrows();
128 if n == 0 || y.len() != n {
129 return Err(FdarError::InvalidDimension {
130 parameter: "data/y",
131 expected: "n > 0 and y.len() == n".to_string(),
132 actual: format!("n={}, y.len()={}", n, y.len()),
133 });
134 }
135 if ncomp == 0 {
136 return Err(FdarError::InvalidParameter {
137 parameter: "ncomp",
138 message: "must be > 0".to_string(),
139 });
140 }
141
142 let (labels, g) = remap_labels(y);
143 if g < 2 {
144 return Err(FdarError::InvalidParameter {
145 parameter: "y",
146 message: format!("need at least 2 classes, got {g}"),
147 });
148 }
149
150 let (features, _mean, _rotation) = build_feature_matrix(data, scalar_covariates, ncomp)?;
151
152 let (class_means, class_chols, class_log_dets, priors) =
153 build_qda_params(&features, &labels, g)?;
154
155 let predicted = qda_predict(
156 &features,
157 &class_means,
158 &class_chols,
159 &class_log_dets,
160 &priors,
161 g,
162 );
163 let accuracy = compute_accuracy(&labels, &predicted);
164 let confusion = confusion_matrix(&labels, &predicted, g);
165
166 Ok(ClassifResult {
167 predicted,
168 probabilities: None,
169 accuracy,
170 confusion,
171 n_classes: g,
172 ncomp: features.ncols().min(ncomp),
173 })
174}