Skip to main content

fdars_core/classification/
fit.rs

1//! ClassifFit: fitted classification model with explainability support.
2
3use crate::error::FdarError;
4use crate::explain_generic::{FpcPredictor, TaskType};
5use crate::matrix::FdMatrix;
6
7use super::knn::knn_predict_loo;
8use super::lda::{lda_params, lda_predict};
9use super::qda::{build_qda_params, qda_predict};
10use super::{
11    build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifCvResult,
12    ClassifResult,
13};
14use crate::linalg::{cholesky_d, mahalanobis_sq};
15
16use super::cv::fclassif_cv;
17
18/// Classification method with stored parameters for prediction.
19#[derive(Debug, Clone, PartialEq)]
20pub enum ClassifMethod {
21    /// Linear Discriminant Analysis.
22    Lda {
23        class_means: Vec<Vec<f64>>,
24        cov_chol: Vec<f64>,
25        priors: Vec<f64>,
26        n_classes: usize,
27    },
28    /// Quadratic Discriminant Analysis.
29    Qda {
30        class_means: Vec<Vec<f64>>,
31        class_chols: Vec<Vec<f64>>,
32        class_log_dets: Vec<f64>,
33        priors: Vec<f64>,
34        n_classes: usize,
35    },
36    /// k-Nearest Neighbors.
37    Knn {
38        training_scores: FdMatrix,
39        training_labels: Vec<usize>,
40        k: usize,
41        n_classes: usize,
42    },
43}
44
45/// A fitted classification model that retains FPCA components for explainability.
46#[derive(Debug, Clone, PartialEq)]
47pub struct ClassifFit {
48    /// Classification result (predicted labels, accuracy, confusion matrix).
49    pub result: ClassifResult,
50    /// FPCA mean function (length m).
51    pub fpca_mean: Vec<f64>,
52    /// FPCA rotation matrix (m × ncomp).
53    pub fpca_rotation: FdMatrix,
54    /// FPCA scores (n × ncomp).
55    pub fpca_scores: FdMatrix,
56    /// Number of FPC components used.
57    pub ncomp: usize,
58    /// The classification method with stored parameters.
59    pub method: ClassifMethod,
60}
61
62/// FPC + LDA classification, retaining FPCA and LDA parameters for explainability.
63///
64/// # Errors
65///
66/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
67/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
68/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
69/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
70/// Returns [`FdarError::ComputationFailed`] if the pooled covariance Cholesky factorization fails.
71#[must_use = "expensive computation whose result should not be discarded"]
72pub fn fclassif_lda_fit(
73    data: &FdMatrix,
74    y: &[usize],
75    scalar_covariates: Option<&FdMatrix>,
76    ncomp: usize,
77) -> Result<ClassifFit, FdarError> {
78    let n = data.nrows();
79    if n == 0 || y.len() != n {
80        return Err(FdarError::InvalidDimension {
81            parameter: "data/y",
82            expected: "n > 0 and y.len() == n".to_string(),
83            actual: format!("n={}, y.len()={}", n, y.len()),
84        });
85    }
86    if ncomp == 0 {
87        return Err(FdarError::InvalidParameter {
88            parameter: "ncomp",
89            message: "must be > 0".to_string(),
90        });
91    }
92
93    let (labels, g) = remap_labels(y);
94    if g < 2 {
95        return Err(FdarError::InvalidParameter {
96            parameter: "y",
97            message: format!("need at least 2 classes, got {g}"),
98        });
99    }
100
101    // _fit variants use FPCA-only features (no scalar_covariates) so that stored
102    // dimensions are consistent with FpcPredictor::project() / predict_from_scores().
103    let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
104    let _ = scalar_covariates; // acknowledged but not used — see docstring
105    let d = features.ncols();
106    let (class_means, cov, priors) = lda_params(&features, &labels, g);
107    let chol = cholesky_d(&cov, d)?;
108
109    let predicted = lda_predict(&features, &class_means, &chol, &priors, g);
110    let accuracy = compute_accuracy(&labels, &predicted);
111    let confusion = confusion_matrix(&labels, &predicted, g);
112
113    Ok(ClassifFit {
114        result: ClassifResult {
115            predicted,
116            probabilities: None,
117            accuracy,
118            confusion,
119            n_classes: g,
120            ncomp: d,
121        },
122        fpca_mean: mean.clone(),
123        fpca_rotation: rotation,
124        fpca_scores: features,
125        ncomp: d,
126        method: ClassifMethod::Lda {
127            class_means,
128            cov_chol: chol,
129            priors,
130            n_classes: g,
131        },
132    })
133}
134
135/// FPC + QDA classification, retaining FPCA and QDA parameters for explainability.
136///
137/// # Errors
138///
139/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
140/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
141/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
142/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
143/// Returns [`FdarError::ComputationFailed`] if a per-class covariance Cholesky factorization fails.
144#[must_use = "expensive computation whose result should not be discarded"]
145pub fn fclassif_qda_fit(
146    data: &FdMatrix,
147    y: &[usize],
148    scalar_covariates: Option<&FdMatrix>,
149    ncomp: usize,
150) -> Result<ClassifFit, FdarError> {
151    let n = data.nrows();
152    if n == 0 || y.len() != n {
153        return Err(FdarError::InvalidDimension {
154            parameter: "data/y",
155            expected: "n > 0 and y.len() == n".to_string(),
156            actual: format!("n={}, y.len()={}", n, y.len()),
157        });
158    }
159    if ncomp == 0 {
160        return Err(FdarError::InvalidParameter {
161            parameter: "ncomp",
162            message: "must be > 0".to_string(),
163        });
164    }
165
166    let (labels, g) = remap_labels(y);
167    if g < 2 {
168        return Err(FdarError::InvalidParameter {
169            parameter: "y",
170            message: format!("need at least 2 classes, got {g}"),
171        });
172    }
173
174    // _fit variants use FPCA-only features — see fclassif_lda_fit comment.
175    let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
176    let _ = scalar_covariates;
177    let (class_means, class_chols, class_log_dets, priors) =
178        build_qda_params(&features, &labels, g)?;
179
180    let predicted = qda_predict(
181        &features,
182        &class_means,
183        &class_chols,
184        &class_log_dets,
185        &priors,
186        g,
187    );
188    let accuracy = compute_accuracy(&labels, &predicted);
189    let confusion = confusion_matrix(&labels, &predicted, g);
190    let d = features.ncols();
191
192    Ok(ClassifFit {
193        result: ClassifResult {
194            predicted,
195            probabilities: None,
196            accuracy,
197            confusion,
198            n_classes: g,
199            ncomp: d,
200        },
201        fpca_mean: mean.clone(),
202        fpca_rotation: rotation,
203        fpca_scores: features,
204        ncomp: d,
205        method: ClassifMethod::Qda {
206            class_means,
207            class_chols,
208            class_log_dets,
209            priors,
210            n_classes: g,
211        },
212    })
213}
214
215/// FPC + k-NN classification, retaining FPCA and training data for explainability.
216///
217/// # Errors
218///
219/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
220/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
221/// Returns [`FdarError::InvalidParameter`] if `k_nn` is zero.
222/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
223/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
224#[must_use = "expensive computation whose result should not be discarded"]
225pub fn fclassif_knn_fit(
226    data: &FdMatrix,
227    y: &[usize],
228    scalar_covariates: Option<&FdMatrix>,
229    ncomp: usize,
230    k_nn: usize,
231) -> Result<ClassifFit, FdarError> {
232    let n = data.nrows();
233    if n == 0 || y.len() != n {
234        return Err(FdarError::InvalidDimension {
235            parameter: "data/y",
236            expected: "n > 0 and y.len() == n".to_string(),
237            actual: format!("n={}, y.len()={}", n, y.len()),
238        });
239    }
240    if ncomp == 0 {
241        return Err(FdarError::InvalidParameter {
242            parameter: "ncomp",
243            message: "must be > 0".to_string(),
244        });
245    }
246    if k_nn == 0 {
247        return Err(FdarError::InvalidParameter {
248            parameter: "k_nn",
249            message: "must be > 0".to_string(),
250        });
251    }
252
253    let (labels, g) = remap_labels(y);
254    if g < 2 {
255        return Err(FdarError::InvalidParameter {
256            parameter: "y",
257            message: format!("need at least 2 classes, got {g}"),
258        });
259    }
260
261    // _fit variants use FPCA-only features — see fclassif_lda_fit comment.
262    let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
263    let _ = scalar_covariates;
264    let d = features.ncols();
265
266    let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
267    let accuracy = compute_accuracy(&labels, &predicted);
268    let confusion = confusion_matrix(&labels, &predicted, g);
269
270    Ok(ClassifFit {
271        result: ClassifResult {
272            predicted,
273            probabilities: None,
274            accuracy,
275            confusion,
276            n_classes: g,
277            ncomp: d,
278        },
279        fpca_mean: mean.clone(),
280        fpca_rotation: rotation,
281        fpca_scores: features.clone(),
282        ncomp: d,
283        method: ClassifMethod::Knn {
284            training_scores: features,
285            training_labels: labels,
286            k: k_nn,
287            n_classes: g,
288        },
289    })
290}
291
292// ---------------------------------------------------------------------------
293// FpcPredictor impl for ClassifFit
294// ---------------------------------------------------------------------------
295
296impl FpcPredictor for ClassifFit {
297    fn fpca_mean(&self) -> &[f64] {
298        &self.fpca_mean
299    }
300
301    fn fpca_rotation(&self) -> &FdMatrix {
302        &self.fpca_rotation
303    }
304
305    fn ncomp(&self) -> usize {
306        self.ncomp
307    }
308
309    fn training_scores(&self) -> &FdMatrix {
310        &self.fpca_scores
311    }
312
313    fn task_type(&self) -> TaskType {
314        match &self.method {
315            ClassifMethod::Lda { n_classes, .. }
316            | ClassifMethod::Qda { n_classes, .. }
317            | ClassifMethod::Knn { n_classes, .. } => {
318                if *n_classes == 2 {
319                    TaskType::BinaryClassification
320                } else {
321                    TaskType::MulticlassClassification(*n_classes)
322                }
323            }
324        }
325    }
326
327    fn predict_from_scores(&self, scores: &[f64], _scalar_covariates: Option<&[f64]>) -> f64 {
328        match &self.method {
329            ClassifMethod::Lda {
330                class_means,
331                cov_chol,
332                priors,
333                n_classes,
334            } => {
335                let g = *n_classes;
336                let d = scores.len();
337                if g == 2 {
338                    // Return P(Y=1) via softmax of discriminant scores
339                    let score0 = priors[0].max(1e-15).ln()
340                        - 0.5 * mahalanobis_sq(scores, &class_means[0], cov_chol, d);
341                    let score1 = priors[1].max(1e-15).ln()
342                        - 0.5 * mahalanobis_sq(scores, &class_means[1], cov_chol, d);
343                    let max_s = score0.max(score1);
344                    let exp0 = (score0 - max_s).exp();
345                    let exp1 = (score1 - max_s).exp();
346                    exp1 / (exp0 + exp1)
347                } else {
348                    // Return predicted class as f64
349                    let mut best_class = 0;
350                    let mut best_score = f64::NEG_INFINITY;
351                    for c in 0..g {
352                        let maha = mahalanobis_sq(scores, &class_means[c], cov_chol, d);
353                        let s = priors[c].max(1e-15).ln() - 0.5 * maha;
354                        if s > best_score {
355                            best_score = s;
356                            best_class = c;
357                        }
358                    }
359                    best_class as f64
360                }
361            }
362            ClassifMethod::Qda {
363                class_means,
364                class_chols,
365                class_log_dets,
366                priors,
367                n_classes,
368            } => {
369                let g = *n_classes;
370                let d = scores.len();
371                if g == 2 {
372                    // Return P(Y=1) via softmax of discriminant scores
373                    let score0 = priors[0].max(1e-15).ln()
374                        - 0.5
375                            * (class_log_dets[0]
376                                + mahalanobis_sq(scores, &class_means[0], &class_chols[0], d));
377                    let score1 = priors[1].max(1e-15).ln()
378                        - 0.5
379                            * (class_log_dets[1]
380                                + mahalanobis_sq(scores, &class_means[1], &class_chols[1], d));
381                    let max_s = score0.max(score1);
382                    let exp0 = (score0 - max_s).exp();
383                    let exp1 = (score1 - max_s).exp();
384                    exp1 / (exp0 + exp1)
385                } else {
386                    let mut best_class = 0;
387                    let mut best_score = f64::NEG_INFINITY;
388                    for c in 0..g {
389                        let maha = mahalanobis_sq(scores, &class_means[c], &class_chols[c], d);
390                        let s = priors[c].max(1e-15).ln() - 0.5 * (class_log_dets[c] + maha);
391                        if s > best_score {
392                            best_score = s;
393                            best_class = c;
394                        }
395                    }
396                    best_class as f64
397                }
398            }
399            ClassifMethod::Knn {
400                training_scores,
401                training_labels,
402                k,
403                n_classes,
404            } => {
405                let g = *n_classes;
406                let n_train = training_scores.nrows();
407                let d = scores.len();
408                let k_nn = (*k).min(n_train);
409
410                let mut dists: Vec<(f64, usize)> = (0..n_train)
411                    .map(|j| {
412                        let d_sq: f64 = (0..d)
413                            .map(|c| (scores[c] - training_scores[(j, c)]).powi(2))
414                            .sum();
415                        (d_sq, training_labels[j])
416                    })
417                    .collect();
418                dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
419
420                let mut votes = vec![0usize; g];
421                for &(_, label) in dists.iter().take(k_nn) {
422                    if label < g {
423                        votes[label] += 1;
424                    }
425                }
426
427                if g == 2 {
428                    // Return proportion voting for class 1 as probability
429                    votes[1] as f64 / k_nn as f64
430                } else {
431                    // Return majority vote class as f64
432                    votes
433                        .iter()
434                        .enumerate()
435                        .max_by_key(|&(_, &v)| v)
436                        .map_or(0.0, |(c, _)| c as f64)
437                }
438            }
439        }
440    }
441}
442
443// ---------------------------------------------------------------------------
444// Class probability vectors (for conformal prediction sets)
445// ---------------------------------------------------------------------------
446
447/// Compute full class probability vectors for each observation.
448///
449/// Returns `n × g` probability vectors suitable for conformal classification.
450/// For each observation, the probabilities sum to 1.
451pub(crate) fn classif_predict_probs(fit: &ClassifFit, scores: &FdMatrix) -> Vec<Vec<f64>> {
452    let n = scores.nrows();
453    let d = scores.ncols();
454    match &fit.method {
455        ClassifMethod::Lda {
456            class_means,
457            cov_chol,
458            priors,
459            n_classes,
460        } => {
461            let g = *n_classes;
462            (0..n)
463                .map(|i| {
464                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
465                    let disc: Vec<f64> = (0..g)
466                        .map(|c| {
467                            priors[c].max(1e-15).ln()
468                                - 0.5 * mahalanobis_sq(&x, &class_means[c], cov_chol, d)
469                        })
470                        .collect();
471                    softmax(&disc)
472                })
473                .collect()
474        }
475        ClassifMethod::Qda {
476            class_means,
477            class_chols,
478            class_log_dets,
479            priors,
480            n_classes,
481        } => {
482            let g = *n_classes;
483            (0..n)
484                .map(|i| {
485                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
486                    let disc: Vec<f64> = (0..g)
487                        .map(|c| {
488                            priors[c].max(1e-15).ln()
489                                - 0.5
490                                    * (class_log_dets[c]
491                                        + mahalanobis_sq(&x, &class_means[c], &class_chols[c], d))
492                        })
493                        .collect();
494                    softmax(&disc)
495                })
496                .collect()
497        }
498        ClassifMethod::Knn {
499            training_scores,
500            training_labels,
501            k,
502            n_classes,
503        } => {
504            let g = *n_classes;
505            let n_train = training_scores.nrows();
506            let k_nn = (*k).min(n_train);
507            (0..n)
508                .map(|i| {
509                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
510                    let mut dists: Vec<(f64, usize)> = (0..n_train)
511                        .map(|j| {
512                            let d_sq: f64 = (0..d)
513                                .map(|c| (x[c] - training_scores[(j, c)]).powi(2))
514                                .sum();
515                            (d_sq, training_labels[j])
516                        })
517                        .collect();
518                    dists
519                        .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
520                    let mut votes = vec![0usize; g];
521                    for &(_, label) in dists.iter().take(k_nn) {
522                        if label < g {
523                            votes[label] += 1;
524                        }
525                    }
526                    votes.iter().map(|&v| v as f64 / k_nn as f64).collect()
527                })
528                .collect()
529        }
530    }
531}
532
533/// Softmax of a vector of log-scores → probabilities.
534fn softmax(scores: &[f64]) -> Vec<f64> {
535    let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
536    let exps: Vec<f64> = scores.iter().map(|&s| (s - max_s).exp()).collect();
537    let sum: f64 = exps.iter().sum();
538    exps.iter().map(|&e| e / sum).collect()
539}
540
541// ---------------------------------------------------------------------------
542// ─── Config-based API ───────────────────────────────────────────────────────
543
544/// Configuration for [`fclassif_cv`].
545#[derive(Debug, Clone, PartialEq)]
546pub struct ClassifCvConfig {
547    /// Classification method name (one of "lda", "qda", "knn", "kernel", "dd").
548    pub method: String,
549    /// Number of FPC components.
550    pub ncomp: usize,
551    /// Number of cross-validation folds.
552    pub nfold: usize,
553    /// Random seed for fold assignment.
554    pub seed: u64,
555}
556
557impl Default for ClassifCvConfig {
558    fn default() -> Self {
559        Self {
560            method: "lda".to_string(),
561            ncomp: 3,
562            nfold: 5,
563            seed: 42,
564        }
565    }
566}
567
568/// Cross-validated classification using a configuration struct.
569///
570/// Equivalent to [`fclassif_cv`] but bundles method parameters in [`ClassifCvConfig`].
571///
572/// # Errors
573///
574/// Returns [`FdarError::InvalidParameter`] if `config.nfold < 2` or `config.nfold > n`.
575/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
576#[must_use = "expensive computation whose result should not be discarded"]
577pub fn fclassif_cv_with_config(
578    data: &FdMatrix,
579    argvals: &[f64],
580    y: &[usize],
581    scalar_covariates: Option<&FdMatrix>,
582    config: &ClassifCvConfig,
583) -> Result<ClassifCvResult, FdarError> {
584    fclassif_cv(
585        data,
586        argvals,
587        y,
588        scalar_covariates,
589        &config.method,
590        config.ncomp,
591        config.nfold,
592        config.seed,
593    )
594}