Skip to main content

fdars_core/
classification.rs

1//! Functional classification with mixed scalar/functional predictors.
2//!
3//! Implements supervised classification for functional data using:
4//! - [`fclassif_lda`] / [`fclassif_qda`] — FPC + LDA/QDA pipeline
5//! - [`fclassif_knn`] — FPC + k-NN classifier
6//! - [`fclassif_kernel`] — Nonparametric kernel classifier with mixed predictors
7//! - [`fclassif_dd`] — Depth-based DD-classifier
8//! - [`fclassif_cv`] — Cross-validated error rate
9
10use crate::depth::fraiman_muniz_1d;
11use crate::helpers::{l2_distance, simpsons_weights};
12use crate::matrix::FdMatrix;
13use crate::regression::fdata_to_pc_1d;
14
15/// Classification result.
16pub struct ClassifResult {
17    /// Predicted class labels (length n)
18    pub predicted: Vec<usize>,
19    /// Posterior/membership probabilities (n x G) — if available
20    pub probabilities: Option<FdMatrix>,
21    /// Training accuracy
22    pub accuracy: f64,
23    /// Confusion matrix (G x G): row = true, col = predicted
24    pub confusion: Vec<Vec<usize>>,
25    /// Number of classes
26    pub n_classes: usize,
27    /// Number of FPC components used
28    pub ncomp: usize,
29}
30
31/// Cross-validation result.
32pub struct ClassifCvResult {
33    /// Mean error rate across folds
34    pub error_rate: f64,
35    /// Per-fold error rates
36    pub fold_errors: Vec<f64>,
37    /// Best ncomp (if tuned)
38    pub best_ncomp: usize,
39}
40
41// ---------------------------------------------------------------------------
42// Utility helpers
43// ---------------------------------------------------------------------------
44
45/// Count distinct classes and remap labels to 0..G-1.
46pub(crate) fn remap_labels(y: &[usize]) -> (Vec<usize>, usize) {
47    let mut labels: Vec<usize> = y.to_vec();
48    let mut unique: Vec<usize> = y.to_vec();
49    unique.sort_unstable();
50    unique.dedup();
51    let g = unique.len();
52    for label in &mut labels {
53        *label = unique.iter().position(|&u| u == *label).unwrap_or(0);
54    }
55    (labels, g)
56}
57
58/// Build confusion matrix (G x G).
59fn confusion_matrix(true_labels: &[usize], pred_labels: &[usize], g: usize) -> Vec<Vec<usize>> {
60    let mut cm = vec![vec![0usize; g]; g];
61    for (&t, &p) in true_labels.iter().zip(pred_labels.iter()) {
62        if t < g && p < g {
63            cm[t][p] += 1;
64        }
65    }
66    cm
67}
68
69/// Accuracy from labels.
70fn compute_accuracy(true_labels: &[usize], pred_labels: &[usize]) -> f64 {
71    let n = true_labels.len();
72    if n == 0 {
73        return 0.0;
74    }
75    let correct = true_labels
76        .iter()
77        .zip(pred_labels.iter())
78        .filter(|(&t, &p)| t == p)
79        .count();
80    correct as f64 / n as f64
81}
82
83/// Extract FPC scores and append optional scalar covariates.
84pub(crate) fn build_feature_matrix(
85    data: &FdMatrix,
86    covariates: Option<&FdMatrix>,
87    ncomp: usize,
88) -> Option<(FdMatrix, Vec<f64>, FdMatrix)> {
89    let fpca = fdata_to_pc_1d(data, ncomp)?;
90    let n = data.nrows();
91    let d_pc = fpca.scores.ncols();
92    let d_cov = covariates.map_or(0, |c| c.ncols());
93    let d = d_pc + d_cov;
94
95    let mut features = FdMatrix::zeros(n, d);
96    for i in 0..n {
97        for j in 0..d_pc {
98            features[(i, j)] = fpca.scores[(i, j)];
99        }
100        if let Some(cov) = covariates {
101            for j in 0..d_cov {
102                features[(i, d_pc + j)] = cov[(i, j)];
103            }
104        }
105    }
106
107    Some((features, fpca.mean, fpca.rotation))
108}
109
110// ---------------------------------------------------------------------------
111// LDA: Linear Discriminant Analysis
112// ---------------------------------------------------------------------------
113
114/// Compute per-class means, counts, and priors from labeled features.
115fn class_means_and_priors(
116    features: &FdMatrix,
117    labels: &[usize],
118    g: usize,
119) -> (Vec<Vec<f64>>, Vec<usize>, Vec<f64>) {
120    let n = features.nrows();
121    let d = features.ncols();
122    let mut counts = vec![0usize; g];
123    let mut class_means = vec![vec![0.0; d]; g];
124    for i in 0..n {
125        let c = labels[i];
126        counts[c] += 1;
127        for j in 0..d {
128            class_means[c][j] += features[(i, j)];
129        }
130    }
131    for c in 0..g {
132        if counts[c] > 0 {
133            for j in 0..d {
134                class_means[c][j] /= counts[c] as f64;
135            }
136        }
137    }
138    let priors: Vec<f64> = counts.iter().map(|&c| c as f64 / n as f64).collect();
139    (class_means, counts, priors)
140}
141
142/// Compute pooled within-class covariance (symmetric, regularized).
143fn pooled_within_cov(
144    features: &FdMatrix,
145    labels: &[usize],
146    class_means: &[Vec<f64>],
147    g: usize,
148) -> Vec<f64> {
149    let n = features.nrows();
150    let d = features.ncols();
151    let mut cov = vec![0.0; d * d];
152    for i in 0..n {
153        let c = labels[i];
154        for r in 0..d {
155            let dr = features[(i, r)] - class_means[c][r];
156            for s in r..d {
157                let val = dr * (features[(i, s)] - class_means[c][s]);
158                cov[r * d + s] += val;
159                if r != s {
160                    cov[s * d + r] += val;
161                }
162            }
163        }
164    }
165    let scale = (n - g).max(1) as f64;
166    for v in cov.iter_mut() {
167        *v /= scale;
168    }
169    for j in 0..d {
170        cov[j * d + j] += 1e-6;
171    }
172    cov
173}
174
175/// Compute per-class means and pooled within-class covariance.
176pub(crate) fn lda_params(
177    features: &FdMatrix,
178    labels: &[usize],
179    g: usize,
180) -> (Vec<Vec<f64>>, Vec<f64>, Vec<f64>) {
181    let (class_means, _counts, priors) = class_means_and_priors(features, labels, g);
182    let cov = pooled_within_cov(features, labels, &class_means, g);
183    (class_means, cov, priors)
184}
185
186/// Cholesky factorization of d×d row-major matrix.
187pub(crate) fn cholesky_d(mat: &[f64], d: usize) -> Option<Vec<f64>> {
188    let mut l = vec![0.0; d * d];
189    for j in 0..d {
190        let mut sum = 0.0;
191        for k in 0..j {
192            sum += l[j * d + k] * l[j * d + k];
193        }
194        let diag = mat[j * d + j] - sum;
195        if diag <= 0.0 {
196            return None;
197        }
198        l[j * d + j] = diag.sqrt();
199        for i in (j + 1)..d {
200            let mut s = 0.0;
201            for k in 0..j {
202                s += l[i * d + k] * l[j * d + k];
203            }
204            l[i * d + j] = (mat[i * d + j] - s) / l[j * d + j];
205        }
206    }
207    Some(l)
208}
209
210/// Forward solve L * x = b.
211pub(crate) fn forward_solve(l: &[f64], b: &[f64], d: usize) -> Vec<f64> {
212    let mut x = vec![0.0; d];
213    for i in 0..d {
214        let mut s = 0.0;
215        for j in 0..i {
216            s += l[i * d + j] * x[j];
217        }
218        x[i] = (b[i] - s) / l[i * d + i];
219    }
220    x
221}
222
223/// Mahalanobis distance squared: (x-mu)^T Sigma^{-1} (x-mu) via Cholesky.
224pub(crate) fn mahalanobis_sq(x: &[f64], mu: &[f64], chol: &[f64], d: usize) -> f64 {
225    let diff: Vec<f64> = x.iter().zip(mu.iter()).map(|(&a, &b)| a - b).collect();
226    let y = forward_solve(chol, &diff, d);
227    y.iter().map(|&v| v * v).sum()
228}
229
230/// LDA prediction: assign to class minimizing Mahalanobis distance (with prior).
231pub(crate) fn lda_predict(
232    features: &FdMatrix,
233    class_means: &[Vec<f64>],
234    cov_chol: &[f64],
235    priors: &[f64],
236    g: usize,
237) -> Vec<usize> {
238    let n = features.nrows();
239    let d = features.ncols();
240
241    (0..n)
242        .map(|i| {
243            let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
244            let mut best_class = 0;
245            let mut best_score = f64::NEG_INFINITY;
246            for c in 0..g {
247                let maha = mahalanobis_sq(&xi, &class_means[c], cov_chol, d);
248                let score = priors[c].max(1e-15).ln() - 0.5 * maha;
249                if score > best_score {
250                    best_score = score;
251                    best_class = c;
252                }
253            }
254            best_class
255        })
256        .collect()
257}
258
259/// FPC + LDA classification.
260///
261/// # Arguments
262/// * `data` — Functional data (n × m)
263/// * `y` — Class labels (length n)
264/// * `covariates` — Optional scalar covariates (n × p)
265/// * `ncomp` — Number of FPC components
266pub fn fclassif_lda(
267    data: &FdMatrix,
268    y: &[usize],
269    covariates: Option<&FdMatrix>,
270    ncomp: usize,
271) -> Option<ClassifResult> {
272    let n = data.nrows();
273    if n == 0 || y.len() != n || ncomp == 0 {
274        return None;
275    }
276
277    let (labels, g) = remap_labels(y);
278    if g < 2 {
279        return None;
280    }
281
282    let (features, _mean, _rotation) = build_feature_matrix(data, covariates, ncomp)?;
283    let d = features.ncols();
284    let (class_means, cov, priors) = lda_params(&features, &labels, g);
285    let chol = cholesky_d(&cov, d)?;
286
287    let predicted = lda_predict(&features, &class_means, &chol, &priors, g);
288    let accuracy = compute_accuracy(&labels, &predicted);
289    let confusion = confusion_matrix(&labels, &predicted, g);
290
291    Some(ClassifResult {
292        predicted,
293        probabilities: None,
294        accuracy,
295        confusion,
296        n_classes: g,
297        ncomp: features.ncols().min(ncomp),
298    })
299}
300
301// ---------------------------------------------------------------------------
302// QDA: Quadratic Discriminant Analysis
303// ---------------------------------------------------------------------------
304
305/// Accumulate symmetric covariance from feature rows.
306fn accumulate_class_cov(
307    features: &FdMatrix,
308    members: &[usize],
309    mean: &[f64],
310    d: usize,
311) -> Vec<f64> {
312    let mut cov = vec![0.0; d * d];
313    for &i in members {
314        for r in 0..d {
315            let dr = features[(i, r)] - mean[r];
316            for s in r..d {
317                let val = dr * (features[(i, s)] - mean[s]);
318                cov[r * d + s] += val;
319                if r != s {
320                    cov[s * d + r] += val;
321                }
322            }
323        }
324    }
325    cov
326}
327
328/// Per-class covariance matrices.
329fn qda_class_covariances(
330    features: &FdMatrix,
331    labels: &[usize],
332    class_means: &[Vec<f64>],
333    g: usize,
334) -> Vec<Vec<f64>> {
335    let n = features.nrows();
336    let d = features.ncols();
337
338    (0..g)
339        .map(|c| {
340            let members: Vec<usize> = (0..n).filter(|&i| labels[i] == c).collect();
341            let nc = members.len();
342            let divisor = (nc.saturating_sub(1)).max(1) as f64;
343            let mut cov = accumulate_class_cov(features, &members, &class_means[c], d);
344            for v in cov.iter_mut() {
345                *v /= divisor;
346            }
347            for j in 0..d {
348                cov[j * d + j] += 1e-6;
349            }
350            cov
351        })
352        .collect()
353}
354
355/// Compute QDA parameters: class means, Cholesky factors, log-dets, priors.
356pub(crate) fn build_qda_params(
357    features: &FdMatrix,
358    labels: &[usize],
359    g: usize,
360) -> Option<(Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<f64>, Vec<f64>)> {
361    let d = features.ncols();
362    let (class_means, _counts, priors) = class_means_and_priors(features, labels, g);
363    let class_covs = qda_class_covariances(features, labels, &class_means, g);
364    let mut class_chols = Vec::with_capacity(g);
365    let mut class_log_dets = Vec::with_capacity(g);
366    for cov in &class_covs {
367        let chol = cholesky_d(cov, d)?;
368        class_log_dets.push(log_det_cholesky(&chol, d));
369        class_chols.push(chol);
370    }
371    Some((class_means, class_chols, class_log_dets, priors))
372}
373
374/// Log-determinant from Cholesky factor.
375pub(crate) fn log_det_cholesky(l: &[f64], d: usize) -> f64 {
376    let mut s = 0.0;
377    for i in 0..d {
378        s += l[i * d + i].ln();
379    }
380    2.0 * s
381}
382
383/// QDA prediction: per-class covariance.
384pub(crate) fn qda_predict(
385    features: &FdMatrix,
386    class_means: &[Vec<f64>],
387    class_chols: &[Vec<f64>],
388    class_log_dets: &[f64],
389    priors: &[f64],
390    g: usize,
391) -> Vec<usize> {
392    let n = features.nrows();
393    let d = features.ncols();
394
395    (0..n)
396        .map(|i| {
397            let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
398            let mut best_class = 0;
399            let mut best_score = f64::NEG_INFINITY;
400            for c in 0..g {
401                let maha = mahalanobis_sq(&xi, &class_means[c], &class_chols[c], d);
402                let score = priors[c].max(1e-15).ln() - 0.5 * (class_log_dets[c] + maha);
403                if score > best_score {
404                    best_score = score;
405                    best_class = c;
406                }
407            }
408            best_class
409        })
410        .collect()
411}
412
413/// FPC + QDA classification.
414pub fn fclassif_qda(
415    data: &FdMatrix,
416    y: &[usize],
417    covariates: Option<&FdMatrix>,
418    ncomp: usize,
419) -> Option<ClassifResult> {
420    let n = data.nrows();
421    if n == 0 || y.len() != n || ncomp == 0 {
422        return None;
423    }
424
425    let (labels, g) = remap_labels(y);
426    if g < 2 {
427        return None;
428    }
429
430    let (features, _mean, _rotation) = build_feature_matrix(data, covariates, ncomp)?;
431
432    let (class_means, class_chols, class_log_dets, priors) =
433        build_qda_params(&features, &labels, g)?;
434
435    let predicted = qda_predict(
436        &features,
437        &class_means,
438        &class_chols,
439        &class_log_dets,
440        &priors,
441        g,
442    );
443    let accuracy = compute_accuracy(&labels, &predicted);
444    let confusion = confusion_matrix(&labels, &predicted, g);
445
446    Some(ClassifResult {
447        predicted,
448        probabilities: None,
449        accuracy,
450        confusion,
451        n_classes: g,
452        ncomp: features.ncols().min(ncomp),
453    })
454}
455
456// ---------------------------------------------------------------------------
457// k-NN classifier
458// ---------------------------------------------------------------------------
459
460/// FPC + k-NN classification.
461///
462/// # Arguments
463/// * `data` — Functional data (n × m)
464/// * `y` — Class labels
465/// * `covariates` — Optional scalar covariates
466/// * `ncomp` — Number of FPC components
467/// * `k_nn` — Number of nearest neighbors
468pub fn fclassif_knn(
469    data: &FdMatrix,
470    y: &[usize],
471    covariates: Option<&FdMatrix>,
472    ncomp: usize,
473    k_nn: usize,
474) -> Option<ClassifResult> {
475    let n = data.nrows();
476    if n == 0 || y.len() != n || ncomp == 0 || k_nn == 0 {
477        return None;
478    }
479
480    let (labels, g) = remap_labels(y);
481    if g < 2 {
482        return None;
483    }
484
485    let (features, _mean, _rotation) = build_feature_matrix(data, covariates, ncomp)?;
486    let d = features.ncols();
487
488    let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
489    let accuracy = compute_accuracy(&labels, &predicted);
490    let confusion = confusion_matrix(&labels, &predicted, g);
491
492    Some(ClassifResult {
493        predicted,
494        probabilities: None,
495        accuracy,
496        confusion,
497        n_classes: g,
498        ncomp: d.min(ncomp),
499    })
500}
501
502/// Leave-one-out k-NN prediction.
503pub(crate) fn knn_predict_loo(
504    features: &FdMatrix,
505    labels: &[usize],
506    g: usize,
507    d: usize,
508    k_nn: usize,
509) -> Vec<usize> {
510    let n = features.nrows();
511    let k_nn = k_nn.min(n - 1);
512
513    (0..n)
514        .map(|i| {
515            let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
516            let mut dists: Vec<(f64, usize)> = (0..n)
517                .filter(|&j| j != i)
518                .map(|j| {
519                    let xj: Vec<f64> = (0..d).map(|jj| features[(j, jj)]).collect();
520                    let d_sq: f64 = xi.iter().zip(&xj).map(|(&a, &b)| (a - b).powi(2)).sum();
521                    (d_sq, labels[j])
522                })
523                .collect();
524            dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
525
526            // Majority vote among k nearest
527            let mut votes = vec![0usize; g];
528            for &(_, label) in dists.iter().take(k_nn) {
529                votes[label] += 1;
530            }
531            votes
532                .iter()
533                .enumerate()
534                .max_by_key(|&(_, &v)| v)
535                .map(|(c, _)| c)
536                .unwrap_or(0)
537        })
538        .collect()
539}
540
541// ---------------------------------------------------------------------------
542// Nonparametric kernel classifier with mixed predictors
543// ---------------------------------------------------------------------------
544
545/// Find class with maximum score.
546fn argmax_class(scores: &[f64]) -> usize {
547    scores
548        .iter()
549        .enumerate()
550        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
551        .map(|(c, _)| c)
552        .unwrap_or(0)
553}
554
555/// Compute marginal rank-based scalar depth of observation i w.r.t. class c.
556fn scalar_depth_for_obs(cov: &FdMatrix, i: usize, class_indices: &[usize], p: usize) -> f64 {
557    let nc = class_indices.len() as f64;
558    if nc < 1.0 || p == 0 {
559        return 0.0;
560    }
561    let mut depth = 0.0;
562    for j in 0..p {
563        let val = cov[(i, j)];
564        let rank = class_indices
565            .iter()
566            .filter(|&&k| cov[(k, j)] <= val)
567            .count() as f64;
568        let u = rank / nc.max(1.0);
569        depth += u.min(1.0 - u).min(0.5);
570    }
571    depth / p as f64
572}
573
574/// Generate bandwidth candidates from distance percentiles.
575fn bandwidth_candidates(dists: &[f64], n: usize) -> Vec<f64> {
576    let mut all_dists: Vec<f64> = Vec::new();
577    for i in 0..n {
578        for j in (i + 1)..n {
579            all_dists.push(dists[i * n + j]);
580        }
581    }
582    all_dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
583
584    (1..=20)
585        .map(|p| {
586            let idx = (p as f64 / 20.0 * (all_dists.len() - 1) as f64) as usize;
587            all_dists[idx.min(all_dists.len() - 1)]
588        })
589        .filter(|&h| h > 1e-15)
590        .collect()
591}
592
593/// LOO classification accuracy for a single bandwidth.
594fn loo_accuracy_for_bandwidth(dists: &[f64], labels: &[usize], g: usize, n: usize, h: f64) -> f64 {
595    let mut correct = 0;
596    for i in 0..n {
597        let mut votes = vec![0.0; g];
598        for j in 0..n {
599            if j != i {
600                votes[labels[j]] += gaussian_kernel(dists[i * n + j], h);
601            }
602        }
603        if argmax_class(&votes) == labels[i] {
604            correct += 1;
605        }
606    }
607    correct as f64 / n as f64
608}
609
610/// Gaussian kernel: exp(-d²/(2h²)).
611fn gaussian_kernel(dist: f64, h: f64) -> f64 {
612    if h < 1e-15 {
613        return 0.0;
614    }
615    (-dist * dist / (2.0 * h * h)).exp()
616}
617
618/// Nonparametric kernel classifier for functional data with optional scalar covariates.
619///
620/// Uses product kernel: K_func × K_scalar. Bandwidth selected by LOO-CV.
621///
622/// # Arguments
623/// * `data` — Functional data (n × m)
624/// * `argvals` — Evaluation points
625/// * `y` — Class labels
626/// * `covariates` — Optional scalar covariates (n × p)
627/// * `h_func` — Functional bandwidth (0 = auto via LOO-CV)
628/// * `h_scalar` — Scalar bandwidth (0 = auto)
629pub fn fclassif_kernel(
630    data: &FdMatrix,
631    argvals: &[f64],
632    y: &[usize],
633    covariates: Option<&FdMatrix>,
634    h_func: f64,
635    h_scalar: f64,
636) -> Option<ClassifResult> {
637    let n = data.nrows();
638    let m = data.ncols();
639    if n == 0 || y.len() != n || argvals.len() != m {
640        return None;
641    }
642
643    let (labels, g) = remap_labels(y);
644    if g < 2 {
645        return None;
646    }
647
648    let weights = simpsons_weights(argvals);
649
650    // Compute pairwise functional distances
651    let func_dists = compute_pairwise_l2(data, &weights);
652
653    // Compute pairwise scalar distances if covariates exist
654    let scalar_dists = covariates.map(compute_pairwise_scalar);
655
656    // Select bandwidths via LOO if needed
657    let h_f = if h_func > 0.0 {
658        h_func
659    } else {
660        select_bandwidth_loo(&func_dists, &labels, g, n, true)
661    };
662    let h_s = match &scalar_dists {
663        Some(sd) if h_scalar <= 0.0 => select_bandwidth_loo(sd, &labels, g, n, false),
664        _ => h_scalar,
665    };
666
667    let predicted = kernel_classify_loo(
668        &func_dists,
669        scalar_dists.as_deref(),
670        &labels,
671        g,
672        n,
673        h_f,
674        h_s,
675    );
676    let accuracy = compute_accuracy(&labels, &predicted);
677    let confusion = confusion_matrix(&labels, &predicted, g);
678
679    Some(ClassifResult {
680        predicted,
681        probabilities: None,
682        accuracy,
683        confusion,
684        n_classes: g,
685        ncomp: 0,
686    })
687}
688
689/// Compute pairwise L2 distances between curves.
690fn compute_pairwise_l2(data: &FdMatrix, weights: &[f64]) -> Vec<f64> {
691    let n = data.nrows();
692    let mut dists = vec![0.0; n * n];
693    for i in 0..n {
694        let ri = data.row(i);
695        for j in (i + 1)..n {
696            let rj = data.row(j);
697            let d = l2_distance(&ri, &rj, weights);
698            dists[i * n + j] = d;
699            dists[j * n + i] = d;
700        }
701    }
702    dists
703}
704
705/// Compute pairwise Euclidean distances between scalar covariate vectors.
706fn compute_pairwise_scalar(covariates: &FdMatrix) -> Vec<f64> {
707    let n = covariates.nrows();
708    let p = covariates.ncols();
709    let mut dists = vec![0.0; n * n];
710    for i in 0..n {
711        for j in (i + 1)..n {
712            let mut d_sq = 0.0;
713            for k in 0..p {
714                d_sq += (covariates[(i, k)] - covariates[(j, k)]).powi(2);
715            }
716            let d = d_sq.sqrt();
717            dists[i * n + j] = d;
718            dists[j * n + i] = d;
719        }
720    }
721    dists
722}
723
724/// Select bandwidth by LOO classification accuracy.
725fn select_bandwidth_loo(dists: &[f64], labels: &[usize], g: usize, n: usize, is_func: bool) -> f64 {
726    let candidates = bandwidth_candidates(dists, n);
727    if candidates.is_empty() {
728        return if is_func { 1.0 } else { 0.5 };
729    }
730
731    let mut best_h = candidates[0];
732    let mut best_acc = 0.0;
733    for &h in &candidates {
734        let acc = loo_accuracy_for_bandwidth(dists, labels, g, n, h);
735        if acc > best_acc {
736            best_acc = acc;
737            best_h = h;
738        }
739    }
740    best_h
741}
742
743/// LOO kernel classification with product kernel.
744fn kernel_classify_loo(
745    func_dists: &[f64],
746    scalar_dists: Option<&[f64]>,
747    labels: &[usize],
748    g: usize,
749    n: usize,
750    h_func: f64,
751    h_scalar: f64,
752) -> Vec<usize> {
753    (0..n)
754        .map(|i| {
755            let mut votes = vec![0.0; g];
756            for j in 0..n {
757                if j == i {
758                    continue;
759                }
760                let kf = gaussian_kernel(func_dists[i * n + j], h_func);
761                let ks = match scalar_dists {
762                    Some(sd) if h_scalar > 1e-15 => gaussian_kernel(sd[i * n + j], h_scalar),
763                    _ => 1.0,
764                };
765                votes[labels[j]] += kf * ks;
766            }
767            argmax_class(&votes)
768        })
769        .collect()
770}
771
772// ---------------------------------------------------------------------------
773// Depth-based DD-classifier
774// ---------------------------------------------------------------------------
775
776/// Depth-based DD-classifier.
777///
778/// Computes functional depth of each observation w.r.t. each class,
779/// then classifies by maximum depth.
780/// Compute depth of all observations w.r.t. each class.
781fn compute_class_depths(data: &FdMatrix, class_indices: &[Vec<usize>], n: usize) -> FdMatrix {
782    let g = class_indices.len();
783    let mut depth_scores = FdMatrix::zeros(n, g);
784    for c in 0..g {
785        if class_indices[c].is_empty() {
786            continue;
787        }
788        let class_data = extract_class_data(data, &class_indices[c]);
789        let depths = fraiman_muniz_1d(data, &class_data, true);
790        for i in 0..n {
791            depth_scores[(i, c)] = depths[i];
792        }
793    }
794    depth_scores
795}
796
797/// Blend functional depth scores with scalar rank depth from covariates.
798fn blend_scalar_depths(
799    depth_scores: &mut FdMatrix,
800    cov: &FdMatrix,
801    class_indices: &[Vec<usize>],
802    n: usize,
803) {
804    let g = class_indices.len();
805    let p = cov.ncols();
806    for c in 0..g {
807        for i in 0..n {
808            let sd = scalar_depth_for_obs(cov, i, &class_indices[c], p);
809            depth_scores[(i, c)] = 0.7 * depth_scores[(i, c)] + 0.3 * sd;
810        }
811    }
812}
813
814pub fn fclassif_dd(
815    data: &FdMatrix,
816    y: &[usize],
817    covariates: Option<&FdMatrix>,
818) -> Option<ClassifResult> {
819    let n = data.nrows();
820    if n == 0 || y.len() != n {
821        return None;
822    }
823
824    let (labels, g) = remap_labels(y);
825    if g < 2 {
826        return None;
827    }
828
829    let class_indices: Vec<Vec<usize>> = (0..g)
830        .map(|c| (0..n).filter(|&i| labels[i] == c).collect())
831        .collect();
832
833    let mut depth_scores = compute_class_depths(data, &class_indices, n);
834
835    if let Some(cov) = covariates {
836        blend_scalar_depths(&mut depth_scores, cov, &class_indices, n);
837    }
838
839    let predicted: Vec<usize> = (0..n)
840        .map(|i| {
841            let scores: Vec<f64> = (0..g).map(|c| depth_scores[(i, c)]).collect();
842            argmax_class(&scores)
843        })
844        .collect();
845
846    let accuracy = compute_accuracy(&labels, &predicted);
847    let confusion = confusion_matrix(&labels, &predicted, g);
848
849    Some(ClassifResult {
850        predicted,
851        probabilities: Some(depth_scores),
852        accuracy,
853        confusion,
854        n_classes: g,
855        ncomp: 0,
856    })
857}
858
859/// Extract rows corresponding to given indices into a new FdMatrix.
860fn extract_class_data(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
861    let nc = indices.len();
862    let m = data.ncols();
863    let mut result = FdMatrix::zeros(nc, m);
864    for (ri, &i) in indices.iter().enumerate() {
865        for j in 0..m {
866            result[(ri, j)] = data[(i, j)];
867        }
868    }
869    result
870}
871
872// ---------------------------------------------------------------------------
873// Cross-validation
874// ---------------------------------------------------------------------------
875
876/// K-fold cross-validated error rate for functional classification.
877///
878/// # Arguments
879/// * `data` — Functional data (n × m)
880/// * `argvals` — Evaluation points
881/// * `y` — Class labels
882/// * `covariates` — Optional scalar covariates
883/// * `method` — "lda", "qda", "knn", "kernel", "dd"
884/// * `ncomp` — Number of FPC components (for lda/qda/knn)
885/// * `nfold` — Number of CV folds
886/// * `seed` — Random seed for fold assignment
887pub fn fclassif_cv(
888    data: &FdMatrix,
889    argvals: &[f64],
890    y: &[usize],
891    covariates: Option<&FdMatrix>,
892    method: &str,
893    ncomp: usize,
894    nfold: usize,
895    seed: u64,
896) -> Option<ClassifCvResult> {
897    let n = data.nrows();
898    if n < nfold || nfold < 2 {
899        return None;
900    }
901
902    let (labels, g) = remap_labels(y);
903    if g < 2 {
904        return None;
905    }
906
907    // Assign folds
908    let folds = assign_folds(n, nfold, seed);
909
910    let mut fold_errors = Vec::with_capacity(nfold);
911
912    for fold in 0..nfold {
913        let (train_idx, test_idx) = fold_split(&folds, fold);
914        let train_data = extract_class_data(data, &train_idx);
915        let test_data = extract_class_data(data, &test_idx);
916        let train_labels: Vec<usize> = train_idx.iter().map(|&i| labels[i]).collect();
917        let test_labels: Vec<usize> = test_idx.iter().map(|&i| labels[i]).collect();
918
919        let train_cov = covariates.map(|c| extract_class_data(c, &train_idx));
920        let test_cov = covariates.map(|c| extract_class_data(c, &test_idx));
921
922        let predictions = cv_fold_predict(
923            &train_data,
924            &test_data,
925            argvals,
926            &train_labels,
927            g,
928            train_cov.as_ref(),
929            test_cov.as_ref(),
930            method,
931            ncomp,
932        );
933
934        let n_test = test_labels.len();
935        let errors = match predictions {
936            Some(pred) => {
937                let wrong = pred
938                    .iter()
939                    .zip(&test_labels)
940                    .filter(|(&p, &t)| p != t)
941                    .count();
942                wrong as f64 / n_test as f64
943            }
944            None => 1.0,
945        };
946        fold_errors.push(errors);
947    }
948
949    let error_rate = fold_errors.iter().sum::<f64>() / nfold as f64;
950
951    Some(ClassifCvResult {
952        error_rate,
953        fold_errors,
954        best_ncomp: ncomp,
955    })
956}
957
958/// Assign observations to folds.
959fn assign_folds(n: usize, nfold: usize, seed: u64) -> Vec<usize> {
960    use rand::prelude::*;
961    let mut rng = StdRng::seed_from_u64(seed);
962    let mut indices: Vec<usize> = (0..n).collect();
963    indices.shuffle(&mut rng);
964
965    let mut folds = vec![0usize; n];
966    for (rank, &idx) in indices.iter().enumerate() {
967        folds[idx] = rank % nfold;
968    }
969    folds
970}
971
972/// Split indices into train and test for given fold.
973fn fold_split(folds: &[usize], fold: usize) -> (Vec<usize>, Vec<usize>) {
974    let train: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] != fold).collect();
975    let test: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] == fold).collect();
976    (train, test)
977}
978
979/// Predict on test set for one CV fold.
980fn cv_fold_predict(
981    train_data: &FdMatrix,
982    test_data: &FdMatrix,
983    _argvals: &[f64],
984    train_labels: &[usize],
985    g: usize,
986    train_cov: Option<&FdMatrix>,
987    test_cov: Option<&FdMatrix>,
988    method: &str,
989    ncomp: usize,
990) -> Option<Vec<usize>> {
991    let fpca = fdata_to_pc_1d(train_data, ncomp)?;
992    match method {
993        "lda" => {
994            let predictions =
995                project_and_classify_lda(test_data, &fpca, train_labels, g, train_cov, test_cov);
996            Some(predictions)
997        }
998        "qda" => {
999            let predictions =
1000                project_and_classify_qda(test_data, &fpca, train_labels, g, train_cov, test_cov);
1001            Some(predictions)
1002        }
1003        "knn" => {
1004            let predictions =
1005                project_and_classify_knn(test_data, &fpca, train_labels, g, train_cov, test_cov, 5);
1006            Some(predictions)
1007        }
1008        // kernel and dd classifiers don't support out-of-sample prediction on new data
1009        "kernel" | "dd" => None,
1010        _ => None,
1011    }
1012}
1013
1014/// Project test data onto FPCA basis (mean-center, multiply by rotation).
1015fn project_test_onto_fpca(test_data: &FdMatrix, fpca: &crate::regression::FpcaResult) -> FdMatrix {
1016    let n_test = test_data.nrows();
1017    let m = test_data.ncols();
1018    let d_pc = fpca.scores.ncols();
1019    let mut test_features = FdMatrix::zeros(n_test, d_pc);
1020    for i in 0..n_test {
1021        for k in 0..d_pc {
1022            let mut score = 0.0;
1023            for j in 0..m {
1024                score += (test_data[(i, j)] - fpca.mean[j]) * fpca.rotation[(j, k)];
1025            }
1026            test_features[(i, k)] = score;
1027        }
1028    }
1029    test_features
1030}
1031
1032/// Append scalar covariates to FPCA scores to form augmented feature matrix.
1033fn append_covariates(scores: &FdMatrix, covariates: Option<&FdMatrix>) -> FdMatrix {
1034    match covariates {
1035        None => scores.clone(),
1036        Some(cov) => {
1037            let n = scores.nrows();
1038            let d_pc = scores.ncols();
1039            let d_cov = cov.ncols();
1040            let mut features = FdMatrix::zeros(n, d_pc + d_cov);
1041            for i in 0..n {
1042                for j in 0..d_pc {
1043                    features[(i, j)] = scores[(i, j)];
1044                }
1045                for j in 0..d_cov {
1046                    features[(i, d_pc + j)] = cov[(i, j)];
1047                }
1048            }
1049            features
1050        }
1051    }
1052}
1053
1054/// Project test data onto training FPCA and classify with LDA.
1055fn project_and_classify_lda(
1056    test_data: &FdMatrix,
1057    fpca: &crate::regression::FpcaResult,
1058    train_labels: &[usize],
1059    g: usize,
1060    train_cov: Option<&FdMatrix>,
1061    test_cov: Option<&FdMatrix>,
1062) -> Vec<usize> {
1063    let test_pc = project_test_onto_fpca(test_data, fpca);
1064    let test_features = append_covariates(&test_pc, test_cov);
1065
1066    let train_features = append_covariates(&fpca.scores, train_cov);
1067    let (class_means, cov, priors) = lda_params(&train_features, train_labels, g);
1068    let d = train_features.ncols();
1069    match cholesky_d(&cov, d) {
1070        Some(chol) => lda_predict(&test_features, &class_means, &chol, &priors, g),
1071        None => vec![0; test_data.nrows()],
1072    }
1073}
1074
1075/// Project test data onto training FPCA and classify with QDA.
1076fn project_and_classify_qda(
1077    test_data: &FdMatrix,
1078    fpca: &crate::regression::FpcaResult,
1079    train_labels: &[usize],
1080    g: usize,
1081    train_cov: Option<&FdMatrix>,
1082    test_cov: Option<&FdMatrix>,
1083) -> Vec<usize> {
1084    let n_test = test_data.nrows();
1085    let test_pc = project_test_onto_fpca(test_data, fpca);
1086    let test_features = append_covariates(&test_pc, test_cov);
1087
1088    let train_features = append_covariates(&fpca.scores, train_cov);
1089
1090    match build_qda_params(&train_features, train_labels, g) {
1091        Some((class_means, class_chols, class_log_dets, priors)) => qda_predict(
1092            &test_features,
1093            &class_means,
1094            &class_chols,
1095            &class_log_dets,
1096            &priors,
1097            g,
1098        ),
1099        None => vec![0; n_test],
1100    }
1101}
1102
1103/// Project test data and classify with k-NN.
1104fn project_and_classify_knn(
1105    test_data: &FdMatrix,
1106    fpca: &crate::regression::FpcaResult,
1107    train_labels: &[usize],
1108    g: usize,
1109    train_cov: Option<&FdMatrix>,
1110    test_cov: Option<&FdMatrix>,
1111    k_nn: usize,
1112) -> Vec<usize> {
1113    let n_test = test_data.nrows();
1114    let n_train = fpca.scores.nrows();
1115
1116    let test_pc = project_test_onto_fpca(test_data, fpca);
1117    let test_features = append_covariates(&test_pc, test_cov);
1118    let train_features = append_covariates(&fpca.scores, train_cov);
1119    let d = train_features.ncols();
1120
1121    (0..n_test)
1122        .map(|i| {
1123            // Distances to all training points in augmented feature space
1124            let mut dists: Vec<(f64, usize)> = (0..n_train)
1125                .map(|t| {
1126                    let d_sq: f64 = (0..d)
1127                        .map(|k| (test_features[(i, k)] - train_features[(t, k)]).powi(2))
1128                        .sum();
1129                    (d_sq, train_labels[t])
1130                })
1131                .collect();
1132            dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
1133
1134            let mut votes = vec![0usize; g];
1135            for &(_, label) in dists.iter().take(k_nn.min(n_train)) {
1136                votes[label] += 1;
1137            }
1138            votes
1139                .iter()
1140                .enumerate()
1141                .max_by_key(|&(_, &v)| v)
1142                .map(|(c, _)| c)
1143                .unwrap_or(0)
1144        })
1145        .collect()
1146}
1147
1148// ===========================================================================
1149// ClassifFit: wrapper for explainability trait
1150// ===========================================================================
1151
1152use crate::explain_generic::{FpcPredictor, TaskType};
1153
1154/// Classification method with stored parameters for prediction.
1155pub enum ClassifMethod {
1156    /// Linear Discriminant Analysis.
1157    Lda {
1158        class_means: Vec<Vec<f64>>,
1159        cov_chol: Vec<f64>,
1160        priors: Vec<f64>,
1161        n_classes: usize,
1162    },
1163    /// Quadratic Discriminant Analysis.
1164    Qda {
1165        class_means: Vec<Vec<f64>>,
1166        class_chols: Vec<Vec<f64>>,
1167        class_log_dets: Vec<f64>,
1168        priors: Vec<f64>,
1169        n_classes: usize,
1170    },
1171    /// k-Nearest Neighbors.
1172    Knn {
1173        training_scores: FdMatrix,
1174        training_labels: Vec<usize>,
1175        k: usize,
1176        n_classes: usize,
1177    },
1178}
1179
1180/// A fitted classification model that retains FPCA components for explainability.
1181pub struct ClassifFit {
1182    /// Classification result (predicted labels, accuracy, confusion matrix).
1183    pub result: ClassifResult,
1184    /// FPCA mean function (length m).
1185    pub fpca_mean: Vec<f64>,
1186    /// FPCA rotation matrix (m × ncomp).
1187    pub fpca_rotation: FdMatrix,
1188    /// FPCA scores (n × ncomp).
1189    pub fpca_scores: FdMatrix,
1190    /// Number of FPC components used.
1191    pub ncomp: usize,
1192    /// The classification method with stored parameters.
1193    pub method: ClassifMethod,
1194}
1195
1196/// FPC + LDA classification, retaining FPCA and LDA parameters for explainability.
1197pub fn fclassif_lda_fit(
1198    data: &FdMatrix,
1199    y: &[usize],
1200    covariates: Option<&FdMatrix>,
1201    ncomp: usize,
1202) -> Option<ClassifFit> {
1203    let n = data.nrows();
1204    if n == 0 || y.len() != n || ncomp == 0 {
1205        return None;
1206    }
1207
1208    let (labels, g) = remap_labels(y);
1209    if g < 2 {
1210        return None;
1211    }
1212
1213    // _fit variants use FPCA-only features (no covariates) so that stored
1214    // dimensions are consistent with FpcPredictor::project() / predict_from_scores().
1215    let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
1216    let _ = covariates; // acknowledged but not used — see docstring
1217    let d = features.ncols();
1218    let (class_means, cov, priors) = lda_params(&features, &labels, g);
1219    let chol = cholesky_d(&cov, d)?;
1220
1221    let predicted = lda_predict(&features, &class_means, &chol, &priors, g);
1222    let accuracy = compute_accuracy(&labels, &predicted);
1223    let confusion = confusion_matrix(&labels, &predicted, g);
1224
1225    Some(ClassifFit {
1226        result: ClassifResult {
1227            predicted,
1228            probabilities: None,
1229            accuracy,
1230            confusion,
1231            n_classes: g,
1232            ncomp: d,
1233        },
1234        fpca_mean: mean.clone(),
1235        fpca_rotation: rotation,
1236        fpca_scores: features,
1237        ncomp: d,
1238        method: ClassifMethod::Lda {
1239            class_means,
1240            cov_chol: chol,
1241            priors,
1242            n_classes: g,
1243        },
1244    })
1245}
1246
1247/// FPC + QDA classification, retaining FPCA and QDA parameters for explainability.
1248pub fn fclassif_qda_fit(
1249    data: &FdMatrix,
1250    y: &[usize],
1251    covariates: Option<&FdMatrix>,
1252    ncomp: usize,
1253) -> Option<ClassifFit> {
1254    let n = data.nrows();
1255    if n == 0 || y.len() != n || ncomp == 0 {
1256        return None;
1257    }
1258
1259    let (labels, g) = remap_labels(y);
1260    if g < 2 {
1261        return None;
1262    }
1263
1264    // _fit variants use FPCA-only features — see fclassif_lda_fit comment.
1265    let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
1266    let _ = covariates;
1267    let (class_means, class_chols, class_log_dets, priors) =
1268        build_qda_params(&features, &labels, g)?;
1269
1270    let predicted = qda_predict(
1271        &features,
1272        &class_means,
1273        &class_chols,
1274        &class_log_dets,
1275        &priors,
1276        g,
1277    );
1278    let accuracy = compute_accuracy(&labels, &predicted);
1279    let confusion = confusion_matrix(&labels, &predicted, g);
1280    let d = features.ncols();
1281
1282    Some(ClassifFit {
1283        result: ClassifResult {
1284            predicted,
1285            probabilities: None,
1286            accuracy,
1287            confusion,
1288            n_classes: g,
1289            ncomp: d,
1290        },
1291        fpca_mean: mean.clone(),
1292        fpca_rotation: rotation,
1293        fpca_scores: features,
1294        ncomp: d,
1295        method: ClassifMethod::Qda {
1296            class_means,
1297            class_chols,
1298            class_log_dets,
1299            priors,
1300            n_classes: g,
1301        },
1302    })
1303}
1304
1305/// FPC + k-NN classification, retaining FPCA and training data for explainability.
1306pub fn fclassif_knn_fit(
1307    data: &FdMatrix,
1308    y: &[usize],
1309    covariates: Option<&FdMatrix>,
1310    ncomp: usize,
1311    k_nn: usize,
1312) -> Option<ClassifFit> {
1313    let n = data.nrows();
1314    if n == 0 || y.len() != n || ncomp == 0 || k_nn == 0 {
1315        return None;
1316    }
1317
1318    let (labels, g) = remap_labels(y);
1319    if g < 2 {
1320        return None;
1321    }
1322
1323    // _fit variants use FPCA-only features — see fclassif_lda_fit comment.
1324    let (features, mean, rotation) = build_feature_matrix(data, None, ncomp)?;
1325    let _ = covariates;
1326    let d = features.ncols();
1327
1328    let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
1329    let accuracy = compute_accuracy(&labels, &predicted);
1330    let confusion = confusion_matrix(&labels, &predicted, g);
1331
1332    Some(ClassifFit {
1333        result: ClassifResult {
1334            predicted,
1335            probabilities: None,
1336            accuracy,
1337            confusion,
1338            n_classes: g,
1339            ncomp: d,
1340        },
1341        fpca_mean: mean.clone(),
1342        fpca_rotation: rotation,
1343        fpca_scores: features.clone(),
1344        ncomp: d,
1345        method: ClassifMethod::Knn {
1346            training_scores: features,
1347            training_labels: labels,
1348            k: k_nn,
1349            n_classes: g,
1350        },
1351    })
1352}
1353
1354// ---------------------------------------------------------------------------
1355// FpcPredictor impl for ClassifFit
1356// ---------------------------------------------------------------------------
1357
1358impl FpcPredictor for ClassifFit {
1359    fn fpca_mean(&self) -> &[f64] {
1360        &self.fpca_mean
1361    }
1362
1363    fn fpca_rotation(&self) -> &FdMatrix {
1364        &self.fpca_rotation
1365    }
1366
1367    fn ncomp(&self) -> usize {
1368        self.ncomp
1369    }
1370
1371    fn training_scores(&self) -> &FdMatrix {
1372        &self.fpca_scores
1373    }
1374
1375    fn task_type(&self) -> TaskType {
1376        match &self.method {
1377            ClassifMethod::Lda { n_classes, .. }
1378            | ClassifMethod::Qda { n_classes, .. }
1379            | ClassifMethod::Knn { n_classes, .. } => {
1380                if *n_classes == 2 {
1381                    TaskType::BinaryClassification
1382                } else {
1383                    TaskType::MulticlassClassification(*n_classes)
1384                }
1385            }
1386        }
1387    }
1388
1389    fn predict_from_scores(&self, scores: &[f64], _scalar_covariates: Option<&[f64]>) -> f64 {
1390        match &self.method {
1391            ClassifMethod::Lda {
1392                class_means,
1393                cov_chol,
1394                priors,
1395                n_classes,
1396            } => {
1397                let g = *n_classes;
1398                let d = scores.len();
1399                if g == 2 {
1400                    // Return P(Y=1) via softmax of discriminant scores
1401                    let score0 = priors[0].max(1e-15).ln()
1402                        - 0.5 * mahalanobis_sq(scores, &class_means[0], cov_chol, d);
1403                    let score1 = priors[1].max(1e-15).ln()
1404                        - 0.5 * mahalanobis_sq(scores, &class_means[1], cov_chol, d);
1405                    let max_s = score0.max(score1);
1406                    let exp0 = (score0 - max_s).exp();
1407                    let exp1 = (score1 - max_s).exp();
1408                    exp1 / (exp0 + exp1)
1409                } else {
1410                    // Return predicted class as f64
1411                    let mut best_class = 0;
1412                    let mut best_score = f64::NEG_INFINITY;
1413                    for c in 0..g {
1414                        let maha = mahalanobis_sq(scores, &class_means[c], cov_chol, d);
1415                        let s = priors[c].max(1e-15).ln() - 0.5 * maha;
1416                        if s > best_score {
1417                            best_score = s;
1418                            best_class = c;
1419                        }
1420                    }
1421                    best_class as f64
1422                }
1423            }
1424            ClassifMethod::Qda {
1425                class_means,
1426                class_chols,
1427                class_log_dets,
1428                priors,
1429                n_classes,
1430            } => {
1431                let g = *n_classes;
1432                let d = scores.len();
1433                if g == 2 {
1434                    // Return P(Y=1) via softmax of discriminant scores
1435                    let score0 = priors[0].max(1e-15).ln()
1436                        - 0.5
1437                            * (class_log_dets[0]
1438                                + mahalanobis_sq(scores, &class_means[0], &class_chols[0], d));
1439                    let score1 = priors[1].max(1e-15).ln()
1440                        - 0.5
1441                            * (class_log_dets[1]
1442                                + mahalanobis_sq(scores, &class_means[1], &class_chols[1], d));
1443                    let max_s = score0.max(score1);
1444                    let exp0 = (score0 - max_s).exp();
1445                    let exp1 = (score1 - max_s).exp();
1446                    exp1 / (exp0 + exp1)
1447                } else {
1448                    let mut best_class = 0;
1449                    let mut best_score = f64::NEG_INFINITY;
1450                    for c in 0..g {
1451                        let maha = mahalanobis_sq(scores, &class_means[c], &class_chols[c], d);
1452                        let s = priors[c].max(1e-15).ln() - 0.5 * (class_log_dets[c] + maha);
1453                        if s > best_score {
1454                            best_score = s;
1455                            best_class = c;
1456                        }
1457                    }
1458                    best_class as f64
1459                }
1460            }
1461            ClassifMethod::Knn {
1462                training_scores,
1463                training_labels,
1464                k,
1465                n_classes,
1466            } => {
1467                let g = *n_classes;
1468                let n_train = training_scores.nrows();
1469                let d = scores.len();
1470                let k_nn = (*k).min(n_train);
1471
1472                let mut dists: Vec<(f64, usize)> = (0..n_train)
1473                    .map(|j| {
1474                        let d_sq: f64 = (0..d)
1475                            .map(|c| (scores[c] - training_scores[(j, c)]).powi(2))
1476                            .sum();
1477                        (d_sq, training_labels[j])
1478                    })
1479                    .collect();
1480                dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
1481
1482                let mut votes = vec![0usize; g];
1483                for &(_, label) in dists.iter().take(k_nn) {
1484                    if label < g {
1485                        votes[label] += 1;
1486                    }
1487                }
1488
1489                if g == 2 {
1490                    // Return proportion voting for class 1 as probability
1491                    votes[1] as f64 / k_nn as f64
1492                } else {
1493                    // Return majority vote class as f64
1494                    votes
1495                        .iter()
1496                        .enumerate()
1497                        .max_by_key(|&(_, &v)| v)
1498                        .map(|(c, _)| c as f64)
1499                        .unwrap_or(0.0)
1500                }
1501            }
1502        }
1503    }
1504}
1505
1506// ---------------------------------------------------------------------------
1507// Class probability vectors (for conformal prediction sets)
1508// ---------------------------------------------------------------------------
1509
1510/// Compute full class probability vectors for each observation.
1511///
1512/// Returns `n × g` probability vectors suitable for conformal classification.
1513/// For each observation, the probabilities sum to 1.
1514pub(crate) fn classif_predict_probs(fit: &ClassifFit, scores: &FdMatrix) -> Vec<Vec<f64>> {
1515    let n = scores.nrows();
1516    let d = scores.ncols();
1517    match &fit.method {
1518        ClassifMethod::Lda {
1519            class_means,
1520            cov_chol,
1521            priors,
1522            n_classes,
1523        } => {
1524            let g = *n_classes;
1525            (0..n)
1526                .map(|i| {
1527                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
1528                    let disc: Vec<f64> = (0..g)
1529                        .map(|c| {
1530                            priors[c].max(1e-15).ln()
1531                                - 0.5 * mahalanobis_sq(&x, &class_means[c], cov_chol, d)
1532                        })
1533                        .collect();
1534                    softmax(&disc)
1535                })
1536                .collect()
1537        }
1538        ClassifMethod::Qda {
1539            class_means,
1540            class_chols,
1541            class_log_dets,
1542            priors,
1543            n_classes,
1544        } => {
1545            let g = *n_classes;
1546            (0..n)
1547                .map(|i| {
1548                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
1549                    let disc: Vec<f64> = (0..g)
1550                        .map(|c| {
1551                            priors[c].max(1e-15).ln()
1552                                - 0.5
1553                                    * (class_log_dets[c]
1554                                        + mahalanobis_sq(&x, &class_means[c], &class_chols[c], d))
1555                        })
1556                        .collect();
1557                    softmax(&disc)
1558                })
1559                .collect()
1560        }
1561        ClassifMethod::Knn {
1562            training_scores,
1563            training_labels,
1564            k,
1565            n_classes,
1566        } => {
1567            let g = *n_classes;
1568            let n_train = training_scores.nrows();
1569            let k_nn = (*k).min(n_train);
1570            (0..n)
1571                .map(|i| {
1572                    let x: Vec<f64> = (0..d).map(|j| scores[(i, j)]).collect();
1573                    let mut dists: Vec<(f64, usize)> = (0..n_train)
1574                        .map(|j| {
1575                            let d_sq: f64 = (0..d)
1576                                .map(|c| (x[c] - training_scores[(j, c)]).powi(2))
1577                                .sum();
1578                            (d_sq, training_labels[j])
1579                        })
1580                        .collect();
1581                    dists
1582                        .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
1583                    let mut votes = vec![0usize; g];
1584                    for &(_, label) in dists.iter().take(k_nn) {
1585                        if label < g {
1586                            votes[label] += 1;
1587                        }
1588                    }
1589                    votes.iter().map(|&v| v as f64 / k_nn as f64).collect()
1590                })
1591                .collect()
1592        }
1593    }
1594}
1595
1596/// Softmax of a vector of log-scores → probabilities.
1597fn softmax(scores: &[f64]) -> Vec<f64> {
1598    let max_s = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1599    let exps: Vec<f64> = scores.iter().map(|&s| (s - max_s).exp()).collect();
1600    let sum: f64 = exps.iter().sum();
1601    exps.iter().map(|&e| e / sum).collect()
1602}
1603
1604// ---------------------------------------------------------------------------
1605// Tests
1606// ---------------------------------------------------------------------------
1607
1608#[cfg(test)]
1609mod tests {
1610    use super::*;
1611    use std::f64::consts::PI;
1612
1613    fn uniform_grid(m: usize) -> Vec<f64> {
1614        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
1615    }
1616
1617    /// Generate two well-separated classes of curves.
1618    fn generate_two_class_data(n_per: usize, m: usize) -> (FdMatrix, Vec<usize>, Vec<f64>) {
1619        let t = uniform_grid(m);
1620        let n = 2 * n_per;
1621        let mut col_major = vec![0.0; n * m];
1622
1623        for i in 0..n_per {
1624            for (j, &tj) in t.iter().enumerate() {
1625                // Class 0: sin
1626                col_major[i + j * n] =
1627                    (2.0 * PI * tj).sin() + 0.05 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
1628            }
1629        }
1630        for i in 0..n_per {
1631            for (j, &tj) in t.iter().enumerate() {
1632                // Class 1: -sin (opposite phase)
1633                col_major[(i + n_per) + j * n] =
1634                    -(2.0 * PI * tj).sin() + 0.05 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
1635            }
1636        }
1637
1638        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1639        let labels: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1640        (data, labels, t)
1641    }
1642
1643    #[test]
1644    fn test_fclassif_lda_basic() {
1645        let (data, labels, _t) = generate_two_class_data(20, 50);
1646        let result = fclassif_lda(&data, &labels, None, 3).unwrap();
1647
1648        assert_eq!(result.predicted.len(), 40);
1649        assert_eq!(result.n_classes, 2);
1650        assert!(
1651            result.accuracy > 0.8,
1652            "LDA accuracy should be high: {}",
1653            result.accuracy
1654        );
1655    }
1656
1657    #[test]
1658    fn test_fclassif_qda_basic() {
1659        let (data, labels, _t) = generate_two_class_data(20, 50);
1660        let result = fclassif_qda(&data, &labels, None, 3).unwrap();
1661
1662        assert_eq!(result.predicted.len(), 40);
1663        assert!(
1664            result.accuracy > 0.8,
1665            "QDA accuracy should be high: {}",
1666            result.accuracy
1667        );
1668    }
1669
1670    #[test]
1671    fn test_fclassif_knn_basic() {
1672        let (data, labels, _t) = generate_two_class_data(20, 50);
1673        let result = fclassif_knn(&data, &labels, None, 3, 5).unwrap();
1674
1675        assert_eq!(result.predicted.len(), 40);
1676        assert!(
1677            result.accuracy > 0.7,
1678            "k-NN accuracy should be reasonable: {}",
1679            result.accuracy
1680        );
1681    }
1682
1683    #[test]
1684    fn test_fclassif_kernel_basic() {
1685        let (data, labels, t) = generate_two_class_data(20, 50);
1686        let result = fclassif_kernel(&data, &t, &labels, None, 0.0, 0.0).unwrap();
1687
1688        assert_eq!(result.predicted.len(), 40);
1689        assert!(
1690            result.accuracy > 0.7,
1691            "Kernel accuracy should be reasonable: {}",
1692            result.accuracy
1693        );
1694    }
1695
1696    #[test]
1697    fn test_fclassif_dd_basic() {
1698        let (data, labels, _t) = generate_two_class_data(20, 50);
1699        let result = fclassif_dd(&data, &labels, None).unwrap();
1700
1701        assert_eq!(result.predicted.len(), 40);
1702        assert_eq!(result.n_classes, 2);
1703        // DD-classifier should work on well-separated data
1704        assert!(
1705            result.accuracy > 0.6,
1706            "DD accuracy should be reasonable: {}",
1707            result.accuracy
1708        );
1709        assert!(result.probabilities.is_some());
1710    }
1711
1712    #[test]
1713    fn test_confusion_matrix_shape() {
1714        let (data, labels, _t) = generate_two_class_data(15, 50);
1715        let result = fclassif_lda(&data, &labels, None, 2).unwrap();
1716
1717        assert_eq!(result.confusion.len(), 2);
1718        assert_eq!(result.confusion[0].len(), 2);
1719        assert_eq!(result.confusion[1].len(), 2);
1720
1721        // Total should equal n
1722        let total: usize = result.confusion.iter().flat_map(|row| row.iter()).sum();
1723        assert_eq!(total, 30);
1724    }
1725
1726    #[test]
1727    fn test_fclassif_cv_lda() {
1728        let (data, labels, t) = generate_two_class_data(25, 50);
1729        let result = fclassif_cv(&data, &t, &labels, None, "lda", 3, 5, 42).unwrap();
1730
1731        assert_eq!(result.fold_errors.len(), 5);
1732        assert!(
1733            result.error_rate < 0.3,
1734            "CV error should be low: {}",
1735            result.error_rate
1736        );
1737    }
1738
1739    #[test]
1740    fn test_fclassif_invalid_input() {
1741        let data = FdMatrix::zeros(0, 0);
1742        assert!(fclassif_lda(&data, &[], None, 1).is_none());
1743
1744        let data = FdMatrix::zeros(10, 50);
1745        let labels = vec![0; 10]; // single class
1746        assert!(fclassif_lda(&data, &labels, None, 1).is_none());
1747    }
1748
1749    #[test]
1750    fn test_remap_labels() {
1751        let (mapped, g) = remap_labels(&[5, 5, 10, 10, 20]);
1752        assert_eq!(g, 3);
1753        assert_eq!(mapped, vec![0, 0, 1, 1, 2]);
1754    }
1755
1756    #[test]
1757    fn test_fclassif_lda_with_covariates() {
1758        let n_per = 15;
1759        let n = 2 * n_per;
1760        let m = 50;
1761        let t = uniform_grid(m);
1762
1763        // Curves are identical across classes
1764        let mut col_major = vec![0.0; n * m];
1765        for i in 0..n {
1766            for (j, &tj) in t.iter().enumerate() {
1767                col_major[i + j * n] = (2.0 * PI * tj).sin();
1768            }
1769        }
1770        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1771
1772        // But covariate separates: 0 vs 10
1773        let mut cov_data = vec![0.0; n];
1774        for i in n_per..n {
1775            cov_data[i] = 10.0;
1776        }
1777        let covariates = FdMatrix::from_column_major(cov_data, n, 1).unwrap();
1778
1779        let labels: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1780
1781        let result = fclassif_lda(&data, &labels, Some(&covariates), 2).unwrap();
1782        assert!(
1783            result.accuracy > 0.9,
1784            "Covariate should enable separation: {}",
1785            result.accuracy
1786        );
1787    }
1788
1789    // -----------------------------------------------------------------------
1790    // Additional coverage tests
1791    // -----------------------------------------------------------------------
1792
1793    /// Helper: generate two-class data with scalar covariates.
1794    fn generate_two_class_with_covariates(
1795        n_per: usize,
1796        m: usize,
1797        p_cov: usize,
1798    ) -> (FdMatrix, Vec<usize>, Vec<f64>, FdMatrix) {
1799        let (data, labels, t) = generate_two_class_data(n_per, m);
1800        let n = 2 * n_per;
1801        // Covariates: class 0 → low values, class 1 → high values
1802        let mut cov_data = vec![0.0; n * p_cov];
1803        for i in 0..n {
1804            for j in 0..p_cov {
1805                let base = if labels[i] == 0 { 0.0 } else { 5.0 };
1806                cov_data[i + j * n] = base + 0.1 * ((i * 3 + j * 7) % 50) as f64 / 50.0;
1807            }
1808        }
1809        let covariates = FdMatrix::from_column_major(cov_data, n, p_cov).unwrap();
1810        (data, labels, t, covariates)
1811    }
1812
1813    #[test]
1814    fn test_fclassif_cv_qda() {
1815        let (data, labels, t) = generate_two_class_data(25, 50);
1816        let result = fclassif_cv(&data, &t, &labels, None, "qda", 3, 5, 42).unwrap();
1817
1818        assert_eq!(result.fold_errors.len(), 5);
1819        assert!(
1820            result.error_rate < 0.4,
1821            "QDA CV error should be low: {}",
1822            result.error_rate
1823        );
1824        assert_eq!(result.best_ncomp, 3);
1825    }
1826
1827    #[test]
1828    fn test_fclassif_cv_knn() {
1829        let (data, labels, t) = generate_two_class_data(25, 50);
1830        let result = fclassif_cv(&data, &t, &labels, None, "knn", 3, 5, 42).unwrap();
1831
1832        assert_eq!(result.fold_errors.len(), 5);
1833        assert!(
1834            result.error_rate < 0.4,
1835            "k-NN CV error should be low: {}",
1836            result.error_rate
1837        );
1838    }
1839
1840    #[test]
1841    fn test_fclassif_cv_kernel() {
1842        let (data, labels, t) = generate_two_class_data(25, 50);
1843        let result = fclassif_cv(&data, &t, &labels, None, "kernel", 3, 5, 42).unwrap();
1844
1845        assert_eq!(result.fold_errors.len(), 5);
1846        // Kernel CV: placeholder prediction may not be accurate, just ensure it runs
1847        assert!(result.error_rate >= 0.0 && result.error_rate <= 1.0);
1848    }
1849
1850    #[test]
1851    fn test_fclassif_cv_dd() {
1852        let (data, labels, t) = generate_two_class_data(25, 50);
1853        let result = fclassif_cv(&data, &t, &labels, None, "dd", 3, 5, 42).unwrap();
1854
1855        assert_eq!(result.fold_errors.len(), 5);
1856        assert!(result.error_rate >= 0.0 && result.error_rate <= 1.0);
1857    }
1858
1859    #[test]
1860    fn test_fclassif_cv_invalid_method() {
1861        let (data, labels, t) = generate_two_class_data(25, 50);
1862        // "bogus" method hits the `_ => None` arm in cv_fold_predict
1863        let result = fclassif_cv(&data, &t, &labels, None, "bogus", 3, 5, 42);
1864
1865        // Should still return Some — fold errors will be 1.0 for each fold
1866        let r = result.unwrap();
1867        assert!((r.error_rate - 1.0).abs() < 1e-10);
1868    }
1869
1870    #[test]
1871    fn test_fclassif_cv_too_few_folds() {
1872        let (data, labels, t) = generate_two_class_data(10, 50);
1873        // nfold < 2 → None
1874        assert!(fclassif_cv(&data, &t, &labels, None, "lda", 3, 1, 42).is_none());
1875        // n < nfold → None
1876        assert!(fclassif_cv(&data, &t, &labels, None, "lda", 3, 100, 42).is_none());
1877    }
1878
1879    #[test]
1880    fn test_fclassif_cv_single_class() {
1881        let (data, _labels, t) = generate_two_class_data(10, 50);
1882        let single = vec![0usize; 20]; // only one class
1883        assert!(fclassif_cv(&data, &t, &single, None, "lda", 3, 5, 42).is_none());
1884    }
1885
1886    #[test]
1887    fn test_fclassif_kernel_with_covariates() {
1888        let (data, labels, t, covariates) = generate_two_class_with_covariates(20, 50, 2);
1889        let result = fclassif_kernel(&data, &t, &labels, Some(&covariates), 0.0, 0.0).unwrap();
1890
1891        assert_eq!(result.predicted.len(), 40);
1892        assert!(
1893            result.accuracy > 0.5,
1894            "Kernel+cov accuracy should be reasonable: {}",
1895            result.accuracy
1896        );
1897        assert_eq!(result.ncomp, 0); // kernel doesn't use ncomp
1898    }
1899
1900    #[test]
1901    fn test_fclassif_kernel_with_covariates_manual_bandwidth() {
1902        let (data, labels, t, covariates) = generate_two_class_with_covariates(15, 50, 1);
1903        // Provide explicit bandwidths (>0 skips LOO selection)
1904        let result = fclassif_kernel(&data, &t, &labels, Some(&covariates), 1.0, 1.0).unwrap();
1905
1906        assert_eq!(result.predicted.len(), 30);
1907        assert!(result.accuracy >= 0.0 && result.accuracy <= 1.0);
1908    }
1909
1910    #[test]
1911    fn test_fclassif_dd_with_covariates() {
1912        let (data, labels, _t, covariates) = generate_two_class_with_covariates(20, 50, 2);
1913        let result = fclassif_dd(&data, &labels, Some(&covariates)).unwrap();
1914
1915        assert_eq!(result.predicted.len(), 40);
1916        assert_eq!(result.n_classes, 2);
1917        assert!(
1918            result.accuracy > 0.5,
1919            "DD+cov accuracy should be reasonable: {}",
1920            result.accuracy
1921        );
1922        assert!(result.probabilities.is_some());
1923    }
1924
1925    #[test]
1926    fn test_fclassif_dd_with_single_covariate() {
1927        // Curves are identical; only the covariate separates classes
1928        let n_per = 15;
1929        let n = 2 * n_per;
1930        let m = 50;
1931        let t = uniform_grid(m);
1932
1933        let mut col_major = vec![0.0; n * m];
1934        for i in 0..n {
1935            for (j, &tj) in t.iter().enumerate() {
1936                col_major[i + j * n] = (2.0 * PI * tj).sin();
1937            }
1938        }
1939        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1940        let labels: Vec<usize> = (0..n).map(|i| if i < n_per { 0 } else { 1 }).collect();
1941
1942        // Covariate: class 0 → [0..1], class 1 → [10..11]
1943        let mut cov_data = vec![0.0; n];
1944        for i in 0..n_per {
1945            cov_data[i] = i as f64 / n_per as f64;
1946        }
1947        for i in n_per..n {
1948            cov_data[i] = 10.0 + (i - n_per) as f64 / n_per as f64;
1949        }
1950        let covariates = FdMatrix::from_column_major(cov_data, n, 1).unwrap();
1951
1952        let result = fclassif_dd(&data, &labels, Some(&covariates)).unwrap();
1953        // The scalar blending should help even when curves are identical
1954        assert!(
1955            result.accuracy >= 0.5,
1956            "DD with scalar covariate: {}",
1957            result.accuracy
1958        );
1959    }
1960
1961    #[test]
1962    fn test_scalar_depth_for_obs_edge_cases() {
1963        // Empty class indices → depth = 0
1964        let cov = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0], 4, 1).unwrap();
1965        assert_eq!(scalar_depth_for_obs(&cov, 0, &[], 1), 0.0);
1966
1967        // p=0 → depth = 0
1968        let cov0 = FdMatrix::zeros(4, 0);
1969        assert_eq!(scalar_depth_for_obs(&cov0, 0, &[0, 1, 2, 3], 0), 0.0);
1970
1971        // Normal case: all indices
1972        let depth = scalar_depth_for_obs(&cov, 1, &[0, 1, 2, 3], 1);
1973        assert!(depth > 0.0 && depth <= 0.5, "depth={}", depth);
1974
1975        // Observation is at the extremes
1976        let depth_min = scalar_depth_for_obs(&cov, 0, &[0, 1, 2, 3], 1);
1977        let depth_max = scalar_depth_for_obs(&cov, 3, &[0, 1, 2, 3], 1);
1978        // Extreme observations should have low depth
1979        assert!(depth_min <= 0.5, "depth_min={}", depth_min);
1980        assert!(depth_max <= 0.5, "depth_max={}", depth_max);
1981    }
1982
1983    #[test]
1984    fn test_scalar_depth_for_obs_multivariate() {
1985        // 2 covariates
1986        let cov =
1987            FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0], 4, 2)
1988                .unwrap();
1989        let depth = scalar_depth_for_obs(&cov, 1, &[0, 1, 2, 3], 2);
1990        assert!(depth > 0.0 && depth <= 0.5, "multivar depth={}", depth);
1991    }
1992
1993    #[test]
1994    fn test_blend_scalar_depths_modifies_scores() {
1995        let n = 6;
1996        let g = 2;
1997        let mut depth_scores = FdMatrix::zeros(n, g);
1998        // Fill with some values
1999        for i in 0..n {
2000            depth_scores[(i, 0)] = 0.5;
2001            depth_scores[(i, 1)] = 0.3;
2002        }
2003
2004        let cov = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0], n, 1).unwrap();
2005        let class_indices = vec![vec![0, 1, 2], vec![3, 4, 5]];
2006
2007        let original_00 = depth_scores[(0, 0)];
2008        blend_scalar_depths(&mut depth_scores, &cov, &class_indices, n);
2009
2010        // Scores should have been modified (blended with 0.7 / 0.3 weights)
2011        let blended_00 = depth_scores[(0, 0)];
2012        // blended = 0.7 * 0.5 + 0.3 * scalar_depth
2013        assert!(
2014            (blended_00 - original_00).abs() > 1e-10,
2015            "blend should change scores: original={}, blended={}",
2016            original_00,
2017            blended_00
2018        );
2019    }
2020
2021    #[test]
2022    fn test_compute_pairwise_scalar() {
2023        let n = 4;
2024        // 2 covariates
2025        let cov = FdMatrix::from_column_major(vec![0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0], n, 2)
2026            .unwrap();
2027        let dists = compute_pairwise_scalar(&cov);
2028        assert_eq!(dists.len(), n * n);
2029
2030        // Diagonal should be zero
2031        for i in 0..n {
2032            assert!((dists[i * n + i]).abs() < 1e-15);
2033        }
2034        // Symmetry
2035        for i in 0..n {
2036            for j in 0..n {
2037                assert!((dists[i * n + j] - dists[j * n + i]).abs() < 1e-15);
2038            }
2039        }
2040        // d(0,1) = sqrt(1^2 + 0^2) = 1.0
2041        assert!((dists[1] - 1.0).abs() < 1e-10);
2042        // d(0,3) = sqrt(3^2 + 0^2) = 3.0
2043        assert!((dists[3] - 3.0).abs() < 1e-10);
2044    }
2045
2046    #[test]
2047    fn test_fclassif_cv_lda_with_covariates() {
2048        let (data, labels, t, covariates) = generate_two_class_with_covariates(25, 50, 1);
2049        let result = fclassif_cv(&data, &t, &labels, Some(&covariates), "lda", 3, 5, 42).unwrap();
2050
2051        assert_eq!(result.fold_errors.len(), 5);
2052        assert!(result.error_rate >= 0.0 && result.error_rate <= 1.0);
2053    }
2054
2055    #[test]
2056    fn test_fclassif_cv_qda_with_covariates() {
2057        let (data, labels, t, covariates) = generate_two_class_with_covariates(25, 50, 1);
2058        let result = fclassif_cv(&data, &t, &labels, Some(&covariates), "qda", 3, 5, 42).unwrap();
2059
2060        assert_eq!(result.fold_errors.len(), 5);
2061        assert!(result.error_rate >= 0.0 && result.error_rate <= 1.0);
2062    }
2063
2064    #[test]
2065    fn test_fclassif_cv_knn_with_covariates() {
2066        let (data, labels, t, covariates) = generate_two_class_with_covariates(25, 50, 1);
2067        let result = fclassif_cv(&data, &t, &labels, Some(&covariates), "knn", 3, 5, 42).unwrap();
2068
2069        assert_eq!(result.fold_errors.len(), 5);
2070        assert!(result.error_rate >= 0.0 && result.error_rate <= 1.0);
2071    }
2072
2073    #[test]
2074    fn test_fclassif_cv_kernel_with_covariates() {
2075        let (data, labels, t, covariates) = generate_two_class_with_covariates(25, 50, 1);
2076        let result =
2077            fclassif_cv(&data, &t, &labels, Some(&covariates), "kernel", 3, 5, 42).unwrap();
2078
2079        assert_eq!(result.fold_errors.len(), 5);
2080        assert!(result.error_rate >= 0.0 && result.error_rate <= 1.0);
2081    }
2082
2083    #[test]
2084    fn test_fclassif_cv_dd_with_covariates() {
2085        let (data, labels, t, covariates) = generate_two_class_with_covariates(25, 50, 2);
2086        let result = fclassif_cv(&data, &t, &labels, Some(&covariates), "dd", 3, 5, 42).unwrap();
2087
2088        assert_eq!(result.fold_errors.len(), 5);
2089        assert!(result.error_rate >= 0.0 && result.error_rate <= 1.0);
2090    }
2091
2092    #[test]
2093    fn test_fclassif_kernel_invalid_inputs() {
2094        let data = FdMatrix::zeros(0, 0);
2095        assert!(fclassif_kernel(&data, &[], &[], None, 0.0, 0.0).is_none());
2096
2097        let data = FdMatrix::zeros(5, 10);
2098        let t = uniform_grid(10);
2099        let labels = vec![0; 5]; // single class
2100        assert!(fclassif_kernel(&data, &t, &labels, None, 0.0, 0.0).is_none());
2101
2102        // Mismatched argvals length
2103        let labels2 = vec![0, 0, 0, 1, 1];
2104        let wrong_t = vec![0.0, 1.0]; // wrong length
2105        assert!(fclassif_kernel(&data, &wrong_t, &labels2, None, 0.0, 0.0).is_none());
2106    }
2107
2108    #[test]
2109    fn test_fclassif_dd_invalid_inputs() {
2110        let data = FdMatrix::zeros(0, 0);
2111        assert!(fclassif_dd(&data, &[], None).is_none());
2112
2113        let data = FdMatrix::zeros(5, 10);
2114        let labels = vec![0; 5]; // single class
2115        assert!(fclassif_dd(&data, &labels, None).is_none());
2116    }
2117
2118    #[test]
2119    fn test_argmax_class_empty() {
2120        assert_eq!(argmax_class(&[]), 0);
2121        assert_eq!(argmax_class(&[0.1]), 0);
2122        assert_eq!(argmax_class(&[0.1, 0.9, 0.5]), 1);
2123    }
2124
2125    #[test]
2126    fn test_gaussian_kernel_values() {
2127        // h=0 → 0
2128        assert_eq!(gaussian_kernel(1.0, 0.0), 0.0);
2129        // dist=0 → 1
2130        assert!((gaussian_kernel(0.0, 1.0) - 1.0).abs() < 1e-15);
2131        // Normal case
2132        let k = gaussian_kernel(1.0, 1.0);
2133        let expected = (-0.5_f64).exp();
2134        assert!((k - expected).abs() < 1e-10);
2135    }
2136
2137    #[test]
2138    fn test_fclassif_qda_with_covariates() {
2139        let (data, labels, _t, covariates) = generate_two_class_with_covariates(20, 50, 1);
2140        let result = fclassif_qda(&data, &labels, Some(&covariates), 3).unwrap();
2141
2142        assert_eq!(result.predicted.len(), 40);
2143        assert!(
2144            result.accuracy > 0.5,
2145            "QDA+cov accuracy: {}",
2146            result.accuracy
2147        );
2148    }
2149
2150    #[test]
2151    fn test_fclassif_knn_with_covariates() {
2152        let (data, labels, _t, covariates) = generate_two_class_with_covariates(20, 50, 1);
2153        let result = fclassif_knn(&data, &labels, Some(&covariates), 3, 5).unwrap();
2154
2155        assert_eq!(result.predicted.len(), 40);
2156        assert!(
2157            result.accuracy > 0.5,
2158            "k-NN+cov accuracy: {}",
2159            result.accuracy
2160        );
2161    }
2162
2163    #[test]
2164    fn test_fclassif_knn_invalid_k() {
2165        let (data, labels, _t) = generate_two_class_data(10, 50);
2166        // k_nn == 0 → None
2167        assert!(fclassif_knn(&data, &labels, None, 3, 0).is_none());
2168    }
2169
2170    #[test]
2171    fn test_bandwidth_candidates_empty_distances() {
2172        // All distances zero → candidates filtered out
2173        let dists = vec![0.0; 9];
2174        let cands = bandwidth_candidates(&dists, 3);
2175        assert!(cands.is_empty());
2176    }
2177
2178    #[test]
2179    fn test_select_bandwidth_loo_empty_candidates() {
2180        // All distances zero → empty candidates → default bandwidth
2181        let dists = vec![0.0; 9];
2182        let labels = vec![0, 0, 1];
2183        let h = select_bandwidth_loo(&dists, &labels, 2, 3, true);
2184        assert!((h - 1.0).abs() < 1e-10, "default func bandwidth: {}", h);
2185
2186        let h2 = select_bandwidth_loo(&dists, &labels, 2, 3, false);
2187        assert!((h2 - 0.5).abs() < 1e-10, "default scalar bandwidth: {}", h2);
2188    }
2189
2190    #[test]
2191    fn test_fold_split() {
2192        let folds = vec![0, 1, 2, 0, 1, 2];
2193        let (train, test) = fold_split(&folds, 1);
2194        assert_eq!(train, vec![0, 2, 3, 5]);
2195        assert_eq!(test, vec![1, 4]);
2196    }
2197
2198    #[test]
2199    fn test_assign_folds_deterministic() {
2200        let f1 = assign_folds(10, 3, 42);
2201        let f2 = assign_folds(10, 3, 42);
2202        assert_eq!(f1, f2);
2203
2204        // All fold indices in [0, nfold)
2205        for &f in &f1 {
2206            assert!(f < 3);
2207        }
2208    }
2209
2210    #[test]
2211    fn test_project_test_onto_fpca() {
2212        let n_train = 20;
2213        let m = 50;
2214        let ncomp = 3;
2215        let (data, _labels, _t) = generate_two_class_data(n_train / 2, m);
2216
2217        let fpca = fdata_to_pc_1d(&data, ncomp).unwrap();
2218
2219        // Create small "test" matrix
2220        let n_test = 5;
2221        let mut test_col = vec![0.0; n_test * m];
2222        for i in 0..n_test {
2223            for j in 0..m {
2224                test_col[i + j * n_test] = data[(i, j)] + 0.01;
2225            }
2226        }
2227        let test_data = FdMatrix::from_column_major(test_col, n_test, m).unwrap();
2228
2229        let projected = project_test_onto_fpca(&test_data, &fpca);
2230        assert_eq!(projected.nrows(), n_test);
2231        assert_eq!(projected.ncols(), ncomp);
2232    }
2233
2234    #[test]
2235    fn test_fclassif_three_classes() {
2236        let n_per = 15;
2237        let n = 3 * n_per;
2238        let m = 50;
2239        let t = uniform_grid(m);
2240
2241        let mut col_major = vec![0.0; n * m];
2242        // Class 0: sin
2243        for i in 0..n_per {
2244            for (j, &tj) in t.iter().enumerate() {
2245                col_major[i + j * n] =
2246                    (2.0 * PI * tj).sin() + 0.02 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
2247            }
2248        }
2249        // Class 1: cos
2250        for i in 0..n_per {
2251            for (j, &tj) in t.iter().enumerate() {
2252                col_major[(i + n_per) + j * n] =
2253                    (2.0 * PI * tj).cos() + 0.02 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
2254            }
2255        }
2256        // Class 2: constant
2257        for i in 0..n_per {
2258            for (j, _) in t.iter().enumerate() {
2259                col_major[(i + 2 * n_per) + j * n] =
2260                    3.0 + 0.02 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
2261            }
2262        }
2263
2264        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
2265        let labels: Vec<usize> = (0..n)
2266            .map(|i| {
2267                if i < n_per {
2268                    0
2269                } else if i < 2 * n_per {
2270                    1
2271                } else {
2272                    2
2273                }
2274            })
2275            .collect();
2276
2277        let result = fclassif_lda(&data, &labels, None, 3).unwrap();
2278        assert_eq!(result.n_classes, 3);
2279        assert!(
2280            result.accuracy > 0.8,
2281            "Three-class accuracy: {}",
2282            result.accuracy
2283        );
2284    }
2285}