ndarray_conv/conv/
mod.rs

1//! Provides convolution operations for `ndarray` arrays.
2//! Includes standard convolution and related utilities.
3
4use ndarray::{
5    Array, ArrayBase, ArrayView, Data, Dim, Dimension, IntoDimension, Ix, RawData, RemoveAxis,
6    SliceArg, SliceInfo, SliceInfoElem,
7};
8use num::traits::NumAssign;
9
10use crate::{
11    dilation::{IntoKernelWithDilation, KernelWithDilation},
12    padding::PaddingExt,
13    ConvMode, PaddingMode,
14};
15
16#[cfg(test)]
17mod tests;
18
19/// Represents explicit convolution parameters after unfolding from `ConvMode`.
20///
21/// This struct holds padding and strides information used directly
22/// by the convolution algorithm.
23pub struct ExplicitConv<const N: usize> {
24    pub padding: [[usize; 2]; N],
25    pub strides: [usize; N],
26}
27
28impl<const N: usize> ConvMode<N> {
29    pub(crate) fn unfold<S>(self, kernel: &KernelWithDilation<S, N>) -> ExplicitConv<N>
30    where
31        S: ndarray::RawData,
32        Dim<[Ix; N]>: Dimension,
33    {
34        let kernel_dim = kernel.kernel.raw_dim();
35        let kernel_dim: [usize; N] = std::array::from_fn(|i|
36                // k + (k - 1) * (d - 1)
37                kernel_dim[i] * kernel.dilation[i] - kernel.dilation[i] + 1);
38
39        match self {
40            ConvMode::Full => ExplicitConv {
41                padding: std::array::from_fn(|i| [kernel_dim[i] - 1; 2]),
42                strides: [1; N],
43            },
44            ConvMode::Same => ExplicitConv {
45                padding: std::array::from_fn(|i| {
46                    let k_size = kernel_dim[i];
47                    if k_size.is_multiple_of(2) {
48                        [(k_size - 1) / 2 + 1, (k_size - 1) / 2]
49                    } else {
50                        [(k_size - 1) / 2; 2]
51                    }
52                }),
53                strides: [1; N],
54            },
55            ConvMode::Valid => ExplicitConv {
56                padding: [[0; 2]; N],
57                strides: [1; N],
58            },
59            ConvMode::Custom { padding, strides } => ExplicitConv {
60                padding: padding.map(|pad| [pad; 2]),
61                strides,
62            },
63            ConvMode::Explicit { padding, strides } => ExplicitConv { padding, strides },
64        }
65    }
66}
67
68/// Extends `ndarray`'s `ArrayBase` with convolution operations.
69///
70/// This trait adds the `conv` method to `ArrayBase`, enabling
71/// standard convolution operations on N-dimensional arrays.
72///
73/// # Type Parameters
74///
75/// *   `T`: The numeric type of the array elements.
76/// *   `S`: The data storage type of the input array.
77/// *   `SK`: The data storage type of the kernel array.
78pub trait ConvExt<'a, T, S, SK, const N: usize>
79where
80    T: NumAssign + Copy,
81    S: RawData,
82    SK: RawData,
83{
84    /// Performs a standard convolution operation.
85    ///
86    /// This method convolves the input array with a given kernel,
87    /// using the specified convolution mode and padding.
88    ///
89    /// # Arguments
90    ///
91    /// *   `kernel`: The convolution kernel. Can be a reference to an array, or an array with dilation settings created using `with_dilation()`.
92    /// *   `conv_mode`: The convolution mode (`Full`, `Same`, `Valid`, `Custom`, `Explicit`).
93    /// *   `padding_mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`).
94    ///
95    /// # Returns
96    ///
97    /// Returns `Ok(Array<T, Dim<[Ix; N]>>)` containing the convolution result, or an `Err(Error<N>)` if the operation fails
98    /// (e.g., due to incompatible shapes or zero-sized dimensions).
99    ///
100    /// # Example
101    ///
102    /// ```rust
103    /// use ndarray::array;
104    /// use ndarray_conv::{ConvExt, ConvMode, PaddingMode};
105    ///
106    /// let input = array![[1, 2, 3], [4, 5, 6]];
107    /// let kernel = array![[1, 1], [1, 1]];
108    /// let result = input.conv(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
109    /// ```
110    fn conv(
111        &self,
112        kernel: impl IntoKernelWithDilation<'a, SK, N>,
113        conv_mode: ConvMode<N>,
114        padding_mode: PaddingMode<N, T>,
115    ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>>;
116}
117
118impl<'a, T, S, SK, const N: usize> ConvExt<'a, T, S, SK, N> for ArrayBase<S, Dim<[Ix; N]>>
119where
120    T: NumAssign + Copy,
121    S: Data<Elem = T> + 'a,
122    SK: Data<Elem = T> + 'a,
123    Dim<[Ix; N]>: RemoveAxis,
124    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
125    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
126        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
127{
128    fn conv(
129        &self,
130        kernel: impl IntoKernelWithDilation<'a, SK, N>,
131        conv_mode: ConvMode<N>,
132        padding_mode: PaddingMode<N, T>,
133    ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>> {
134        let kwd = kernel.into_kernel_with_dilation();
135
136        let self_raw_dim = self.raw_dim();
137        if self.shape().iter().product::<usize>() == 0 {
138            return Err(crate::Error::DataShape(self_raw_dim));
139        }
140
141        let kernel_raw_dim = kwd.kernel.raw_dim();
142        if kwd.kernel.shape().iter().product::<usize>() == 0 {
143            return Err(crate::Error::DataShape(kernel_raw_dim));
144        }
145
146        let kernel_raw_dim_with_dilation: [usize; N] =
147            std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
148
149        let cm = conv_mode.unfold(&kwd);
150        let pds = self.padding(padding_mode, cm.padding);
151
152        let pds_raw_dim = pds.raw_dim();
153        if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
154            return Err(crate::Error::MismatchShape(
155                conv_mode,
156                kernel_raw_dim_with_dilation,
157            ));
158        }
159
160        let offset_list = kwd.gen_offset_list(pds.strides());
161
162        let output_shape: [usize; N] = std::array::from_fn(|i| {
163            (cm.padding[i][0] + cm.padding[i][1] + self_raw_dim[i]
164                - kernel_raw_dim_with_dilation[i])
165                / cm.strides[i]
166                + 1
167        });
168        let mut ret = Array::zeros(output_shape);
169
170        let shape: [usize; N] = std::array::from_fn(|i| ret.raw_dim()[i]);
171        let strides: [usize; N] =
172            std::array::from_fn(|i| cm.strides[i] * pds.strides()[i] as usize);
173
174        // dbg!(&offset_list);
175        // dbg!(strides);
176
177        unsafe {
178            // use raw pointer to improve performance.
179            let p: *mut T = ret.as_mut_ptr();
180
181            // use ArrayView's iter without handle strides
182            let view = ArrayView::from_shape(
183                ndarray::ShapeBuilder::strides(shape, strides),
184                pds.as_slice().unwrap(),
185            )
186            .unwrap();
187
188            view.iter().enumerate().for_each(|(i, cur)| {
189                let mut tmp_res = T::zero();
190
191                offset_list.iter().for_each(|(tmp_offset, tmp_kernel)| {
192                    tmp_res += *(cur as *const T).offset(*tmp_offset) * *tmp_kernel
193                });
194
195                *p.add(i) = tmp_res;
196            });
197        }
198
199        Ok(ret)
200    }
201}