Skip to main content

ndarray_ndimage/interpolation/
spline_filter.rs

1use ndarray::{arr1, s, Array, Array1, ArrayRef, ArrayViewMut1, Axis, Dimension};
2use num_traits::ToPrimitive;
3
4use crate::BorderMode;
5
6/// Multidimensional spline filter.
7///
8/// The multidimensional filter is implemented as a sequence of one-dimensional spline filters. The
9/// input `data` will be processed in `f64` and returned as such.
10///
11/// * `data` - The input N-D data.
12/// * `order` - The order of the spline.
13/// * `mode` - The mode parameter determines how the input array is extended beyond its boundaries.
14///
15/// **Panics** if `order` isn't in the range \[2, 5\].
16pub fn spline_filter<A, D>(
17    data: &ArrayRef<A, D>,
18    order: usize,
19    mode: BorderMode<A>,
20) -> Array<f64, D>
21where
22    A: Copy + ToPrimitive,
23    D: Dimension,
24{
25    let mut data = data.map(|v| v.to_f64().unwrap());
26    if data.len() == 1 {
27        return data;
28    }
29
30    let poles = get_filter_poles(order);
31    let gain = filter_gain(&poles);
32    for axis in 0..data.ndim() {
33        _spline_filter1d(&mut data, mode, Axis(axis), &poles, gain);
34    }
35    data
36}
37
38/// Calculate a 1-D spline filter along the given axis.
39///
40/// The lines of the array along the given axis are filtered by a spline filter. The input `data`
41/// will be processed in `f64` and returned as such.
42///
43/// * `data` - The input N-D data.
44/// * `order` - The order of the spline.
45/// * `mode` - The mode parameter determines how the input array is extended beyond its boundaries.
46/// * `axis` - The axis along which the spline filter is applied.
47///
48/// **Panics** if `order` isn't in the range \[0, 5\].
49pub fn spline_filter1d<A, D>(
50    data: &ArrayRef<A, D>,
51    order: usize,
52    mode: BorderMode<A>,
53    axis: Axis,
54) -> Array<f64, D>
55where
56    A: Copy + ToPrimitive,
57    D: Dimension,
58{
59    let mut data = data.map(|v| v.to_f64().unwrap());
60    if order == 0 || order == 1 || data.len() == 1 {
61        return data;
62    }
63
64    let poles = get_filter_poles(order);
65    let gain = filter_gain(&poles);
66
67    _spline_filter1d(&mut data, mode, axis, &poles, gain);
68    data
69}
70
71fn _spline_filter1d<A, D>(
72    data: &mut Array<f64, D>,
73    mode: BorderMode<A>,
74    axis: Axis,
75    poles: &Array1<f64>,
76    gain: f64,
77) where
78    A: Copy,
79    D: Dimension,
80{
81    for mut line in data.lanes_mut(axis) {
82        for val in line.iter_mut() {
83            *val *= gain;
84        }
85        for &pole in poles {
86            init_causal_coefficient(&mut line, pole, mode);
87            for i in 1..line.len() {
88                line[i] += pole * line[i - 1];
89            }
90
91            init_anticausal_coefficient(&mut line, pole, mode);
92            for i in (0..line.len() - 1).rev() {
93                line[i] = pole * (line[i + 1] - line[i]);
94            }
95        }
96    }
97}
98
99fn get_filter_poles(order: usize) -> Array1<f64> {
100    match order {
101        1 => panic!("Can't use 'spline_filter' with order 1"),
102        2 => arr1(&[8.0f64.sqrt() - 3.0]),
103        3 => arr1(&[3.0f64.sqrt() - 2.0]),
104        4 => arr1(&[
105            (664.0 - 438976.0f64.sqrt()).sqrt() + 304.0f64.sqrt() - 19.0,
106            (664.0 + 438976.0f64.sqrt()).sqrt() - 304.0f64.sqrt() - 19.0,
107        ]),
108        5 => arr1(&[
109            (67.5 - 4436.25f64.sqrt()).sqrt() + 26.25f64.sqrt() - 6.5,
110            (67.5 + 4436.25f64.sqrt()).sqrt() - 26.25f64.sqrt() - 6.5,
111        ]),
112        _ => panic!("Order must be between 2 and 5"),
113    }
114}
115
116fn filter_gain(poles: &Array1<f64>) -> f64 {
117    let mut gain = 1.0;
118    for pole in poles {
119        gain *= (1.0 - pole) * (1.0 - 1.0 / pole);
120    }
121    gain
122}
123
124fn init_causal_coefficient<A>(line: &mut ArrayViewMut1<f64>, pole: f64, mode: BorderMode<A>) {
125    match mode {
126        BorderMode::Constant(_) | BorderMode::Mirror | BorderMode::Wrap => {
127            init_causal_mirror(line, pole)
128        }
129        BorderMode::Nearest | BorderMode::Reflect => init_causal_reflect(line, pole),
130    }
131}
132
133fn init_causal_mirror(line: &mut ArrayViewMut1<f64>, pole: f64) {
134    let mut z_i = pole;
135
136    // TODO I can't find this code anywhere in SciPy. It should be removed.
137    let tolerance: f64 = 1e-15;
138    let last_coefficient = (tolerance.ln().ceil() / pole.abs().ln()) as usize;
139    if last_coefficient < line.len() {
140        let mut sum = line[0];
141        // All values from line[1..last_coefficient]
142        for val in line.iter().take(last_coefficient).skip(1) {
143            sum += z_i * val;
144            z_i *= pole;
145        }
146        line[0] = sum;
147    } else {
148        let inv_z = 1.0 / pole;
149        let z_n_1 = pole.powi(line.len() as i32 - 1);
150        let mut z_2n_2_i = z_n_1 * z_n_1 * inv_z;
151
152        let mut sum = line[0] + (line[line.len() - 1] * z_n_1);
153        for v in line.slice(s![1..line.len() - 1]) {
154            sum += (z_i + z_2n_2_i) * v;
155            z_i *= pole;
156            z_2n_2_i *= inv_z;
157        }
158        line[0] = sum / (1.0 - z_n_1 * z_n_1);
159    }
160}
161
162fn init_causal_reflect(line: &mut ArrayViewMut1<f64>, pole: f64) {
163    let lm1 = line.len() - 1;
164    let mut z_i = pole;
165    let z_n = pole.powi(line.len() as i32);
166    let l0 = line[0];
167
168    line[0] += z_n * line[lm1];
169    for i in 1..line.len() {
170        line[0] += z_i * (line[i] + z_n * line[lm1 - i]);
171        z_i *= pole;
172    }
173    line[0] *= pole / (1.0 - z_n * z_n);
174    line[0] += l0;
175}
176
177fn init_anticausal_coefficient<A>(line: &mut ArrayViewMut1<f64>, pole: f64, mode: BorderMode<A>) {
178    match mode {
179        BorderMode::Constant(_) | BorderMode::Mirror | BorderMode::Wrap => {
180            init_anticausal_mirror(line, pole)
181        }
182        BorderMode::Nearest | BorderMode::Reflect => init_anticausal_reflect(line, pole),
183    }
184}
185
186fn init_anticausal_mirror(line: &mut ArrayViewMut1<f64>, pole: f64) {
187    let lm1 = line.len() - 1;
188    line[lm1] = pole / (pole * pole - 1.0) * (pole * line[line.len() - 2] + line[lm1]);
189}
190
191fn init_anticausal_reflect(line: &mut ArrayViewMut1<f64>, pole: f64) {
192    let lm1 = line.len() - 1;
193    line[lm1] *= pole / (pole - 1.0);
194}