1use crate::helpers::{simpsons_weights, simpsons_weights_2d};
4use rayon::prelude::*;
5
6pub fn mean_1d(data: &[f64], nrow: usize, ncol: usize) -> Vec<f64> {
16 if nrow == 0 || ncol == 0 || data.len() != nrow * ncol {
17 return Vec::new();
18 }
19
20 (0..ncol)
21 .into_par_iter()
22 .map(|j| {
23 let mut sum = 0.0;
24 for i in 0..nrow {
25 sum += data[i + j * nrow];
26 }
27 sum / nrow as f64
28 })
29 .collect()
30}
31
32pub fn mean_2d(data: &[f64], nrow: usize, ncol: usize) -> Vec<f64> {
36 mean_1d(data, nrow, ncol)
38}
39
40pub fn center_1d(data: &[f64], nrow: usize, ncol: usize) -> Vec<f64> {
50 if nrow == 0 || ncol == 0 || data.len() != nrow * ncol {
51 return Vec::new();
52 }
53
54 let means: Vec<f64> = (0..ncol)
56 .into_par_iter()
57 .map(|j| {
58 let mut sum = 0.0;
59 for i in 0..nrow {
60 sum += data[i + j * nrow];
61 }
62 sum / nrow as f64
63 })
64 .collect();
65
66 let mut centered = vec![0.0; nrow * ncol];
68 for j in 0..ncol {
69 for i in 0..nrow {
70 centered[i + j * nrow] = data[i + j * nrow] - means[j];
71 }
72 }
73
74 centered
75}
76
77pub fn norm_lp_1d(data: &[f64], nrow: usize, ncol: usize, argvals: &[f64], p: f64) -> Vec<f64> {
89 if nrow == 0 || ncol == 0 || argvals.len() != ncol || data.len() != nrow * ncol {
90 return Vec::new();
91 }
92
93 let weights = simpsons_weights(argvals);
94
95 (0..nrow)
96 .into_par_iter()
97 .map(|i| {
98 let mut integral = 0.0;
99 for j in 0..ncol {
100 let val = data[i + j * nrow].abs().powf(p);
101 integral += val * weights[j];
102 }
103 integral.powf(1.0 / p)
104 })
105 .collect()
106}
107
108pub fn deriv_1d(
120 data: &[f64],
121 nrow: usize,
122 ncol: usize,
123 argvals: &[f64],
124 nderiv: usize,
125) -> Vec<f64> {
126 if nrow == 0 || ncol == 0 || argvals.len() != ncol || nderiv < 1 || data.len() != nrow * ncol {
127 return vec![0.0; nrow * ncol];
128 }
129
130 let mut current = data.to_vec();
131
132 let h0 = argvals[1] - argvals[0];
134 let hn = argvals[ncol - 1] - argvals[ncol - 2];
135 let h_central: Vec<f64> = (1..(ncol - 1))
136 .map(|j| argvals[j + 1] - argvals[j - 1])
137 .collect();
138
139 for _ in 0..nderiv {
140 let deriv: Vec<f64> = (0..nrow)
142 .into_par_iter()
143 .flat_map(|i| {
144 let mut row_deriv = vec![0.0; ncol];
145
146 row_deriv[0] = (current[i + nrow] - current[i]) / h0;
148
149 for j in 1..(ncol - 1) {
151 row_deriv[j] = (current[i + (j + 1) * nrow] - current[i + (j - 1) * nrow])
152 / h_central[j - 1];
153 }
154
155 row_deriv[ncol - 1] =
157 (current[i + (ncol - 1) * nrow] - current[i + (ncol - 2) * nrow]) / hn;
158
159 row_deriv
160 })
161 .collect();
162
163 current = vec![0.0; nrow * ncol];
165 for i in 0..nrow {
166 for j in 0..ncol {
167 current[i + j * nrow] = deriv[i * ncol + j];
168 }
169 }
170 }
171
172 current
173}
174
175pub struct Deriv2DResult {
177 pub ds: Vec<f64>,
179 pub dt: Vec<f64>,
181 pub dsdt: Vec<f64>,
183}
184
185pub fn deriv_2d(
200 data: &[f64],
201 n: usize,
202 argvals_s: &[f64],
203 argvals_t: &[f64],
204 m1: usize,
205 m2: usize,
206) -> Option<Deriv2DResult> {
207 let ncol = m1 * m2;
208 if n == 0 || ncol == 0 || argvals_s.len() != m1 || argvals_t.len() != m2 {
209 return None;
210 }
211
212 let get_val = |i: usize, si: usize, ti: usize| -> f64 { data[i + (si + ti * m1) * n] };
214
215 let hs: Vec<f64> = (0..m1)
217 .map(|j| {
218 if j == 0 {
219 argvals_s[1] - argvals_s[0]
220 } else if j == m1 - 1 {
221 argvals_s[m1 - 1] - argvals_s[m1 - 2]
222 } else {
223 argvals_s[j + 1] - argvals_s[j - 1]
224 }
225 })
226 .collect();
227
228 let ht: Vec<f64> = (0..m2)
230 .map(|j| {
231 if j == 0 {
232 argvals_t[1] - argvals_t[0]
233 } else if j == m2 - 1 {
234 argvals_t[m2 - 1] - argvals_t[m2 - 2]
235 } else {
236 argvals_t[j + 1] - argvals_t[j - 1]
237 }
238 })
239 .collect();
240
241 let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = (0..n)
243 .into_par_iter()
244 .map(|i| {
245 let mut ds = vec![0.0; m1 * m2];
246 let mut dt = vec![0.0; m1 * m2];
247 let mut dsdt = vec![0.0; m1 * m2];
248
249 for ti in 0..m2 {
250 for si in 0..m1 {
251 let idx = si + ti * m1;
252
253 if si == 0 {
255 ds[idx] = (get_val(i, 1, ti) - get_val(i, 0, ti)) / hs[si];
256 } else if si == m1 - 1 {
257 ds[idx] = (get_val(i, m1 - 1, ti) - get_val(i, m1 - 2, ti)) / hs[si];
258 } else {
259 ds[idx] = (get_val(i, si + 1, ti) - get_val(i, si - 1, ti)) / hs[si];
260 }
261
262 if ti == 0 {
264 dt[idx] = (get_val(i, si, 1) - get_val(i, si, 0)) / ht[ti];
265 } else if ti == m2 - 1 {
266 dt[idx] = (get_val(i, si, m2 - 1) - get_val(i, si, m2 - 2)) / ht[ti];
267 } else {
268 dt[idx] = (get_val(i, si, ti + 1) - get_val(i, si, ti - 1)) / ht[ti];
269 }
270
271 let denom = hs[si] * ht[ti];
273
274 if si == 0 && ti == 0 {
275 dsdt[idx] = (get_val(i, 1, 1) - get_val(i, 0, 1) - get_val(i, 1, 0)
276 + get_val(i, 0, 0))
277 / denom;
278 } else if si == m1 - 1 && ti == 0 {
279 dsdt[idx] =
280 (get_val(i, m1 - 1, 1) - get_val(i, m1 - 2, 1) - get_val(i, m1 - 1, 0)
281 + get_val(i, m1 - 2, 0))
282 / denom;
283 } else if si == 0 && ti == m2 - 1 {
284 dsdt[idx] =
285 (get_val(i, 1, m2 - 1) - get_val(i, 0, m2 - 1) - get_val(i, 1, m2 - 2)
286 + get_val(i, 0, m2 - 2))
287 / denom;
288 } else if si == m1 - 1 && ti == m2 - 1 {
289 dsdt[idx] = (get_val(i, m1 - 1, m2 - 1)
290 - get_val(i, m1 - 2, m2 - 1)
291 - get_val(i, m1 - 1, m2 - 2)
292 + get_val(i, m1 - 2, m2 - 2))
293 / denom;
294 } else if si == 0 {
295 dsdt[idx] =
296 (get_val(i, 1, ti + 1) - get_val(i, 0, ti + 1) - get_val(i, 1, ti - 1)
297 + get_val(i, 0, ti - 1))
298 / denom;
299 } else if si == m1 - 1 {
300 dsdt[idx] = (get_val(i, m1 - 1, ti + 1)
301 - get_val(i, m1 - 2, ti + 1)
302 - get_val(i, m1 - 1, ti - 1)
303 + get_val(i, m1 - 2, ti - 1))
304 / denom;
305 } else if ti == 0 {
306 dsdt[idx] =
307 (get_val(i, si + 1, 1) - get_val(i, si - 1, 1) - get_val(i, si + 1, 0)
308 + get_val(i, si - 1, 0))
309 / denom;
310 } else if ti == m2 - 1 {
311 dsdt[idx] = (get_val(i, si + 1, m2 - 1)
312 - get_val(i, si - 1, m2 - 1)
313 - get_val(i, si + 1, m2 - 2)
314 + get_val(i, si - 1, m2 - 2))
315 / denom;
316 } else {
317 dsdt[idx] = (get_val(i, si + 1, ti + 1)
318 - get_val(i, si - 1, ti + 1)
319 - get_val(i, si + 1, ti - 1)
320 + get_val(i, si - 1, ti - 1))
321 / denom;
322 }
323 }
324 }
325
326 (ds, dt, dsdt)
327 })
328 .collect();
329
330 let mut ds_mat = vec![0.0; n * ncol];
332 let mut dt_mat = vec![0.0; n * ncol];
333 let mut dsdt_mat = vec![0.0; n * ncol];
334
335 for i in 0..n {
336 for j in 0..ncol {
337 ds_mat[i + j * n] = results[i].0[j];
338 dt_mat[i + j * n] = results[i].1[j];
339 dsdt_mat[i + j * n] = results[i].2[j];
340 }
341 }
342
343 Some(Deriv2DResult {
344 ds: ds_mat,
345 dt: dt_mat,
346 dsdt: dsdt_mat,
347 })
348}
349
350pub fn geometric_median_1d(
362 data: &[f64],
363 nrow: usize,
364 ncol: usize,
365 argvals: &[f64],
366 max_iter: usize,
367 tol: f64,
368) -> Vec<f64> {
369 if nrow == 0 || ncol == 0 || argvals.len() != ncol || data.len() != nrow * ncol {
370 return Vec::new();
371 }
372
373 let weights = simpsons_weights(argvals);
374
375 let mut median: Vec<f64> = (0..ncol)
377 .map(|j| {
378 let mut sum = 0.0;
379 for i in 0..nrow {
380 sum += data[i + j * nrow];
381 }
382 sum / nrow as f64
383 })
384 .collect();
385
386 for _ in 0..max_iter {
387 let distances: Vec<f64> = (0..nrow)
389 .map(|i| {
390 let mut dist_sq = 0.0;
391 for j in 0..ncol {
392 let diff = data[i + j * nrow] - median[j];
393 dist_sq += diff * diff * weights[j];
394 }
395 dist_sq.sqrt()
396 })
397 .collect();
398
399 let eps = 1e-10;
401 let inv_distances: Vec<f64> = distances
402 .iter()
403 .map(|d| if *d > eps { 1.0 / d } else { 1.0 / eps })
404 .collect();
405
406 let sum_inv_dist: f64 = inv_distances.iter().sum();
407
408 let new_median: Vec<f64> = (0..ncol)
410 .map(|j| {
411 let mut weighted_sum = 0.0;
412 for i in 0..nrow {
413 weighted_sum += data[i + j * nrow] * inv_distances[i];
414 }
415 weighted_sum / sum_inv_dist
416 })
417 .collect();
418
419 let diff: f64 = median
421 .iter()
422 .zip(new_median.iter())
423 .map(|(a, b)| (a - b).abs())
424 .sum::<f64>()
425 / ncol as f64;
426
427 median = new_median;
428
429 if diff < tol {
430 break;
431 }
432 }
433
434 median
435}
436
437pub fn geometric_median_2d(
441 data: &[f64],
442 nrow: usize,
443 ncol: usize,
444 argvals_s: &[f64],
445 argvals_t: &[f64],
446 max_iter: usize,
447 tol: f64,
448) -> Vec<f64> {
449 let expected_cols = argvals_s.len() * argvals_t.len();
450 if nrow == 0 || ncol == 0 || ncol != expected_cols || data.len() != nrow * ncol {
451 return Vec::new();
452 }
453
454 let weights = simpsons_weights_2d(argvals_s, argvals_t);
455
456 let mut median: Vec<f64> = (0..ncol)
458 .map(|j| {
459 let mut sum = 0.0;
460 for i in 0..nrow {
461 sum += data[i + j * nrow];
462 }
463 sum / nrow as f64
464 })
465 .collect();
466
467 for _ in 0..max_iter {
468 let distances: Vec<f64> = (0..nrow)
470 .map(|i| {
471 let mut dist_sq = 0.0;
472 for j in 0..ncol {
473 let diff = data[i + j * nrow] - median[j];
474 dist_sq += diff * diff * weights[j];
475 }
476 dist_sq.sqrt()
477 })
478 .collect();
479
480 let eps = 1e-10;
482 let inv_distances: Vec<f64> = distances
483 .iter()
484 .map(|d| if *d > eps { 1.0 / d } else { 1.0 / eps })
485 .collect();
486
487 let sum_inv_dist: f64 = inv_distances.iter().sum();
488
489 let new_median: Vec<f64> = (0..ncol)
491 .map(|j| {
492 let mut weighted_sum = 0.0;
493 for i in 0..nrow {
494 weighted_sum += data[i + j * nrow] * inv_distances[i];
495 }
496 weighted_sum / sum_inv_dist
497 })
498 .collect();
499
500 let diff: f64 = median
502 .iter()
503 .zip(new_median.iter())
504 .map(|(a, b)| (a - b).abs())
505 .sum::<f64>()
506 / ncol as f64;
507
508 median = new_median;
509
510 if diff < tol {
511 break;
512 }
513 }
514
515 median
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521 use std::f64::consts::PI;
522
523 fn uniform_grid(n: usize) -> Vec<f64> {
524 (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
525 }
526
527 #[test]
530 fn test_mean_1d() {
531 let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; let mean = mean_1d(&data, 2, 3);
537 assert_eq!(mean, vec![2.0, 3.0, 4.0]);
538 }
539
540 #[test]
541 fn test_mean_1d_single_sample() {
542 let data = vec![1.0, 2.0, 3.0];
543 let mean = mean_1d(&data, 1, 3);
544 assert_eq!(mean, vec![1.0, 2.0, 3.0]);
545 }
546
547 #[test]
548 fn test_mean_1d_invalid() {
549 assert!(mean_1d(&[], 0, 0).is_empty());
550 assert!(mean_1d(&[1.0], 1, 2).is_empty()); }
552
553 #[test]
554 fn test_mean_2d_delegates() {
555 let data = vec![1.0, 3.0, 2.0, 4.0];
556 let mean1d = mean_1d(&data, 2, 2);
557 let mean2d = mean_2d(&data, 2, 2);
558 assert_eq!(mean1d, mean2d);
559 }
560
561 #[test]
564 fn test_center_1d() {
565 let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; let centered = center_1d(&data, 2, 3);
567 assert_eq!(centered, vec![-1.0, 1.0, -1.0, 1.0, -1.0, 1.0]);
569 }
570
571 #[test]
572 fn test_center_1d_mean_zero() {
573 let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0];
574 let centered = center_1d(&data, 2, 3);
575 let centered_mean = mean_1d(¢ered, 2, 3);
576 for m in centered_mean {
577 assert!(m.abs() < 1e-10, "Centered data should have zero mean");
578 }
579 }
580
581 #[test]
582 fn test_center_1d_invalid() {
583 assert!(center_1d(&[], 0, 0).is_empty());
584 }
585
586 #[test]
589 fn test_norm_lp_1d_constant() {
590 let argvals = uniform_grid(21);
592 let mut data = vec![0.0; 21];
593 for j in 0..21 {
594 data[j] = 2.0;
595 }
596 let norms = norm_lp_1d(&data, 1, 21, &argvals, 2.0);
597 assert_eq!(norms.len(), 1);
598 assert!(
599 (norms[0] - 2.0).abs() < 0.1,
600 "L2 norm of constant 2 should be 2"
601 );
602 }
603
604 #[test]
605 fn test_norm_lp_1d_sine() {
606 let argvals = uniform_grid(101);
608 let mut data = vec![0.0; 101];
609 for j in 0..101 {
610 data[j] = (PI * argvals[j]).sin();
611 }
612 let norms = norm_lp_1d(&data, 1, 101, &argvals, 2.0);
613 let expected = 0.5_f64.sqrt();
614 assert!(
615 (norms[0] - expected).abs() < 0.05,
616 "Expected {}, got {}",
617 expected,
618 norms[0]
619 );
620 }
621
622 #[test]
623 fn test_norm_lp_1d_invalid() {
624 assert!(norm_lp_1d(&[], 0, 0, &[], 2.0).is_empty());
625 }
626
627 #[test]
630 fn test_deriv_1d_linear() {
631 let argvals = uniform_grid(21);
633 let data = argvals.clone();
634 let deriv = deriv_1d(&data, 1, 21, &argvals, 1);
635 for j in 2..19 {
637 assert!((deriv[j] - 1.0).abs() < 0.1, "Derivative of x should be 1");
638 }
639 }
640
641 #[test]
642 fn test_deriv_1d_quadratic() {
643 let argvals = uniform_grid(51);
645 let mut data = vec![0.0; 51];
646 for j in 0..51 {
647 data[j] = argvals[j] * argvals[j];
648 }
649 let deriv = deriv_1d(&data, 1, 51, &argvals, 1);
650 for j in 5..45 {
652 let expected = 2.0 * argvals[j];
653 assert!(
654 (deriv[j] - expected).abs() < 0.1,
655 "Derivative of x^2 should be 2x"
656 );
657 }
658 }
659
660 #[test]
661 fn test_deriv_1d_invalid() {
662 let result = deriv_1d(&[], 0, 0, &[], 1);
663 assert!(result.is_empty() || result.iter().all(|&x| x == 0.0));
664 }
665
666 #[test]
669 fn test_geometric_median_identical_curves() {
670 let argvals = uniform_grid(21);
672 let n = 5;
673 let m = 21;
674 let mut data = vec![0.0; n * m];
675 for i in 0..n {
676 for j in 0..m {
677 data[i + j * n] = (2.0 * PI * argvals[j]).sin();
678 }
679 }
680 let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
681 for j in 0..m {
682 let expected = (2.0 * PI * argvals[j]).sin();
683 assert!(
684 (median[j] - expected).abs() < 0.01,
685 "Median should equal all curves"
686 );
687 }
688 }
689
690 #[test]
691 fn test_geometric_median_converges() {
692 let argvals = uniform_grid(21);
693 let n = 10;
694 let m = 21;
695 let mut data = vec![0.0; n * m];
696 for i in 0..n {
697 for j in 0..m {
698 data[i + j * n] = (i as f64 / n as f64) * argvals[j];
699 }
700 }
701 let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
702 assert_eq!(median.len(), m);
703 assert!(median.iter().all(|&x| x.is_finite()));
704 }
705
706 #[test]
707 fn test_geometric_median_invalid() {
708 assert!(geometric_median_1d(&[], 0, 0, &[], 100, 1e-6).is_empty());
709 }
710}