1use crate::error::FdarError;
4use crate::explain_generic::{FpcPredictor, TaskType};
5use crate::matrix::FdMatrix;
6
7use super::knn::knn_predict_loo;
8use super::lda::{lda_params, lda_predict};
9use super::qda::{build_qda_params, qda_predict};
10use super::{
11 build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifCvResult,
12 ClassifResult,
13};
14use crate::linalg::{cholesky_d, mahalanobis_sq};
15
16use super::cv::fclassif_cv;
17
18#[derive(Debug, Clone, PartialEq)]
20pub enum ClassifMethod {
21 Lda {
23 class_means: Vec<Vec<f64>>,
24 cov_chol: Vec<f64>,
25 priors: Vec<f64>,
26 n_classes: usize,
27 },
28 Qda {
30 class_means: Vec<Vec<f64>>,
31 class_chols: Vec<Vec<f64>>,
32 class_log_dets: Vec<f64>,
33 priors: Vec<f64>,
34 n_classes: usize,
35 },
36 Knn {
38 training_scores: FdMatrix,
39 training_labels: Vec<usize>,
40 k: usize,
41 n_classes: usize,
42 },
43}
44
45#[derive(Debug, Clone, PartialEq)]
47pub struct ClassifFit {
48 pub result: ClassifResult,
50 pub fpca_mean: Vec<f64>,
52 pub fpca_rotation: FdMatrix,
54 pub fpca_scores: FdMatrix,
56 pub ncomp: usize,
58 pub method: ClassifMethod,
60}
61
62#[must_use = "expensive computation whose result should not be discarded"]
72pub fn fclassif_lda_fit(
73 data: &FdMatrix,
74 y: &[usize],
75 scalar_covariates: Option<&FdMatrix>,
76 ncomp: usize,
77) -> Result<ClassifFit, FdarError> {
78 let n = data.nrows();
79 if n == 0 || y.len() != n {
80 return Err(FdarError::InvalidDimension {
81 parameter: "data/y",
82 expected: "n > 0 and y.len() == n".to_string(),
83 actual: format!("n={}, y.len()={}", n, y.len()),
84 });
85 }
86 if ncomp == 0 {
87 return Err(FdarError::InvalidParameter {
88 parameter: "ncomp",
89 message: "must be > 0".to_string(),
90 });
91 }
92
93 let (labels, g) = remap_labels(y);
94 if g < 2 {
95 return Err(FdarError::InvalidParameter {
96 parameter: "y",
97 message: format!("need at least 2 classes, got {g}"),
98 });
99 }
100
101 let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
104 let _ = scalar_covariates; let d = features.ncols();
106 let (class_means, cov, priors) = lda_params(&features, &labels, g);
107 let chol = cholesky_d(&cov, d)?;
108
109 let predicted = lda_predict(&features, &class_means, &chol, &priors, g);
110 let accuracy = compute_accuracy(&labels, &predicted);
111 let confusion = confusion_matrix(&labels, &predicted, g);
112
113 Ok(ClassifFit {
114 result: ClassifResult {
115 predicted,
116 probabilities: None,
117 accuracy,
118 confusion,
119 n_classes: g,
120 ncomp: d,
121 },
122 fpca_mean: mean.clone(),
123 fpca_rotation: rotation,
124 fpca_scores: features,
125 ncomp: d,
126 method: ClassifMethod::Lda {
127 class_means,
128 cov_chol: chol,
129 priors,
130 n_classes: g,
131 },
132 })
133}
134
135#[must_use = "expensive computation whose result should not be discarded"]
145pub fn fclassif_qda_fit(
146 data: &FdMatrix,
147 y: &[usize],
148 scalar_covariates: Option<&FdMatrix>,
149 ncomp: usize,
150) -> Result<ClassifFit, FdarError> {
151 let n = data.nrows();
152 if n == 0 || y.len() != n {
153 return Err(FdarError::InvalidDimension {
154 parameter: "data/y",
155 expected: "n > 0 and y.len() == n".to_string(),
156 actual: format!("n={}, y.len()={}", n, y.len()),
157 });
158 }
159 if ncomp == 0 {
160 return Err(FdarError::InvalidParameter {
161 parameter: "ncomp",
162 message: "must be > 0".to_string(),
163 });
164 }
165
166 let (labels, g) = remap_labels(y);
167 if g < 2 {
168 return Err(FdarError::InvalidParameter {
169 parameter: "y",
170 message: format!("need at least 2 classes, got {g}"),
171 });
172 }
173
174 let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
176 let _ = scalar_covariates;
177 let (class_means, class_chols, class_log_dets, priors) =
178 build_qda_params(&features, &labels, g)?;
179
180 let predicted = qda_predict(
181 &features,
182 &class_means,
183 &class_chols,
184 &class_log_dets,
185 &priors,
186 g,
187 );
188 let accuracy = compute_accuracy(&labels, &predicted);
189 let confusion = confusion_matrix(&labels, &predicted, g);
190 let d = features.ncols();
191
192 Ok(ClassifFit {
193 result: ClassifResult {
194 predicted,
195 probabilities: None,
196 accuracy,
197 confusion,
198 n_classes: g,
199 ncomp: d,
200 },
201 fpca_mean: mean.clone(),
202 fpca_rotation: rotation,
203 fpca_scores: features,
204 ncomp: d,
205 method: ClassifMethod::Qda {
206 class_means,
207 class_chols,
208 class_log_dets,
209 priors,
210 n_classes: g,
211 },
212 })
213}
214
215#[must_use = "expensive computation whose result should not be discarded"]
225pub fn fclassif_knn_fit(
226 data: &FdMatrix,
227 y: &[usize],
228 scalar_covariates: Option<&FdMatrix>,
229 ncomp: usize,
230 k_nn: usize,
231) -> Result<ClassifFit, FdarError> {
232 let n = data.nrows();
233 if n == 0 || y.len() != n {
234 return Err(FdarError::InvalidDimension {
235 parameter: "data/y",
236 expected: "n > 0 and y.len() == n".to_string(),
237 actual: format!("n={}, y.len()={}", n, y.len()),
238 });
239 }
240 if ncomp == 0 {
241 return Err(FdarError::InvalidParameter {
242 parameter: "ncomp",
243 message: "must be > 0".to_string(),
244 });
245 }
246 if k_nn == 0 {
247 return Err(FdarError::InvalidParameter {
248 parameter: "k_nn",
249 message: "must be > 0".to_string(),
250 });
251 }
252
253 let (labels, g) = remap_labels(y);
254 if g < 2 {
255 return Err(FdarError::InvalidParameter {
256 parameter: "y",
257 message: format!("need at least 2 classes, got {g}"),
258 });
259 }
260
261 let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
263 let _ = scalar_covariates;
264 let d = features.ncols();
265
266 let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
267 let accuracy = compute_accuracy(&labels, &predicted);
268 let confusion = confusion_matrix(&labels, &predicted, g);
269
270 Ok(ClassifFit {
271 result: ClassifResult {
272 predicted,
273 probabilities: None,
274 accuracy,
275 confusion,
276 n_classes: g,
277 ncomp: d,
278 },
279 fpca_mean: mean.clone(),
280 fpca_rotation: rotation,
281 fpca_scores: features.clone(),
282 ncomp: d,
283 method: ClassifMethod::Knn {
284 training_scores: features,
285 training_labels: labels,
286 k: k_nn,
287 n_classes: g,
288 },
289 })
290}
291
292impl FpcPredictor for ClassifFit {
297 fn fpca_mean(&self) -> &[f64] {
298 &self.fpca_mean
299 }
300
301 fn fpca_rotation(&self) -> &FdMatrix {
302 &self.fpca_rotation
303 }
304
305 fn ncomp(&self) -> usize {
306 self.ncomp
307 }
308
309 fn training_scores(&self) -> &FdMatrix {
310 &self.fpca_scores
311 }
312
313 fn task_type(&self) -> TaskType {
314 match &self.method {
315 ClassifMethod::Lda { n_classes, .. }
316 | ClassifMethod::Qda { n_classes, .. }
317 | ClassifMethod::Knn { n_classes, .. } => {
318 if *n_classes == 2 {
319 TaskType::BinaryClassification
320 } else {
321 TaskType::MulticlassClassification(*n_classes)
322 }
323 }
324 }
325 }
326
327 fn predict_from_scores(&self, scores: &[f64], _scalar_covariates: Option<&[f64]>) -> f64 {
328 match &self.method {
329 ClassifMethod::Lda {
330 class_means,
331 cov_chol,
332 priors,
333 n_classes,
334 } => {
335 let g = *n_classes;
336 let d = scores.len();
337 if g == 2 {
338 let score0 = priors[0].max(1e-15).ln()
340 - 0.5 * mahalanobis_sq(scores, &class_means[0], cov_chol, d);
341 let score1 = priors[1].max(1e-15).ln()
342 - 0.5 * mahalanobis_sq(scores, &class_means[1], cov_chol, d);
343 let max_s = score0.max(score1);
344 let exp0 = (score0 - max_s).exp();
345 let exp1 = (score1 - max_s).exp();
346 exp1 / (exp0 + exp1)
347 } else {
348 let mut best_class = 0;
350 let mut best_score = f64::NEG_INFINITY;
351 for c in 0..g {
352 let maha = mahalanobis_sq(scores, &class_means[c], cov_chol, d);
353 let s = priors[c].max(1e-15).ln() - 0.5 * maha;
354 if s > best_score {
355 best_score = s;
356 best_class = c;
357 }
358 }
359 best_class as f64
360 }
361 }
362 ClassifMethod::Qda {
363 class_means,
364 class_chols,
365 class_log_dets,
366 priors,
367 n_classes,
368 } => {
369 let g = *n_classes;
370 let d = scores.len();
371 if g == 2 {
372 let score0 = priors[0].max(1e-15).ln()
374 - 0.5
375 * (class_log_dets[0]
376 + mahalanobis_sq(scores, &class_means[0], &class_chols[0], d));
377 let score1 = priors[1].max(1e-15).ln()
378 - 0.5
379 * (class_log_dets[1]
380 + mahalanobis_sq(scores, &class_means[1], &class_chols[1], d));
381 let max_s = score0.max(score1);
382 let exp0 = (score0 - max_s).exp();
383 let exp1 = (score1 - max_s).exp();
384 exp1 / (exp0 + exp1)
385 } else {
386 let mut best_class = 0;
387 let mut best_score = f64::NEG_INFINITY;
388 for c in 0..g {
389 let maha = mahalanobis_sq(scores, &class_means[c], &class_chols[c], d);
390 let s = priors[c].max(1e-15).ln() - 0.5 * (class_log_dets[c] + maha);
391 if s > best_score {
392 best_score = s;
393 best_class = c;
394 }
395 }
396 best_class as f64
397 }
398 }
399 ClassifMethod::Knn {
400 training_scores,
401 training_labels,
402 k,
403 n_classes,
404 } => {
405 let g = *n_classes;
406 let n_train = training_scores.nrows();
407 let d = scores.len();
408 let k_nn = (*k).min(n_train);
409
410 let mut dists: Vec<(f64, usize)> = (0..n_train)
411 .map(|j| {
412 let d_sq: f64 = (0..d)
413 .map(|c| (scores[c] - training_scores[(j, c)]).powi(2))
414 .sum();
415 (d_sq, training_labels[j])
416 })
417 .collect();
418 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
419
420 let mut votes = vec![0usize; g];
421 for &(_, label) in dists.iter().take(k_nn) {
422 if label < g {
423 votes[label] += 1;
424 }
425 }
426
427 if g == 2 {
428 votes[1] as f64 / k_nn as f64
430 } else {
431 votes
433 .iter()
434 .enumerate()
435 .max_by_key(|&(_, &v)| v)
436 .map_or(0.0, |(c, _)| c as f64)
437 }
438 }
439 }
440 }
441}
442
443pub(crate) fn classif_predict_probs(fit: &ClassifFit, scores: &FdMatrix) -> Vec<Vec<f64>> {
452 let n = scores.nrows();
453 let d = scores.ncols();
454 match &fit.method {
455 ClassifMethod::Lda {
456 class_means,
457 cov_chol,
458 priors,
459 n_classes,
460 } => {
461 let g = *n_classes;
462 (0..n)
463 .map(|i| {
464 let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
465 let disc: Vec<f64> = (0..g)
466 .map(|c| {
467 priors[c].max(1e-15).ln()
468 - 0.5 * mahalanobis_sq(&x, &class_means[c], cov_chol, d)
469 })
470 .collect();
471 softmax(&disc)
472 })
473 .collect()
474 }
475 ClassifMethod::Qda {
476 class_means,
477 class_chols,
478 class_log_dets,
479 priors,
480 n_classes,
481 } => {
482 let g = *n_classes;
483 (0..n)
484 .map(|i| {
485 let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
486 let disc: Vec<f64> = (0..g)
487 .map(|c| {
488 priors[c].max(1e-15).ln()
489 - 0.5
490 * (class_log_dets[c]
491 + mahalanobis_sq(&x, &class_means[c], &class_chols[c], d))
492 })
493 .collect();
494 softmax(&disc)
495 })
496 .collect()
497 }
498 ClassifMethod::Knn {
499 training_scores,
500 training_labels,
501 k,
502 n_classes,
503 } => {
504 let g = *n_classes;
505 let n_train = training_scores.nrows();
506 let k_nn = (*k).min(n_train);
507 (0..n)
508 .map(|i| {
509 let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
510 let mut dists: Vec<(f64, usize)> = (0..n_train)
511 .map(|j| {
512 let d_sq: f64 = (0..d)
513 .map(|c| (x[c] - training_scores[(j, c)]).powi(2))
514 .sum();
515 (d_sq, training_labels[j])
516 })
517 .collect();
518 dists
519 .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
520 let mut votes = vec![0usize; g];
521 for &(_, label) in dists.iter().take(k_nn) {
522 if label < g {
523 votes[label] += 1;
524 }
525 }
526 votes.iter().map(|&v| v as f64 / k_nn as f64).collect()
527 })
528 .collect()
529 }
530 }
531}
532
533fn softmax(scores: &[f64]) -> Vec<f64> {
535 let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
536 let exps: Vec<f64> = scores.iter().map(|&s| (s - max_s).exp()).collect();
537 let sum: f64 = exps.iter().sum();
538 exps.iter().map(|&e| e / sum).collect()
539}
540
541#[derive(Debug, Clone, PartialEq)]
546pub struct ClassifCvConfig {
547 pub method: String,
549 pub ncomp: usize,
551 pub nfold: usize,
553 pub seed: u64,
555}
556
557impl Default for ClassifCvConfig {
558 fn default() -> Self {
559 Self {
560 method: "lda".to_string(),
561 ncomp: 3,
562 nfold: 5,
563 seed: 42,
564 }
565 }
566}
567
568#[must_use = "expensive computation whose result should not be discarded"]
577pub fn fclassif_cv_with_config(
578 data: &FdMatrix,
579 argvals: &[f64],
580 y: &[usize],
581 scalar_covariates: Option<&FdMatrix>,
582 config: &ClassifCvConfig,
583) -> Result<ClassifCvResult, FdarError> {
584 fclassif_cv(
585 data,
586 argvals,
587 y,
588 scalar_covariates,
589 &config.method,
590 config.ncomp,
591 config.nfold,
592 config.seed,
593 )
594}