Skip to main content

fdars_core/classification/
fit.rs

1//! ClassifFit: fitted classification model with explainability support.
2
3use crate::error::FdarError;
4use crate::explain_generic::{FpcPredictor, TaskType};
5use crate::matrix::FdMatrix;
6
7use super::knn::knn_predict_loo;
8use super::lda::{lda_params, lda_predict};
9use super::qda::{build_qda_params, qda_predict};
10use super::{
11    build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifCvResult,
12    ClassifResult,
13};
14use crate::linalg::{cholesky_d, mahalanobis_sq};
15
16use super::cv::fclassif_cv;
17
18/// Classification method with stored parameters for prediction.
19#[derive(Debug, Clone, PartialEq)]
20#[non_exhaustive]
21pub enum ClassifMethod {
22    /// Linear Discriminant Analysis.
23    Lda {
24        class_means: Vec<Vec<f64>>,
25        cov_chol: Vec<f64>,
26        priors: Vec<f64>,
27        n_classes: usize,
28    },
29    /// Quadratic Discriminant Analysis.
30    Qda {
31        class_means: Vec<Vec<f64>>,
32        class_chols: Vec<Vec<f64>>,
33        class_log_dets: Vec<f64>,
34        priors: Vec<f64>,
35        n_classes: usize,
36    },
37    /// k-Nearest Neighbors.
38    Knn {
39        training_scores: FdMatrix,
40        training_labels: Vec<usize>,
41        k: usize,
42        n_classes: usize,
43    },
44}
45
46/// A fitted classification model that retains FPCA components for explainability.
47#[derive(Debug, Clone, PartialEq)]
48#[non_exhaustive]
49pub struct ClassifFit {
50    /// Classification result (predicted labels, accuracy, confusion matrix).
51    pub result: ClassifResult,
52    /// FPCA mean function (length m).
53    pub fpca_mean: Vec<f64>,
54    /// FPCA rotation matrix (m × ncomp).
55    pub fpca_rotation: FdMatrix,
56    /// FPCA scores (n × ncomp).
57    pub fpca_scores: FdMatrix,
58    /// Number of FPC components used.
59    pub ncomp: usize,
60    /// The classification method with stored parameters.
61    pub method: ClassifMethod,
62}
63
64/// FPC + LDA classification, retaining FPCA and LDA parameters for explainability.
65///
66/// # Errors
67///
68/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
69/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
70/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
71/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
72/// Returns [`FdarError::ComputationFailed`] if the pooled covariance Cholesky factorization fails.
73#[must_use = "expensive computation whose result should not be discarded"]
74pub fn fclassif_lda_fit(
75    data: &FdMatrix,
76    y: &[usize],
77    scalar_covariates: Option<&FdMatrix>,
78    ncomp: usize,
79) -> Result<ClassifFit, FdarError> {
80    let n = data.nrows();
81    if n == 0 || y.len() != n {
82        return Err(FdarError::InvalidDimension {
83            parameter: "data/y",
84            expected: "n > 0 and y.len() == n".to_string(),
85            actual: format!("n={}, y.len()={}", n, y.len()),
86        });
87    }
88    if ncomp == 0 {
89        return Err(FdarError::InvalidParameter {
90            parameter: "ncomp",
91            message: "must be > 0".to_string(),
92        });
93    }
94
95    let (labels, g) = remap_labels(y);
96    if g < 2 {
97        return Err(FdarError::InvalidParameter {
98            parameter: "y",
99            message: format!("need at least 2 classes, got {g}"),
100        });
101    }
102
103    // _fit variants use FPCA-only features (no scalar_covariates) so that stored
104    // dimensions are consistent with FpcPredictor::project() / predict_from_scores().
105    let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
106    let _ = scalar_covariates; // acknowledged but not used — see docstring
107    let d = features.ncols();
108    let (class_means, cov, priors) = lda_params(&features, &labels, g);
109    let chol = cholesky_d(&cov, d)?;
110
111    let predicted = lda_predict(&features, &class_means, &chol, &priors, g);
112    let accuracy = compute_accuracy(&labels, &predicted);
113    let confusion = confusion_matrix(&labels, &predicted, g);
114
115    Ok(ClassifFit {
116        result: ClassifResult {
117            predicted,
118            probabilities: None,
119            accuracy,
120            confusion,
121            n_classes: g,
122            ncomp: d,
123        },
124        fpca_mean: mean.clone(),
125        fpca_rotation: rotation,
126        fpca_scores: features,
127        ncomp: d,
128        method: ClassifMethod::Lda {
129            class_means,
130            cov_chol: chol,
131            priors,
132            n_classes: g,
133        },
134    })
135}
136
137/// FPC + QDA classification, retaining FPCA and QDA parameters for explainability.
138///
139/// # Errors
140///
141/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
142/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
143/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
144/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
145/// Returns [`FdarError::ComputationFailed`] if a per-class covariance Cholesky factorization fails.
146#[must_use = "expensive computation whose result should not be discarded"]
147pub fn fclassif_qda_fit(
148    data: &FdMatrix,
149    y: &[usize],
150    scalar_covariates: Option<&FdMatrix>,
151    ncomp: usize,
152) -> Result<ClassifFit, FdarError> {
153    let n = data.nrows();
154    if n == 0 || y.len() != n {
155        return Err(FdarError::InvalidDimension {
156            parameter: "data/y",
157            expected: "n > 0 and y.len() == n".to_string(),
158            actual: format!("n={}, y.len()={}", n, y.len()),
159        });
160    }
161    if ncomp == 0 {
162        return Err(FdarError::InvalidParameter {
163            parameter: "ncomp",
164            message: "must be > 0".to_string(),
165        });
166    }
167
168    let (labels, g) = remap_labels(y);
169    if g < 2 {
170        return Err(FdarError::InvalidParameter {
171            parameter: "y",
172            message: format!("need at least 2 classes, got {g}"),
173        });
174    }
175
176    // _fit variants use FPCA-only features — see fclassif_lda_fit comment.
177    let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
178    let _ = scalar_covariates;
179    let (class_means, class_chols, class_log_dets, priors) =
180        build_qda_params(&features, &labels, g)?;
181
182    let predicted = qda_predict(
183        &features,
184        &class_means,
185        &class_chols,
186        &class_log_dets,
187        &priors,
188        g,
189    );
190    let accuracy = compute_accuracy(&labels, &predicted);
191    let confusion = confusion_matrix(&labels, &predicted, g);
192    let d = features.ncols();
193
194    Ok(ClassifFit {
195        result: ClassifResult {
196            predicted,
197            probabilities: None,
198            accuracy,
199            confusion,
200            n_classes: g,
201            ncomp: d,
202        },
203        fpca_mean: mean.clone(),
204        fpca_rotation: rotation,
205        fpca_scores: features,
206        ncomp: d,
207        method: ClassifMethod::Qda {
208            class_means,
209            class_chols,
210            class_log_dets,
211            priors,
212            n_classes: g,
213        },
214    })
215}
216
217/// FPC + k-NN classification, retaining FPCA and training data for explainability.
218///
219/// # Errors
220///
221/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
222/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
223/// Returns [`FdarError::InvalidParameter`] if `k_nn` is zero.
224/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
225/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
226#[must_use = "expensive computation whose result should not be discarded"]
227pub fn fclassif_knn_fit(
228    data: &FdMatrix,
229    y: &[usize],
230    scalar_covariates: Option<&FdMatrix>,
231    ncomp: usize,
232    k_nn: usize,
233) -> Result<ClassifFit, FdarError> {
234    let n = data.nrows();
235    if n == 0 || y.len() != n {
236        return Err(FdarError::InvalidDimension {
237            parameter: "data/y",
238            expected: "n > 0 and y.len() == n".to_string(),
239            actual: format!("n={}, y.len()={}", n, y.len()),
240        });
241    }
242    if ncomp == 0 {
243        return Err(FdarError::InvalidParameter {
244            parameter: "ncomp",
245            message: "must be > 0".to_string(),
246        });
247    }
248    if k_nn == 0 {
249        return Err(FdarError::InvalidParameter {
250            parameter: "k_nn",
251            message: "must be > 0".to_string(),
252        });
253    }
254
255    let (labels, g) = remap_labels(y);
256    if g < 2 {
257        return Err(FdarError::InvalidParameter {
258            parameter: "y",
259            message: format!("need at least 2 classes, got {g}"),
260        });
261    }
262
263    // _fit variants use FPCA-only features — see fclassif_lda_fit comment.
264    let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
265    let _ = scalar_covariates;
266    let d = features.ncols();
267
268    let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
269    let accuracy = compute_accuracy(&labels, &predicted);
270    let confusion = confusion_matrix(&labels, &predicted, g);
271
272    Ok(ClassifFit {
273        result: ClassifResult {
274            predicted,
275            probabilities: None,
276            accuracy,
277            confusion,
278            n_classes: g,
279            ncomp: d,
280        },
281        fpca_mean: mean.clone(),
282        fpca_rotation: rotation,
283        fpca_scores: features.clone(),
284        ncomp: d,
285        method: ClassifMethod::Knn {
286            training_scores: features,
287            training_labels: labels,
288            k: k_nn,
289            n_classes: g,
290        },
291    })
292}
293
294// ---------------------------------------------------------------------------
295// FpcPredictor impl for ClassifFit
296// ---------------------------------------------------------------------------
297
298impl FpcPredictor for ClassifFit {
299    fn fpca_mean(&self) -> &[f64] {
300        &self.fpca_mean
301    }
302
303    fn fpca_rotation(&self) -> &FdMatrix {
304        &self.fpca_rotation
305    }
306
307    fn ncomp(&self) -> usize {
308        self.ncomp
309    }
310
311    fn training_scores(&self) -> &FdMatrix {
312        &self.fpca_scores
313    }
314
315    fn task_type(&self) -> TaskType {
316        match &self.method {
317            ClassifMethod::Lda { n_classes, .. }
318            | ClassifMethod::Qda { n_classes, .. }
319            | ClassifMethod::Knn { n_classes, .. } => {
320                if *n_classes == 2 {
321                    TaskType::BinaryClassification
322                } else {
323                    TaskType::MulticlassClassification(*n_classes)
324                }
325            }
326        }
327    }
328
329    fn predict_from_scores(&self, scores: &[f64], _scalar_covariates: Option<&[f64]>) -> f64 {
330        match &self.method {
331            ClassifMethod::Lda {
332                class_means,
333                cov_chol,
334                priors,
335                n_classes,
336            } => {
337                let g = *n_classes;
338                let d = scores.len();
339                if g == 2 {
340                    // Return P(Y=1) via softmax of discriminant scores
341                    let score0 = priors[0].max(1e-15).ln()
342                        - 0.5 * mahalanobis_sq(scores, &class_means[0], cov_chol, d);
343                    let score1 = priors[1].max(1e-15).ln()
344                        - 0.5 * mahalanobis_sq(scores, &class_means[1], cov_chol, d);
345                    let max_s = score0.max(score1);
346                    let exp0 = (score0 - max_s).exp();
347                    let exp1 = (score1 - max_s).exp();
348                    exp1 / (exp0 + exp1)
349                } else {
350                    // Return predicted class as f64
351                    let mut best_class = 0;
352                    let mut best_score = f64::NEG_INFINITY;
353                    for c in 0..g {
354                        let maha = mahalanobis_sq(scores, &class_means[c], cov_chol, d);
355                        let s = priors[c].max(1e-15).ln() - 0.5 * maha;
356                        if s > best_score {
357                            best_score = s;
358                            best_class = c;
359                        }
360                    }
361                    best_class as f64
362                }
363            }
364            ClassifMethod::Qda {
365                class_means,
366                class_chols,
367                class_log_dets,
368                priors,
369                n_classes,
370            } => {
371                let g = *n_classes;
372                let d = scores.len();
373                if g == 2 {
374                    // Return P(Y=1) via softmax of discriminant scores
375                    let score0 = priors[0].max(1e-15).ln()
376                        - 0.5
377                            * (class_log_dets[0]
378                                + mahalanobis_sq(scores, &class_means[0], &class_chols[0], d));
379                    let score1 = priors[1].max(1e-15).ln()
380                        - 0.5
381                            * (class_log_dets[1]
382                                + mahalanobis_sq(scores, &class_means[1], &class_chols[1], d));
383                    let max_s = score0.max(score1);
384                    let exp0 = (score0 - max_s).exp();
385                    let exp1 = (score1 - max_s).exp();
386                    exp1 / (exp0 + exp1)
387                } else {
388                    let mut best_class = 0;
389                    let mut best_score = f64::NEG_INFINITY;
390                    for c in 0..g {
391                        let maha = mahalanobis_sq(scores, &class_means[c], &class_chols[c], d);
392                        let s = priors[c].max(1e-15).ln() - 0.5 * (class_log_dets[c] + maha);
393                        if s > best_score {
394                            best_score = s;
395                            best_class = c;
396                        }
397                    }
398                    best_class as f64
399                }
400            }
401            ClassifMethod::Knn {
402                training_scores,
403                training_labels,
404                k,
405                n_classes,
406            } => {
407                let g = *n_classes;
408                let n_train = training_scores.nrows();
409                let d = scores.len();
410                let k_nn = (*k).min(n_train);
411
412                let mut dists: Vec<(f64, usize)> = (0..n_train)
413                    .map(|j| {
414                        let d_sq: f64 = (0..d)
415                            .map(|c| (scores[c] - training_scores[(j, c)]).powi(2))
416                            .sum();
417                        (d_sq, training_labels[j])
418                    })
419                    .collect();
420                dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
421
422                let mut votes = vec![0usize; g];
423                for &(_, label) in dists.iter().take(k_nn) {
424                    if label < g {
425                        votes[label] += 1;
426                    }
427                }
428
429                if g == 2 {
430                    // Return proportion voting for class 1 as probability
431                    votes[1] as f64 / k_nn as f64
432                } else {
433                    // Return majority vote class as f64
434                    votes
435                        .iter()
436                        .enumerate()
437                        .max_by_key(|&(_, &v)| v)
438                        .map_or(0.0, |(c, _)| c as f64)
439                }
440            }
441        }
442    }
443}
444
445// ---------------------------------------------------------------------------
446// Class probability vectors (for conformal prediction sets)
447// ---------------------------------------------------------------------------
448
449/// Compute full class probability vectors for each observation.
450///
451/// Returns `n × g` probability vectors suitable for conformal classification.
452/// For each observation, the probabilities sum to 1.
453pub(crate) fn classif_predict_probs(fit: &ClassifFit, scores: &FdMatrix) -> Vec<Vec<f64>> {
454    let n = scores.nrows();
455    let d = scores.ncols();
456    match &fit.method {
457        ClassifMethod::Lda {
458            class_means,
459            cov_chol,
460            priors,
461            n_classes,
462        } => {
463            let g = *n_classes;
464            (0..n)
465                .map(|i| {
466                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
467                    let disc: Vec<f64> = (0..g)
468                        .map(|c| {
469                            priors[c].max(1e-15).ln()
470                                - 0.5 * mahalanobis_sq(&x, &class_means[c], cov_chol, d)
471                        })
472                        .collect();
473                    softmax(&disc)
474                })
475                .collect()
476        }
477        ClassifMethod::Qda {
478            class_means,
479            class_chols,
480            class_log_dets,
481            priors,
482            n_classes,
483        } => {
484            let g = *n_classes;
485            (0..n)
486                .map(|i| {
487                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
488                    let disc: Vec<f64> = (0..g)
489                        .map(|c| {
490                            priors[c].max(1e-15).ln()
491                                - 0.5
492                                    * (class_log_dets[c]
493                                        + mahalanobis_sq(&x, &class_means[c], &class_chols[c], d))
494                        })
495                        .collect();
496                    softmax(&disc)
497                })
498                .collect()
499        }
500        ClassifMethod::Knn {
501            training_scores,
502            training_labels,
503            k,
504            n_classes,
505        } => {
506            let g = *n_classes;
507            let n_train = training_scores.nrows();
508            let k_nn = (*k).min(n_train);
509            (0..n)
510                .map(|i| {
511                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
512                    let mut dists: Vec<(f64, usize)> = (0..n_train)
513                        .map(|j| {
514                            let d_sq: f64 = (0..d)
515                                .map(|c| (x[c] - training_scores[(j, c)]).powi(2))
516                                .sum();
517                            (d_sq, training_labels[j])
518                        })
519                        .collect();
520                    dists
521                        .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
522                    let mut votes = vec![0usize; g];
523                    for &(_, label) in dists.iter().take(k_nn) {
524                        if label < g {
525                            votes[label] += 1;
526                        }
527                    }
528                    votes.iter().map(|&v| v as f64 / k_nn as f64).collect()
529                })
530                .collect()
531        }
532    }
533}
534
535/// Softmax of a vector of log-scores → probabilities.
536fn softmax(scores: &[f64]) -> Vec<f64> {
537    let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
538    let exps: Vec<f64> = scores.iter().map(|&s| (s - max_s).exp()).collect();
539    let sum: f64 = exps.iter().sum();
540    exps.iter().map(|&e| e / sum).collect()
541}
542
543// ---------------------------------------------------------------------------
544// ─── Config-based API ───────────────────────────────────────────────────────
545
546/// Configuration for [`fclassif_cv`].
547#[derive(Debug, Clone, PartialEq)]
548pub struct ClassifCvConfig {
549    /// Classification method name (one of "lda", "qda", "knn", "kernel", "dd").
550    pub method: String,
551    /// Number of FPC components.
552    pub ncomp: usize,
553    /// Number of cross-validation folds.
554    pub nfold: usize,
555    /// Random seed for fold assignment.
556    pub seed: u64,
557}
558
559impl Default for ClassifCvConfig {
560    fn default() -> Self {
561        Self {
562            method: "lda".to_string(),
563            ncomp: 3,
564            nfold: 5,
565            seed: 42,
566        }
567    }
568}
569
570/// Cross-validated classification using a configuration struct.
571///
572/// Equivalent to [`fclassif_cv`] but bundles method parameters in [`ClassifCvConfig`].
573///
574/// # Errors
575///
576/// Returns [`FdarError::InvalidParameter`] if `config.nfold < 2` or `config.nfold > n`.
577/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
578#[must_use = "expensive computation whose result should not be discarded"]
579pub fn fclassif_cv_with_config(
580    data: &FdMatrix,
581    argvals: &[f64],
582    y: &[usize],
583    scalar_covariates: Option<&FdMatrix>,
584    config: &ClassifCvConfig,
585) -> Result<ClassifCvResult, FdarError> {
586    fclassif_cv(
587        data,
588        argvals,
589        y,
590        scalar_covariates,
591        &config.method,
592        config.ncomp,
593        config.nfold,
594        config.seed,
595    )
596}