ndarray_ndimage/interpolation/
zoom_shift.rs

1use std::ops::{Add, Sub};
2
3use ndarray::{s, Array, Array2, ArrayBase, ArrayViewMut1, Data, Ix3, Zip};
4use num_traits::{FromPrimitive, Num, ToPrimitive};
5
6use crate::{array_like, pad, round_ties_even, spline_filter, BorderMode, PadMode};
7
8/// Shift an array.
9///
10/// The array is shifted using spline interpolation of the requested order. Points outside the
11/// boundaries of the input are filled according to the given mode.
12///
13/// * `data` - A 3D array of the data to shift.
14/// * `shift` - The shift along the axes.
15/// * `order` - The order of the spline.
16/// * `mode` - The mode parameter determines how the input array is extended beyond its boundaries.
17/// * `prefilter` - Determines if the input array is prefiltered with spline_filter before
18///   interpolation. The default is `true`, which will create a temporary `f64` array of filtered
19///   values if `order > 1`. If setting this to `false`, the output will be slightly blurred if
20///   `order > 1`, unless the input is prefiltered.
21pub fn shift<S, A>(
22    data: &ArrayBase<S, Ix3>,
23    shift: [f64; 3],
24    order: usize,
25    mode: BorderMode<A>,
26    prefilter: bool,
27) -> Array<A, Ix3>
28where
29    S: Data<Elem = A>,
30    A: Copy + Num + FromPrimitive + PartialOrd + ToPrimitive,
31{
32    let dim = [data.dim().0, data.dim().1, data.dim().2];
33    let shift = shift.map(|s| -s);
34    run_zoom_shift(data, dim, [1.0, 1.0, 1.0], shift, order, mode, prefilter)
35}
36
37/// Zoom an array.
38///
39/// The array is zoomed using spline interpolation of the requested order.
40///
41/// * `data` - A 3D array of the data to zoom
42/// * `zoom` - The zoom factor along the axes.
43/// * `order` - The order of the spline.
44/// * `mode` - The mode parameter determines how the input array is extended beyond its boundaries.
45/// * `prefilter` - Determines if the input array is prefiltered with spline_filter before
46///   interpolation. The default is `true`, which will create a temporary `f64` array of filtered
47///   values if `order > 1`. If setting this to `false`, the output will be slightly blurred if
48///   `order > 1`, unless the input is prefiltered.
49pub fn zoom<S, A>(
50    data: &ArrayBase<S, Ix3>,
51    zoom: [f64; 3],
52    order: usize,
53    mode: BorderMode<A>,
54    prefilter: bool,
55) -> Array<A, Ix3>
56where
57    S: Data<Elem = A>,
58    A: Copy + Num + FromPrimitive + PartialOrd + ToPrimitive,
59{
60    let mut o_dim = data.raw_dim();
61    for (ax, (&ax_len, zoom)) in data.shape().iter().zip(zoom.iter()).enumerate() {
62        o_dim[ax] = round_ties_even(ax_len as f64 * zoom) as usize;
63    }
64    let o_dim = [o_dim[0], o_dim[1], o_dim[2]];
65
66    let mut nom = data.raw_dim();
67    let mut div = o_dim.clone();
68    for ax in 0..data.ndim() {
69        nom[ax] -= 1;
70        div[ax] -= 1;
71    }
72    let zoom = [
73        nom[0] as f64 / div[0] as f64,
74        nom[1] as f64 / div[1] as f64,
75        nom[2] as f64 / div[2] as f64,
76    ];
77
78    run_zoom_shift(data, o_dim, zoom, [0.0, 0.0, 0.0], order, mode, prefilter)
79}
80
81fn run_zoom_shift<S, A>(
82    data: &ArrayBase<S, Ix3>,
83    odim: [usize; 3],
84    zooms: [f64; 3],
85    shifts: [f64; 3],
86    order: usize,
87    mode: BorderMode<A>,
88    prefilter: bool,
89) -> Array<A, Ix3>
90where
91    S: Data<Elem = A>,
92    A: Copy + Num + FromPrimitive + PartialOrd + ToPrimitive,
93{
94    let idim = [data.dim().0, data.dim().1, data.dim().2];
95    let mut out = array_like(&data, odim, A::zero());
96    if prefilter && order > 1 {
97        // We need to allocate and work on filtered data
98        let (data, nb_prepad) = match mode {
99            BorderMode::Nearest => {
100                let padded = pad(data, &[[12, 12]], PadMode::Edge);
101                (spline_filter(&padded, order, mode), 12)
102            }
103            _ => (spline_filter(data, order, mode), 0),
104        };
105        let reslicer = ZoomShiftReslicer::new(idim, odim, zooms, shifts, order, mode, nb_prepad);
106        Zip::indexed(&mut out).for_each(|idx, o| {
107            *o = A::from_f64(reslicer.interpolate(&data, idx)).unwrap();
108        });
109    } else {
110        // We can use the &data as-is
111        let reslicer = ZoomShiftReslicer::new(idim, odim, zooms, shifts, order, mode, 0);
112        Zip::indexed(&mut out).for_each(|idx, o| {
113            *o = A::from_f64(reslicer.interpolate(data, idx)).unwrap();
114        });
115    }
116    out
117}
118
119/// Zoom shift transformation (only scaling and translation).
120struct ZoomShiftReslicer {
121    order: usize,
122    offsets: [Vec<isize>; 3],
123    edge_offsets: [Array2<isize>; 3],
124    is_edge_case: [Vec<bool>; 3],
125    splvals: [Array2<f64>; 3],
126    zeros: [Vec<bool>; 3],
127    cval: f64,
128}
129
130impl ZoomShiftReslicer {
131    /// Build all necessary data to call `interpolate`.
132    pub fn new<A>(
133        idim: [usize; 3],
134        odim: [usize; 3],
135        zooms: [f64; 3],
136        shifts: [f64; 3],
137        order: usize,
138        mode: BorderMode<A>,
139        nb_prepad: isize,
140    ) -> ZoomShiftReslicer
141    where
142        A: Copy + ToPrimitive,
143    {
144        let offsets = [vec![0; odim[0]], vec![0; odim[1]], vec![0; odim[2]]];
145        let is_edge_case = [vec![false; odim[0]], vec![false; odim[1]], vec![false; odim[2]]];
146        let (edge_offsets, splvals) = if order > 0 {
147            let dim0 = (odim[0], order + 1);
148            let dim1 = (odim[1], order + 1);
149            let dim2 = (odim[2], order + 1);
150            let e = [Array2::zeros(dim0), Array2::zeros(dim1), Array2::zeros(dim2)];
151            let s = [Array2::zeros(dim0), Array2::zeros(dim1), Array2::zeros(dim2)];
152            (e, s)
153        } else {
154            // We do not need to allocate when order == 0
155            let e = [Array2::zeros((0, 0)), Array2::zeros((0, 0)), Array2::zeros((0, 0))];
156            let s = [Array2::zeros((0, 0)), Array2::zeros((0, 0)), Array2::zeros((0, 0))];
157            (e, s)
158        };
159        let zeros = [vec![false; odim[0]], vec![false; odim[1]], vec![false; odim[2]]];
160        let cval = match mode {
161            BorderMode::Constant(cval) => cval.to_f64().unwrap(),
162            _ => 0.0,
163        };
164
165        let mut reslicer =
166            ZoomShiftReslicer { order, offsets, edge_offsets, is_edge_case, splvals, zeros, cval };
167        reslicer.build_arrays(idim, odim, zooms, shifts, order, mode, nb_prepad);
168        reslicer
169    }
170
171    fn build_arrays<A>(
172        &mut self,
173        idim: [usize; 3],
174        odim: [usize; 3],
175        zooms: [f64; 3],
176        shifts: [f64; 3],
177        order: usize,
178        mode: BorderMode<A>,
179        nb_prepad: isize,
180    ) where
181        A: Copy,
182    {
183        // Modes without an anlaytic prefilter or explicit prepadding use mirror extension
184        let spline_mode = match mode {
185            BorderMode::Constant(_) | BorderMode::Wrap => BorderMode::Mirror,
186            _ => mode,
187        };
188        let iorder = order as isize;
189        let idim = [
190            idim[0] as isize + 2 * nb_prepad,
191            idim[1] as isize + 2 * nb_prepad,
192            idim[2] as isize + 2 * nb_prepad,
193        ];
194        let nb_prepad = nb_prepad as f64;
195
196        for axis in 0..3 {
197            let splvals = &mut self.splvals[axis];
198            let offsets = &mut self.offsets[axis];
199            let edge_offsets = &mut self.edge_offsets[axis];
200            let is_edge_case = &mut self.is_edge_case[axis];
201            let zeros = &mut self.zeros[axis];
202            let len = idim[axis] as f64;
203            for from in 0..odim[axis] {
204                let mut to = (from as f64 + shifts[axis]) * zooms[axis] + nb_prepad;
205                match mode {
206                    BorderMode::Nearest => {}
207                    _ => to = map_coordinates(to, idim[axis] as f64, mode),
208                };
209                if to > -1.0 {
210                    if order > 0 {
211                        build_splines(to, &mut splvals.row_mut(from), order);
212                    }
213                    if order & 1 == 0 {
214                        to += 0.5;
215                    }
216
217                    let start = to.floor() as isize - iorder / 2;
218                    offsets[from] = start;
219                    if start < 0 || start + iorder >= idim[axis] {
220                        is_edge_case[from] = true;
221                        for o in 0..=order {
222                            let x = (start + o as isize) as f64;
223                            let idx = map_coordinates(x, len, spline_mode) as isize;
224                            edge_offsets[(from, o)] = idx - start;
225                        }
226                    }
227                } else {
228                    zeros[from] = true;
229                }
230            }
231        }
232    }
233
234    /// Spline interpolation with up-to 8 neighbors of a point.
235    pub fn interpolate<A, S>(&self, data: &ArrayBase<S, Ix3>, start: (usize, usize, usize)) -> f64
236    where
237        S: Data<Elem = A>,
238        A: ToPrimitive + Add<Output = A> + Sub<Output = A> + Copy,
239    {
240        if self.zeros[0][start.0] || self.zeros[1][start.1] || self.zeros[2][start.2] {
241            return self.cval;
242        }
243
244        // Order = 0
245        // We do not want to go further because
246        // - it would be uselessly slower
247        // - self.splvals is empty so it would crash (although we could fill it with 1.0)
248        if self.edge_offsets[0].is_empty() {
249            let x = self.offsets[0][start.0] as usize;
250            let y = self.offsets[1][start.1] as usize;
251            let z = self.offsets[2][start.2] as usize;
252            return data[(x, y, z)].to_f64().unwrap();
253        }
254
255        // Linear interpolation use a nxnxn block. This is simple enough, but we must adjust this
256        // block when the `start` is near the edges.
257        let n = self.order + 1;
258        let valid_index = |original_offset, is_edge, start, d: usize, v| {
259            (original_offset + if is_edge { self.edge_offsets[d][(start, v)] } else { v as isize })
260                as usize
261        };
262
263        let original_offset_x = self.offsets[0][start.0];
264        let is_edge_x = self.is_edge_case[0][start.0];
265        let mut xs = [0; 6];
266        let original_offset_y = self.offsets[1][start.1];
267        let is_edge_y = self.is_edge_case[1][start.1];
268        let mut ys = [0; 6];
269        let original_offset_z = self.offsets[2][start.2];
270        let is_edge_z = self.is_edge_case[2][start.2];
271        let mut zs = [0; 6];
272        for i in 0..n {
273            xs[i] = valid_index(original_offset_x, is_edge_x, start.0, 0, i);
274            ys[i] = valid_index(original_offset_y, is_edge_y, start.1, 1, i);
275            zs[i] = valid_index(original_offset_z, is_edge_z, start.2, 2, i);
276        }
277
278        let mut t = 0.0;
279        for (z, &idx_z) in zs[..n].iter().enumerate() {
280            let spline_z = self.splvals[2][(start.2, z)];
281            for (y, &idx_y) in ys[..n].iter().enumerate() {
282                let spline_yz = self.splvals[1][(start.1, y)] * spline_z;
283                for (x, &idx_x) in xs[..n].iter().enumerate() {
284                    let spline_xyz = self.splvals[0][(start.0, x)] * spline_yz;
285                    t += data[(idx_x, idx_y, idx_z)].to_f64().unwrap() * spline_xyz;
286                }
287            }
288        }
289        t
290    }
291}
292
293fn build_splines(to: f64, spline: &mut ArrayViewMut1<f64>, order: usize) {
294    let x = to - if order & 1 == 1 { to } else { to + 0.5 }.floor();
295    match order {
296        1 => spline[0] = 1.0 - x,
297        2 => {
298            spline[0] = 0.5 * (0.5 - x).powi(2);
299            spline[1] = 0.75 - x * x;
300        }
301        3 => {
302            let z = 1.0 - x;
303            spline[0] = z * z * z / 6.0;
304            spline[1] = (x * x * (x - 2.0) * 3.0 + 4.0) / 6.0;
305            spline[2] = (z * z * (z - 2.0) * 3.0 + 4.0) / 6.0;
306        }
307        4 => {
308            let t = x * x;
309            let y = 1.0 + x;
310            let z = 1.0 - x;
311            spline[0] = (0.5 - x).powi(4) / 24.0;
312            spline[1] = y * (y * (y * (5.0 - y) / 6.0 - 1.25) + 5.0 / 24.0) + 55.0 / 96.0;
313            spline[2] = t * (t * 0.25 - 0.625) + 115.0 / 192.0;
314            spline[3] = z * (z * (z * (5.0 - z) / 6.0 - 1.25) + 5.0 / 24.0) + 55.0 / 96.0;
315        }
316        5 => {
317            let y = 1.0 - x;
318            let t = y * y;
319            spline[0] = y * t * t / 120.0;
320            let y = x + 1.0;
321            spline[1] = y * (y * (y * (y * (y / 24.0 - 0.375) + 1.25) - 1.75) + 0.625) + 0.425;
322            let t = x * x;
323            spline[2] = t * (t * (0.25 - x / 12.0) - 0.5) + 0.55;
324            let z = 1.0 - x;
325            let t = z * z;
326            spline[3] = t * (t * (0.25 - z / 12.0) - 0.5) + 0.55;
327            let z = z + 1.0;
328            spline[4] = z * (z * (z * (z * (z / 24.0 - 0.375) + 1.25) - 1.75) + 0.625) + 0.425;
329        }
330        _ => panic!("order must be between 1 and 5"),
331    }
332    spline[order] = 1.0 - spline.slice(s![..order]).sum();
333}
334
335fn map_coordinates<A>(mut idx: f64, len: f64, mode: BorderMode<A>) -> f64 {
336    match mode {
337        BorderMode::Constant(_) => {
338            if idx < 0.0 || idx >= len {
339                idx = -1.0;
340            }
341        }
342        BorderMode::Nearest => {
343            if idx < 0.0 {
344                idx = 0.0;
345            } else if idx >= len {
346                idx = len - 1.0;
347            }
348        }
349        BorderMode::Mirror => {
350            let s2 = 2.0 * len - 2.0;
351            if idx < 0.0 {
352                idx = s2 * (-idx / s2).floor() + idx;
353                idx = if idx <= 1.0 - len { idx + s2 } else { -idx };
354            } else if idx >= len {
355                idx -= s2 * (idx / s2).floor();
356                if idx >= len {
357                    idx = s2 - idx;
358                }
359            }
360        }
361        BorderMode::Reflect => {
362            let s2 = 2.0 * len;
363            if idx < 0.0 {
364                if idx < -s2 {
365                    idx = s2 * (-idx / s2).floor() + idx;
366                }
367                idx = if idx < -len { idx + s2 } else { -idx - 1.0 };
368            } else if idx >= len {
369                idx -= s2 * (idx / s2).floor();
370                if idx >= len {
371                    idx = s2 - idx - 1.0;
372                }
373            }
374        }
375        BorderMode::Wrap => {
376            let s = len - 1.0;
377            if idx < 0.0 {
378                idx += s * ((-idx / s).floor() + 1.0);
379            } else if idx >= len {
380                idx -= s * (idx / s).floor();
381            }
382        }
383    };
384    idx
385}