ndarray_ndimage/interpolation/
spline_filter.rs

1use ndarray::{arr1, s, Array, Array1, ArrayBase, ArrayViewMut1, Axis, Data, Dimension};
2use num_traits::ToPrimitive;
3
4use crate::BorderMode;
5
6/// Multidimensional spline filter.
7///
8/// The multidimensional filter is implemented as a sequence of one-dimensional spline filters. The
9/// input `data` will be processed in `f64` and returned as such.
10///
11/// * `data` - The input N-D data.
12/// * `order` - The order of the spline.
13/// * `mode` - The mode parameter determines how the input array is extended beyond its boundaries.
14///
15/// **Panics** if `order` isn't in the range \[2, 5\].
16pub fn spline_filter<S, A, D>(
17    data: &ArrayBase<S, D>,
18    order: usize,
19    mode: BorderMode<A>,
20) -> Array<f64, D>
21where
22    S: Data<Elem = A>,
23    A: Copy + ToPrimitive,
24    D: Dimension,
25{
26    let mut data = data.map(|v| v.to_f64().unwrap());
27    if data.len() == 1 {
28        return data;
29    }
30
31    let poles = get_filter_poles(order);
32    let gain = filter_gain(&poles);
33    for axis in 0..data.ndim() {
34        _spline_filter1d(&mut data, mode, Axis(axis), &poles, gain);
35    }
36    data
37}
38
39/// Calculate a 1-D spline filter along the given axis.
40///
41/// The lines of the array along the given axis are filtered by a spline filter. The input `data`
42/// will be processed in `f64` and returned as such.
43///
44/// * `data` - The input N-D data.
45/// * `order` - The order of the spline.
46/// * `mode` - The mode parameter determines how the input array is extended beyond its boundaries.
47/// * `axis` - The axis along which the spline filter is applied.
48///
49/// **Panics** if `order` isn't in the range \[0, 5\].
50pub fn spline_filter1d<S, A, D>(
51    data: &ArrayBase<S, D>,
52    order: usize,
53    mode: BorderMode<A>,
54    axis: Axis,
55) -> Array<f64, D>
56where
57    S: Data<Elem = A>,
58    A: Copy + ToPrimitive,
59    D: Dimension,
60{
61    let mut data = data.map(|v| v.to_f64().unwrap());
62    if order == 0 || order == 1 || data.len() == 1 {
63        return data;
64    }
65
66    let poles = get_filter_poles(order);
67    let gain = filter_gain(&poles);
68
69    _spline_filter1d(&mut data, mode, axis, &poles, gain);
70    data
71}
72
73fn _spline_filter1d<A, D>(
74    data: &mut Array<f64, D>,
75    mode: BorderMode<A>,
76    axis: Axis,
77    poles: &Array1<f64>,
78    gain: f64,
79) where
80    A: Copy,
81    D: Dimension,
82{
83    for mut line in data.lanes_mut(axis) {
84        for val in line.iter_mut() {
85            *val *= gain;
86        }
87        for &pole in poles {
88            init_causal_coefficient(&mut line, pole, mode);
89            for i in 1..line.len() {
90                line[i] += pole * line[i - 1];
91            }
92
93            init_anticausal_coefficient(&mut line, pole, mode);
94            for i in (0..line.len() - 1).rev() {
95                line[i] = pole * (line[i + 1] - line[i]);
96            }
97        }
98    }
99}
100
101fn get_filter_poles(order: usize) -> Array1<f64> {
102    match order {
103        1 => panic!("Can't use 'spline_filter' with order 1"),
104        2 => arr1(&[8.0f64.sqrt() - 3.0]),
105        3 => arr1(&[3.0f64.sqrt() - 2.0]),
106        4 => arr1(&[
107            (664.0 - 438976.0f64.sqrt()).sqrt() + 304.0f64.sqrt() - 19.0,
108            (664.0 + 438976.0f64.sqrt()).sqrt() - 304.0f64.sqrt() - 19.0,
109        ]),
110        5 => arr1(&[
111            (67.5 - 4436.25f64.sqrt()).sqrt() + 26.25f64.sqrt() - 6.5,
112            (67.5 + 4436.25f64.sqrt()).sqrt() - 26.25f64.sqrt() - 6.5,
113        ]),
114        _ => panic!("Order must be between 2 and 5"),
115    }
116}
117
118fn filter_gain(poles: &Array1<f64>) -> f64 {
119    let mut gain = 1.0;
120    for pole in poles {
121        gain *= (1.0 - pole) * (1.0 - 1.0 / pole);
122    }
123    gain
124}
125
126fn init_causal_coefficient<A>(line: &mut ArrayViewMut1<f64>, pole: f64, mode: BorderMode<A>) {
127    match mode {
128        BorderMode::Constant(_) | BorderMode::Mirror | BorderMode::Wrap => {
129            init_causal_mirror(line, pole)
130        }
131        BorderMode::Nearest | BorderMode::Reflect => init_causal_reflect(line, pole),
132    }
133}
134
135fn init_causal_mirror(line: &mut ArrayViewMut1<f64>, pole: f64) {
136    let mut z_i = pole;
137
138    // TODO I can't find this code anywhere in SciPy. It should be removed.
139    let tolerance: f64 = 1e-15;
140    let last_coefficient = (tolerance.ln().ceil() / pole.abs().ln()) as usize;
141    if last_coefficient < line.len() {
142        let mut sum = line[0];
143        // All values from line[1..last_coefficient]
144        for val in line.iter().take(last_coefficient).skip(1) {
145            sum += z_i * val;
146            z_i *= pole;
147        }
148        line[0] = sum;
149    } else {
150        let inv_z = 1.0 / pole;
151        let z_n_1 = pole.powi(line.len() as i32 - 1);
152        let mut z_2n_2_i = z_n_1 * z_n_1 * inv_z;
153
154        let mut sum = line[0] + (line[line.len() - 1] * z_n_1);
155        for v in line.slice(s![1..line.len() - 1]) {
156            sum += (z_i + z_2n_2_i) * v;
157            z_i *= pole;
158            z_2n_2_i *= inv_z;
159        }
160        line[0] = sum / (1.0 - z_n_1 * z_n_1);
161    }
162}
163
164fn init_causal_reflect(line: &mut ArrayViewMut1<f64>, pole: f64) {
165    let lm1 = line.len() - 1;
166    let mut z_i = pole;
167    let z_n = pole.powi(line.len() as i32);
168    let l0 = line[0];
169
170    line[0] += z_n * line[lm1];
171    for i in 1..line.len() {
172        line[0] += z_i * (line[i] + z_n * line[lm1 - i]);
173        z_i *= pole;
174    }
175    line[0] *= pole / (1.0 - z_n * z_n);
176    line[0] += l0;
177}
178
179fn init_anticausal_coefficient<A>(line: &mut ArrayViewMut1<f64>, pole: f64, mode: BorderMode<A>) {
180    match mode {
181        BorderMode::Constant(_) | BorderMode::Mirror | BorderMode::Wrap => {
182            init_anticausal_mirror(line, pole)
183        }
184        BorderMode::Nearest | BorderMode::Reflect => init_anticausal_reflect(line, pole),
185    }
186}
187
188fn init_anticausal_mirror(line: &mut ArrayViewMut1<f64>, pole: f64) {
189    let lm1 = line.len() - 1;
190    line[lm1] = pole / (pole * pole - 1.0) * (pole * line[line.len() - 2] + line[lm1]);
191}
192
193fn init_anticausal_reflect(line: &mut ArrayViewMut1<f64>, pole: f64) {
194    let lm1 = line.len() - 1;
195    line[lm1] *= pole / (pole - 1.0);
196}