1use crate::error::FdarError;
4use crate::explain_generic::{FpcPredictor, TaskType};
5use crate::matrix::FdMatrix;
6
7use super::knn::knn_predict_loo;
8use super::lda::{lda_params, lda_predict};
9use super::qda::{build_qda_params, qda_predict};
10use super::{
11 build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifCvResult,
12 ClassifResult,
13};
14use crate::linalg::{cholesky_d, mahalanobis_sq};
15
16use super::cv::fclassif_cv;
17
18#[derive(Debug, Clone, PartialEq)]
20#[non_exhaustive]
21pub enum ClassifMethod {
22 Lda {
24 class_means: Vec<Vec<f64>>,
25 cov_chol: Vec<f64>,
26 priors: Vec<f64>,
27 n_classes: usize,
28 },
29 Qda {
31 class_means: Vec<Vec<f64>>,
32 class_chols: Vec<Vec<f64>>,
33 class_log_dets: Vec<f64>,
34 priors: Vec<f64>,
35 n_classes: usize,
36 },
37 Knn {
39 training_scores: FdMatrix,
40 training_labels: Vec<usize>,
41 k: usize,
42 n_classes: usize,
43 },
44}
45
46#[derive(Debug, Clone, PartialEq)]
48#[non_exhaustive]
49pub struct ClassifFit {
50 pub result: ClassifResult,
52 pub fpca_mean: Vec<f64>,
54 pub fpca_rotation: FdMatrix,
56 pub fpca_scores: FdMatrix,
58 pub ncomp: usize,
60 pub method: ClassifMethod,
62 pub fpca_int_weights: Vec<f64>,
64}
65
66#[must_use = "expensive computation whose result should not be discarded"]
76pub fn fclassif_lda_fit(
77 data: &FdMatrix,
78 y: &[usize],
79 scalar_covariates: Option<&FdMatrix>,
80 ncomp: usize,
81) -> Result<ClassifFit, FdarError> {
82 let n = data.nrows();
83 if n == 0 || y.len() != n {
84 return Err(FdarError::InvalidDimension {
85 parameter: "data/y",
86 expected: "n > 0 and y.len() == n".to_string(),
87 actual: format!("n={}, y.len()={}", n, y.len()),
88 });
89 }
90 if ncomp == 0 {
91 return Err(FdarError::InvalidParameter {
92 parameter: "ncomp",
93 message: "must be > 0".to_string(),
94 });
95 }
96
97 let (labels, g) = remap_labels(y);
98 if g < 2 {
99 return Err(FdarError::InvalidParameter {
100 parameter: "y",
101 message: format!("need at least 2 classes, got {g}"),
102 });
103 }
104
105 let (features, mean, rotation, int_weights) = build_feature_matrix(data, None, ncomp)?;
108 let _ = scalar_covariates; let d = features.ncols();
110 let (class_means, cov, priors) = lda_params(&features, &labels, g);
111 let chol = cholesky_d(&cov, d)?;
112
113 let predicted = lda_predict(&features, &class_means, &chol, &priors, g);
114 let accuracy = compute_accuracy(&labels, &predicted);
115 let confusion = confusion_matrix(&labels, &predicted, g);
116
117 Ok(ClassifFit {
118 result: ClassifResult {
119 predicted,
120 probabilities: None,
121 accuracy,
122 confusion,
123 n_classes: g,
124 ncomp: d,
125 },
126 fpca_mean: mean.clone(),
127 fpca_rotation: rotation,
128 fpca_scores: features,
129 ncomp: d,
130 method: ClassifMethod::Lda {
131 class_means,
132 cov_chol: chol,
133 priors,
134 n_classes: g,
135 },
136 fpca_int_weights: int_weights,
137 })
138}
139
140#[must_use = "expensive computation whose result should not be discarded"]
150pub fn fclassif_qda_fit(
151 data: &FdMatrix,
152 y: &[usize],
153 scalar_covariates: Option<&FdMatrix>,
154 ncomp: usize,
155) -> Result<ClassifFit, FdarError> {
156 let n = data.nrows();
157 if n == 0 || y.len() != n {
158 return Err(FdarError::InvalidDimension {
159 parameter: "data/y",
160 expected: "n > 0 and y.len() == n".to_string(),
161 actual: format!("n={}, y.len()={}", n, y.len()),
162 });
163 }
164 if ncomp == 0 {
165 return Err(FdarError::InvalidParameter {
166 parameter: "ncomp",
167 message: "must be > 0".to_string(),
168 });
169 }
170
171 let (labels, g) = remap_labels(y);
172 if g < 2 {
173 return Err(FdarError::InvalidParameter {
174 parameter: "y",
175 message: format!("need at least 2 classes, got {g}"),
176 });
177 }
178
179 let (features, mean, rotation, int_weights) = build_feature_matrix(data, None, ncomp)?;
181 let _ = scalar_covariates;
182 let (class_means, class_chols, class_log_dets, priors) =
183 build_qda_params(&features, &labels, g)?;
184
185 let predicted = qda_predict(
186 &features,
187 &class_means,
188 &class_chols,
189 &class_log_dets,
190 &priors,
191 g,
192 );
193 let accuracy = compute_accuracy(&labels, &predicted);
194 let confusion = confusion_matrix(&labels, &predicted, g);
195 let d = features.ncols();
196
197 Ok(ClassifFit {
198 result: ClassifResult {
199 predicted,
200 probabilities: None,
201 accuracy,
202 confusion,
203 n_classes: g,
204 ncomp: d,
205 },
206 fpca_mean: mean.clone(),
207 fpca_rotation: rotation,
208 fpca_scores: features,
209 ncomp: d,
210 method: ClassifMethod::Qda {
211 class_means,
212 class_chols,
213 class_log_dets,
214 priors,
215 n_classes: g,
216 },
217 fpca_int_weights: int_weights,
218 })
219}
220
221#[must_use = "expensive computation whose result should not be discarded"]
231pub fn fclassif_knn_fit(
232 data: &FdMatrix,
233 y: &[usize],
234 scalar_covariates: Option<&FdMatrix>,
235 ncomp: usize,
236 k_nn: usize,
237) -> Result<ClassifFit, FdarError> {
238 let n = data.nrows();
239 if n == 0 || y.len() != n {
240 return Err(FdarError::InvalidDimension {
241 parameter: "data/y",
242 expected: "n > 0 and y.len() == n".to_string(),
243 actual: format!("n={}, y.len()={}", n, y.len()),
244 });
245 }
246 if ncomp == 0 {
247 return Err(FdarError::InvalidParameter {
248 parameter: "ncomp",
249 message: "must be > 0".to_string(),
250 });
251 }
252 if k_nn == 0 {
253 return Err(FdarError::InvalidParameter {
254 parameter: "k_nn",
255 message: "must be > 0".to_string(),
256 });
257 }
258
259 let (labels, g) = remap_labels(y);
260 if g < 2 {
261 return Err(FdarError::InvalidParameter {
262 parameter: "y",
263 message: format!("need at least 2 classes, got {g}"),
264 });
265 }
266
267 let (features, mean, rotation, int_weights) = build_feature_matrix(data, None, ncomp)?;
269 let _ = scalar_covariates;
270 let d = features.ncols();
271
272 let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
273 let accuracy = compute_accuracy(&labels, &predicted);
274 let confusion = confusion_matrix(&labels, &predicted, g);
275
276 Ok(ClassifFit {
277 result: ClassifResult {
278 predicted,
279 probabilities: None,
280 accuracy,
281 confusion,
282 n_classes: g,
283 ncomp: d,
284 },
285 fpca_mean: mean.clone(),
286 fpca_rotation: rotation,
287 fpca_scores: features.clone(),
288 ncomp: d,
289 method: ClassifMethod::Knn {
290 training_scores: features,
291 training_labels: labels,
292 k: k_nn,
293 n_classes: g,
294 },
295 fpca_int_weights: int_weights,
296 })
297}
298
299impl FpcPredictor for ClassifFit {
304 fn fpca_mean(&self) -> &[f64] {
305 &self.fpca_mean
306 }
307
308 fn fpca_rotation(&self) -> &FdMatrix {
309 &self.fpca_rotation
310 }
311
312 fn ncomp(&self) -> usize {
313 self.ncomp
314 }
315
316 fn training_scores(&self) -> &FdMatrix {
317 &self.fpca_scores
318 }
319
320 fn fpca_weights(&self) -> &[f64] {
321 &self.fpca_int_weights
322 }
323
324 fn task_type(&self) -> TaskType {
325 match &self.method {
326 ClassifMethod::Lda { n_classes, .. }
327 | ClassifMethod::Qda { n_classes, .. }
328 | ClassifMethod::Knn { n_classes, .. } => {
329 if *n_classes == 2 {
330 TaskType::BinaryClassification
331 } else {
332 TaskType::MulticlassClassification(*n_classes)
333 }
334 }
335 }
336 }
337
338 fn predict_from_scores(&self, scores: &[f64], _scalar_covariates: Option<&[f64]>) -> f64 {
339 match &self.method {
340 ClassifMethod::Lda {
341 class_means,
342 cov_chol,
343 priors,
344 n_classes,
345 } => {
346 let g = *n_classes;
347 let d = scores.len();
348 if g == 2 {
349 let score0 = priors[0].max(1e-15).ln()
351 - 0.5 * mahalanobis_sq(scores, &class_means[0], cov_chol, d);
352 let score1 = priors[1].max(1e-15).ln()
353 - 0.5 * mahalanobis_sq(scores, &class_means[1], cov_chol, d);
354 let max_s = score0.max(score1);
355 let exp0 = (score0 - max_s).exp();
356 let exp1 = (score1 - max_s).exp();
357 exp1 / (exp0 + exp1)
358 } else {
359 let mut best_class = 0;
361 let mut best_score = f64::NEG_INFINITY;
362 for c in 0..g {
363 let maha = mahalanobis_sq(scores, &class_means[c], cov_chol, d);
364 let s = priors[c].max(1e-15).ln() - 0.5 * maha;
365 if s > best_score {
366 best_score = s;
367 best_class = c;
368 }
369 }
370 best_class as f64
371 }
372 }
373 ClassifMethod::Qda {
374 class_means,
375 class_chols,
376 class_log_dets,
377 priors,
378 n_classes,
379 } => {
380 let g = *n_classes;
381 let d = scores.len();
382 if g == 2 {
383 let score0 = priors[0].max(1e-15).ln()
385 - 0.5
386 * (class_log_dets[0]
387 + mahalanobis_sq(scores, &class_means[0], &class_chols[0], d));
388 let score1 = priors[1].max(1e-15).ln()
389 - 0.5
390 * (class_log_dets[1]
391 + mahalanobis_sq(scores, &class_means[1], &class_chols[1], d));
392 let max_s = score0.max(score1);
393 let exp0 = (score0 - max_s).exp();
394 let exp1 = (score1 - max_s).exp();
395 exp1 / (exp0 + exp1)
396 } else {
397 let mut best_class = 0;
398 let mut best_score = f64::NEG_INFINITY;
399 for c in 0..g {
400 let maha = mahalanobis_sq(scores, &class_means[c], &class_chols[c], d);
401 let s = priors[c].max(1e-15).ln() - 0.5 * (class_log_dets[c] + maha);
402 if s > best_score {
403 best_score = s;
404 best_class = c;
405 }
406 }
407 best_class as f64
408 }
409 }
410 ClassifMethod::Knn {
411 training_scores,
412 training_labels,
413 k,
414 n_classes,
415 } => {
416 let g = *n_classes;
417 let n_train = training_scores.nrows();
418 let d = scores.len();
419 let k_nn = (*k).min(n_train);
420
421 let mut dists: Vec<(f64, usize)> = (0..n_train)
422 .map(|j| {
423 let d_sq: f64 = (0..d)
424 .map(|c| (scores[c] - training_scores[(j, c)]).powi(2))
425 .sum();
426 (d_sq, training_labels[j])
427 })
428 .collect();
429 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
430
431 let mut votes = vec![0usize; g];
432 for &(_, label) in dists.iter().take(k_nn) {
433 if label < g {
434 votes[label] += 1;
435 }
436 }
437
438 if g == 2 {
439 votes[1] as f64 / k_nn as f64
441 } else {
442 votes
444 .iter()
445 .enumerate()
446 .max_by_key(|&(_, &v)| v)
447 .map_or(0.0, |(c, _)| c as f64)
448 }
449 }
450 }
451 }
452}
453
454pub(crate) fn classif_predict_probs(fit: &ClassifFit, scores: &FdMatrix) -> Vec<Vec<f64>> {
463 let n = scores.nrows();
464 let d = scores.ncols();
465 match &fit.method {
466 ClassifMethod::Lda {
467 class_means,
468 cov_chol,
469 priors,
470 n_classes,
471 } => {
472 let g = *n_classes;
473 (0..n)
474 .map(|i| {
475 let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
476 let disc: Vec<f64> = (0..g)
477 .map(|c| {
478 priors[c].max(1e-15).ln()
479 - 0.5 * mahalanobis_sq(&x, &class_means[c], cov_chol, d)
480 })
481 .collect();
482 softmax(&disc)
483 })
484 .collect()
485 }
486 ClassifMethod::Qda {
487 class_means,
488 class_chols,
489 class_log_dets,
490 priors,
491 n_classes,
492 } => {
493 let g = *n_classes;
494 (0..n)
495 .map(|i| {
496 let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
497 let disc: Vec<f64> = (0..g)
498 .map(|c| {
499 priors[c].max(1e-15).ln()
500 - 0.5
501 * (class_log_dets[c]
502 + mahalanobis_sq(&x, &class_means[c], &class_chols[c], d))
503 })
504 .collect();
505 softmax(&disc)
506 })
507 .collect()
508 }
509 ClassifMethod::Knn {
510 training_scores,
511 training_labels,
512 k,
513 n_classes,
514 } => {
515 let g = *n_classes;
516 let n_train = training_scores.nrows();
517 let k_nn = (*k).min(n_train);
518 (0..n)
519 .map(|i| {
520 let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
521 let mut dists: Vec<(f64, usize)> = (0..n_train)
522 .map(|j| {
523 let d_sq: f64 = (0..d)
524 .map(|c| (x[c] - training_scores[(j, c)]).powi(2))
525 .sum();
526 (d_sq, training_labels[j])
527 })
528 .collect();
529 dists
530 .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
531 let mut votes = vec![0usize; g];
532 for &(_, label) in dists.iter().take(k_nn) {
533 if label < g {
534 votes[label] += 1;
535 }
536 }
537 votes.iter().map(|&v| v as f64 / k_nn as f64).collect()
538 })
539 .collect()
540 }
541 }
542}
543
544fn softmax(scores: &[f64]) -> Vec<f64> {
546 let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
547 let exps: Vec<f64> = scores.iter().map(|&s| (s - max_s).exp()).collect();
548 let sum: f64 = exps.iter().sum();
549 exps.iter().map(|&e| e / sum).collect()
550}
551
552#[derive(Debug, Clone, PartialEq)]
557pub struct ClassifCvConfig {
558 pub method: String,
560 pub ncomp: usize,
562 pub nfold: usize,
564 pub seed: u64,
566}
567
568impl Default for ClassifCvConfig {
569 fn default() -> Self {
570 Self {
571 method: "lda".to_string(),
572 ncomp: 3,
573 nfold: 5,
574 seed: 42,
575 }
576 }
577}
578
579#[must_use = "expensive computation whose result should not be discarded"]
588pub fn fclassif_cv_with_config(
589 data: &FdMatrix,
590 argvals: &[f64],
591 y: &[usize],
592 scalar_covariates: Option<&FdMatrix>,
593 config: &ClassifCvConfig,
594) -> Result<ClassifCvResult, FdarError> {
595 fclassif_cv(
596 data,
597 argvals,
598 y,
599 scalar_covariates,
600 &config.method,
601 config.ncomp,
602 config.nfold,
603 config.seed,
604 )
605}