Skip to main content

fdars_core/classification/
kernel.rs

1//! Nonparametric kernel classifier with mixed predictors.
2
3use crate::error::FdarError;
4use crate::helpers::{l2_distance, simpsons_weights};
5use crate::iter_maybe_parallel;
6use crate::matrix::FdMatrix;
7
8use super::{compute_accuracy, confusion_matrix, remap_labels, ClassifResult};
9
10/// Find class with maximum score.
11pub(super) fn argmax_class(scores: &[f64]) -> usize {
12    scores
13        .iter()
14        .enumerate()
15        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
16        .map_or(0, |(c, _)| c)
17}
18
19/// Compute marginal rank-based scalar depth of observation i w.r.t. class c.
20pub(super) fn scalar_depth_for_obs(
21    cov: &FdMatrix,
22    i: usize,
23    class_indices: &[usize],
24    p: usize,
25) -> f64 {
26    let nc = class_indices.len() as f64;
27    if nc < 1.0 || p == 0 {
28        return 0.0;
29    }
30    let mut depth = 0.0;
31    for j in 0..p {
32        let val = cov[(i, j)];
33        let rank = class_indices
34            .iter()
35            .filter(|&&k| cov[(k, j)] <= val)
36            .count() as f64;
37        let u = rank / nc.max(1.0);
38        depth += u.min(1.0 - u).min(0.5);
39    }
40    depth / p as f64
41}
42
43/// Generate bandwidth candidates from distance percentiles.
44pub(super) fn bandwidth_candidates(dists: &[f64], n: usize) -> Vec<f64> {
45    let mut all_dists: Vec<f64> = Vec::new();
46    for i in 0..n {
47        for j in (i + 1)..n {
48            all_dists.push(dists[i * n + j]);
49        }
50    }
51    crate::helpers::sort_nan_safe(&mut all_dists);
52
53    (1..=20)
54        .map(|p| {
55            let idx = (f64::from(p) / 20.0 * (all_dists.len() - 1) as f64) as usize;
56            all_dists[idx.min(all_dists.len() - 1)]
57        })
58        .filter(|&h| h > 1e-15)
59        .collect()
60}
61
62/// LOO classification accuracy for a single bandwidth.
63fn loo_accuracy_for_bandwidth(dists: &[f64], labels: &[usize], g: usize, n: usize, h: f64) -> f64 {
64    #[cfg(feature = "parallel")]
65    use rayon::iter::ParallelIterator;
66
67    let correct = iter_maybe_parallel!(0..n)
68        .filter(|&i| {
69            let mut votes = vec![0.0; g];
70            for j in 0..n {
71                if j != i {
72                    votes[labels[j]] += gaussian_kernel(dists[i * n + j], h);
73                }
74            }
75            argmax_class(&votes) == labels[i]
76        })
77        .count();
78    correct as f64 / n as f64
79}
80
81/// Gaussian kernel: exp(-d²/(2h²)).
82pub(super) fn gaussian_kernel(dist: f64, h: f64) -> f64 {
83    if h < 1e-15 {
84        return 0.0;
85    }
86    (-dist * dist / (2.0 * h * h)).exp()
87}
88
89/// Nonparametric kernel classifier for functional data with optional scalar covariates.
90///
91/// Uses product kernel: K_func × K_scalar. Bandwidth selected by LOO-CV.
92///
93/// # Arguments
94/// * `data` — Functional data (n × m)
95/// * `y` — Class labels
96/// * `argvals` — Evaluation points
97/// * `scalar_covariates` — Optional scalar covariates (n × p)
98/// * `h_func` — Functional bandwidth (0 = auto via LOO-CV)
99/// * `h_scalar` — Scalar bandwidth (0 = auto)
100///
101/// # Errors
102///
103/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, `y.len() != n`,
104/// or `argvals.len() != m`.
105/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
106#[must_use = "expensive computation whose result should not be discarded"]
107pub fn fclassif_kernel(
108    data: &FdMatrix,
109    y: &[usize],
110    argvals: &[f64],
111    scalar_covariates: Option<&FdMatrix>,
112    h_func: f64,
113    h_scalar: f64,
114) -> Result<ClassifResult, FdarError> {
115    let n = data.nrows();
116    let m = data.ncols();
117    if n == 0 || y.len() != n || argvals.len() != m {
118        return Err(FdarError::InvalidDimension {
119            parameter: "data/y/argvals",
120            expected: "n > 0, y.len() == n, argvals.len() == m".to_string(),
121            actual: format!(
122                "n={}, y.len()={}, m={}, argvals.len()={}",
123                n,
124                y.len(),
125                m,
126                argvals.len()
127            ),
128        });
129    }
130
131    let (labels, g) = remap_labels(y);
132    if g < 2 {
133        return Err(FdarError::InvalidParameter {
134            parameter: "y",
135            message: format!("need at least 2 classes, got {g}"),
136        });
137    }
138
139    let weights = simpsons_weights(argvals);
140
141    // Compute pairwise functional distances
142    let func_dists = compute_pairwise_l2(data, &weights);
143
144    // Compute pairwise scalar distances if covariates exist
145    let scalar_dists = scalar_covariates.map(compute_pairwise_scalar);
146
147    // Select bandwidths via LOO if needed
148    let h_f = if h_func > 0.0 {
149        h_func
150    } else {
151        select_bandwidth_loo(&func_dists, &labels, g, n, true)
152    };
153    let h_s = match &scalar_dists {
154        Some(sd) if h_scalar <= 0.0 => select_bandwidth_loo(sd, &labels, g, n, false),
155        _ => h_scalar,
156    };
157
158    let predicted = kernel_classify_loo(
159        &func_dists,
160        scalar_dists.as_deref(),
161        &labels,
162        g,
163        n,
164        h_f,
165        h_s,
166    );
167    let accuracy = compute_accuracy(&labels, &predicted);
168    let confusion = confusion_matrix(&labels, &predicted, g);
169
170    Ok(ClassifResult {
171        predicted,
172        probabilities: None,
173        accuracy,
174        confusion,
175        n_classes: g,
176        ncomp: 0,
177    })
178}
179
180/// Kernel classifier from a precomputed functional distance matrix.
181///
182/// Works with **any** distance matrix (elastic, DTW, Lp, or custom).
183/// Bandwidth is selected via LOO-CV if `h_func <= 0`.
184///
185/// # Arguments
186/// * `func_dists` — Flat n × n functional distance matrix (row-major)
187/// * `y` — Class labels (length n, 0-indexed)
188/// * `scalar_covariates` — Optional scalar covariates (n × p), uses Euclidean distance internally
189/// * `h_func` — Functional bandwidth (0 = auto via LOO-CV)
190/// * `h_scalar` — Scalar bandwidth (0 = auto)
191///
192/// # Errors
193/// Returns errors if `y.len() != n` or fewer than 2 classes.
194#[must_use = "expensive computation whose result should not be discarded"]
195pub fn kernel_classify_from_distances(
196    func_dists: &[f64],
197    y: &[usize],
198    scalar_covariates: Option<&FdMatrix>,
199    h_func: f64,
200    h_scalar: f64,
201) -> Result<ClassifResult, FdarError> {
202    let n = y.len();
203    if n == 0 {
204        return Err(FdarError::InvalidDimension {
205            parameter: "y",
206            expected: "n > 0".to_string(),
207            actual: "0".to_string(),
208        });
209    }
210    if func_dists.len() != n * n {
211        return Err(FdarError::InvalidDimension {
212            parameter: "func_dists",
213            expected: format!("{} elements (n*n)", n * n),
214            actual: format!("{} elements", func_dists.len()),
215        });
216    }
217
218    let (labels, g) = remap_labels(y);
219    if g < 2 {
220        return Err(FdarError::InvalidParameter {
221            parameter: "y",
222            message: format!("need at least 2 classes, got {g}"),
223        });
224    }
225
226    let scalar_dists = scalar_covariates.map(compute_pairwise_scalar);
227
228    let h_f = if h_func > 0.0 {
229        h_func
230    } else {
231        select_bandwidth_loo(func_dists, &labels, g, n, true)
232    };
233    let h_s = match &scalar_dists {
234        Some(sd) if h_scalar <= 0.0 => select_bandwidth_loo(sd, &labels, g, n, false),
235        _ => h_scalar,
236    };
237
238    let predicted =
239        kernel_classify_loo(func_dists, scalar_dists.as_deref(), &labels, g, n, h_f, h_s);
240    let accuracy = compute_accuracy(&labels, &predicted);
241    let confusion = confusion_matrix(&labels, &predicted, g);
242
243    Ok(ClassifResult {
244        predicted,
245        probabilities: None,
246        accuracy,
247        confusion,
248        n_classes: g,
249        ncomp: 0,
250    })
251}
252
253/// Compute pairwise L2 distances between curves.
254fn compute_pairwise_l2(data: &FdMatrix, weights: &[f64]) -> Vec<f64> {
255    #[cfg(feature = "parallel")]
256    use rayon::iter::ParallelIterator;
257
258    let n = data.nrows();
259    // Build upper-triangle pair list, compute distances in parallel, then scatter.
260    let pairs: Vec<(usize, usize)> = (0..n)
261        .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
262        .collect();
263    let pair_dists: Vec<(usize, usize, f64)> = iter_maybe_parallel!(pairs)
264        .map(|(i, j)| {
265            let ri = data.row(i);
266            let rj = data.row(j);
267            (i, j, l2_distance(&ri, &rj, weights))
268        })
269        .collect();
270    let mut dists = vec![0.0; n * n];
271    for (i, j, d) in pair_dists {
272        dists[i * n + j] = d;
273        dists[j * n + i] = d;
274    }
275    dists
276}
277
278/// Compute pairwise Euclidean distances between scalar covariate vectors.
279pub(super) fn compute_pairwise_scalar(scalar_covariates: &FdMatrix) -> Vec<f64> {
280    let n = scalar_covariates.nrows();
281    let p = scalar_covariates.ncols();
282    let mut dists = vec![0.0; n * n];
283    for i in 0..n {
284        for j in (i + 1)..n {
285            let mut d_sq = 0.0;
286            for k in 0..p {
287                d_sq += (scalar_covariates[(i, k)] - scalar_covariates[(j, k)]).powi(2);
288            }
289            let d = d_sq.sqrt();
290            dists[i * n + j] = d;
291            dists[j * n + i] = d;
292        }
293    }
294    dists
295}
296
297/// Select bandwidth by LOO classification accuracy.
298pub(super) fn select_bandwidth_loo(
299    dists: &[f64],
300    labels: &[usize],
301    g: usize,
302    n: usize,
303    is_func: bool,
304) -> f64 {
305    let candidates = bandwidth_candidates(dists, n);
306    if candidates.is_empty() {
307        return if is_func { 1.0 } else { 0.5 };
308    }
309
310    let mut best_h = candidates[0];
311    let mut best_acc = 0.0;
312    for &h in &candidates {
313        let acc = loo_accuracy_for_bandwidth(dists, labels, g, n, h);
314        if acc > best_acc {
315            best_acc = acc;
316            best_h = h;
317        }
318    }
319    best_h
320}
321
322/// LOO kernel classification with product kernel.
323fn kernel_classify_loo(
324    func_dists: &[f64],
325    scalar_dists: Option<&[f64]>,
326    labels: &[usize],
327    g: usize,
328    n: usize,
329    h_func: f64,
330    h_scalar: f64,
331) -> Vec<usize> {
332    (0..n)
333        .map(|i| {
334            let mut votes = vec![0.0; g];
335            for j in 0..n {
336                if j == i {
337                    continue;
338                }
339                let kf = gaussian_kernel(func_dists[i * n + j], h_func);
340                let ks = match scalar_dists {
341                    Some(sd) if h_scalar > 1e-15 => gaussian_kernel(sd[i * n + j], h_scalar),
342                    _ => 1.0,
343                };
344                votes[labels[j]] += kf * ks;
345            }
346            argmax_class(&votes)
347        })
348        .collect()
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn kernel_from_distances_smoke() {
357        let n = 6;
358        let mut dists = vec![0.0; n * n];
359        // Within class 0: small distances
360        for i in 0..3 {
361            for j in 0..3 {
362                if i != j {
363                    dists[i * n + j] = 0.1;
364                }
365            }
366        }
367        // Within class 1: small distances
368        for i in 3..6 {
369            for j in 3..6 {
370                if i != j {
371                    dists[i * n + j] = 0.1;
372                }
373            }
374        }
375        // Between classes: large distances
376        for i in 0..3 {
377            for j in 3..6 {
378                dists[i * n + j] = 5.0;
379                dists[j * n + i] = 5.0;
380            }
381        }
382
383        let y = vec![0, 0, 0, 1, 1, 1];
384        let result = kernel_classify_from_distances(&dists, &y, None, 0.5, 0.0).unwrap();
385        assert_eq!(result.predicted, vec![0, 0, 0, 1, 1, 1]);
386    }
387}