fdars_core/classification/
kernel.rs1use crate::error::FdarError;
4use crate::helpers::{l2_distance, simpsons_weights};
5use crate::iter_maybe_parallel;
6use crate::matrix::FdMatrix;
7
8use super::{compute_accuracy, confusion_matrix, remap_labels, ClassifResult};
9
10pub(super) fn argmax_class(scores: &[f64]) -> usize {
12 scores
13 .iter()
14 .enumerate()
15 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
16 .map_or(0, |(c, _)| c)
17}
18
19pub(super) fn scalar_depth_for_obs(
21 cov: &FdMatrix,
22 i: usize,
23 class_indices: &[usize],
24 p: usize,
25) -> f64 {
26 let nc = class_indices.len() as f64;
27 if nc < 1.0 || p == 0 {
28 return 0.0;
29 }
30 let mut depth = 0.0;
31 for j in 0..p {
32 let val = cov[(i, j)];
33 let rank = class_indices
34 .iter()
35 .filter(|&&k| cov[(k, j)] <= val)
36 .count() as f64;
37 let u = rank / nc.max(1.0);
38 depth += u.min(1.0 - u).min(0.5);
39 }
40 depth / p as f64
41}
42
43pub(super) fn bandwidth_candidates(dists: &[f64], n: usize) -> Vec<f64> {
45 let mut all_dists: Vec<f64> = Vec::new();
46 for i in 0..n {
47 for j in (i + 1)..n {
48 all_dists.push(dists[i * n + j]);
49 }
50 }
51 crate::helpers::sort_nan_safe(&mut all_dists);
52
53 (1..=20)
54 .map(|p| {
55 let idx = (f64::from(p) / 20.0 * (all_dists.len() - 1) as f64) as usize;
56 all_dists[idx.min(all_dists.len() - 1)]
57 })
58 .filter(|&h| h > 1e-15)
59 .collect()
60}
61
62fn loo_accuracy_for_bandwidth(dists: &[f64], labels: &[usize], g: usize, n: usize, h: f64) -> f64 {
64 #[cfg(feature = "parallel")]
65 use rayon::iter::ParallelIterator;
66
67 let correct = iter_maybe_parallel!(0..n)
68 .filter(|&i| {
69 let mut votes = vec![0.0; g];
70 for j in 0..n {
71 if j != i {
72 votes[labels[j]] += gaussian_kernel(dists[i * n + j], h);
73 }
74 }
75 argmax_class(&votes) == labels[i]
76 })
77 .count();
78 correct as f64 / n as f64
79}
80
81pub(super) fn gaussian_kernel(dist: f64, h: f64) -> f64 {
83 if h < 1e-15 {
84 return 0.0;
85 }
86 (-dist * dist / (2.0 * h * h)).exp()
87}
88
89#[must_use = "expensive computation whose result should not be discarded"]
107pub fn fclassif_kernel(
108 data: &FdMatrix,
109 y: &[usize],
110 argvals: &[f64],
111 scalar_covariates: Option<&FdMatrix>,
112 h_func: f64,
113 h_scalar: f64,
114) -> Result<ClassifResult, FdarError> {
115 let n = data.nrows();
116 let m = data.ncols();
117 if n == 0 || y.len() != n || argvals.len() != m {
118 return Err(FdarError::InvalidDimension {
119 parameter: "data/y/argvals",
120 expected: "n > 0, y.len() == n, argvals.len() == m".to_string(),
121 actual: format!(
122 "n={}, y.len()={}, m={}, argvals.len()={}",
123 n,
124 y.len(),
125 m,
126 argvals.len()
127 ),
128 });
129 }
130
131 let (labels, g) = remap_labels(y);
132 if g < 2 {
133 return Err(FdarError::InvalidParameter {
134 parameter: "y",
135 message: format!("need at least 2 classes, got {g}"),
136 });
137 }
138
139 let weights = simpsons_weights(argvals);
140
141 let func_dists = compute_pairwise_l2(data, &weights);
143
144 let scalar_dists = scalar_covariates.map(compute_pairwise_scalar);
146
147 let h_f = if h_func > 0.0 {
149 h_func
150 } else {
151 select_bandwidth_loo(&func_dists, &labels, g, n, true)
152 };
153 let h_s = match &scalar_dists {
154 Some(sd) if h_scalar <= 0.0 => select_bandwidth_loo(sd, &labels, g, n, false),
155 _ => h_scalar,
156 };
157
158 let predicted = kernel_classify_loo(
159 &func_dists,
160 scalar_dists.as_deref(),
161 &labels,
162 g,
163 n,
164 h_f,
165 h_s,
166 );
167 let accuracy = compute_accuracy(&labels, &predicted);
168 let confusion = confusion_matrix(&labels, &predicted, g);
169
170 Ok(ClassifResult {
171 predicted,
172 probabilities: None,
173 accuracy,
174 confusion,
175 n_classes: g,
176 ncomp: 0,
177 })
178}
179
180#[must_use = "expensive computation whose result should not be discarded"]
195pub fn kernel_classify_from_distances(
196 func_dists: &[f64],
197 y: &[usize],
198 scalar_covariates: Option<&FdMatrix>,
199 h_func: f64,
200 h_scalar: f64,
201) -> Result<ClassifResult, FdarError> {
202 let n = y.len();
203 if n == 0 {
204 return Err(FdarError::InvalidDimension {
205 parameter: "y",
206 expected: "n > 0".to_string(),
207 actual: "0".to_string(),
208 });
209 }
210 if func_dists.len() != n * n {
211 return Err(FdarError::InvalidDimension {
212 parameter: "func_dists",
213 expected: format!("{} elements (n*n)", n * n),
214 actual: format!("{} elements", func_dists.len()),
215 });
216 }
217
218 let (labels, g) = remap_labels(y);
219 if g < 2 {
220 return Err(FdarError::InvalidParameter {
221 parameter: "y",
222 message: format!("need at least 2 classes, got {g}"),
223 });
224 }
225
226 let scalar_dists = scalar_covariates.map(compute_pairwise_scalar);
227
228 let h_f = if h_func > 0.0 {
229 h_func
230 } else {
231 select_bandwidth_loo(func_dists, &labels, g, n, true)
232 };
233 let h_s = match &scalar_dists {
234 Some(sd) if h_scalar <= 0.0 => select_bandwidth_loo(sd, &labels, g, n, false),
235 _ => h_scalar,
236 };
237
238 let predicted =
239 kernel_classify_loo(func_dists, scalar_dists.as_deref(), &labels, g, n, h_f, h_s);
240 let accuracy = compute_accuracy(&labels, &predicted);
241 let confusion = confusion_matrix(&labels, &predicted, g);
242
243 Ok(ClassifResult {
244 predicted,
245 probabilities: None,
246 accuracy,
247 confusion,
248 n_classes: g,
249 ncomp: 0,
250 })
251}
252
253fn compute_pairwise_l2(data: &FdMatrix, weights: &[f64]) -> Vec<f64> {
255 #[cfg(feature = "parallel")]
256 use rayon::iter::ParallelIterator;
257
258 let n = data.nrows();
259 let pairs: Vec<(usize, usize)> = (0..n)
261 .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
262 .collect();
263 let pair_dists: Vec<(usize, usize, f64)> = iter_maybe_parallel!(pairs)
264 .map(|(i, j)| {
265 let ri = data.row(i);
266 let rj = data.row(j);
267 (i, j, l2_distance(&ri, &rj, weights))
268 })
269 .collect();
270 let mut dists = vec![0.0; n * n];
271 for (i, j, d) in pair_dists {
272 dists[i * n + j] = d;
273 dists[j * n + i] = d;
274 }
275 dists
276}
277
278pub(super) fn compute_pairwise_scalar(scalar_covariates: &FdMatrix) -> Vec<f64> {
280 let n = scalar_covariates.nrows();
281 let p = scalar_covariates.ncols();
282 let mut dists = vec![0.0; n * n];
283 for i in 0..n {
284 for j in (i + 1)..n {
285 let mut d_sq = 0.0;
286 for k in 0..p {
287 d_sq += (scalar_covariates[(i, k)] - scalar_covariates[(j, k)]).powi(2);
288 }
289 let d = d_sq.sqrt();
290 dists[i * n + j] = d;
291 dists[j * n + i] = d;
292 }
293 }
294 dists
295}
296
297pub(super) fn select_bandwidth_loo(
299 dists: &[f64],
300 labels: &[usize],
301 g: usize,
302 n: usize,
303 is_func: bool,
304) -> f64 {
305 let candidates = bandwidth_candidates(dists, n);
306 if candidates.is_empty() {
307 return if is_func { 1.0 } else { 0.5 };
308 }
309
310 let mut best_h = candidates[0];
311 let mut best_acc = 0.0;
312 for &h in &candidates {
313 let acc = loo_accuracy_for_bandwidth(dists, labels, g, n, h);
314 if acc > best_acc {
315 best_acc = acc;
316 best_h = h;
317 }
318 }
319 best_h
320}
321
322fn kernel_classify_loo(
324 func_dists: &[f64],
325 scalar_dists: Option<&[f64]>,
326 labels: &[usize],
327 g: usize,
328 n: usize,
329 h_func: f64,
330 h_scalar: f64,
331) -> Vec<usize> {
332 (0..n)
333 .map(|i| {
334 let mut votes = vec![0.0; g];
335 for j in 0..n {
336 if j == i {
337 continue;
338 }
339 let kf = gaussian_kernel(func_dists[i * n + j], h_func);
340 let ks = match scalar_dists {
341 Some(sd) if h_scalar > 1e-15 => gaussian_kernel(sd[i * n + j], h_scalar),
342 _ => 1.0,
343 };
344 votes[labels[j]] += kf * ks;
345 }
346 argmax_class(&votes)
347 })
348 .collect()
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn kernel_from_distances_smoke() {
357 let n = 6;
358 let mut dists = vec![0.0; n * n];
359 for i in 0..3 {
361 for j in 0..3 {
362 if i != j {
363 dists[i * n + j] = 0.1;
364 }
365 }
366 }
367 for i in 3..6 {
369 for j in 3..6 {
370 if i != j {
371 dists[i * n + j] = 0.1;
372 }
373 }
374 }
375 for i in 0..3 {
377 for j in 3..6 {
378 dists[i * n + j] = 5.0;
379 dists[j * n + i] = 5.0;
380 }
381 }
382
383 let y = vec![0, 0, 0, 1, 1, 1];
384 let result = kernel_classify_from_distances(&dists, &y, None, 0.5, 0.0).unwrap();
385 assert_eq!(result.predicted, vec![0, 0, 0, 1, 1, 1]);
386 }
387}