1use crate::helpers::simpsons_weights;
4use crate::iter_maybe_parallel;
5#[cfg(feature = "parallel")]
6use rayon::iter::ParallelIterator;
7use std::f64::consts::PI;
8
9pub fn integrate_simpson(values: &[f64], argvals: &[f64]) -> f64 {
15 if values.len() != argvals.len() || values.is_empty() {
16 return 0.0;
17 }
18
19 let weights = simpsons_weights(argvals);
20 values
21 .iter()
22 .zip(weights.iter())
23 .map(|(&v, &w)| v * w)
24 .sum()
25}
26
27pub fn inner_product(curve1: &[f64], curve2: &[f64], argvals: &[f64]) -> f64 {
34 if curve1.len() != curve2.len() || curve1.len() != argvals.len() || curve1.is_empty() {
35 return 0.0;
36 }
37
38 let weights = simpsons_weights(argvals);
39 curve1
40 .iter()
41 .zip(curve2.iter())
42 .zip(weights.iter())
43 .map(|((&c1, &c2), &w)| c1 * c2 * w)
44 .sum()
45}
46
47pub fn inner_product_matrix(data: &[f64], n: usize, m: usize, argvals: &[f64]) -> Vec<f64> {
58 if n == 0 || m == 0 || argvals.len() != m || data.len() != n * m {
59 return Vec::new();
60 }
61
62 let weights = simpsons_weights(argvals);
63
64 let upper_triangle: Vec<(usize, usize, f64)> = iter_maybe_parallel!(0..n)
66 .flat_map(|i| {
67 (i..n)
68 .map(|j| {
69 let mut ip = 0.0;
70 for k in 0..m {
71 ip += data[i + k * n] * data[j + k * n] * weights[k];
72 }
73 (i, j, ip)
74 })
75 .collect::<Vec<_>>()
76 })
77 .collect();
78
79 let mut result = vec![0.0; n * n];
81 for (i, j, ip) in upper_triangle {
82 result[i + j * n] = ip;
83 result[j + i * n] = ip;
84 }
85
86 result
87}
88
89fn packed_sym_index(a: usize, b: usize) -> usize {
92 let (hi, lo) = if a >= b { (a, b) } else { (b, a) };
93 hi * (hi - 1) / 2 + lo - 1
94}
95
96fn adot_pair_sum(inprod: &[f64], n: usize, i: usize, j: usize) -> f64 {
98 let ij = packed_sym_index(i, j);
99 let ii = packed_sym_index(i, i);
100 let jj = packed_sym_index(j, j);
101 let mut sumr = 0.0;
102
103 for r in 1..=n {
104 if i == r || j == r {
105 sumr += PI;
106 } else {
107 let rr = packed_sym_index(r, r);
108 let ir = packed_sym_index(i, r);
109 let rj = packed_sym_index(r, j);
110
111 let num = inprod[ij] - inprod[ir] - inprod[rj] + inprod[rr];
112 let aux1 = (inprod[ii] - 2.0 * inprod[ir] + inprod[rr]).sqrt();
113 let aux2 = (inprod[jj] - 2.0 * inprod[rj] + inprod[rr]).sqrt();
114 let den = aux1 * aux2;
115
116 let mut quo = if den.abs() > 1e-10 { num / den } else { 0.0 };
117 quo = quo.clamp(-1.0, 1.0);
118
119 sumr += (PI - quo.acos()).abs();
120 }
121 }
122
123 sumr
124}
125
126pub fn compute_adot(n: usize, inprod: &[f64]) -> Vec<f64> {
127 if n == 0 {
128 return Vec::new();
129 }
130
131 let expected_len = (n * n + n) / 2;
132 if inprod.len() != expected_len {
133 return Vec::new();
134 }
135
136 let out_len = (n * n - n + 2) / 2;
137 let mut adot_vec = vec![0.0; out_len];
138
139 adot_vec[0] = PI * (n + 1) as f64;
140
141 let pairs: Vec<(usize, usize)> = (2..=n).flat_map(|i| (1..i).map(move |j| (i, j))).collect();
143
144 let results: Vec<(usize, f64)> = iter_maybe_parallel!(pairs)
146 .map(|(i, j)| {
147 let sumr = adot_pair_sum(inprod, n, i, j);
148 let idx = 1 + ((i - 1) * (i - 2) / 2) + j - 1;
149 (idx, sumr)
150 })
151 .collect();
152
153 for (idx, val) in results {
155 if idx < adot_vec.len() {
156 adot_vec[idx] = val;
157 }
158 }
159
160 adot_vec
161}
162
163pub fn pcvm_statistic(adot_vec: &[f64], residuals: &[f64]) -> f64 {
165 let n = residuals.len();
166
167 if n == 0 || adot_vec.is_empty() {
168 return 0.0;
169 }
170
171 let mut sums = 0.0;
172 for i in 2..=n {
173 for j in 1..i {
174 let idx = 1 + ((i - 1) * (i - 2) / 2) + j - 1;
175 if idx < adot_vec.len() {
176 sums += residuals[i - 1] * adot_vec[idx] * residuals[j - 1];
177 }
178 }
179 }
180
181 let diag_sum: f64 = residuals.iter().map(|r| r * r).sum();
182 adot_vec[0] * diag_sum + 2.0 * sums
183}
184
185pub struct RpStatResult {
187 pub cvm: Vec<f64>,
189 pub ks: Vec<f64>,
191}
192
193pub fn rp_stat(proj_x_ord: &[i32], residuals: &[f64], n_proj: usize) -> RpStatResult {
195 let n = residuals.len();
196
197 if n == 0 || n_proj == 0 || proj_x_ord.len() != n * n_proj {
198 return RpStatResult {
199 cvm: Vec::new(),
200 ks: Vec::new(),
201 };
202 }
203
204 let stats: Vec<(f64, f64)> = iter_maybe_parallel!(0..n_proj)
206 .map(|p| {
207 let mut y = vec![0.0; n];
208 let mut cumsum = 0.0;
209
210 for i in 0..n {
211 let idx = proj_x_ord[p * n + i] as usize;
212 if idx > 0 && idx <= n {
213 cumsum += residuals[idx - 1];
214 }
215 y[i] = cumsum;
216 }
217
218 let sum_y_sq: f64 = y.iter().map(|yi| yi * yi).sum();
219 let cvm = sum_y_sq / (n * n) as f64;
220
221 let max_abs_y = y.iter().map(|yi| yi.abs()).fold(0.0, f64::max);
222 let ks = max_abs_y / (n as f64).sqrt();
223
224 (cvm, ks)
225 })
226 .collect();
227
228 let cvm_stats: Vec<f64> = stats.iter().map(|(cvm, _)| *cvm).collect();
229 let ks_stats: Vec<f64> = stats.iter().map(|(_, ks)| *ks).collect();
230
231 RpStatResult {
232 cvm: cvm_stats,
233 ks: ks_stats,
234 }
235}
236
237pub fn knn_predict(
239 distance_matrix: &[f64],
240 y: &[f64],
241 n_train: usize,
242 n_test: usize,
243 k: usize,
244) -> Vec<f64> {
245 if n_train == 0 || n_test == 0 || k == 0 || y.len() != n_train {
246 return vec![0.0; n_test];
247 }
248
249 let k = k.min(n_train);
250
251 iter_maybe_parallel!(0..n_test)
252 .map(|i| {
253 let mut distances: Vec<(usize, f64)> = (0..n_train)
255 .map(|j| (j, distance_matrix[i + j * n_test]))
256 .collect();
257
258 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
260
261 let sum: f64 = distances.iter().take(k).map(|(j, _)| y[*j]).sum();
263 sum / k as f64
264 })
265 .collect()
266}
267
268pub fn knn_loocv(distance_matrix: &[f64], y: &[f64], n: usize, k: usize) -> f64 {
270 if n == 0 || k == 0 || y.len() != n || distance_matrix.len() != n * n {
271 return f64::INFINITY;
272 }
273
274 let k = k.min(n - 1);
275
276 let errors: Vec<f64> = iter_maybe_parallel!(0..n)
277 .map(|i| {
278 let mut distances: Vec<(usize, f64)> = (0..n)
280 .filter(|&j| j != i)
281 .map(|j| (j, distance_matrix[i + j * n]))
282 .collect();
283
284 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
286
287 let pred: f64 = distances.iter().take(k).map(|(j, _)| y[*j]).sum::<f64>() / k as f64;
289
290 (y[i] - pred).powi(2)
292 })
293 .collect();
294
295 errors.iter().sum::<f64>() / n as f64
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 fn uniform_grid(n: usize) -> Vec<f64> {
303 (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
304 }
305
306 #[test]
307 fn test_integrate_simpson_constant() {
308 let argvals = uniform_grid(11);
309 let values = vec![1.0; 11];
310 let result = integrate_simpson(&values, &argvals);
311 assert!((result - 1.0).abs() < 1e-10);
312 }
313
314 #[test]
315 fn test_inner_product_orthogonal() {
316 let argvals = uniform_grid(101);
317 let curve1: Vec<f64> = argvals.iter().map(|&t| (2.0 * PI * t).sin()).collect();
318 let curve2: Vec<f64> = argvals.iter().map(|&t| (2.0 * PI * t).cos()).collect();
319 let result = inner_product(&curve1, &curve2, &argvals);
320 assert!(result.abs() < 0.01);
321 }
322
323 #[test]
324 fn test_inner_product_matrix_symmetry() {
325 let n = 5;
326 let m = 10;
327 let argvals = uniform_grid(m);
328 let data: Vec<f64> = (0..n * m).map(|i| (i as f64).sin()).collect();
329
330 let matrix = inner_product_matrix(&data, n, m, &argvals);
331
332 for i in 0..n {
333 for j in 0..n {
334 let diff = (matrix[i + j * n] - matrix[j + i * n]).abs();
335 assert!(diff < 1e-10, "Matrix should be symmetric");
336 }
337 }
338 }
339
340 #[test]
341 fn test_knn_predict() {
342 let n_train = 10;
343 let n_test = 3;
344 let k = 3;
345
346 let mut distance_matrix = vec![0.0; n_test * n_train];
347 for i in 0..n_test {
348 for j in 0..n_train {
349 distance_matrix[i + j * n_test] = ((i as f64) - (j as f64)).abs();
350 }
351 }
352
353 let y: Vec<f64> = (0..n_train).map(|i| i as f64).collect();
354 let predictions = knn_predict(&distance_matrix, &y, n_train, n_test, k);
355
356 assert_eq!(predictions.len(), n_test);
357 }
358
359 #[test]
362 fn test_compute_adot_basic() {
363 let n = 4;
364 let mut inprod = vec![0.0; (n * (n + 1)) / 2];
367 for i in 1..=n {
370 let idx = i * (i - 1) / 2 + i - 1;
371 inprod[idx] = 1.0;
372 }
373
374 let adot = compute_adot(n, &inprod);
375
376 let expected_len = (n * n - n + 2) / 2;
377 assert_eq!(
378 adot.len(),
379 expected_len,
380 "Adot length should be (n^2-n+2)/2"
381 );
382 assert!(
383 (adot[0] - PI * (n + 1) as f64).abs() < 1e-10,
384 "First element should be π*(n+1), got {}",
385 adot[0]
386 );
387 for (i, &val) in adot.iter().enumerate() {
388 assert!(val.is_finite(), "Adot[{}] should be finite, got {}", i, val);
389 }
390 }
391
392 #[test]
393 fn test_compute_adot_n1() {
394 let n = 1;
395 let inprod = vec![1.0]; let adot = compute_adot(n, &inprod);
397
398 assert_eq!(adot.len(), 1, "n=1 should give length 1");
399 assert!(
400 (adot[0] - PI * 2.0).abs() < 1e-10,
401 "n=1: first element should be π*2, got {}",
402 adot[0]
403 );
404 }
405
406 #[test]
407 fn test_compute_adot_invalid() {
408 assert!(compute_adot(0, &[]).is_empty());
410
411 assert!(compute_adot(4, &[1.0, 2.0]).is_empty());
413 }
414
415 #[test]
418 fn test_pcvm_statistic_basic() {
419 let n = 4;
420 let mut inprod = vec![0.0; (n * (n + 1)) / 2];
421 for i in 1..=n {
422 let idx = i * (i - 1) / 2 + i - 1;
423 inprod[idx] = 1.0;
424 }
425 let adot = compute_adot(n, &inprod);
426 let residuals = vec![0.5, -0.3, 0.2, -0.1];
427
428 let stat = pcvm_statistic(&adot, &residuals);
429
430 assert!(stat.is_finite(), "PCvM statistic should be finite");
431 assert!(stat >= 0.0, "PCvM statistic should be non-negative");
432 }
433
434 #[test]
435 fn test_pcvm_statistic_zero_residuals() {
436 let n = 4;
437 let mut inprod = vec![0.0; (n * (n + 1)) / 2];
438 for i in 1..=n {
439 let idx = i * (i - 1) / 2 + i - 1;
440 inprod[idx] = 1.0;
441 }
442 let adot = compute_adot(n, &inprod);
443 let residuals = vec![0.0, 0.0, 0.0, 0.0];
444
445 let stat = pcvm_statistic(&adot, &residuals);
446 assert!(
447 stat.abs() < 1e-10,
448 "PCvM with zero residuals should be ~0, got {}",
449 stat
450 );
451 }
452
453 #[test]
454 fn test_pcvm_statistic_empty() {
455 assert!(pcvm_statistic(&[], &[]).abs() < 1e-10);
456 assert!(pcvm_statistic(&[1.0], &[]).abs() < 1e-10);
457 }
458
459 #[test]
462 fn test_rp_stat_basic() {
463 let n_proj = 3;
464 let residuals = vec![0.5, -0.3, 0.2, -0.1, 0.4];
465
466 let proj_x_ord: Vec<i32> = vec![
468 1, 3, 5, 2, 4, 2, 4, 1, 5, 3, 5, 1, 3, 4, 2, ];
472
473 let result = rp_stat(&proj_x_ord, &residuals, n_proj);
474
475 assert_eq!(result.cvm.len(), n_proj);
476 assert_eq!(result.ks.len(), n_proj);
477 for &cvm_val in &result.cvm {
478 assert!(cvm_val >= 0.0, "CvM stat should be non-negative");
479 assert!(cvm_val.is_finite(), "CvM stat should be finite");
480 }
481 for &ks_val in &result.ks {
482 assert!(ks_val >= 0.0, "KS stat should be non-negative");
483 assert!(ks_val.is_finite(), "KS stat should be finite");
484 }
485 }
486
487 #[test]
488 fn test_rp_stat_invalid() {
489 let result = rp_stat(&[], &[], 0);
490 assert!(result.cvm.is_empty());
491 assert!(result.ks.is_empty());
492
493 let result = rp_stat(&[], &[1.0], 0);
494 assert!(result.cvm.is_empty());
495 }
496
497 #[test]
500 fn test_knn_loocv_basic() {
501 let size = 5;
502 let k = 2;
503 let mut dist = vec![0.0; size * size];
505 for i in 0..size {
506 for j in 0..size {
507 dist[i + j * size] = ((i as f64) - (j as f64)).abs();
508 }
509 }
510 let y: Vec<f64> = (0..size).map(|i| i as f64 * 2.0).collect();
511
512 let mse = knn_loocv(&dist, &y, size, k);
513
514 assert!(mse.is_finite(), "k-NN LOOCV MSE should be finite");
515 assert!(mse >= 0.0, "k-NN LOOCV MSE should be non-negative");
516 }
517
518 #[test]
519 fn test_knn_loocv_perfect() {
520 let n = 4;
522 let k = 1;
523 let mut dist = vec![100.0; n * n];
525 for i in 0..n {
526 dist[i + i * n] = 0.0;
527 }
528 dist[n] = 0.1;
530 dist[1] = 0.1;
531 dist[2 + 3 * n] = 0.1;
532 dist[3 + 2 * n] = 0.1;
533
534 let y = vec![1.0, 1.0, 5.0, 5.0];
536 let mse = knn_loocv(&dist, &y, n, k);
537
538 assert!(
539 mse < 1e-10,
540 "k-NN LOOCV MSE should be ~0 for perfectly paired data, got {}",
541 mse
542 );
543 }
544
545 #[test]
546 fn test_knn_loocv_invalid() {
547 assert!(knn_loocv(&[], &[], 0, 1).is_infinite());
548 assert!(knn_loocv(&[0.0], &[1.0], 1, 0).is_infinite());
549 }
550}