1use crate::helpers::{simpsons_weights, simpsons_weights_2d, NUMERICAL_EPS};
4use rayon::prelude::*;
5
6fn finite_diff_1d(
11 values: impl Fn(usize) -> f64,
12 idx: usize,
13 n_points: usize,
14 step_sizes: &[f64],
15) -> f64 {
16 if idx == 0 {
17 (values(1) - values(0)) / step_sizes[0]
18 } else if idx == n_points - 1 {
19 (values(n_points - 1) - values(n_points - 2)) / step_sizes[n_points - 1]
20 } else {
21 (values(idx + 1) - values(idx - 1)) / step_sizes[idx]
22 }
23}
24
25fn compute_2d_derivatives(
29 get_val: impl Fn(usize, usize) -> f64,
30 si: usize,
31 ti: usize,
32 m1: usize,
33 m2: usize,
34 hs: &[f64],
35 ht: &[f64],
36) -> (f64, f64, f64) {
37 let ds = finite_diff_1d(|s| get_val(s, ti), si, m1, hs);
39
40 let dt = finite_diff_1d(|t| get_val(si, t), ti, m2, ht);
42
43 let denom = hs[si] * ht[ti];
45
46 let (s_lo, s_hi) = if si == 0 {
48 (0, 1)
49 } else if si == m1 - 1 {
50 (m1 - 2, m1 - 1)
51 } else {
52 (si - 1, si + 1)
53 };
54
55 let (t_lo, t_hi) = if ti == 0 {
56 (0, 1)
57 } else if ti == m2 - 1 {
58 (m2 - 2, m2 - 1)
59 } else {
60 (ti - 1, ti + 1)
61 };
62
63 let dsdt = (get_val(s_hi, t_hi) - get_val(s_lo, t_hi) - get_val(s_hi, t_lo)
64 + get_val(s_lo, t_lo))
65 / denom;
66
67 (ds, dt, dsdt)
68}
69
70fn weiszfeld_iteration(
74 data: &[f64],
75 n: usize,
76 m: usize,
77 weights: &[f64],
78 max_iter: usize,
79 tol: f64,
80) -> Vec<f64> {
81 let mut median: Vec<f64> = (0..m)
83 .map(|j| {
84 let mut sum = 0.0;
85 for i in 0..n {
86 sum += data[i + j * n];
87 }
88 sum / n as f64
89 })
90 .collect();
91
92 for _ in 0..max_iter {
93 let distances: Vec<f64> = (0..n)
95 .map(|i| {
96 let mut dist_sq = 0.0;
97 for j in 0..m {
98 let diff = data[i + j * n] - median[j];
99 dist_sq += diff * diff * weights[j];
100 }
101 dist_sq.sqrt()
102 })
103 .collect();
104
105 let inv_distances: Vec<f64> = distances
107 .iter()
108 .map(|d| {
109 if *d > NUMERICAL_EPS {
110 1.0 / d
111 } else {
112 1.0 / NUMERICAL_EPS
113 }
114 })
115 .collect();
116
117 let sum_inv_dist: f64 = inv_distances.iter().sum();
118
119 let new_median: Vec<f64> = (0..m)
121 .map(|j| {
122 let mut weighted_sum = 0.0;
123 for i in 0..n {
124 weighted_sum += data[i + j * n] * inv_distances[i];
125 }
126 weighted_sum / sum_inv_dist
127 })
128 .collect();
129
130 let diff: f64 = median
132 .iter()
133 .zip(new_median.iter())
134 .map(|(a, b)| (a - b).abs())
135 .sum::<f64>()
136 / m as f64;
137
138 median = new_median;
139
140 if diff < tol {
141 break;
142 }
143 }
144
145 median
146}
147
148pub fn mean_1d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
158 if n == 0 || m == 0 || data.len() != n * m {
159 return Vec::new();
160 }
161
162 (0..m)
163 .into_par_iter()
164 .map(|j| {
165 let mut sum = 0.0;
166 for i in 0..n {
167 sum += data[i + j * n];
168 }
169 sum / n as f64
170 })
171 .collect()
172}
173
174pub fn mean_2d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
178 mean_1d(data, n, m)
180}
181
182pub fn center_1d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
192 if n == 0 || m == 0 || data.len() != n * m {
193 return Vec::new();
194 }
195
196 let means: Vec<f64> = (0..m)
198 .into_par_iter()
199 .map(|j| {
200 let mut sum = 0.0;
201 for i in 0..n {
202 sum += data[i + j * n];
203 }
204 sum / n as f64
205 })
206 .collect();
207
208 let mut centered = vec![0.0; n * m];
210 for j in 0..m {
211 for i in 0..n {
212 centered[i + j * n] = data[i + j * n] - means[j];
213 }
214 }
215
216 centered
217}
218
219pub fn norm_lp_1d(data: &[f64], n: usize, m: usize, argvals: &[f64], p: f64) -> Vec<f64> {
231 if n == 0 || m == 0 || argvals.len() != m || data.len() != n * m {
232 return Vec::new();
233 }
234
235 let weights = simpsons_weights(argvals);
236
237 (0..n)
238 .into_par_iter()
239 .map(|i| {
240 let mut integral = 0.0;
241 for j in 0..m {
242 let val = data[i + j * n].abs().powf(p);
243 integral += val * weights[j];
244 }
245 integral.powf(1.0 / p)
246 })
247 .collect()
248}
249
250pub fn deriv_1d(data: &[f64], n: usize, m: usize, argvals: &[f64], nderiv: usize) -> Vec<f64> {
262 if n == 0 || m == 0 || argvals.len() != m || nderiv < 1 || data.len() != n * m {
263 return vec![0.0; n * m];
264 }
265
266 let mut current = data.to_vec();
267
268 let h0 = argvals[1] - argvals[0];
270 let hn = argvals[m - 1] - argvals[m - 2];
271 let h_central: Vec<f64> = (1..(m - 1))
272 .map(|j| argvals[j + 1] - argvals[j - 1])
273 .collect();
274
275 for _ in 0..nderiv {
276 let deriv: Vec<f64> = (0..n)
278 .into_par_iter()
279 .flat_map(|i| {
280 let mut row_deriv = vec![0.0; m];
281
282 row_deriv[0] = (current[i + n] - current[i]) / h0;
284
285 for j in 1..(m - 1) {
287 row_deriv[j] =
288 (current[i + (j + 1) * n] - current[i + (j - 1) * n]) / h_central[j - 1];
289 }
290
291 row_deriv[m - 1] = (current[i + (m - 1) * n] - current[i + (m - 2) * n]) / hn;
293
294 row_deriv
295 })
296 .collect();
297
298 current = vec![0.0; n * m];
300 for i in 0..n {
301 for j in 0..m {
302 current[i + j * n] = deriv[i * m + j];
303 }
304 }
305 }
306
307 current
308}
309
310pub struct Deriv2DResult {
312 pub ds: Vec<f64>,
314 pub dt: Vec<f64>,
316 pub dsdt: Vec<f64>,
318}
319
320pub fn deriv_2d(
335 data: &[f64],
336 n: usize,
337 argvals_s: &[f64],
338 argvals_t: &[f64],
339 m1: usize,
340 m2: usize,
341) -> Option<Deriv2DResult> {
342 let ncol = m1 * m2;
343 if n == 0 || ncol == 0 || argvals_s.len() != m1 || argvals_t.len() != m2 {
344 return None;
345 }
346
347 let hs: Vec<f64> = (0..m1)
349 .map(|j| {
350 if j == 0 {
351 argvals_s[1] - argvals_s[0]
352 } else if j == m1 - 1 {
353 argvals_s[m1 - 1] - argvals_s[m1 - 2]
354 } else {
355 argvals_s[j + 1] - argvals_s[j - 1]
356 }
357 })
358 .collect();
359
360 let ht: Vec<f64> = (0..m2)
362 .map(|j| {
363 if j == 0 {
364 argvals_t[1] - argvals_t[0]
365 } else if j == m2 - 1 {
366 argvals_t[m2 - 1] - argvals_t[m2 - 2]
367 } else {
368 argvals_t[j + 1] - argvals_t[j - 1]
369 }
370 })
371 .collect();
372
373 let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = (0..n)
375 .into_par_iter()
376 .map(|i| {
377 let mut ds = vec![0.0; m1 * m2];
378 let mut dt = vec![0.0; m1 * m2];
379 let mut dsdt = vec![0.0; m1 * m2];
380
381 let get_val = |si: usize, ti: usize| -> f64 { data[i + (si + ti * m1) * n] };
383
384 for ti in 0..m2 {
385 for si in 0..m1 {
386 let idx = si + ti * m1;
387 let (ds_val, dt_val, dsdt_val) =
388 compute_2d_derivatives(get_val, si, ti, m1, m2, &hs, &ht);
389 ds[idx] = ds_val;
390 dt[idx] = dt_val;
391 dsdt[idx] = dsdt_val;
392 }
393 }
394
395 (ds, dt, dsdt)
396 })
397 .collect();
398
399 let mut ds_mat = vec![0.0; n * ncol];
401 let mut dt_mat = vec![0.0; n * ncol];
402 let mut dsdt_mat = vec![0.0; n * ncol];
403
404 for i in 0..n {
405 for j in 0..ncol {
406 ds_mat[i + j * n] = results[i].0[j];
407 dt_mat[i + j * n] = results[i].1[j];
408 dsdt_mat[i + j * n] = results[i].2[j];
409 }
410 }
411
412 Some(Deriv2DResult {
413 ds: ds_mat,
414 dt: dt_mat,
415 dsdt: dsdt_mat,
416 })
417}
418
419pub fn geometric_median_1d(
431 data: &[f64],
432 n: usize,
433 m: usize,
434 argvals: &[f64],
435 max_iter: usize,
436 tol: f64,
437) -> Vec<f64> {
438 if n == 0 || m == 0 || argvals.len() != m || data.len() != n * m {
439 return Vec::new();
440 }
441
442 let weights = simpsons_weights(argvals);
443 weiszfeld_iteration(data, n, m, &weights, max_iter, tol)
444}
445
446pub fn geometric_median_2d(
459 data: &[f64],
460 n: usize,
461 m: usize,
462 argvals_s: &[f64],
463 argvals_t: &[f64],
464 max_iter: usize,
465 tol: f64,
466) -> Vec<f64> {
467 let expected_cols = argvals_s.len() * argvals_t.len();
468 if n == 0 || m == 0 || m != expected_cols || data.len() != n * m {
469 return Vec::new();
470 }
471
472 let weights = simpsons_weights_2d(argvals_s, argvals_t);
473 weiszfeld_iteration(data, n, m, &weights, max_iter, tol)
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479 use std::f64::consts::PI;
480
481 fn uniform_grid(n: usize) -> Vec<f64> {
482 (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
483 }
484
485 #[test]
488 fn test_mean_1d() {
489 let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; let mean = mean_1d(&data, 2, 3);
495 assert_eq!(mean, vec![2.0, 3.0, 4.0]);
496 }
497
498 #[test]
499 fn test_mean_1d_single_sample() {
500 let data = vec![1.0, 2.0, 3.0];
501 let mean = mean_1d(&data, 1, 3);
502 assert_eq!(mean, vec![1.0, 2.0, 3.0]);
503 }
504
505 #[test]
506 fn test_mean_1d_invalid() {
507 assert!(mean_1d(&[], 0, 0).is_empty());
508 assert!(mean_1d(&[1.0], 1, 2).is_empty()); }
510
511 #[test]
512 fn test_mean_2d_delegates() {
513 let data = vec![1.0, 3.0, 2.0, 4.0];
514 let mean1d = mean_1d(&data, 2, 2);
515 let mean2d = mean_2d(&data, 2, 2);
516 assert_eq!(mean1d, mean2d);
517 }
518
519 #[test]
522 fn test_center_1d() {
523 let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; let centered = center_1d(&data, 2, 3);
525 assert_eq!(centered, vec![-1.0, 1.0, -1.0, 1.0, -1.0, 1.0]);
527 }
528
529 #[test]
530 fn test_center_1d_mean_zero() {
531 let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0];
532 let centered = center_1d(&data, 2, 3);
533 let centered_mean = mean_1d(¢ered, 2, 3);
534 for m in centered_mean {
535 assert!(m.abs() < 1e-10, "Centered data should have zero mean");
536 }
537 }
538
539 #[test]
540 fn test_center_1d_invalid() {
541 assert!(center_1d(&[], 0, 0).is_empty());
542 }
543
544 #[test]
547 fn test_norm_lp_1d_constant() {
548 let argvals = uniform_grid(21);
550 let mut data = vec![0.0; 21];
551 for j in 0..21 {
552 data[j] = 2.0;
553 }
554 let norms = norm_lp_1d(&data, 1, 21, &argvals, 2.0);
555 assert_eq!(norms.len(), 1);
556 assert!(
557 (norms[0] - 2.0).abs() < 0.1,
558 "L2 norm of constant 2 should be 2"
559 );
560 }
561
562 #[test]
563 fn test_norm_lp_1d_sine() {
564 let argvals = uniform_grid(101);
566 let mut data = vec![0.0; 101];
567 for j in 0..101 {
568 data[j] = (PI * argvals[j]).sin();
569 }
570 let norms = norm_lp_1d(&data, 1, 101, &argvals, 2.0);
571 let expected = 0.5_f64.sqrt();
572 assert!(
573 (norms[0] - expected).abs() < 0.05,
574 "Expected {}, got {}",
575 expected,
576 norms[0]
577 );
578 }
579
580 #[test]
581 fn test_norm_lp_1d_invalid() {
582 assert!(norm_lp_1d(&[], 0, 0, &[], 2.0).is_empty());
583 }
584
585 #[test]
588 fn test_deriv_1d_linear() {
589 let argvals = uniform_grid(21);
591 let data = argvals.clone();
592 let deriv = deriv_1d(&data, 1, 21, &argvals, 1);
593 for j in 2..19 {
595 assert!((deriv[j] - 1.0).abs() < 0.1, "Derivative of x should be 1");
596 }
597 }
598
599 #[test]
600 fn test_deriv_1d_quadratic() {
601 let argvals = uniform_grid(51);
603 let mut data = vec![0.0; 51];
604 for j in 0..51 {
605 data[j] = argvals[j] * argvals[j];
606 }
607 let deriv = deriv_1d(&data, 1, 51, &argvals, 1);
608 for j in 5..45 {
610 let expected = 2.0 * argvals[j];
611 assert!(
612 (deriv[j] - expected).abs() < 0.1,
613 "Derivative of x^2 should be 2x"
614 );
615 }
616 }
617
618 #[test]
619 fn test_deriv_1d_invalid() {
620 let result = deriv_1d(&[], 0, 0, &[], 1);
621 assert!(result.is_empty() || result.iter().all(|&x| x == 0.0));
622 }
623
624 #[test]
627 fn test_geometric_median_identical_curves() {
628 let argvals = uniform_grid(21);
630 let n = 5;
631 let m = 21;
632 let mut data = vec![0.0; n * m];
633 for i in 0..n {
634 for j in 0..m {
635 data[i + j * n] = (2.0 * PI * argvals[j]).sin();
636 }
637 }
638 let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
639 for j in 0..m {
640 let expected = (2.0 * PI * argvals[j]).sin();
641 assert!(
642 (median[j] - expected).abs() < 0.01,
643 "Median should equal all curves"
644 );
645 }
646 }
647
648 #[test]
649 fn test_geometric_median_converges() {
650 let argvals = uniform_grid(21);
651 let n = 10;
652 let m = 21;
653 let mut data = vec![0.0; n * m];
654 for i in 0..n {
655 for j in 0..m {
656 data[i + j * n] = (i as f64 / n as f64) * argvals[j];
657 }
658 }
659 let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
660 assert_eq!(median.len(), m);
661 assert!(median.iter().all(|&x| x.is_finite()));
662 }
663
664 #[test]
665 fn test_geometric_median_invalid() {
666 assert!(geometric_median_1d(&[], 0, 0, &[], 100, 1e-6).is_empty());
667 }
668
669 #[test]
672 fn test_deriv_2d_linear_surface() {
673 let m1 = 11;
676 let m2 = 11;
677 let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
678 let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
679
680 let n = 1; let ncol = m1 * m2;
682 let mut data = vec![0.0; n * ncol];
683
684 for si in 0..m1 {
685 for ti in 0..m2 {
686 let s = argvals_s[si];
687 let t = argvals_t[ti];
688 let idx = si + ti * m1;
689 data[idx] = 2.0 * s + 3.0 * t;
690 }
691 }
692
693 let result = deriv_2d(&data, n, &argvals_s, &argvals_t, m1, m2).unwrap();
694
695 for si in 2..(m1 - 2) {
697 for ti in 2..(m2 - 2) {
698 let idx = si + ti * m1;
699 assert!(
700 (result.ds[idx] - 2.0).abs() < 0.2,
701 "∂f/∂s at ({}, {}) = {}, expected 2",
702 si,
703 ti,
704 result.ds[idx]
705 );
706 }
707 }
708
709 for si in 2..(m1 - 2) {
711 for ti in 2..(m2 - 2) {
712 let idx = si + ti * m1;
713 assert!(
714 (result.dt[idx] - 3.0).abs() < 0.2,
715 "∂f/∂t at ({}, {}) = {}, expected 3",
716 si,
717 ti,
718 result.dt[idx]
719 );
720 }
721 }
722
723 for si in 2..(m1 - 2) {
725 for ti in 2..(m2 - 2) {
726 let idx = si + ti * m1;
727 assert!(
728 result.dsdt[idx].abs() < 0.5,
729 "∂²f/∂s∂t at ({}, {}) = {}, expected 0",
730 si,
731 ti,
732 result.dsdt[idx]
733 );
734 }
735 }
736 }
737
738 #[test]
739 fn test_deriv_2d_quadratic_surface() {
740 let m1 = 21;
743 let m2 = 21;
744 let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
745 let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
746
747 let n = 1;
748 let ncol = m1 * m2;
749 let mut data = vec![0.0; n * ncol];
750
751 for si in 0..m1 {
752 for ti in 0..m2 {
753 let s = argvals_s[si];
754 let t = argvals_t[ti];
755 let idx = si + ti * m1;
756 data[idx] = s * t;
757 }
758 }
759
760 let result = deriv_2d(&data, n, &argvals_s, &argvals_t, m1, m2).unwrap();
761
762 for si in 3..(m1 - 3) {
764 for ti in 3..(m2 - 3) {
765 let idx = si + ti * m1;
766 let expected = argvals_t[ti];
767 assert!(
768 (result.ds[idx] - expected).abs() < 0.1,
769 "∂f/∂s at ({}, {}) = {}, expected {}",
770 si,
771 ti,
772 result.ds[idx],
773 expected
774 );
775 }
776 }
777
778 for si in 3..(m1 - 3) {
780 for ti in 3..(m2 - 3) {
781 let idx = si + ti * m1;
782 let expected = argvals_s[si];
783 assert!(
784 (result.dt[idx] - expected).abs() < 0.1,
785 "∂f/∂t at ({}, {}) = {}, expected {}",
786 si,
787 ti,
788 result.dt[idx],
789 expected
790 );
791 }
792 }
793
794 for si in 3..(m1 - 3) {
796 for ti in 3..(m2 - 3) {
797 let idx = si + ti * m1;
798 assert!(
799 (result.dsdt[idx] - 1.0).abs() < 0.3,
800 "∂²f/∂s∂t at ({}, {}) = {}, expected 1",
801 si,
802 ti,
803 result.dsdt[idx]
804 );
805 }
806 }
807 }
808
809 #[test]
810 fn test_deriv_2d_invalid_input() {
811 let result = deriv_2d(&[], 0, &[], &[], 0, 0);
813 assert!(result.is_none());
814
815 let data = vec![1.0; 4];
817 let argvals = vec![0.0, 1.0];
818 let result = deriv_2d(&data, 1, &argvals, &[0.0, 0.5, 1.0], 2, 2);
819 assert!(result.is_none());
820 }
821
822 #[test]
825 fn test_geometric_median_2d_basic() {
826 let m1 = 5;
828 let m2 = 5;
829 let m = m1 * m2;
830 let n = 3;
831 let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
832 let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
833
834 let mut data = vec![0.0; n * m];
835
836 for i in 0..n {
838 for si in 0..m1 {
839 for ti in 0..m2 {
840 let idx = si + ti * m1;
841 let s = argvals_s[si];
842 let t = argvals_t[ti];
843 data[i + idx * n] = s + t;
844 }
845 }
846 }
847
848 let median = geometric_median_2d(&data, n, m, &argvals_s, &argvals_t, 100, 1e-6);
849 assert_eq!(median.len(), m);
850
851 for si in 0..m1 {
853 for ti in 0..m2 {
854 let idx = si + ti * m1;
855 let expected = argvals_s[si] + argvals_t[ti];
856 assert!(
857 (median[idx] - expected).abs() < 0.01,
858 "Median at ({}, {}) = {}, expected {}",
859 si,
860 ti,
861 median[idx],
862 expected
863 );
864 }
865 }
866 }
867}