concision_core/ops/pad/
utils.rs

1/*
2    Appellation: utils <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::{PadAction, PadError, PadMode};
6use crate::ArrayLike;
7use ndarray::{Array, ArrayBase, AxisDescription, Data, DataOwned, Dimension, Slice};
8use num::{FromPrimitive, Num};
9
10fn reader(nb_dim: usize, pad: &[[usize; 2]]) -> Result<Vec<[usize; 2]>, PadError> {
11    if pad.len() == 1 && pad.len() < nb_dim {
12        // The user provided a single padding for all dimensions
13        Ok(vec![pad[0]; nb_dim])
14    } else if pad.len() == nb_dim {
15        Ok(pad.to_vec())
16    } else {
17        Err(PadError::InconsistentDimensions)
18    }
19}
20
21pub fn pad<A, S, D>(
22    data: &ArrayBase<S, D>,
23    pad: &[[usize; 2]],
24    mode: PadMode<A>,
25) -> Result<Array<A, D>, PadError>
26where
27    A: Copy + FromPrimitive + Num,
28    D: Dimension,
29    S: DataOwned<Elem = A>,
30{
31    let pad = reader(data.ndim(), pad)?;
32    let mut dim = data.raw_dim();
33    for (ax, (&ax_len, pad)) in data.shape().iter().zip(pad.iter()).enumerate() {
34        dim[ax] = ax_len + pad[0] + pad[1];
35    }
36
37    let mut padded = data.array_like(dim, mode.init()).to_owned();
38    pad_to(data, &pad, mode, &mut padded)?;
39    Ok(padded)
40}
41
42pub fn pad_to<A, S, D>(
43    data: &ArrayBase<S, D>,
44    pad: &[[usize; 2]],
45    mode: PadMode<A>,
46    output: &mut Array<A, D>,
47) -> super::PadResult
48where
49    A: Copy + FromPrimitive + Num,
50    D: Dimension,
51    S: Data<Elem = A>,
52{
53    let pad = reader(data.ndim(), pad)?;
54
55    // Select portion of padded array that needs to be copied from the original array.
56    output
57        .slice_each_axis_mut(|ad| {
58            let AxisDescription { axis, len, .. } = ad;
59            let pad = pad[axis.index()];
60            Slice::from(pad[0]..len - pad[1])
61        })
62        .assign(data);
63
64    match mode.into_pad_action() {
65        PadAction::StopAfterCopy => {
66            // Do nothing
67            Ok(())
68        }
69        _ => unimplemented!(),
70    }
71}