ndarray_conv/conv/
mod.rs

1//! Provides convolution operations for `ndarray` arrays.
2//! Includes standard convolution and related utilities.
3
4use std::fmt::Debug;
5
6use ndarray::{
7    Array, ArrayBase, ArrayView, Data, Dim, Dimension, IntoDimension, Ix, RawData, RemoveAxis,
8    SliceArg, SliceInfo, SliceInfoElem,
9};
10use num::traits::NumAssign;
11
12use crate::{
13    dilation::{IntoKernelWithDilation, KernelWithDilation},
14    padding::PaddingExt,
15    ConvMode, PaddingMode,
16};
17
18#[cfg(test)]
19mod tests;
20
21/// Represents explicit convolution parameters after unfolding from `ConvMode`.
22///
23/// This struct holds padding and strides information used directly
24/// by the convolution algorithm.
25pub struct ExplicitConv<const N: usize> {
26    pub padding: [[usize; 2]; N],
27    pub strides: [usize; N],
28}
29
30impl<const N: usize> ConvMode<N> {
31    pub(crate) fn unfold<S>(self, kernel: &KernelWithDilation<S, N>) -> ExplicitConv<N>
32    where
33        S: ndarray::RawData,
34        Dim<[Ix; N]>: Dimension,
35    {
36        let kernel_dim = kernel.kernel.raw_dim();
37        let kernel_dim: [usize; N] = std::array::from_fn(|i|
38                // k + (k - 1) * (d - 1)
39                kernel_dim[i] * kernel.dilation[i] - kernel.dilation[i] + 1);
40
41        match self {
42            ConvMode::Full => ExplicitConv {
43                padding: std::array::from_fn(|i| [kernel_dim[i] - 1; 2]),
44                strides: [1; N],
45            },
46            ConvMode::Same => ExplicitConv {
47                padding: std::array::from_fn(|i| {
48                    let k_size = kernel_dim[i];
49                    if k_size % 2 == 0 {
50                        [(k_size - 1) / 2 + 1, (k_size - 1) / 2]
51                    } else {
52                        [(k_size - 1) / 2; 2]
53                    }
54                }),
55                strides: [1; N],
56            },
57            ConvMode::Valid => ExplicitConv {
58                padding: [[0; 2]; N],
59                strides: [1; N],
60            },
61            ConvMode::Custom { padding, strides } => ExplicitConv {
62                padding: padding.map(|pad| [pad; 2]),
63                strides,
64            },
65            ConvMode::Explicit { padding, strides } => ExplicitConv { padding, strides },
66        }
67    }
68}
69
70/// Extends `ndarray`'s `ArrayBase` with convolution operations.
71///
72/// This trait adds the `conv` method to `ArrayBase`, enabling
73/// standard convolution operations on N-dimensional arrays.
74///
75/// # Type Parameters
76///
77/// *   `T`: The numeric type of the array elements.
78/// *   `S`: The data storage type of the input array.
79/// *   `SK`: The data storage type of the kernel array.
80pub trait ConvExt<'a, T, S, SK, const N: usize>
81where
82    T: NumAssign + Copy,
83    S: RawData,
84    SK: RawData,
85{
86    /// Performs a standard convolution operation.
87    ///
88    /// This method convolves the input array with a given kernel,
89    /// using the specified convolution mode and padding.
90    ///
91    /// # Arguments
92    ///
93    /// *   `kernel`: The convolution kernel.
94    /// *   `conv_mode`: The convolution mode (`Full`, `Same`, `Valid`, `Custom`, `Explicit`).
95    /// *   `padding_mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`).
96    ///
97    /// # Returns
98    ///
99    fn conv(
100        &self,
101        kernel: impl IntoKernelWithDilation<'a, SK, N>,
102        conv_mode: ConvMode<N>,
103        padding_mode: PaddingMode<N, T>,
104    ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>>;
105}
106
107impl<'a, T, S, SK, const N: usize> ConvExt<'a, T, S, SK, N> for ArrayBase<S, Dim<[Ix; N]>>
108where
109    T: NumAssign + Copy,
110    S: Data<Elem = T> + 'a,
111    SK: Data<Elem = T> + 'a,
112    Dim<[Ix; N]>: RemoveAxis,
113    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
114    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
115        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
116{
117    fn conv(
118        &self,
119        kernel: impl IntoKernelWithDilation<'a, SK, N>,
120        conv_mode: ConvMode<N>,
121        padding_mode: PaddingMode<N, T>,
122    ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>> {
123        let kwd = kernel.into_kernel_with_dilation();
124
125        let self_raw_dim = self.raw_dim();
126        if self.shape().iter().product::<usize>() == 0 {
127            return Err(crate::Error::DataShape(self_raw_dim));
128        }
129
130        let kernel_raw_dim = kwd.kernel.raw_dim();
131        if kwd.kernel.shape().iter().product::<usize>() == 0 {
132            return Err(crate::Error::DataShape(kernel_raw_dim));
133        }
134
135        let kernel_raw_dim_with_dilation: [usize; N] =
136            std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
137
138        let cm = conv_mode.unfold(&kwd);
139        let pds = self.padding(padding_mode, cm.padding);
140
141        let pds_raw_dim = pds.raw_dim();
142        if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
143            return Err(crate::Error::MismatchShape(
144                conv_mode,
145                kernel_raw_dim_with_dilation,
146            ));
147        }
148
149        let offset_list = kwd.gen_offset_list(pds.strides());
150
151        let output_shape: [usize; N] = std::array::from_fn(|i| {
152            (cm.padding[i][0] + cm.padding[i][1] + self_raw_dim[i]
153                - kernel_raw_dim_with_dilation[i])
154                / cm.strides[i]
155                + 1
156        });
157        let mut ret = Array::zeros(output_shape);
158
159        let shape: [usize; N] = std::array::from_fn(|i| ret.raw_dim()[i]);
160        let strides: [usize; N] =
161            std::array::from_fn(|i| cm.strides[i] * pds.strides()[i] as usize);
162
163        // dbg!(&offset_list);
164        // dbg!(strides);
165
166        unsafe {
167            // use raw pointer to improve performance.
168            let p: *mut T = ret.as_mut_ptr();
169
170            // use ArrayView's iter without handle strides
171            let view = ArrayView::from_shape(
172                ndarray::ShapeBuilder::strides(shape, strides),
173                pds.as_slice().unwrap(),
174            )
175            .unwrap();
176
177            view.iter().enumerate().for_each(|(i, cur)| {
178                let mut tmp_res = T::zero();
179
180                offset_list.iter().for_each(|(tmp_offset, tmp_kernel)| {
181                    tmp_res += *(cur as *const T).offset(*tmp_offset) * *tmp_kernel
182                });
183
184                *p.add(i) = tmp_res;
185            });
186        }
187
188        Ok(ret)
189    }
190}