1use std::borrow::Cow;
4
5use ndarray::{
6 s, Array, Array1, ArrayBase, ArrayView1, Axis, AxisDescription, Data, Dimension, Slice, Zip,
7};
8use ndarray_stats::QuantileExt;
9use num_traits::{FromPrimitive, Num, Zero};
10
11use crate::array_like;
12
13#[derive(Copy, Clone, Debug, PartialEq)]
14pub enum PadMode<T> {
16 Constant(T),
20
21 Edge,
25
26 Maximum,
30
31 Mean,
35
36 Median,
40
41 Minimum,
45
46 Reflect,
51
52 Symmetric,
56
57 Wrap,
62}
63
64impl<T: PartialEq> PadMode<T> {
65 pub(crate) fn init(&self) -> T
66 where
67 T: Copy + Zero,
68 {
69 match *self {
70 PadMode::Constant(init) => init,
71 _ => T::zero(),
72 }
73 }
74
75 fn action(&self) -> PadAction {
76 match *self {
77 PadMode::Constant(_) => PadAction::StopAfterCopy,
78 PadMode::Maximum | PadMode::Mean | PadMode::Median | PadMode::Minimum => {
79 PadAction::ByLane
80 }
81 PadMode::Reflect | PadMode::Symmetric => PadAction::ByReflecting,
82 PadMode::Wrap => PadAction::ByWrapping,
83 PadMode::Edge => PadAction::BySides,
84 }
85 }
86
87 fn dynamic_value(&self, lane: ArrayView1<T>, buffer: &mut Array1<T>) -> T
88 where
89 T: Clone + Copy + FromPrimitive + Num + PartialOrd,
90 {
91 match *self {
92 PadMode::Minimum => *lane.min().expect("Can't find min because of NaN values"),
93 PadMode::Mean => lane.mean().expect("Can't find mean because of NaN values"),
94 PadMode::Median => {
95 buffer.assign(&lane);
96 buffer.as_slice_mut().unwrap().sort_unstable_by(|a, b| {
97 a.partial_cmp(b).expect("Can't find median because of NaN values")
98 });
99 let n = buffer.len();
100 let h = (n - 1) / 2;
101 if n & 1 > 0 {
102 buffer[h]
103 } else {
104 (buffer[h] + buffer[h + 1]) / T::from_u32(2).unwrap()
105 }
106 }
107 PadMode::Maximum => *lane.max().expect("Can't find max because of NaN values"),
108 _ => panic!("Only Minimum, Median and Maximum have a dynamic value"),
109 }
110 }
111
112 fn needs_buffer(&self) -> bool {
113 *self == PadMode::Median
114 }
115}
116
117#[derive(PartialEq)]
118enum PadAction {
119 StopAfterCopy,
120 ByLane,
121 ByReflecting,
122 ByWrapping,
123 BySides,
124}
125
126pub fn pad<S, A, D>(data: &ArrayBase<S, D>, pad: &[[usize; 2]], mode: PadMode<A>) -> Array<A, D>
133where
134 S: Data<Elem = A>,
135 A: Copy + FromPrimitive + Num + PartialOrd,
136 D: Dimension,
137{
138 let pad = read_pad(data.ndim(), pad);
139 let mut new_dim = data.raw_dim();
140 for (ax, (&ax_len, pad)) in data.shape().iter().zip(pad.iter()).enumerate() {
141 new_dim[ax] = ax_len + pad[0] + pad[1];
142 }
143
144 let mut padded = array_like(&data, new_dim, mode.init());
145 pad_to(data, &pad, mode, &mut padded);
146 padded
147}
148
149pub fn pad_to<S, A, D>(
159 data: &ArrayBase<S, D>,
160 pad: &[[usize; 2]],
161 mode: PadMode<A>,
162 output: &mut Array<A, D>,
163) where
164 S: Data<Elem = A>,
165 A: Copy + FromPrimitive + Num + PartialOrd,
166 D: Dimension,
167{
168 let pad = read_pad(data.ndim(), pad);
169
170 output
172 .slice_each_axis_mut(|ad| {
173 let AxisDescription { axis, len, .. } = ad;
174 let pad = pad[axis.index()];
175 Slice::from(pad[0]..len - pad[1])
176 })
177 .assign(data);
178
179 match mode.action() {
180 PadAction::StopAfterCopy => { }
181 PadAction::ByReflecting => {
182 let edge_offset = match mode {
183 PadMode::Reflect => 1,
184 PadMode::Symmetric => 0,
185 _ => unreachable!(),
186 };
187 for d in 0..data.ndim() {
188 let pad = pad[d];
189 let d = Axis(d);
190
191 let (mut left, rest) = output.view_mut().split_at(d, pad[0]);
192 left.assign(&rest.slice_each_axis(|ad| {
193 if ad.axis == d {
194 Slice::from(edge_offset..edge_offset + pad[0]).step_by(-1)
195 } else {
196 Slice::from(..)
197 }
198 }));
199
200 let idx = output.len_of(d) - pad[1];
201 let (rest, mut right) = output.view_mut().split_at(d, idx);
202 right.assign(&rest.slice_each_axis(|ad| {
203 let AxisDescription { axis, len, .. } = ad;
204 if axis == d {
205 Slice::from(len - pad[1] - edge_offset..len - edge_offset).step_by(-1)
206 } else {
207 Slice::from(..)
208 }
209 }));
210 }
211 }
212 PadAction::ByWrapping => {
213 for d in 0..data.ndim() {
214 let pad = pad[d];
215 let d = Axis(d);
216
217 let (mut left, rest) = output.view_mut().split_at(d, pad[0]);
218 left.assign(&rest.slice_each_axis(|ad| {
219 let AxisDescription { axis, len, .. } = ad;
220 if axis == d {
221 Slice::from(len - pad[0] - pad[1]..len - pad[1])
222 } else {
223 Slice::from(..)
224 }
225 }));
226
227 let idx = output.len_of(d) - pad[1];
228 let (rest, mut right) = output.view_mut().split_at(d, idx);
229 right.assign(&rest.slice_each_axis(|ad| {
230 if ad.axis == d {
231 Slice::from(pad[0]..pad[0] + pad[1])
232 } else {
233 Slice::from(..)
234 }
235 }));
236 }
237 }
238 PadAction::ByLane => {
239 for d in 0..data.ndim() {
240 let start = pad[d][0];
241 let end = start + data.shape()[d];
242 let data_zone = s![start..end];
243 let real_end = output.shape()[d];
244 let mut buffer =
245 if mode.needs_buffer() { Array1::zeros(end - start) } else { Array1::zeros(0) };
246 Zip::from(output.lanes_mut(Axis(d))).for_each(|mut lane| {
247 let v = mode.dynamic_value(lane.slice(data_zone), &mut buffer);
248 for i in 0..start {
249 lane[i] = v;
250 }
251 for i in end..real_end {
252 lane[i] = v;
253 }
254 });
255 }
256 }
257 PadAction::BySides => {
258 for d in 0..data.ndim() {
259 let start = pad[d][0];
260 let end = start + data.shape()[d];
261 let real_end = output.shape()[d];
262 Zip::from(output.lanes_mut(Axis(d))).for_each(|mut lane| {
263 let left = lane[start];
264 let right = lane[end - 1];
265 for i in 0..start {
266 lane[i] = left;
267 }
268 for i in end..real_end {
269 lane[i] = right;
270 }
271 });
272 }
273 }
274 }
275}
276
277fn read_pad(nb_dim: usize, pad: &[[usize; 2]]) -> Cow<[[usize; 2]]> {
278 if pad.len() == 1 && pad.len() < nb_dim {
279 Cow::from(vec![pad[0]; nb_dim])
281 } else if pad.len() == nb_dim {
282 Cow::from(pad)
283 } else {
284 panic!("Inconsistant number of dimensions and pad arrays");
285 }
286}