Skip to main content

fdars_core/classification/
fit.rs

1//! ClassifFit: fitted classification model with explainability support.
2
3use crate::error::FdarError;
4use crate::explain_generic::{FpcPredictor, TaskType};
5use crate::matrix::FdMatrix;
6
7use super::knn::knn_predict_loo;
8use super::lda::{lda_params, lda_predict};
9use super::qda::{build_qda_params, qda_predict};
10use super::{
11    build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifCvResult,
12    ClassifResult,
13};
14use crate::linalg::{cholesky_d, mahalanobis_sq};
15
16use super::cv::fclassif_cv;
17
18/// Classification method with stored parameters for prediction.
19#[derive(Debug, Clone, PartialEq)]
20#[non_exhaustive]
21pub enum ClassifMethod {
22    /// Linear Discriminant Analysis.
23    Lda {
24        class_means: Vec<Vec<f64>>,
25        cov_chol: Vec<f64>,
26        priors: Vec<f64>,
27        n_classes: usize,
28    },
29    /// Quadratic Discriminant Analysis.
30    Qda {
31        class_means: Vec<Vec<f64>>,
32        class_chols: Vec<Vec<f64>>,
33        class_log_dets: Vec<f64>,
34        priors: Vec<f64>,
35        n_classes: usize,
36    },
37    /// k-Nearest Neighbors.
38    Knn {
39        training_scores: FdMatrix,
40        training_labels: Vec<usize>,
41        k: usize,
42        n_classes: usize,
43    },
44}
45
46/// A fitted classification model that retains FPCA components for explainability.
47#[derive(Debug, Clone, PartialEq)]
48#[non_exhaustive]
49pub struct ClassifFit {
50    /// Classification result (predicted labels, accuracy, confusion matrix).
51    pub result: ClassifResult,
52    /// FPCA mean function (length m).
53    pub fpca_mean: Vec<f64>,
54    /// FPCA rotation matrix (m × ncomp).
55    pub fpca_rotation: FdMatrix,
56    /// FPCA scores (n × ncomp).
57    pub fpca_scores: FdMatrix,
58    /// Number of FPC components used.
59    pub ncomp: usize,
60    /// The classification method with stored parameters.
61    pub method: ClassifMethod,
62    /// Integration weights from FPCA (length m).
63    pub fpca_int_weights: Vec<f64>,
64}
65
66/// FPC + LDA classification, retaining FPCA and LDA parameters for explainability.
67///
68/// # Errors
69///
70/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
71/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
72/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
73/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
74/// Returns [`FdarError::ComputationFailed`] if the pooled covariance Cholesky factorization fails.
75#[must_use = "expensive computation whose result should not be discarded"]
76pub fn fclassif_lda_fit(
77    data: &FdMatrix,
78    y: &[usize],
79    scalar_covariates: Option<&FdMatrix>,
80    ncomp: usize,
81) -> Result<ClassifFit, FdarError> {
82    let n = data.nrows();
83    if n == 0 || y.len() != n {
84        return Err(FdarError::InvalidDimension {
85            parameter: "data/y",
86            expected: "n > 0 and y.len() == n".to_string(),
87            actual: format!("n={}, y.len()={}", n, y.len()),
88        });
89    }
90    if ncomp == 0 {
91        return Err(FdarError::InvalidParameter {
92            parameter: "ncomp",
93            message: "must be > 0".to_string(),
94        });
95    }
96
97    let (labels, g) = remap_labels(y);
98    if g < 2 {
99        return Err(FdarError::InvalidParameter {
100            parameter: "y",
101            message: format!("need at least 2 classes, got {g}"),
102        });
103    }
104
105    // _fit variants use FPCA-only features (no scalar_covariates) so that stored
106    // dimensions are consistent with FpcPredictor::project() / predict_from_scores().
107    let (features, mean, rotation, int_weights) = build_feature_matrix(data, None, ncomp)?;
108    let _ = scalar_covariates; // acknowledged but not used — see docstring
109    let d = features.ncols();
110    let (class_means, cov, priors) = lda_params(&features, &labels, g);
111    let chol = cholesky_d(&cov, d)?;
112
113    let predicted = lda_predict(&features, &class_means, &chol, &priors, g);
114    let accuracy = compute_accuracy(&labels, &predicted);
115    let confusion = confusion_matrix(&labels, &predicted, g);
116
117    Ok(ClassifFit {
118        result: ClassifResult {
119            predicted,
120            probabilities: None,
121            accuracy,
122            confusion,
123            n_classes: g,
124            ncomp: d,
125        },
126        fpca_mean: mean.clone(),
127        fpca_rotation: rotation,
128        fpca_scores: features,
129        ncomp: d,
130        method: ClassifMethod::Lda {
131            class_means,
132            cov_chol: chol,
133            priors,
134            n_classes: g,
135        },
136        fpca_int_weights: int_weights,
137    })
138}
139
140/// FPC + QDA classification, retaining FPCA and QDA parameters for explainability.
141///
142/// # Errors
143///
144/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
145/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
146/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
147/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
148/// Returns [`FdarError::ComputationFailed`] if a per-class covariance Cholesky factorization fails.
149#[must_use = "expensive computation whose result should not be discarded"]
150pub fn fclassif_qda_fit(
151    data: &FdMatrix,
152    y: &[usize],
153    scalar_covariates: Option<&FdMatrix>,
154    ncomp: usize,
155) -> Result<ClassifFit, FdarError> {
156    let n = data.nrows();
157    if n == 0 || y.len() != n {
158        return Err(FdarError::InvalidDimension {
159            parameter: "data/y",
160            expected: "n > 0 and y.len() == n".to_string(),
161            actual: format!("n={}, y.len()={}", n, y.len()),
162        });
163    }
164    if ncomp == 0 {
165        return Err(FdarError::InvalidParameter {
166            parameter: "ncomp",
167            message: "must be > 0".to_string(),
168        });
169    }
170
171    let (labels, g) = remap_labels(y);
172    if g < 2 {
173        return Err(FdarError::InvalidParameter {
174            parameter: "y",
175            message: format!("need at least 2 classes, got {g}"),
176        });
177    }
178
179    // _fit variants use FPCA-only features — see fclassif_lda_fit comment.
180    let (features, mean, rotation, int_weights) = build_feature_matrix(data, None, ncomp)?;
181    let _ = scalar_covariates;
182    let (class_means, class_chols, class_log_dets, priors) =
183        build_qda_params(&features, &labels, g)?;
184
185    let predicted = qda_predict(
186        &features,
187        &class_means,
188        &class_chols,
189        &class_log_dets,
190        &priors,
191        g,
192    );
193    let accuracy = compute_accuracy(&labels, &predicted);
194    let confusion = confusion_matrix(&labels, &predicted, g);
195    let d = features.ncols();
196
197    Ok(ClassifFit {
198        result: ClassifResult {
199            predicted,
200            probabilities: None,
201            accuracy,
202            confusion,
203            n_classes: g,
204            ncomp: d,
205        },
206        fpca_mean: mean.clone(),
207        fpca_rotation: rotation,
208        fpca_scores: features,
209        ncomp: d,
210        method: ClassifMethod::Qda {
211            class_means,
212            class_chols,
213            class_log_dets,
214            priors,
215            n_classes: g,
216        },
217        fpca_int_weights: int_weights,
218    })
219}
220
221/// FPC + k-NN classification, retaining FPCA and training data for explainability.
222///
223/// # Errors
224///
225/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
226/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
227/// Returns [`FdarError::InvalidParameter`] if `k_nn` is zero.
228/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
229/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
230#[must_use = "expensive computation whose result should not be discarded"]
231pub fn fclassif_knn_fit(
232    data: &FdMatrix,
233    y: &[usize],
234    scalar_covariates: Option<&FdMatrix>,
235    ncomp: usize,
236    k_nn: usize,
237) -> Result<ClassifFit, FdarError> {
238    let n = data.nrows();
239    if n == 0 || y.len() != n {
240        return Err(FdarError::InvalidDimension {
241            parameter: "data/y",
242            expected: "n > 0 and y.len() == n".to_string(),
243            actual: format!("n={}, y.len()={}", n, y.len()),
244        });
245    }
246    if ncomp == 0 {
247        return Err(FdarError::InvalidParameter {
248            parameter: "ncomp",
249            message: "must be > 0".to_string(),
250        });
251    }
252    if k_nn == 0 {
253        return Err(FdarError::InvalidParameter {
254            parameter: "k_nn",
255            message: "must be > 0".to_string(),
256        });
257    }
258
259    let (labels, g) = remap_labels(y);
260    if g < 2 {
261        return Err(FdarError::InvalidParameter {
262            parameter: "y",
263            message: format!("need at least 2 classes, got {g}"),
264        });
265    }
266
267    // _fit variants use FPCA-only features — see fclassif_lda_fit comment.
268    let (features, mean, rotation, int_weights) = build_feature_matrix(data, None, ncomp)?;
269    let _ = scalar_covariates;
270    let d = features.ncols();
271
272    let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
273    let accuracy = compute_accuracy(&labels, &predicted);
274    let confusion = confusion_matrix(&labels, &predicted, g);
275
276    Ok(ClassifFit {
277        result: ClassifResult {
278            predicted,
279            probabilities: None,
280            accuracy,
281            confusion,
282            n_classes: g,
283            ncomp: d,
284        },
285        fpca_mean: mean.clone(),
286        fpca_rotation: rotation,
287        fpca_scores: features.clone(),
288        ncomp: d,
289        method: ClassifMethod::Knn {
290            training_scores: features,
291            training_labels: labels,
292            k: k_nn,
293            n_classes: g,
294        },
295        fpca_int_weights: int_weights,
296    })
297}
298
299// ---------------------------------------------------------------------------
300// FpcPredictor impl for ClassifFit
301// ---------------------------------------------------------------------------
302
303impl FpcPredictor for ClassifFit {
304    fn fpca_mean(&self) -> &[f64] {
305        &self.fpca_mean
306    }
307
308    fn fpca_rotation(&self) -> &FdMatrix {
309        &self.fpca_rotation
310    }
311
312    fn ncomp(&self) -> usize {
313        self.ncomp
314    }
315
316    fn training_scores(&self) -> &FdMatrix {
317        &self.fpca_scores
318    }
319
320    fn fpca_weights(&self) -> &[f64] {
321        &self.fpca_int_weights
322    }
323
324    fn task_type(&self) -> TaskType {
325        match &self.method {
326            ClassifMethod::Lda { n_classes, .. }
327            | ClassifMethod::Qda { n_classes, .. }
328            | ClassifMethod::Knn { n_classes, .. } => {
329                if *n_classes == 2 {
330                    TaskType::BinaryClassification
331                } else {
332                    TaskType::MulticlassClassification(*n_classes)
333                }
334            }
335        }
336    }
337
338    fn predict_from_scores(&self, scores: &[f64], _scalar_covariates: Option<&[f64]>) -> f64 {
339        match &self.method {
340            ClassifMethod::Lda {
341                class_means,
342                cov_chol,
343                priors,
344                n_classes,
345            } => {
346                let g = *n_classes;
347                let d = scores.len();
348                if g == 2 {
349                    // Return P(Y=1) via softmax of discriminant scores
350                    let score0 = priors[0].max(1e-15).ln()
351                        - 0.5 * mahalanobis_sq(scores, &class_means[0], cov_chol, d);
352                    let score1 = priors[1].max(1e-15).ln()
353                        - 0.5 * mahalanobis_sq(scores, &class_means[1], cov_chol, d);
354                    let max_s = score0.max(score1);
355                    let exp0 = (score0 - max_s).exp();
356                    let exp1 = (score1 - max_s).exp();
357                    exp1 / (exp0 + exp1)
358                } else {
359                    // Return predicted class as f64
360                    let mut best_class = 0;
361                    let mut best_score = f64::NEG_INFINITY;
362                    for c in 0..g {
363                        let maha = mahalanobis_sq(scores, &class_means[c], cov_chol, d);
364                        let s = priors[c].max(1e-15).ln() - 0.5 * maha;
365                        if s > best_score {
366                            best_score = s;
367                            best_class = c;
368                        }
369                    }
370                    best_class as f64
371                }
372            }
373            ClassifMethod::Qda {
374                class_means,
375                class_chols,
376                class_log_dets,
377                priors,
378                n_classes,
379            } => {
380                let g = *n_classes;
381                let d = scores.len();
382                if g == 2 {
383                    // Return P(Y=1) via softmax of discriminant scores
384                    let score0 = priors[0].max(1e-15).ln()
385                        - 0.5
386                            * (class_log_dets[0]
387                                + mahalanobis_sq(scores, &class_means[0], &class_chols[0], d));
388                    let score1 = priors[1].max(1e-15).ln()
389                        - 0.5
390                            * (class_log_dets[1]
391                                + mahalanobis_sq(scores, &class_means[1], &class_chols[1], d));
392                    let max_s = score0.max(score1);
393                    let exp0 = (score0 - max_s).exp();
394                    let exp1 = (score1 - max_s).exp();
395                    exp1 / (exp0 + exp1)
396                } else {
397                    let mut best_class = 0;
398                    let mut best_score = f64::NEG_INFINITY;
399                    for c in 0..g {
400                        let maha = mahalanobis_sq(scores, &class_means[c], &class_chols[c], d);
401                        let s = priors[c].max(1e-15).ln() - 0.5 * (class_log_dets[c] + maha);
402                        if s > best_score {
403                            best_score = s;
404                            best_class = c;
405                        }
406                    }
407                    best_class as f64
408                }
409            }
410            ClassifMethod::Knn {
411                training_scores,
412                training_labels,
413                k,
414                n_classes,
415            } => {
416                let g = *n_classes;
417                let n_train = training_scores.nrows();
418                let d = scores.len();
419                let k_nn = (*k).min(n_train);
420
421                let mut dists: Vec<(f64, usize)> = (0..n_train)
422                    .map(|j| {
423                        let d_sq: f64 = (0..d)
424                            .map(|c| (scores[c] - training_scores[(j, c)]).powi(2))
425                            .sum();
426                        (d_sq, training_labels[j])
427                    })
428                    .collect();
429                dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
430
431                let mut votes = vec![0usize; g];
432                for &(_, label) in dists.iter().take(k_nn) {
433                    if label < g {
434                        votes[label] += 1;
435                    }
436                }
437
438                if g == 2 {
439                    // Return proportion voting for class 1 as probability
440                    votes[1] as f64 / k_nn as f64
441                } else {
442                    // Return majority vote class as f64
443                    votes
444                        .iter()
445                        .enumerate()
446                        .max_by_key(|&(_, &v)| v)
447                        .map_or(0.0, |(c, _)| c as f64)
448                }
449            }
450        }
451    }
452}
453
454// ---------------------------------------------------------------------------
455// Class probability vectors (for conformal prediction sets)
456// ---------------------------------------------------------------------------
457
458/// Compute full class probability vectors for each observation.
459///
460/// Returns `n × g` probability vectors suitable for conformal classification.
461/// For each observation, the probabilities sum to 1.
462pub(crate) fn classif_predict_probs(fit: &ClassifFit, scores: &FdMatrix) -> Vec<Vec<f64>> {
463    let n = scores.nrows();
464    let d = scores.ncols();
465    match &fit.method {
466        ClassifMethod::Lda {
467            class_means,
468            cov_chol,
469            priors,
470            n_classes,
471        } => {
472            let g = *n_classes;
473            (0..n)
474                .map(|i| {
475                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
476                    let disc: Vec<f64> = (0..g)
477                        .map(|c| {
478                            priors[c].max(1e-15).ln()
479                                - 0.5 * mahalanobis_sq(&x, &class_means[c], cov_chol, d)
480                        })
481                        .collect();
482                    softmax(&disc)
483                })
484                .collect()
485        }
486        ClassifMethod::Qda {
487            class_means,
488            class_chols,
489            class_log_dets,
490            priors,
491            n_classes,
492        } => {
493            let g = *n_classes;
494            (0..n)
495                .map(|i| {
496                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
497                    let disc: Vec<f64> = (0..g)
498                        .map(|c| {
499                            priors[c].max(1e-15).ln()
500                                - 0.5
501                                    * (class_log_dets[c]
502                                        + mahalanobis_sq(&x, &class_means[c], &class_chols[c], d))
503                        })
504                        .collect();
505                    softmax(&disc)
506                })
507                .collect()
508        }
509        ClassifMethod::Knn {
510            training_scores,
511            training_labels,
512            k,
513            n_classes,
514        } => {
515            let g = *n_classes;
516            let n_train = training_scores.nrows();
517            let k_nn = (*k).min(n_train);
518            (0..n)
519                .map(|i| {
520                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
521                    let mut dists: Vec<(f64, usize)> = (0..n_train)
522                        .map(|j| {
523                            let d_sq: f64 = (0..d)
524                                .map(|c| (x[c] - training_scores[(j, c)]).powi(2))
525                                .sum();
526                            (d_sq, training_labels[j])
527                        })
528                        .collect();
529                    dists
530                        .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
531                    let mut votes = vec![0usize; g];
532                    for &(_, label) in dists.iter().take(k_nn) {
533                        if label < g {
534                            votes[label] += 1;
535                        }
536                    }
537                    votes.iter().map(|&v| v as f64 / k_nn as f64).collect()
538                })
539                .collect()
540        }
541    }
542}
543
544/// Softmax of a vector of log-scores → probabilities.
545fn softmax(scores: &[f64]) -> Vec<f64> {
546    let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
547    let exps: Vec<f64> = scores.iter().map(|&s| (s - max_s).exp()).collect();
548    let sum: f64 = exps.iter().sum();
549    exps.iter().map(|&e| e / sum).collect()
550}
551
552// ---------------------------------------------------------------------------
553// ─── Config-based API ───────────────────────────────────────────────────────
554
555/// Configuration for [`fclassif_cv`].
556#[derive(Debug, Clone, PartialEq)]
557pub struct ClassifCvConfig {
558    /// Classification method name (one of "lda", "qda", "knn", "kernel", "dd").
559    pub method: String,
560    /// Number of FPC components.
561    pub ncomp: usize,
562    /// Number of cross-validation folds.
563    pub nfold: usize,
564    /// Random seed for fold assignment.
565    pub seed: u64,
566}
567
568impl Default for ClassifCvConfig {
569    fn default() -> Self {
570        Self {
571            method: "lda".to_string(),
572            ncomp: 3,
573            nfold: 5,
574            seed: 42,
575        }
576    }
577}
578
579/// Cross-validated classification using a configuration struct.
580///
581/// Equivalent to [`fclassif_cv`] but bundles method parameters in [`ClassifCvConfig`].
582///
583/// # Errors
584///
585/// Returns [`FdarError::InvalidParameter`] if `config.nfold < 2` or `config.nfold > n`.
586/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
587#[must_use = "expensive computation whose result should not be discarded"]
588pub fn fclassif_cv_with_config(
589    data: &FdMatrix,
590    argvals: &[f64],
591    y: &[usize],
592    scalar_covariates: Option<&FdMatrix>,
593    config: &ClassifCvConfig,
594) -> Result<ClassifCvResult, FdarError> {
595    fclassif_cv(
596        data,
597        argvals,
598        y,
599        scalar_covariates,
600        &config.method,
601        config.ncomp,
602        config.nfold,
603        config.seed,
604    )
605}