ndarray_ndimage/
pad.rs

1//! This modules defines some image padding methods for 3D images.
2
3use std::borrow::Cow;
4
5use ndarray::{
6    s, Array, Array1, ArrayBase, ArrayView1, Axis, AxisDescription, Data, Dimension, Slice, Zip,
7};
8use ndarray_stats::QuantileExt;
9use num_traits::{FromPrimitive, Num, Zero};
10
11use crate::array_like;
12
13#[derive(Copy, Clone, Debug, PartialEq)]
14/// Method that will be used to select the padded values.
15pub enum PadMode<T> {
16    /// Pads with a constant value.
17    ///
18    /// `[1, 2, 3] -> [T, T, 1, 2, 3, T, T]`
19    Constant(T),
20
21    /// Pads with the edge values of array.
22    ///
23    /// `[1, 2, 3] -> [1, 1, 1, 2, 3, 3, 3]`
24    Edge,
25
26    /// Pads with the maximum value of all or part of the vector along each axis.
27    ///
28    /// `[1, 2, 3] -> [3, 3, 1, 2, 3, 3, 3]`
29    Maximum,
30
31    /// Pads with the mean value of all or part of the vector along each axis.
32    ///
33    /// `[1, 2, 3] -> [2, 2, 1, 2, 3, 2, 2]`
34    Mean,
35
36    /// Pads with the median value of all or part of the vector along each axis.
37    ///
38    /// `[1, 2, 3] -> [2, 2, 1, 2, 3, 2, 2]`
39    Median,
40
41    /// Pads with the minimum value of all or part of the vector along each axis.
42    ///
43    /// `[1, 2, 3] -> [1, 1, 1, 2, 3, 1, 1]`
44    Minimum,
45
46    /// Pads with the reflection of the vector mirrored on the first and last values of the vector
47    /// along each axis.
48    ///
49    /// `[1, 2, 3] -> [3, 2, 1, 2, 3, 2, 1]`
50    Reflect,
51
52    /// Pads with the reflection of the vector mirrored along the edge of the array.
53    ///
54    /// `[1, 2, 3] -> [2, 1, 1, 2, 3, 3, 2]`
55    Symmetric,
56
57    /// Pads with the wrap of the vector along the axis. The first values are used to pad the end
58    /// and the end values are used to pad the beginning.
59    ///
60    /// `[1, 2, 3] -> [2, 3, 1, 2, 3, 1, 2]`
61    Wrap,
62}
63
64impl<T: PartialEq> PadMode<T> {
65    pub(crate) fn init(&self) -> T
66    where
67        T: Copy + Zero,
68    {
69        match *self {
70            PadMode::Constant(init) => init,
71            _ => T::zero(),
72        }
73    }
74
75    fn action(&self) -> PadAction {
76        match *self {
77            PadMode::Constant(_) => PadAction::StopAfterCopy,
78            PadMode::Maximum | PadMode::Mean | PadMode::Median | PadMode::Minimum => {
79                PadAction::ByLane
80            }
81            PadMode::Reflect | PadMode::Symmetric => PadAction::ByReflecting,
82            PadMode::Wrap => PadAction::ByWrapping,
83            PadMode::Edge => PadAction::BySides,
84        }
85    }
86
87    fn dynamic_value(&self, lane: ArrayView1<T>, buffer: &mut Array1<T>) -> T
88    where
89        T: Clone + Copy + FromPrimitive + Num + PartialOrd,
90    {
91        match *self {
92            PadMode::Minimum => *lane.min().expect("Can't find min because of NaN values"),
93            PadMode::Mean => lane.mean().expect("Can't find mean because of NaN values"),
94            PadMode::Median => {
95                buffer.assign(&lane);
96                buffer.as_slice_mut().unwrap().sort_unstable_by(|a, b| {
97                    a.partial_cmp(b).expect("Can't find median because of NaN values")
98                });
99                let n = buffer.len();
100                let h = (n - 1) / 2;
101                if n & 1 > 0 {
102                    buffer[h]
103                } else {
104                    (buffer[h] + buffer[h + 1]) / T::from_u32(2).unwrap()
105                }
106            }
107            PadMode::Maximum => *lane.max().expect("Can't find max because of NaN values"),
108            _ => panic!("Only Minimum, Median and Maximum have a dynamic value"),
109        }
110    }
111
112    fn needs_buffer(&self) -> bool {
113        *self == PadMode::Median
114    }
115}
116
117#[derive(PartialEq)]
118enum PadAction {
119    StopAfterCopy,
120    ByLane,
121    ByReflecting,
122    ByWrapping,
123    BySides,
124}
125
126/// Pad an image.
127///
128/// * `data` - A N-D array of the data to pad.
129/// * `pad` - Number of values padded to the edges of each axis.
130/// * `mode` - Method that will be used to select the padded values. See the
131///   [`PadMode`](crate::PadMode) enum for more information.
132pub fn pad<S, A, D>(data: &ArrayBase<S, D>, pad: &[[usize; 2]], mode: PadMode<A>) -> Array<A, D>
133where
134    S: Data<Elem = A>,
135    A: Copy + FromPrimitive + Num + PartialOrd,
136    D: Dimension,
137{
138    let pad = read_pad(data.ndim(), pad);
139    let mut new_dim = data.raw_dim();
140    for (ax, (&ax_len, pad)) in data.shape().iter().zip(pad.iter()).enumerate() {
141        new_dim[ax] = ax_len + pad[0] + pad[1];
142    }
143
144    let mut padded = array_like(&data, new_dim, mode.init());
145    pad_to(data, &pad, mode, &mut padded);
146    padded
147}
148
149/// Pad an image.
150///
151/// Write the result in the already_allocated array `output`.
152///
153/// * `data` - A N-D array of the data to pad.
154/// * `pad` - Number of values padded to the edges of each axis.
155/// * `mode` - Method that will be used to select the padded values. See the
156///   [`PadMode`](crate::PadMode) enum for more information.
157/// * `output` - An already allocated N-D array used to write the results.
158pub fn pad_to<S, A, D>(
159    data: &ArrayBase<S, D>,
160    pad: &[[usize; 2]],
161    mode: PadMode<A>,
162    output: &mut Array<A, D>,
163) where
164    S: Data<Elem = A>,
165    A: Copy + FromPrimitive + Num + PartialOrd,
166    D: Dimension,
167{
168    let pad = read_pad(data.ndim(), pad);
169
170    // Select portion of padded array that needs to be copied from the original array.
171    output
172        .slice_each_axis_mut(|ad| {
173            let AxisDescription { axis, len, .. } = ad;
174            let pad = pad[axis.index()];
175            Slice::from(pad[0]..len - pad[1])
176        })
177        .assign(data);
178
179    match mode.action() {
180        PadAction::StopAfterCopy => { /* Nothing */ }
181        PadAction::ByReflecting => {
182            let edge_offset = match mode {
183                PadMode::Reflect => 1,
184                PadMode::Symmetric => 0,
185                _ => unreachable!(),
186            };
187            for d in 0..data.ndim() {
188                let pad = pad[d];
189                let d = Axis(d);
190
191                let (mut left, rest) = output.view_mut().split_at(d, pad[0]);
192                left.assign(&rest.slice_each_axis(|ad| {
193                    if ad.axis == d {
194                        Slice::from(edge_offset..edge_offset + pad[0]).step_by(-1)
195                    } else {
196                        Slice::from(..)
197                    }
198                }));
199
200                let idx = output.len_of(d) - pad[1];
201                let (rest, mut right) = output.view_mut().split_at(d, idx);
202                right.assign(&rest.slice_each_axis(|ad| {
203                    let AxisDescription { axis, len, .. } = ad;
204                    if axis == d {
205                        Slice::from(len - pad[1] - edge_offset..len - edge_offset).step_by(-1)
206                    } else {
207                        Slice::from(..)
208                    }
209                }));
210            }
211        }
212        PadAction::ByWrapping => {
213            for d in 0..data.ndim() {
214                let pad = pad[d];
215                let d = Axis(d);
216
217                let (mut left, rest) = output.view_mut().split_at(d, pad[0]);
218                left.assign(&rest.slice_each_axis(|ad| {
219                    let AxisDescription { axis, len, .. } = ad;
220                    if axis == d {
221                        Slice::from(len - pad[0] - pad[1]..len - pad[1])
222                    } else {
223                        Slice::from(..)
224                    }
225                }));
226
227                let idx = output.len_of(d) - pad[1];
228                let (rest, mut right) = output.view_mut().split_at(d, idx);
229                right.assign(&rest.slice_each_axis(|ad| {
230                    if ad.axis == d {
231                        Slice::from(pad[0]..pad[0] + pad[1])
232                    } else {
233                        Slice::from(..)
234                    }
235                }));
236            }
237        }
238        PadAction::ByLane => {
239            for d in 0..data.ndim() {
240                let start = pad[d][0];
241                let end = start + data.shape()[d];
242                let data_zone = s![start..end];
243                let real_end = output.shape()[d];
244                let mut buffer =
245                    if mode.needs_buffer() { Array1::zeros(end - start) } else { Array1::zeros(0) };
246                Zip::from(output.lanes_mut(Axis(d))).for_each(|mut lane| {
247                    let v = mode.dynamic_value(lane.slice(data_zone), &mut buffer);
248                    for i in 0..start {
249                        lane[i] = v;
250                    }
251                    for i in end..real_end {
252                        lane[i] = v;
253                    }
254                });
255            }
256        }
257        PadAction::BySides => {
258            for d in 0..data.ndim() {
259                let start = pad[d][0];
260                let end = start + data.shape()[d];
261                let real_end = output.shape()[d];
262                Zip::from(output.lanes_mut(Axis(d))).for_each(|mut lane| {
263                    let left = lane[start];
264                    let right = lane[end - 1];
265                    for i in 0..start {
266                        lane[i] = left;
267                    }
268                    for i in end..real_end {
269                        lane[i] = right;
270                    }
271                });
272            }
273        }
274    }
275}
276
277fn read_pad(nb_dim: usize, pad: &[[usize; 2]]) -> Cow<[[usize; 2]]> {
278    if pad.len() == 1 && pad.len() < nb_dim {
279        // The user provided a single padding for all dimensions
280        Cow::from(vec![pad[0]; nb_dim])
281    } else if pad.len() == nb_dim {
282        Cow::from(pad)
283    } else {
284        panic!("Inconsistant number of dimensions and pad arrays");
285    }
286}