fdars_core/classification/
kernel.rs1use crate::error::FdarError;
4use crate::helpers::{l2_distance, simpsons_weights};
5use crate::iter_maybe_parallel;
6use crate::matrix::FdMatrix;
7
8use super::{compute_accuracy, confusion_matrix, remap_labels, ClassifResult};
9
10pub(super) fn argmax_class(scores: &[f64]) -> usize {
12 scores
13 .iter()
14 .enumerate()
15 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
16 .map_or(0, |(c, _)| c)
17}
18
19pub(super) fn scalar_depth_for_obs(
21 cov: &FdMatrix,
22 i: usize,
23 class_indices: &[usize],
24 p: usize,
25) -> f64 {
26 let nc = class_indices.len() as f64;
27 if nc < 1.0 || p == 0 {
28 return 0.0;
29 }
30 let mut depth = 0.0;
31 for j in 0..p {
32 let val = cov[(i, j)];
33 let rank = class_indices
34 .iter()
35 .filter(|&&k| cov[(k, j)] <= val)
36 .count() as f64;
37 let u = rank / nc.max(1.0);
38 depth += u.min(1.0 - u).min(0.5);
39 }
40 depth / p as f64
41}
42
43pub(super) fn bandwidth_candidates(dists: &[f64], n: usize) -> Vec<f64> {
45 let mut all_dists: Vec<f64> = Vec::new();
46 for i in 0..n {
47 for j in (i + 1)..n {
48 all_dists.push(dists[i * n + j]);
49 }
50 }
51 crate::helpers::sort_nan_safe(&mut all_dists);
52
53 (1..=20)
54 .map(|p| {
55 let idx = (f64::from(p) / 20.0 * (all_dists.len() - 1) as f64) as usize;
56 all_dists[idx.min(all_dists.len() - 1)]
57 })
58 .filter(|&h| h > 1e-15)
59 .collect()
60}
61
62fn loo_accuracy_for_bandwidth(dists: &[f64], labels: &[usize], g: usize, n: usize, h: f64) -> f64 {
64 #[cfg(feature = "parallel")]
65 use rayon::iter::ParallelIterator;
66
67 let correct = iter_maybe_parallel!(0..n)
68 .filter(|&i| {
69 let mut votes = vec![0.0; g];
70 for j in 0..n {
71 if j != i {
72 votes[labels[j]] += gaussian_kernel(dists[i * n + j], h);
73 }
74 }
75 argmax_class(&votes) == labels[i]
76 })
77 .count();
78 correct as f64 / n as f64
79}
80
81pub(super) fn gaussian_kernel(dist: f64, h: f64) -> f64 {
83 if h < 1e-15 {
84 return 0.0;
85 }
86 (-dist * dist / (2.0 * h * h)).exp()
87}
88
89#[must_use = "expensive computation whose result should not be discarded"]
107pub fn fclassif_kernel(
108 data: &FdMatrix,
109 y: &[usize],
110 argvals: &[f64],
111 scalar_covariates: Option<&FdMatrix>,
112 h_func: f64,
113 h_scalar: f64,
114) -> Result<ClassifResult, FdarError> {
115 let n = data.nrows();
116 let m = data.ncols();
117 if n == 0 || y.len() != n || argvals.len() != m {
118 return Err(FdarError::InvalidDimension {
119 parameter: "data/y/argvals",
120 expected: "n > 0, y.len() == n, argvals.len() == m".to_string(),
121 actual: format!(
122 "n={}, y.len()={}, m={}, argvals.len()={}",
123 n,
124 y.len(),
125 m,
126 argvals.len()
127 ),
128 });
129 }
130
131 let (labels, g) = remap_labels(y);
132 if g < 2 {
133 return Err(FdarError::InvalidParameter {
134 parameter: "y",
135 message: format!("need at least 2 classes, got {g}"),
136 });
137 }
138
139 let weights = simpsons_weights(argvals);
140
141 let func_dists = compute_pairwise_l2(data, &weights);
143
144 let scalar_dists = scalar_covariates.map(compute_pairwise_scalar);
146
147 let h_f = if h_func > 0.0 {
149 h_func
150 } else {
151 select_bandwidth_loo(&func_dists, &labels, g, n, true)
152 };
153 let h_s = match &scalar_dists {
154 Some(sd) if h_scalar <= 0.0 => select_bandwidth_loo(sd, &labels, g, n, false),
155 _ => h_scalar,
156 };
157
158 let predicted = kernel_classify_loo(
159 &func_dists,
160 scalar_dists.as_deref(),
161 &labels,
162 g,
163 n,
164 h_f,
165 h_s,
166 );
167 let accuracy = compute_accuracy(&labels, &predicted);
168 let confusion = confusion_matrix(&labels, &predicted, g);
169
170 Ok(ClassifResult {
171 predicted,
172 probabilities: None,
173 accuracy,
174 confusion,
175 n_classes: g,
176 ncomp: 0,
177 })
178}
179
180fn compute_pairwise_l2(data: &FdMatrix, weights: &[f64]) -> Vec<f64> {
182 #[cfg(feature = "parallel")]
183 use rayon::iter::ParallelIterator;
184
185 let n = data.nrows();
186 let pairs: Vec<(usize, usize)> = (0..n)
188 .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
189 .collect();
190 let pair_dists: Vec<(usize, usize, f64)> = iter_maybe_parallel!(pairs)
191 .map(|(i, j)| {
192 let ri = data.row(i);
193 let rj = data.row(j);
194 (i, j, l2_distance(&ri, &rj, weights))
195 })
196 .collect();
197 let mut dists = vec![0.0; n * n];
198 for (i, j, d) in pair_dists {
199 dists[i * n + j] = d;
200 dists[j * n + i] = d;
201 }
202 dists
203}
204
205pub(super) fn compute_pairwise_scalar(scalar_covariates: &FdMatrix) -> Vec<f64> {
207 let n = scalar_covariates.nrows();
208 let p = scalar_covariates.ncols();
209 let mut dists = vec![0.0; n * n];
210 for i in 0..n {
211 for j in (i + 1)..n {
212 let mut d_sq = 0.0;
213 for k in 0..p {
214 d_sq += (scalar_covariates[(i, k)] - scalar_covariates[(j, k)]).powi(2);
215 }
216 let d = d_sq.sqrt();
217 dists[i * n + j] = d;
218 dists[j * n + i] = d;
219 }
220 }
221 dists
222}
223
224pub(super) fn select_bandwidth_loo(
226 dists: &[f64],
227 labels: &[usize],
228 g: usize,
229 n: usize,
230 is_func: bool,
231) -> f64 {
232 let candidates = bandwidth_candidates(dists, n);
233 if candidates.is_empty() {
234 return if is_func { 1.0 } else { 0.5 };
235 }
236
237 let mut best_h = candidates[0];
238 let mut best_acc = 0.0;
239 for &h in &candidates {
240 let acc = loo_accuracy_for_bandwidth(dists, labels, g, n, h);
241 if acc > best_acc {
242 best_acc = acc;
243 best_h = h;
244 }
245 }
246 best_h
247}
248
249fn kernel_classify_loo(
251 func_dists: &[f64],
252 scalar_dists: Option<&[f64]>,
253 labels: &[usize],
254 g: usize,
255 n: usize,
256 h_func: f64,
257 h_scalar: f64,
258) -> Vec<usize> {
259 (0..n)
260 .map(|i| {
261 let mut votes = vec![0.0; g];
262 for j in 0..n {
263 if j == i {
264 continue;
265 }
266 let kf = gaussian_kernel(func_dists[i * n + j], h_func);
267 let ks = match scalar_dists {
268 Some(sd) if h_scalar > 1e-15 => gaussian_kernel(sd[i * n + j], h_scalar),
269 _ => 1.0,
270 };
271 votes[labels[j]] += kf * ks;
272 }
273 argmax_class(&votes)
274 })
275 .collect()
276}