1use crate::tensor::Tensor;
9use crate::error::Result;
10use rayon::prelude::*;
11
12pub fn conv2d_optimized(
14 input: &Tensor,
15 weight: &Tensor,
16 bias: Option<&Tensor>,
17 stride: (usize, usize),
18 padding: (usize, usize),
19) -> Result<Tensor> {
20 let input_dims = input.dims();
21 let weight_dims = weight.dims();
22
23 let _batch = input_dims[0];
24 let in_channels = input_dims[1];
25 let in_h = input_dims[2];
26 let in_w = input_dims[3];
27
28 let _out_channels = weight_dims[0];
29 let kernel_h = weight_dims[2];
30 let kernel_w = weight_dims[3];
31
32 let out_h = (in_h + 2 * padding.0 - kernel_h) / stride.0 + 1;
33 let out_w = (in_w + 2 * padding.1 - kernel_w) / stride.1 + 1;
34
35 if kernel_h == 3 && kernel_w == 3 && stride == (1, 1) {
37 conv2d_winograd(input, weight, bias, padding, out_h, out_w)
39 } else if kernel_h * kernel_w * in_channels > 64 {
40 conv2d_im2col(input, weight, bias, stride, padding, out_h, out_w)
42 } else {
43 conv2d_direct(input, weight, bias, stride, padding, out_h, out_w)
45 }
46}
47
48fn conv2d_im2col(
50 input: &Tensor,
51 weight: &Tensor,
52 bias: Option<&Tensor>,
53 stride: (usize, usize),
54 padding: (usize, usize),
55 out_h: usize,
56 out_w: usize,
57) -> Result<Tensor> {
58 let input_dims = input.dims();
59 let weight_dims = weight.dims();
60
61 let batch = input_dims[0];
62 let in_channels = input_dims[1];
63 let in_h = input_dims[2];
64 let in_w = input_dims[3];
65
66 let out_channels = weight_dims[0];
67 let kernel_h = weight_dims[2];
68 let kernel_w = weight_dims[3];
69
70 let input_data = input.data_f32();
71 let weight_data = weight.data_f32();
72
73 let col_size = in_channels * kernel_h * kernel_w;
76 let output_size = out_h * out_w;
77 let mut col_data = vec![0.0f32; batch * col_size * output_size];
78
79 col_data.par_chunks_mut(col_size * output_size)
81 .enumerate()
82 .for_each(|(b, batch_col)| {
83 for c in 0..in_channels {
84 for kh in 0..kernel_h {
85 for kw in 0..kernel_w {
86 let col_idx = (c * kernel_h * kernel_w + kh * kernel_w + kw) * output_size;
87
88 for oh in 0..out_h {
89 for ow in 0..out_w {
90 let ih = oh * stride.0 + kh;
91 let iw = ow * stride.1 + kw;
92
93 let ih_pad = ih as i32 - padding.0 as i32;
94 let iw_pad = iw as i32 - padding.1 as i32;
95
96 let val = if ih_pad >= 0 && ih_pad < in_h as i32
97 && iw_pad >= 0 && iw_pad < in_w as i32 {
98 let input_idx = b * in_channels * in_h * in_w
99 + c * in_h * in_w
100 + ih_pad as usize * in_w
101 + iw_pad as usize;
102 input_data[input_idx]
103 } else {
104 0.0
105 };
106
107 batch_col[col_idx + oh * out_w + ow] = val;
108 }
109 }
110 }
111 }
112 }
113 });
114
115 let mut output_data = vec![0.0f32; batch * out_channels * output_size];
122
123 #[cfg(feature = "blas")]
125 {
126 use cblas::*;
127 for b in 0..batch {
128 let col_offset = b * col_size * output_size;
129 let out_offset = b * out_channels * output_size;
130
131 unsafe {
132 sgemm(
133 Layout::RowMajor,
134 Transpose::None,
135 Transpose::None,
136 out_channels as i32,
137 output_size as i32,
138 col_size as i32,
139 1.0,
140 &weight_data,
141 col_size as i32,
142 &col_data[col_offset..],
143 output_size as i32,
144 0.0,
145 &mut output_data[out_offset..],
146 output_size as i32,
147 );
148 }
149 }
150 }
151
152 #[cfg(not(feature = "blas"))]
154 {
155 output_data.par_chunks_mut(out_channels * output_size)
156 .enumerate()
157 .for_each(|(b, batch_out)| {
158 let col_offset = b * col_size * output_size;
159
160 for oc in 0..out_channels {
161 for out_idx in 0..output_size {
162 let mut sum = 0.0f32;
163 for k in 0..col_size {
164 sum += weight_data[oc * col_size + k]
165 * col_data[col_offset + k * output_size + out_idx];
166 }
167 batch_out[oc * output_size + out_idx] = sum;
168 }
169 }
170 });
171 }
172
173 if let Some(bias_tensor) = bias {
175 let bias_data = bias_tensor.data_f32();
176 output_data.par_chunks_mut(out_channels * output_size)
177 .for_each(|batch_out| {
178 for oc in 0..out_channels {
179 for out_idx in 0..output_size {
180 batch_out[oc * output_size + out_idx] += bias_data[oc];
181 }
182 }
183 });
184 }
185
186 Tensor::from_slice(&output_data, &[batch, out_channels, out_h, out_w])
188}
189
190fn conv2d_winograd(
192 input: &Tensor,
193 weight: &Tensor,
194 bias: Option<&Tensor>,
195 padding: (usize, usize),
196 out_h: usize,
197 out_w: usize,
198) -> Result<Tensor> {
199 let input_dims = input.dims();
203 let weight_dims = weight.dims();
204
205 let _batch = input_dims[0];
206 let _in_channels = input_dims[1];
207 let _out_channels = weight_dims[0];
208
209 let _g = [
211 [1.0, 0.0, 0.0],
212 [0.5, 0.5, 0.5],
213 [0.5, -0.5, 0.5],
214 [0.0, 0.0, 1.0],
215 ];
216
217 let _b_t = [
218 [1.0, 0.0, -1.0, 0.0],
219 [0.0, 1.0, 1.0, 0.0],
220 [0.0, -1.0, 1.0, 0.0],
221 [0.0, 1.0, 0.0, -1.0],
222 ];
223
224 let _a_t = [
225 [1.0, 1.0, 1.0, 0.0],
226 [0.0, 1.0, -1.0, -1.0],
227 ];
228
229 conv2d_im2col(input, weight, bias, (1, 1), padding, out_h, out_w)
232}
233
234fn conv2d_direct(
236 input: &Tensor,
237 weight: &Tensor,
238 bias: Option<&Tensor>,
239 stride: (usize, usize),
240 padding: (usize, usize),
241 out_h: usize,
242 out_w: usize,
243) -> Result<Tensor> {
244 let input_dims = input.dims();
245 let weight_dims = weight.dims();
246
247 let batch = input_dims[0];
248 let in_channels = input_dims[1];
249 let in_h = input_dims[2];
250 let in_w = input_dims[3];
251
252 let out_channels = weight_dims[0];
253 let kernel_h = weight_dims[2];
254 let kernel_w = weight_dims[3];
255
256 let input_data = input.data_f32();
257 let weight_data = weight.data_f32();
258
259 let mut output = vec![0.0f32; batch * out_channels * out_h * out_w];
260
261 output.par_chunks_mut(out_h * out_w)
263 .enumerate()
264 .for_each(|(idx, out_slice)| {
265 let b = idx / out_channels;
266 let oc = idx % out_channels;
267
268 for oh in 0..out_h {
269 for ow in 0..out_w {
270 let mut sum = 0.0f32;
271
272 for ic in 0..in_channels {
273 for kh in 0..kernel_h {
274 for kw in 0..kernel_w {
275 let ih = oh * stride.0 + kh;
276 let iw = ow * stride.1 + kw;
277
278 let ih_pad = ih as i32 - padding.0 as i32;
279 let iw_pad = iw as i32 - padding.1 as i32;
280
281 if ih_pad >= 0 && ih_pad < in_h as i32
282 && iw_pad >= 0 && iw_pad < in_w as i32 {
283 let input_idx = b * in_channels * in_h * in_w
284 + ic * in_h * in_w
285 + ih_pad as usize * in_w
286 + iw_pad as usize;
287 let weight_idx = oc * in_channels * kernel_h * kernel_w
288 + ic * kernel_h * kernel_w
289 + kh * kernel_w
290 + kw;
291 sum += input_data[input_idx] * weight_data[weight_idx];
292 }
293 }
294 }
295 }
296
297 out_slice[oh * out_w + ow] = sum;
298 }
299 }
300 });
301
302 if let Some(bias_tensor) = bias {
304 let bias_data = bias_tensor.data_f32();
305 output.par_chunks_mut(out_h * out_w)
306 .enumerate()
307 .for_each(|(idx, out_slice)| {
308 let oc = idx % out_channels;
309 for val in out_slice.iter_mut() {
310 *val += bias_data[oc];
311 }
312 });
313 }
314
315 Tensor::from_slice(&output, &[batch, out_channels, out_h, out_w])
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_conv2d_im2col() {
324 let input = Tensor::randn(&[2, 3, 32, 32]);
325 let weight = Tensor::randn(&[16, 3, 3, 3]);
326 let bias = Some(Tensor::zeros(&[16]));
327
328 let output = conv2d_optimized(&input, &weight, bias.as_ref(), (1, 1), (1, 1)).unwrap();
329 assert_eq!(output.dims(), &[2, 16, 32, 32]);
330 }
331
332 #[test]
333 fn test_conv2d_stride() {
334 let input = Tensor::randn(&[2, 3, 32, 32]);
335 let weight = Tensor::randn(&[16, 3, 3, 3]);
336
337 let output = conv2d_optimized(&input, &weight, None, (2, 2), (1, 1)).unwrap();
338 assert_eq!(output.dims(), &[2, 16, 16, 16]);
339 }
340}