1use crate::error::FdarError;
4use crate::explain_generic::{FpcPredictor, TaskType};
5use crate::matrix::FdMatrix;
6
7use super::knn::knn_predict_loo;
8use super::lda::{lda_params, lda_predict};
9use super::qda::{build_qda_params, qda_predict};
10use super::{
11 build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifCvResult,
12 ClassifResult,
13};
14use crate::linalg::{cholesky_d, mahalanobis_sq};
15
16use super::cv::fclassif_cv;
17
18#[derive(Debug, Clone, PartialEq)]
20#[non_exhaustive]
21pub enum ClassifMethod {
22 Lda {
24 class_means: Vec<Vec<f64>>,
25 cov_chol: Vec<f64>,
26 priors: Vec<f64>,
27 n_classes: usize,
28 },
29 Qda {
31 class_means: Vec<Vec<f64>>,
32 class_chols: Vec<Vec<f64>>,
33 class_log_dets: Vec<f64>,
34 priors: Vec<f64>,
35 n_classes: usize,
36 },
37 Knn {
39 training_scores: FdMatrix,
40 training_labels: Vec<usize>,
41 k: usize,
42 n_classes: usize,
43 },
44}
45
46#[derive(Debug, Clone, PartialEq)]
48#[non_exhaustive]
49pub struct ClassifFit {
50 pub result: ClassifResult,
52 pub fpca_mean: Vec<f64>,
54 pub fpca_rotation: FdMatrix,
56 pub fpca_scores: FdMatrix,
58 pub ncomp: usize,
60 pub method: ClassifMethod,
62}
63
64#[must_use = "expensive computation whose result should not be discarded"]
74pub fn fclassif_lda_fit(
75 data: &FdMatrix,
76 y: &[usize],
77 scalar_covariates: Option<&FdMatrix>,
78 ncomp: usize,
79) -> Result<ClassifFit, FdarError> {
80 let n = data.nrows();
81 if n == 0 || y.len() != n {
82 return Err(FdarError::InvalidDimension {
83 parameter: "data/y",
84 expected: "n > 0 and y.len() == n".to_string(),
85 actual: format!("n={}, y.len()={}", n, y.len()),
86 });
87 }
88 if ncomp == 0 {
89 return Err(FdarError::InvalidParameter {
90 parameter: "ncomp",
91 message: "must be > 0".to_string(),
92 });
93 }
94
95 let (labels, g) = remap_labels(y);
96 if g < 2 {
97 return Err(FdarError::InvalidParameter {
98 parameter: "y",
99 message: format!("need at least 2 classes, got {g}"),
100 });
101 }
102
103 let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
106 let _ = scalar_covariates; let d = features.ncols();
108 let (class_means, cov, priors) = lda_params(&features, &labels, g);
109 let chol = cholesky_d(&cov, d)?;
110
111 let predicted = lda_predict(&features, &class_means, &chol, &priors, g);
112 let accuracy = compute_accuracy(&labels, &predicted);
113 let confusion = confusion_matrix(&labels, &predicted, g);
114
115 Ok(ClassifFit {
116 result: ClassifResult {
117 predicted,
118 probabilities: None,
119 accuracy,
120 confusion,
121 n_classes: g,
122 ncomp: d,
123 },
124 fpca_mean: mean.clone(),
125 fpca_rotation: rotation,
126 fpca_scores: features,
127 ncomp: d,
128 method: ClassifMethod::Lda {
129 class_means,
130 cov_chol: chol,
131 priors,
132 n_classes: g,
133 },
134 })
135}
136
137#[must_use = "expensive computation whose result should not be discarded"]
147pub fn fclassif_qda_fit(
148 data: &FdMatrix,
149 y: &[usize],
150 scalar_covariates: Option<&FdMatrix>,
151 ncomp: usize,
152) -> Result<ClassifFit, FdarError> {
153 let n = data.nrows();
154 if n == 0 || y.len() != n {
155 return Err(FdarError::InvalidDimension {
156 parameter: "data/y",
157 expected: "n > 0 and y.len() == n".to_string(),
158 actual: format!("n={}, y.len()={}", n, y.len()),
159 });
160 }
161 if ncomp == 0 {
162 return Err(FdarError::InvalidParameter {
163 parameter: "ncomp",
164 message: "must be > 0".to_string(),
165 });
166 }
167
168 let (labels, g) = remap_labels(y);
169 if g < 2 {
170 return Err(FdarError::InvalidParameter {
171 parameter: "y",
172 message: format!("need at least 2 classes, got {g}"),
173 });
174 }
175
176 let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
178 let _ = scalar_covariates;
179 let (class_means, class_chols, class_log_dets, priors) =
180 build_qda_params(&features, &labels, g)?;
181
182 let predicted = qda_predict(
183 &features,
184 &class_means,
185 &class_chols,
186 &class_log_dets,
187 &priors,
188 g,
189 );
190 let accuracy = compute_accuracy(&labels, &predicted);
191 let confusion = confusion_matrix(&labels, &predicted, g);
192 let d = features.ncols();
193
194 Ok(ClassifFit {
195 result: ClassifResult {
196 predicted,
197 probabilities: None,
198 accuracy,
199 confusion,
200 n_classes: g,
201 ncomp: d,
202 },
203 fpca_mean: mean.clone(),
204 fpca_rotation: rotation,
205 fpca_scores: features,
206 ncomp: d,
207 method: ClassifMethod::Qda {
208 class_means,
209 class_chols,
210 class_log_dets,
211 priors,
212 n_classes: g,
213 },
214 })
215}
216
217#[must_use = "expensive computation whose result should not be discarded"]
227pub fn fclassif_knn_fit(
228 data: &FdMatrix,
229 y: &[usize],
230 scalar_covariates: Option<&FdMatrix>,
231 ncomp: usize,
232 k_nn: usize,
233) -> Result<ClassifFit, FdarError> {
234 let n = data.nrows();
235 if n == 0 || y.len() != n {
236 return Err(FdarError::InvalidDimension {
237 parameter: "data/y",
238 expected: "n > 0 and y.len() == n".to_string(),
239 actual: format!("n={}, y.len()={}", n, y.len()),
240 });
241 }
242 if ncomp == 0 {
243 return Err(FdarError::InvalidParameter {
244 parameter: "ncomp",
245 message: "must be > 0".to_string(),
246 });
247 }
248 if k_nn == 0 {
249 return Err(FdarError::InvalidParameter {
250 parameter: "k_nn",
251 message: "must be > 0".to_string(),
252 });
253 }
254
255 let (labels, g) = remap_labels(y);
256 if g < 2 {
257 return Err(FdarError::InvalidParameter {
258 parameter: "y",
259 message: format!("need at least 2 classes, got {g}"),
260 });
261 }
262
263 let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
265 let _ = scalar_covariates;
266 let d = features.ncols();
267
268 let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
269 let accuracy = compute_accuracy(&labels, &predicted);
270 let confusion = confusion_matrix(&labels, &predicted, g);
271
272 Ok(ClassifFit {
273 result: ClassifResult {
274 predicted,
275 probabilities: None,
276 accuracy,
277 confusion,
278 n_classes: g,
279 ncomp: d,
280 },
281 fpca_mean: mean.clone(),
282 fpca_rotation: rotation,
283 fpca_scores: features.clone(),
284 ncomp: d,
285 method: ClassifMethod::Knn {
286 training_scores: features,
287 training_labels: labels,
288 k: k_nn,
289 n_classes: g,
290 },
291 })
292}
293
294impl FpcPredictor for ClassifFit {
299 fn fpca_mean(&self) -> &[f64] {
300 &self.fpca_mean
301 }
302
303 fn fpca_rotation(&self) -> &FdMatrix {
304 &self.fpca_rotation
305 }
306
307 fn ncomp(&self) -> usize {
308 self.ncomp
309 }
310
311 fn training_scores(&self) -> &FdMatrix {
312 &self.fpca_scores
313 }
314
315 fn task_type(&self) -> TaskType {
316 match &self.method {
317 ClassifMethod::Lda { n_classes, .. }
318 | ClassifMethod::Qda { n_classes, .. }
319 | ClassifMethod::Knn { n_classes, .. } => {
320 if *n_classes == 2 {
321 TaskType::BinaryClassification
322 } else {
323 TaskType::MulticlassClassification(*n_classes)
324 }
325 }
326 }
327 }
328
329 fn predict_from_scores(&self, scores: &[f64], _scalar_covariates: Option<&[f64]>) -> f64 {
330 match &self.method {
331 ClassifMethod::Lda {
332 class_means,
333 cov_chol,
334 priors,
335 n_classes,
336 } => {
337 let g = *n_classes;
338 let d = scores.len();
339 if g == 2 {
340 let score0 = priors[0].max(1e-15).ln()
342 - 0.5 * mahalanobis_sq(scores, &class_means[0], cov_chol, d);
343 let score1 = priors[1].max(1e-15).ln()
344 - 0.5 * mahalanobis_sq(scores, &class_means[1], cov_chol, d);
345 let max_s = score0.max(score1);
346 let exp0 = (score0 - max_s).exp();
347 let exp1 = (score1 - max_s).exp();
348 exp1 / (exp0 + exp1)
349 } else {
350 let mut best_class = 0;
352 let mut best_score = f64::NEG_INFINITY;
353 for c in 0..g {
354 let maha = mahalanobis_sq(scores, &class_means[c], cov_chol, d);
355 let s = priors[c].max(1e-15).ln() - 0.5 * maha;
356 if s > best_score {
357 best_score = s;
358 best_class = c;
359 }
360 }
361 best_class as f64
362 }
363 }
364 ClassifMethod::Qda {
365 class_means,
366 class_chols,
367 class_log_dets,
368 priors,
369 n_classes,
370 } => {
371 let g = *n_classes;
372 let d = scores.len();
373 if g == 2 {
374 let score0 = priors[0].max(1e-15).ln()
376 - 0.5
377 * (class_log_dets[0]
378 + mahalanobis_sq(scores, &class_means[0], &class_chols[0], d));
379 let score1 = priors[1].max(1e-15).ln()
380 - 0.5
381 * (class_log_dets[1]
382 + mahalanobis_sq(scores, &class_means[1], &class_chols[1], d));
383 let max_s = score0.max(score1);
384 let exp0 = (score0 - max_s).exp();
385 let exp1 = (score1 - max_s).exp();
386 exp1 / (exp0 + exp1)
387 } else {
388 let mut best_class = 0;
389 let mut best_score = f64::NEG_INFINITY;
390 for c in 0..g {
391 let maha = mahalanobis_sq(scores, &class_means[c], &class_chols[c], d);
392 let s = priors[c].max(1e-15).ln() - 0.5 * (class_log_dets[c] + maha);
393 if s > best_score {
394 best_score = s;
395 best_class = c;
396 }
397 }
398 best_class as f64
399 }
400 }
401 ClassifMethod::Knn {
402 training_scores,
403 training_labels,
404 k,
405 n_classes,
406 } => {
407 let g = *n_classes;
408 let n_train = training_scores.nrows();
409 let d = scores.len();
410 let k_nn = (*k).min(n_train);
411
412 let mut dists: Vec<(f64, usize)> = (0..n_train)
413 .map(|j| {
414 let d_sq: f64 = (0..d)
415 .map(|c| (scores[c] - training_scores[(j, c)]).powi(2))
416 .sum();
417 (d_sq, training_labels[j])
418 })
419 .collect();
420 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
421
422 let mut votes = vec![0usize; g];
423 for &(_, label) in dists.iter().take(k_nn) {
424 if label < g {
425 votes[label] += 1;
426 }
427 }
428
429 if g == 2 {
430 votes[1] as f64 / k_nn as f64
432 } else {
433 votes
435 .iter()
436 .enumerate()
437 .max_by_key(|&(_, &v)| v)
438 .map_or(0.0, |(c, _)| c as f64)
439 }
440 }
441 }
442 }
443}
444
445pub(crate) fn classif_predict_probs(fit: &ClassifFit, scores: &FdMatrix) -> Vec<Vec<f64>> {
454 let n = scores.nrows();
455 let d = scores.ncols();
456 match &fit.method {
457 ClassifMethod::Lda {
458 class_means,
459 cov_chol,
460 priors,
461 n_classes,
462 } => {
463 let g = *n_classes;
464 (0..n)
465 .map(|i| {
466 let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
467 let disc: Vec<f64> = (0..g)
468 .map(|c| {
469 priors[c].max(1e-15).ln()
470 - 0.5 * mahalanobis_sq(&x, &class_means[c], cov_chol, d)
471 })
472 .collect();
473 softmax(&disc)
474 })
475 .collect()
476 }
477 ClassifMethod::Qda {
478 class_means,
479 class_chols,
480 class_log_dets,
481 priors,
482 n_classes,
483 } => {
484 let g = *n_classes;
485 (0..n)
486 .map(|i| {
487 let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
488 let disc: Vec<f64> = (0..g)
489 .map(|c| {
490 priors[c].max(1e-15).ln()
491 - 0.5
492 * (class_log_dets[c]
493 + mahalanobis_sq(&x, &class_means[c], &class_chols[c], d))
494 })
495 .collect();
496 softmax(&disc)
497 })
498 .collect()
499 }
500 ClassifMethod::Knn {
501 training_scores,
502 training_labels,
503 k,
504 n_classes,
505 } => {
506 let g = *n_classes;
507 let n_train = training_scores.nrows();
508 let k_nn = (*k).min(n_train);
509 (0..n)
510 .map(|i| {
511 let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
512 let mut dists: Vec<(f64, usize)> = (0..n_train)
513 .map(|j| {
514 let d_sq: f64 = (0..d)
515 .map(|c| (x[c] - training_scores[(j, c)]).powi(2))
516 .sum();
517 (d_sq, training_labels[j])
518 })
519 .collect();
520 dists
521 .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
522 let mut votes = vec![0usize; g];
523 for &(_, label) in dists.iter().take(k_nn) {
524 if label < g {
525 votes[label] += 1;
526 }
527 }
528 votes.iter().map(|&v| v as f64 / k_nn as f64).collect()
529 })
530 .collect()
531 }
532 }
533}
534
535fn softmax(scores: &[f64]) -> Vec<f64> {
537 let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
538 let exps: Vec<f64> = scores.iter().map(|&s| (s - max_s).exp()).collect();
539 let sum: f64 = exps.iter().sum();
540 exps.iter().map(|&e| e / sum).collect()
541}
542
543#[derive(Debug, Clone, PartialEq)]
548pub struct ClassifCvConfig {
549 pub method: String,
551 pub ncomp: usize,
553 pub nfold: usize,
555 pub seed: u64,
557}
558
559impl Default for ClassifCvConfig {
560 fn default() -> Self {
561 Self {
562 method: "lda".to_string(),
563 ncomp: 3,
564 nfold: 5,
565 seed: 42,
566 }
567 }
568}
569
570#[must_use = "expensive computation whose result should not be discarded"]
579pub fn fclassif_cv_with_config(
580 data: &FdMatrix,
581 argvals: &[f64],
582 y: &[usize],
583 scalar_covariates: Option<&FdMatrix>,
584 config: &ClassifCvConfig,
585) -> Result<ClassifCvResult, FdarError> {
586 fclassif_cv(
587 data,
588 argvals,
589 y,
590 scalar_covariates,
591 &config.method,
592 config.ncomp,
593 config.nfold,
594 config.seed,
595 )
596}