1use crate::tensor::Tensor;
9use crate::error::Result;
10#[cfg(feature = "rayon")]
11use rayon::prelude::*;
12
13pub fn conv2d_optimized(
15 input: &Tensor,
16 weight: &Tensor,
17 bias: Option<&Tensor>,
18 stride: (usize, usize),
19 padding: (usize, usize),
20) -> Result<Tensor> {
21 let input_dims = input.dims();
22 let weight_dims = weight.dims();
23
24 let _batch = input_dims[0];
25 let in_channels = input_dims[1];
26 let in_h = input_dims[2];
27 let in_w = input_dims[3];
28
29 let _out_channels = weight_dims[0];
30 let kernel_h = weight_dims[2];
31 let kernel_w = weight_dims[3];
32
33 let out_h = (in_h + 2 * padding.0 - kernel_h) / stride.0 + 1;
34 let out_w = (in_w + 2 * padding.1 - kernel_w) / stride.1 + 1;
35
36 if kernel_h == 3 && kernel_w == 3 && stride == (1, 1) {
38 conv2d_winograd(input, weight, bias, padding, out_h, out_w)
40 } else if kernel_h * kernel_w * in_channels > 64 {
41 conv2d_im2col(input, weight, bias, stride, padding, out_h, out_w)
43 } else {
44 conv2d_direct(input, weight, bias, stride, padding, out_h, out_w)
46 }
47}
48
49fn conv2d_im2col(
51 input: &Tensor,
52 weight: &Tensor,
53 bias: Option<&Tensor>,
54 stride: (usize, usize),
55 padding: (usize, usize),
56 out_h: usize,
57 out_w: usize,
58) -> Result<Tensor> {
59 let input_dims = input.dims();
60 let weight_dims = weight.dims();
61
62 let batch = input_dims[0];
63 let in_channels = input_dims[1];
64 let in_h = input_dims[2];
65 let in_w = input_dims[3];
66
67 let out_channels = weight_dims[0];
68 let kernel_h = weight_dims[2];
69 let kernel_w = weight_dims[3];
70
71 let input_data = input.data_f32();
72 let weight_data = weight.data_f32();
73
74 let col_size = in_channels * kernel_h * kernel_w;
77 let output_size = out_h * out_w;
78 let mut col_data = vec![0.0f32; batch * col_size * output_size];
79
80 col_data.chunks_mut(col_size * output_size)
82 .enumerate()
83 .for_each(|(b, batch_col)| {
84 for c in 0..in_channels {
85 for kh in 0..kernel_h {
86 for kw in 0..kernel_w {
87 let col_idx = (c * kernel_h * kernel_w + kh * kernel_w + kw) * output_size;
88
89 for oh in 0..out_h {
90 for ow in 0..out_w {
91 let ih = oh * stride.0 + kh;
92 let iw = ow * stride.1 + kw;
93
94 let ih_pad = ih as i32 - padding.0 as i32;
95 let iw_pad = iw as i32 - padding.1 as i32;
96
97 let val = if ih_pad >= 0 && ih_pad < in_h as i32
98 && iw_pad >= 0 && iw_pad < in_w as i32 {
99 let input_idx = b * in_channels * in_h * in_w
100 + c * in_h * in_w
101 + ih_pad as usize * in_w
102 + iw_pad as usize;
103 input_data[input_idx]
104 } else {
105 0.0
106 };
107
108 batch_col[col_idx + oh * out_w + ow] = val;
109 }
110 }
111 }
112 }
113 }
114 });
115
116 let mut output_data = vec![0.0f32; batch * out_channels * output_size];
123
124 #[cfg(feature = "blas")]
126 {
127 use cblas::*;
128 for b in 0..batch {
129 let col_offset = b * col_size * output_size;
130 let out_offset = b * out_channels * output_size;
131
132 unsafe {
133 sgemm(
134 Layout::RowMajor,
135 Transpose::None,
136 Transpose::None,
137 out_channels as i32,
138 output_size as i32,
139 col_size as i32,
140 1.0,
141 &weight_data,
142 col_size as i32,
143 &col_data[col_offset..],
144 output_size as i32,
145 0.0,
146 &mut output_data[out_offset..],
147 output_size as i32,
148 );
149 }
150 }
151 }
152
153 #[cfg(not(feature = "blas"))]
155 {
156 output_data.chunks_mut(out_channels * output_size)
157 .enumerate()
158 .for_each(|(b, batch_out)| {
159 let col_offset = b * col_size * output_size;
160
161 for oc in 0..out_channels {
162 for out_idx in 0..output_size {
163 let mut sum = 0.0f32;
164 for k in 0..col_size {
165 sum += weight_data[oc * col_size + k]
166 * col_data[col_offset + k * output_size + out_idx];
167 }
168 batch_out[oc * output_size + out_idx] = sum;
169 }
170 }
171 });
172 }
173
174 if let Some(bias_tensor) = bias {
176 let bias_data = bias_tensor.data_f32();
177 output_data.chunks_mut(out_channels * output_size)
178 .for_each(|batch_out| {
179 for oc in 0..out_channels {
180 for out_idx in 0..output_size {
181 batch_out[oc * output_size + out_idx] += bias_data[oc];
182 }
183 }
184 });
185 }
186
187 Tensor::from_slice(&output_data, &[batch, out_channels, out_h, out_w])
189}
190
191fn conv2d_winograd(
193 input: &Tensor,
194 weight: &Tensor,
195 bias: Option<&Tensor>,
196 padding: (usize, usize),
197 out_h: usize,
198 out_w: usize,
199) -> Result<Tensor> {
200 let input_dims = input.dims();
204 let weight_dims = weight.dims();
205
206 let _batch = input_dims[0];
207 let _in_channels = input_dims[1];
208 let _out_channels = weight_dims[0];
209
210 let _g = [
212 [1.0, 0.0, 0.0],
213 [0.5, 0.5, 0.5],
214 [0.5, -0.5, 0.5],
215 [0.0, 0.0, 1.0],
216 ];
217
218 let _b_t = [
219 [1.0, 0.0, -1.0, 0.0],
220 [0.0, 1.0, 1.0, 0.0],
221 [0.0, -1.0, 1.0, 0.0],
222 [0.0, 1.0, 0.0, -1.0],
223 ];
224
225 let _a_t = [
226 [1.0, 1.0, 1.0, 0.0],
227 [0.0, 1.0, -1.0, -1.0],
228 ];
229
230 conv2d_im2col(input, weight, bias, (1, 1), padding, out_h, out_w)
233}
234
235fn conv2d_direct(
237 input: &Tensor,
238 weight: &Tensor,
239 bias: Option<&Tensor>,
240 stride: (usize, usize),
241 padding: (usize, usize),
242 out_h: usize,
243 out_w: usize,
244) -> Result<Tensor> {
245 let input_dims = input.dims();
246 let weight_dims = weight.dims();
247
248 let batch = input_dims[0];
249 let in_channels = input_dims[1];
250 let in_h = input_dims[2];
251 let in_w = input_dims[3];
252
253 let out_channels = weight_dims[0];
254 let kernel_h = weight_dims[2];
255 let kernel_w = weight_dims[3];
256
257 let input_data = input.data_f32();
258 let weight_data = weight.data_f32();
259
260 let mut output = vec![0.0f32; batch * out_channels * out_h * out_w];
261
262 output.chunks_mut(out_h * out_w)
264 .enumerate()
265 .for_each(|(idx, out_slice)| {
266 let b = idx / out_channels;
267 let oc = idx % out_channels;
268
269 for oh in 0..out_h {
270 for ow in 0..out_w {
271 let mut sum = 0.0f32;
272
273 for ic in 0..in_channels {
274 for kh in 0..kernel_h {
275 for kw in 0..kernel_w {
276 let ih = oh * stride.0 + kh;
277 let iw = ow * stride.1 + kw;
278
279 let ih_pad = ih as i32 - padding.0 as i32;
280 let iw_pad = iw as i32 - padding.1 as i32;
281
282 if ih_pad >= 0 && ih_pad < in_h as i32
283 && iw_pad >= 0 && iw_pad < in_w as i32 {
284 let input_idx = b * in_channels * in_h * in_w
285 + ic * in_h * in_w
286 + ih_pad as usize * in_w
287 + iw_pad as usize;
288 let weight_idx = oc * in_channels * kernel_h * kernel_w
289 + ic * kernel_h * kernel_w
290 + kh * kernel_w
291 + kw;
292 sum += input_data[input_idx] * weight_data[weight_idx];
293 }
294 }
295 }
296 }
297
298 out_slice[oh * out_w + ow] = sum;
299 }
300 }
301 });
302
303 if let Some(bias_tensor) = bias {
305 let bias_data = bias_tensor.data_f32();
306 output.chunks_mut(out_h * out_w)
307 .enumerate()
308 .for_each(|(idx, out_slice)| {
309 let oc = idx % out_channels;
310 for val in out_slice.iter_mut() {
311 *val += bias_data[oc];
312 }
313 });
314 }
315
316 Tensor::from_slice(&output, &[batch, out_channels, out_h, out_w])
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_conv2d_im2col() {
325 let input = Tensor::randn(&[2, 3, 32, 32]);
326 let weight = Tensor::randn(&[16, 3, 3, 3]);
327 let bias = Some(Tensor::zeros(&[16]));
328
329 let output = conv2d_optimized(&input, &weight, bias.as_ref(), (1, 1), (1, 1)).unwrap();
330 assert_eq!(output.dims(), &[2, 16, 32, 32]);
331 }
332
333 #[test]
334 fn test_conv2d_stride() {
335 let input = Tensor::randn(&[2, 3, 32, 32]);
336 let weight = Tensor::randn(&[16, 3, 3, 3]);
337
338 let output = conv2d_optimized(&input, &weight, None, (2, 2), (1, 1)).unwrap();
339 assert_eq!(output.dims(), &[2, 16, 16, 16]);
340 }
341}
342