1use std::borrow::Cow;
4
5use ndarray::{
6 s, Array, Array1, ArrayRef, ArrayView1, Axis, AxisDescription, Dimension, Slice, Zip,
7};
8use ndarray_stats::QuantileExt;
9use num_traits::{FromPrimitive, Num, Zero};
10
11use crate::array_like;
12
13#[derive(Copy, Clone, Debug, PartialEq)]
14pub enum PadMode<T> {
16 Constant(T),
20
21 Edge,
25
26 Maximum,
30
31 Mean,
35
36 Median,
40
41 Minimum,
45
46 Reflect,
51
52 Symmetric,
56
57 Wrap,
62}
63
64impl<T: PartialEq> PadMode<T> {
65 pub(crate) fn init(&self) -> T
66 where
67 T: Copy + Zero,
68 {
69 match *self {
70 PadMode::Constant(init) => init,
71 _ => T::zero(),
72 }
73 }
74
75 fn action(&self) -> PadAction {
76 match *self {
77 PadMode::Constant(_) => PadAction::StopAfterCopy,
78 PadMode::Maximum | PadMode::Mean | PadMode::Median | PadMode::Minimum => {
79 PadAction::ByLane
80 }
81 PadMode::Reflect | PadMode::Symmetric => PadAction::ByReflecting,
82 PadMode::Wrap => PadAction::ByWrapping,
83 PadMode::Edge => PadAction::BySides,
84 }
85 }
86
87 fn dynamic_value(&self, lane: ArrayView1<T>, buffer: &mut Array1<T>) -> T
88 where
89 T: Clone + Copy + FromPrimitive + Num + PartialOrd,
90 {
91 match *self {
92 PadMode::Minimum => *lane.min().expect("Can't find min because of NaN values"),
93 PadMode::Mean => lane.mean().expect("Can't find mean because of NaN values"),
94 PadMode::Median => {
95 buffer.assign(&lane);
96 buffer.as_slice_mut().unwrap().sort_unstable_by(|a, b| {
97 a.partial_cmp(b).expect("Can't find median because of NaN values")
98 });
99 let n = buffer.len();
100 let h = (n - 1) / 2;
101 if n & 1 > 0 {
102 buffer[h]
103 } else {
104 (buffer[h] + buffer[h + 1]) / T::from_u32(2).unwrap()
105 }
106 }
107 PadMode::Maximum => *lane.max().expect("Can't find max because of NaN values"),
108 _ => panic!("Only Minimum, Median and Maximum have a dynamic value"),
109 }
110 }
111
112 fn needs_buffer(&self) -> bool {
113 *self == PadMode::Median
114 }
115}
116
117#[derive(PartialEq)]
118enum PadAction {
119 StopAfterCopy,
120 ByLane,
121 ByReflecting,
122 ByWrapping,
123 BySides,
124}
125
126pub fn pad<A, D>(data: &ArrayRef<A, D>, pad: &[[usize; 2]], mode: PadMode<A>) -> Array<A, D>
133where
134 A: Copy + FromPrimitive + Num + PartialOrd,
135 D: Dimension,
136{
137 let pad = read_pad(data.ndim(), pad);
138 let mut new_dim = data.raw_dim();
139 for (ax, (&ax_len, pad)) in data.shape().iter().zip(pad.iter()).enumerate() {
140 new_dim[ax] = ax_len + pad[0] + pad[1];
141 }
142
143 let mut padded = array_like(&data, new_dim, mode.init());
144 pad_to(data, &pad, mode, &mut padded);
145 padded
146}
147
148pub fn pad_to<A, D>(
158 data: &ArrayRef<A, D>,
159 pad: &[[usize; 2]],
160 mode: PadMode<A>,
161 output: &mut Array<A, D>,
162) where
163 A: Copy + FromPrimitive + Num + PartialOrd,
164 D: Dimension,
165{
166 let pad = read_pad(data.ndim(), pad);
167
168 output
170 .slice_each_axis_mut(|ad| {
171 let AxisDescription { axis, len, .. } = ad;
172 let pad = pad[axis.index()];
173 Slice::from(pad[0]..len - pad[1])
174 })
175 .assign(data);
176
177 match mode.action() {
178 PadAction::StopAfterCopy => { }
179 PadAction::ByReflecting => {
180 let edge_offset = match mode {
181 PadMode::Reflect => 1,
182 PadMode::Symmetric => 0,
183 _ => unreachable!(),
184 };
185 for d in 0..data.ndim() {
186 let pad = pad[d];
187 let d = Axis(d);
188
189 let (mut left, rest) = output.view_mut().split_at(d, pad[0]);
190 left.assign(&rest.slice_each_axis(|ad| {
191 if ad.axis == d {
192 Slice::from(edge_offset..edge_offset + pad[0]).step_by(-1)
193 } else {
194 Slice::from(..)
195 }
196 }));
197
198 let idx = output.len_of(d) - pad[1];
199 let (rest, mut right) = output.view_mut().split_at(d, idx);
200 right.assign(&rest.slice_each_axis(|ad| {
201 let AxisDescription { axis, len, .. } = ad;
202 if axis == d {
203 Slice::from(len - pad[1] - edge_offset..len - edge_offset).step_by(-1)
204 } else {
205 Slice::from(..)
206 }
207 }));
208 }
209 }
210 PadAction::ByWrapping => {
211 for d in 0..data.ndim() {
212 let pad = pad[d];
213 let d = Axis(d);
214
215 let (mut left, rest) = output.view_mut().split_at(d, pad[0]);
216 left.assign(&rest.slice_each_axis(|ad| {
217 let AxisDescription { axis, len, .. } = ad;
218 if axis == d {
219 Slice::from(len - pad[0] - pad[1]..len - pad[1])
220 } else {
221 Slice::from(..)
222 }
223 }));
224
225 let idx = output.len_of(d) - pad[1];
226 let (rest, mut right) = output.view_mut().split_at(d, idx);
227 right.assign(&rest.slice_each_axis(|ad| {
228 if ad.axis == d {
229 Slice::from(pad[0]..pad[0] + pad[1])
230 } else {
231 Slice::from(..)
232 }
233 }));
234 }
235 }
236 PadAction::ByLane => {
237 for d in 0..data.ndim() {
238 let start = pad[d][0];
239 let end = start + data.shape()[d];
240 let data_zone = s![start..end];
241 let real_end = output.shape()[d];
242 let mut buffer =
243 if mode.needs_buffer() { Array1::zeros(end - start) } else { Array1::zeros(0) };
244 Zip::from(output.lanes_mut(Axis(d))).for_each(|mut lane| {
245 let v = mode.dynamic_value(lane.slice(data_zone), &mut buffer);
246 for i in 0..start {
247 lane[i] = v;
248 }
249 for i in end..real_end {
250 lane[i] = v;
251 }
252 });
253 }
254 }
255 PadAction::BySides => {
256 for d in 0..data.ndim() {
257 let start = pad[d][0];
258 let end = start + data.shape()[d];
259 let real_end = output.shape()[d];
260 Zip::from(output.lanes_mut(Axis(d))).for_each(|mut lane| {
261 let left = lane[start];
262 let right = lane[end - 1];
263 for i in 0..start {
264 lane[i] = left;
265 }
266 for i in end..real_end {
267 lane[i] = right;
268 }
269 });
270 }
271 }
272 }
273}
274
275fn read_pad(nb_dim: usize, pad: &[[usize; 2]]) -> Cow<'_, [[usize; 2]]> {
276 if pad.len() == 1 && pad.len() < nb_dim {
277 Cow::from(vec![pad[0]; nb_dim])
279 } else if pad.len() == nb_dim {
280 Cow::from(pad)
281 } else {
282 panic!("Inconsistant number of dimensions and pad arrays");
283 }
284}