ndspec/core/
integration.rs

1use ndarray::{ArrayView1, ArrayView2};
2use rayon::prelude::*;
3
4// sequential trapezoidal integration
5pub fn trapz(y: ArrayView1<f64>, x: ArrayView1<f64>) -> f64 {
6    (0..(x.len() - 1))
7        .map(|i| (x[i + 1] - x[i]) * (y[i + 1] + y[i]) / 2.0)
8        .sum()
9}
10
11// parallel trapezoidal integration
12#[allow(dead_code)]
13pub fn trapzp(y: ArrayView1<f64>, x: ArrayView1<f64>) -> f64 {
14    (0..x.len() - 1)
15        .into_par_iter()
16        .map(|i| (x[i + 1] - x[i]) * (y[i + 1] + y[i]) / 2.0)
17        .sum()
18}
19
20// parallel 2D trapezoidal integration
21pub fn trapz2d(f: ArrayView2<f64>, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
22    let nx = x.len();
23    let ny = y.len();
24
25    (0..ny - 1)
26        .into_par_iter()
27        .map(|j| {
28            let dy = (y[j + 1] - y[j]) / 2.0;
29            (0..nx - 1)
30                //                .into_par_iter()
31                .map(|i| {
32                    let dx = (x[i + 1] - x[i]) / 2.0;
33                    (f[[j, i]] + f[[j + 1, i]] + f[[j, i + 1]] + f[[j + 1, i + 1]]) * dx * dy
34                })
35                .sum::<f64>()
36        })
37        .sum()
38}
39
40#[cfg(test)]
41mod tests_integration {
42    use super::*;
43    use ndarray::arr1;
44    use ndarray::{Array1, Array2};
45
46    #[test]
47    fn test_trapz() {
48        // Test case 1
49        let y1 = arr1(&[1.0, 2.0, 3.0, 4.0]);
50        let x1 = arr1(&[0.0, 1.0, 2.0, 3.0]);
51        let result1 = trapz(y1.view(), x1.view());
52        assert_eq!(result1, 7.5);
53
54        // Test case 2
55        let y2 = arr1(&[0.0, 0.5, 1.0]);
56        let x2 = arr1(&[0.0, 1.0, 2.0]);
57        let result2 = trapz(y2.view(), x2.view());
58        assert_eq!(result2, 1.0);
59
60        // Add more test cases as needed
61    }
62
63    #[test]
64    fn test_trapz2d() {
65        let x = Array1::from(vec![0., 1., 2.]);
66        let y = Array1::from(vec![0., 1., 2.]);
67        let z = Array2::from_shape_vec((3, 3), vec![0.0, 0.0, 0.0, 0.0, 1.0, 4.0, 0.0, 4.0, 16.0])
68            .unwrap();
69        assert_eq!(trapz2d(z.view(), x.view(), y.view()), 9.0);
70    }
71}