Skip to main content

ndarray_ndimage/
pad.rs

1//! This modules defines some image padding methods for 3D images.
2
3use std::borrow::Cow;
4
5use ndarray::{
6    s, Array, Array1, ArrayRef, ArrayView1, Axis, AxisDescription, Dimension, Slice, Zip,
7};
8use ndarray_stats::QuantileExt;
9use num_traits::{FromPrimitive, Num, Zero};
10
11use crate::array_like;
12
13#[derive(Copy, Clone, Debug, PartialEq)]
14/// Method that will be used to select the padded values.
15pub enum PadMode<T> {
16    /// Pads with a constant value.
17    ///
18    /// `[1, 2, 3] -> [T, T, 1, 2, 3, T, T]`
19    Constant(T),
20
21    /// Pads with the edge values of array.
22    ///
23    /// `[1, 2, 3] -> [1, 1, 1, 2, 3, 3, 3]`
24    Edge,
25
26    /// Pads with the maximum value of all or part of the vector along each axis.
27    ///
28    /// `[1, 2, 3] -> [3, 3, 1, 2, 3, 3, 3]`
29    Maximum,
30
31    /// Pads with the mean value of all or part of the vector along each axis.
32    ///
33    /// `[1, 2, 3] -> [2, 2, 1, 2, 3, 2, 2]`
34    Mean,
35
36    /// Pads with the median value of all or part of the vector along each axis.
37    ///
38    /// `[1, 2, 3] -> [2, 2, 1, 2, 3, 2, 2]`
39    Median,
40
41    /// Pads with the minimum value of all or part of the vector along each axis.
42    ///
43    /// `[1, 2, 3] -> [1, 1, 1, 2, 3, 1, 1]`
44    Minimum,
45
46    /// Pads with the reflection of the vector mirrored on the first and last values of the vector
47    /// along each axis.
48    ///
49    /// `[1, 2, 3] -> [3, 2, 1, 2, 3, 2, 1]`
50    Reflect,
51
52    /// Pads with the reflection of the vector mirrored along the edge of the array.
53    ///
54    /// `[1, 2, 3] -> [2, 1, 1, 2, 3, 3, 2]`
55    Symmetric,
56
57    /// Pads with the wrap of the vector along the axis. The first values are used to pad the end
58    /// and the end values are used to pad the beginning.
59    ///
60    /// `[1, 2, 3] -> [2, 3, 1, 2, 3, 1, 2]`
61    Wrap,
62}
63
64impl<T: PartialEq> PadMode<T> {
65    pub(crate) fn init(&self) -> T
66    where
67        T: Copy + Zero,
68    {
69        match *self {
70            PadMode::Constant(init) => init,
71            _ => T::zero(),
72        }
73    }
74
75    fn action(&self) -> PadAction {
76        match *self {
77            PadMode::Constant(_) => PadAction::StopAfterCopy,
78            PadMode::Maximum | PadMode::Mean | PadMode::Median | PadMode::Minimum => {
79                PadAction::ByLane
80            }
81            PadMode::Reflect | PadMode::Symmetric => PadAction::ByReflecting,
82            PadMode::Wrap => PadAction::ByWrapping,
83            PadMode::Edge => PadAction::BySides,
84        }
85    }
86
87    fn dynamic_value(&self, lane: ArrayView1<T>, buffer: &mut Array1<T>) -> T
88    where
89        T: Clone + Copy + FromPrimitive + Num + PartialOrd,
90    {
91        match *self {
92            PadMode::Minimum => *lane.min().expect("Can't find min because of NaN values"),
93            PadMode::Mean => lane.mean().expect("Can't find mean because of NaN values"),
94            PadMode::Median => {
95                buffer.assign(&lane);
96                buffer.as_slice_mut().unwrap().sort_unstable_by(|a, b| {
97                    a.partial_cmp(b).expect("Can't find median because of NaN values")
98                });
99                let n = buffer.len();
100                let h = (n - 1) / 2;
101                if n & 1 > 0 {
102                    buffer[h]
103                } else {
104                    (buffer[h] + buffer[h + 1]) / T::from_u32(2).unwrap()
105                }
106            }
107            PadMode::Maximum => *lane.max().expect("Can't find max because of NaN values"),
108            _ => panic!("Only Minimum, Median and Maximum have a dynamic value"),
109        }
110    }
111
112    fn needs_buffer(&self) -> bool {
113        *self == PadMode::Median
114    }
115}
116
117#[derive(PartialEq)]
118enum PadAction {
119    StopAfterCopy,
120    ByLane,
121    ByReflecting,
122    ByWrapping,
123    BySides,
124}
125
126/// Pad an image.
127///
128/// * `data` - A N-D array of the data to pad.
129/// * `pad` - Number of values padded to the edges of each axis.
130/// * `mode` - Method that will be used to select the padded values. See the
131///   [`PadMode`](crate::PadMode) enum for more information.
132pub fn pad<A, D>(data: &ArrayRef<A, D>, pad: &[[usize; 2]], mode: PadMode<A>) -> Array<A, D>
133where
134    A: Copy + FromPrimitive + Num + PartialOrd,
135    D: Dimension,
136{
137    let pad = read_pad(data.ndim(), pad);
138    let mut new_dim = data.raw_dim();
139    for (ax, (&ax_len, pad)) in data.shape().iter().zip(pad.iter()).enumerate() {
140        new_dim[ax] = ax_len + pad[0] + pad[1];
141    }
142
143    let mut padded = array_like(&data, new_dim, mode.init());
144    pad_to(data, &pad, mode, &mut padded);
145    padded
146}
147
148/// Pad an image.
149///
150/// Write the result in the already_allocated array `output`.
151///
152/// * `data` - A N-D array of the data to pad.
153/// * `pad` - Number of values padded to the edges of each axis.
154/// * `mode` - Method that will be used to select the padded values. See the
155///   [`PadMode`](crate::PadMode) enum for more information.
156/// * `output` - An already allocated N-D array used to write the results.
157pub fn pad_to<A, D>(
158    data: &ArrayRef<A, D>,
159    pad: &[[usize; 2]],
160    mode: PadMode<A>,
161    output: &mut Array<A, D>,
162) where
163    A: Copy + FromPrimitive + Num + PartialOrd,
164    D: Dimension,
165{
166    let pad = read_pad(data.ndim(), pad);
167
168    // Select portion of padded array that needs to be copied from the original array.
169    output
170        .slice_each_axis_mut(|ad| {
171            let AxisDescription { axis, len, .. } = ad;
172            let pad = pad[axis.index()];
173            Slice::from(pad[0]..len - pad[1])
174        })
175        .assign(data);
176
177    match mode.action() {
178        PadAction::StopAfterCopy => { /* Nothing */ }
179        PadAction::ByReflecting => {
180            let edge_offset = match mode {
181                PadMode::Reflect => 1,
182                PadMode::Symmetric => 0,
183                _ => unreachable!(),
184            };
185            for d in 0..data.ndim() {
186                let pad = pad[d];
187                let d = Axis(d);
188
189                let (mut left, rest) = output.view_mut().split_at(d, pad[0]);
190                left.assign(&rest.slice_each_axis(|ad| {
191                    if ad.axis == d {
192                        Slice::from(edge_offset..edge_offset + pad[0]).step_by(-1)
193                    } else {
194                        Slice::from(..)
195                    }
196                }));
197
198                let idx = output.len_of(d) - pad[1];
199                let (rest, mut right) = output.view_mut().split_at(d, idx);
200                right.assign(&rest.slice_each_axis(|ad| {
201                    let AxisDescription { axis, len, .. } = ad;
202                    if axis == d {
203                        Slice::from(len - pad[1] - edge_offset..len - edge_offset).step_by(-1)
204                    } else {
205                        Slice::from(..)
206                    }
207                }));
208            }
209        }
210        PadAction::ByWrapping => {
211            for d in 0..data.ndim() {
212                let pad = pad[d];
213                let d = Axis(d);
214
215                let (mut left, rest) = output.view_mut().split_at(d, pad[0]);
216                left.assign(&rest.slice_each_axis(|ad| {
217                    let AxisDescription { axis, len, .. } = ad;
218                    if axis == d {
219                        Slice::from(len - pad[0] - pad[1]..len - pad[1])
220                    } else {
221                        Slice::from(..)
222                    }
223                }));
224
225                let idx = output.len_of(d) - pad[1];
226                let (rest, mut right) = output.view_mut().split_at(d, idx);
227                right.assign(&rest.slice_each_axis(|ad| {
228                    if ad.axis == d {
229                        Slice::from(pad[0]..pad[0] + pad[1])
230                    } else {
231                        Slice::from(..)
232                    }
233                }));
234            }
235        }
236        PadAction::ByLane => {
237            for d in 0..data.ndim() {
238                let start = pad[d][0];
239                let end = start + data.shape()[d];
240                let data_zone = s![start..end];
241                let real_end = output.shape()[d];
242                let mut buffer =
243                    if mode.needs_buffer() { Array1::zeros(end - start) } else { Array1::zeros(0) };
244                Zip::from(output.lanes_mut(Axis(d))).for_each(|mut lane| {
245                    let v = mode.dynamic_value(lane.slice(data_zone), &mut buffer);
246                    for i in 0..start {
247                        lane[i] = v;
248                    }
249                    for i in end..real_end {
250                        lane[i] = v;
251                    }
252                });
253            }
254        }
255        PadAction::BySides => {
256            for d in 0..data.ndim() {
257                let start = pad[d][0];
258                let end = start + data.shape()[d];
259                let real_end = output.shape()[d];
260                Zip::from(output.lanes_mut(Axis(d))).for_each(|mut lane| {
261                    let left = lane[start];
262                    let right = lane[end - 1];
263                    for i in 0..start {
264                        lane[i] = left;
265                    }
266                    for i in end..real_end {
267                        lane[i] = right;
268                    }
269                });
270            }
271        }
272    }
273}
274
275fn read_pad(nb_dim: usize, pad: &[[usize; 2]]) -> Cow<'_, [[usize; 2]]> {
276    if pad.len() == 1 && pad.len() < nb_dim {
277        // The user provided a single padding for all dimensions
278        Cow::from(vec![pad[0]; nb_dim])
279    } else if pad.len() == nb_dim {
280        Cow::from(pad)
281    } else {
282        panic!("Inconsistant number of dimensions and pad arrays");
283    }
284}