1use std::fmt::Debug;
5
6use ndarray::{
7 Array, ArrayBase, ArrayView, Data, Dim, Dimension, IntoDimension, Ix, RawData, RemoveAxis,
8 SliceArg, SliceInfo, SliceInfoElem,
9};
10use num::traits::NumAssign;
11
12use crate::{
13 dilation::{IntoKernelWithDilation, KernelWithDilation},
14 padding::PaddingExt,
15 ConvMode, PaddingMode,
16};
17
18#[cfg(test)]
19mod tests;
20
21pub struct ExplicitConv<const N: usize> {
26 pub padding: [[usize; 2]; N],
27 pub strides: [usize; N],
28}
29
30impl<const N: usize> ConvMode<N> {
31 pub(crate) fn unfold<S>(self, kernel: &KernelWithDilation<S, N>) -> ExplicitConv<N>
32 where
33 S: ndarray::RawData,
34 Dim<[Ix; N]>: Dimension,
35 {
36 let kernel_dim = kernel.kernel.raw_dim();
37 let kernel_dim: [usize; N] = std::array::from_fn(|i|
38 kernel_dim[i] * kernel.dilation[i] - kernel.dilation[i] + 1);
40
41 match self {
42 ConvMode::Full => ExplicitConv {
43 padding: std::array::from_fn(|i| [kernel_dim[i] - 1; 2]),
44 strides: [1; N],
45 },
46 ConvMode::Same => ExplicitConv {
47 padding: std::array::from_fn(|i| {
48 let k_size = kernel_dim[i];
49 if k_size % 2 == 0 {
50 [(k_size - 1) / 2 + 1, (k_size - 1) / 2]
51 } else {
52 [(k_size - 1) / 2; 2]
53 }
54 }),
55 strides: [1; N],
56 },
57 ConvMode::Valid => ExplicitConv {
58 padding: [[0; 2]; N],
59 strides: [1; N],
60 },
61 ConvMode::Custom { padding, strides } => ExplicitConv {
62 padding: padding.map(|pad| [pad; 2]),
63 strides,
64 },
65 ConvMode::Explicit { padding, strides } => ExplicitConv { padding, strides },
66 }
67 }
68}
69
70pub trait ConvExt<'a, T, S, SK, const N: usize>
81where
82 T: NumAssign + Copy,
83 S: RawData,
84 SK: RawData,
85{
86 fn conv(
100 &self,
101 kernel: impl IntoKernelWithDilation<'a, SK, N>,
102 conv_mode: ConvMode<N>,
103 padding_mode: PaddingMode<N, T>,
104 ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>>;
105}
106
107impl<'a, T, S, SK, const N: usize> ConvExt<'a, T, S, SK, N> for ArrayBase<S, Dim<[Ix; N]>>
108where
109 T: NumAssign + Copy,
110 S: Data<Elem = T> + 'a,
111 SK: Data<Elem = T> + 'a,
112 Dim<[Ix; N]>: RemoveAxis,
113 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
114 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
115 SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
116{
117 fn conv(
118 &self,
119 kernel: impl IntoKernelWithDilation<'a, SK, N>,
120 conv_mode: ConvMode<N>,
121 padding_mode: PaddingMode<N, T>,
122 ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>> {
123 let kwd = kernel.into_kernel_with_dilation();
124
125 let self_raw_dim = self.raw_dim();
126 if self.shape().iter().product::<usize>() == 0 {
127 return Err(crate::Error::DataShape(self_raw_dim));
128 }
129
130 let kernel_raw_dim = kwd.kernel.raw_dim();
131 if kwd.kernel.shape().iter().product::<usize>() == 0 {
132 return Err(crate::Error::DataShape(kernel_raw_dim));
133 }
134
135 let kernel_raw_dim_with_dilation: [usize; N] =
136 std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
137
138 let cm = conv_mode.unfold(&kwd);
139 let pds = self.padding(padding_mode, cm.padding);
140
141 let pds_raw_dim = pds.raw_dim();
142 if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
143 return Err(crate::Error::MismatchShape(
144 conv_mode,
145 kernel_raw_dim_with_dilation,
146 ));
147 }
148
149 let offset_list = kwd.gen_offset_list(pds.strides());
150
151 let output_shape: [usize; N] = std::array::from_fn(|i| {
152 (cm.padding[i][0] + cm.padding[i][1] + self_raw_dim[i]
153 - kernel_raw_dim_with_dilation[i])
154 / cm.strides[i]
155 + 1
156 });
157 let mut ret = Array::zeros(output_shape);
158
159 let shape: [usize; N] = std::array::from_fn(|i| ret.raw_dim()[i]);
160 let strides: [usize; N] =
161 std::array::from_fn(|i| cm.strides[i] * pds.strides()[i] as usize);
162
163 unsafe {
167 let p: *mut T = ret.as_mut_ptr();
169
170 let view = ArrayView::from_shape(
172 ndarray::ShapeBuilder::strides(shape, strides),
173 pds.as_slice().unwrap(),
174 )
175 .unwrap();
176
177 view.iter().enumerate().for_each(|(i, cur)| {
178 let mut tmp_res = T::zero();
179
180 offset_list.iter().for_each(|(tmp_offset, tmp_kernel)| {
181 tmp_res += *(cur as *const T).offset(*tmp_offset) * *tmp_kernel
182 });
183
184 *p.add(i) = tmp_res;
185 });
186 }
187
188 Ok(ret)
189 }
190}