1use ndarray::{
5 Array, ArrayBase, ArrayView, Data, Dim, Dimension, IntoDimension, Ix, RawData, RemoveAxis,
6 SliceArg, SliceInfo, SliceInfoElem,
7};
8use num::traits::NumAssign;
9
10use crate::{
11 dilation::{IntoKernelWithDilation, KernelWithDilation},
12 padding::PaddingExt,
13 ConvMode, PaddingMode,
14};
15
16#[cfg(test)]
17mod tests;
18
19pub struct ExplicitConv<const N: usize> {
24 pub padding: [[usize; 2]; N],
25 pub strides: [usize; N],
26}
27
28impl<const N: usize> ConvMode<N> {
29 pub(crate) fn unfold<S>(self, kernel: &KernelWithDilation<S, N>) -> ExplicitConv<N>
30 where
31 S: ndarray::RawData,
32 Dim<[Ix; N]>: Dimension,
33 {
34 let kernel_dim = kernel.kernel.raw_dim();
35 let kernel_dim: [usize; N] = std::array::from_fn(|i|
36 kernel_dim[i] * kernel.dilation[i] - kernel.dilation[i] + 1);
38
39 match self {
40 ConvMode::Full => ExplicitConv {
41 padding: std::array::from_fn(|i| [kernel_dim[i] - 1; 2]),
42 strides: [1; N],
43 },
44 ConvMode::Same => ExplicitConv {
45 padding: std::array::from_fn(|i| {
46 let k_size = kernel_dim[i];
47 if k_size.is_multiple_of(2) {
48 [(k_size - 1) / 2 + 1, (k_size - 1) / 2]
49 } else {
50 [(k_size - 1) / 2; 2]
51 }
52 }),
53 strides: [1; N],
54 },
55 ConvMode::Valid => ExplicitConv {
56 padding: [[0; 2]; N],
57 strides: [1; N],
58 },
59 ConvMode::Custom { padding, strides } => ExplicitConv {
60 padding: padding.map(|pad| [pad; 2]),
61 strides,
62 },
63 ConvMode::Explicit { padding, strides } => ExplicitConv { padding, strides },
64 }
65 }
66}
67
68pub trait ConvExt<'a, T, S, SK, const N: usize>
79where
80 T: NumAssign + Copy,
81 S: RawData,
82 SK: RawData,
83{
84 fn conv(
98 &self,
99 kernel: impl IntoKernelWithDilation<'a, SK, N>,
100 conv_mode: ConvMode<N>,
101 padding_mode: PaddingMode<N, T>,
102 ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>>;
103}
104
105impl<'a, T, S, SK, const N: usize> ConvExt<'a, T, S, SK, N> for ArrayBase<S, Dim<[Ix; N]>>
106where
107 T: NumAssign + Copy,
108 S: Data<Elem = T> + 'a,
109 SK: Data<Elem = T> + 'a,
110 Dim<[Ix; N]>: RemoveAxis,
111 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
112 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
113 SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
114{
115 fn conv(
116 &self,
117 kernel: impl IntoKernelWithDilation<'a, SK, N>,
118 conv_mode: ConvMode<N>,
119 padding_mode: PaddingMode<N, T>,
120 ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>> {
121 let kwd = kernel.into_kernel_with_dilation();
122
123 let self_raw_dim = self.raw_dim();
124 if self.shape().iter().product::<usize>() == 0 {
125 return Err(crate::Error::DataShape(self_raw_dim));
126 }
127
128 let kernel_raw_dim = kwd.kernel.raw_dim();
129 if kwd.kernel.shape().iter().product::<usize>() == 0 {
130 return Err(crate::Error::DataShape(kernel_raw_dim));
131 }
132
133 let kernel_raw_dim_with_dilation: [usize; N] =
134 std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
135
136 let cm = conv_mode.unfold(&kwd);
137 let pds = self.padding(padding_mode, cm.padding);
138
139 let pds_raw_dim = pds.raw_dim();
140 if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
141 return Err(crate::Error::MismatchShape(
142 conv_mode,
143 kernel_raw_dim_with_dilation,
144 ));
145 }
146
147 let offset_list = kwd.gen_offset_list(pds.strides());
148
149 let output_shape: [usize; N] = std::array::from_fn(|i| {
150 (cm.padding[i][0] + cm.padding[i][1] + self_raw_dim[i]
151 - kernel_raw_dim_with_dilation[i])
152 / cm.strides[i]
153 + 1
154 });
155 let mut ret = Array::zeros(output_shape);
156
157 let shape: [usize; N] = std::array::from_fn(|i| ret.raw_dim()[i]);
158 let strides: [usize; N] =
159 std::array::from_fn(|i| cm.strides[i] * pds.strides()[i] as usize);
160
161 unsafe {
165 let p: *mut T = ret.as_mut_ptr();
167
168 let view = ArrayView::from_shape(
170 ndarray::ShapeBuilder::strides(shape, strides),
171 pds.as_slice().unwrap(),
172 )
173 .unwrap();
174
175 view.iter().enumerate().for_each(|(i, cur)| {
176 let mut tmp_res = T::zero();
177
178 offset_list.iter().for_each(|(tmp_offset, tmp_kernel)| {
179 tmp_res += *(cur as *const T).offset(*tmp_offset) * *tmp_kernel
180 });
181
182 *p.add(i) = tmp_res;
183 });
184 }
185
186 Ok(ret)
187 }
188}