1use axonml_autograd::Variable;
9use axonml_tensor::Tensor;
10
11use crate::module::Module;
12
13pub struct MaxPool1d {
23 kernel_size: usize,
24 stride: usize,
25 padding: usize,
26}
27
28impl MaxPool1d {
29 pub fn new(kernel_size: usize) -> Self {
31 Self {
32 kernel_size,
33 stride: kernel_size, padding: 0,
35 }
36 }
37
38 pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
40 Self {
41 kernel_size,
42 stride,
43 padding,
44 }
45 }
46}
47
48impl Module for MaxPool1d {
49 fn forward(&self, input: &Variable) -> Variable {
50 let shape = input.shape();
51 let batch = shape[0];
52 let channels = shape[1];
53 let length = shape[2];
54
55 let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
56
57 let input_vec = input.data().to_vec();
58 let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_length];
59
60 for b in 0..batch {
61 for c in 0..channels {
62 for ol in 0..out_length {
63 let in_start = ol * self.stride;
64 let mut max_val = f32::NEG_INFINITY;
65
66 for k in 0..self.kernel_size {
67 let il = in_start + k;
68 if il >= self.padding && il < length + self.padding {
69 let actual_il = il - self.padding;
70 let idx = b * channels * length + c * length + actual_il;
71 max_val = max_val.max(input_vec[idx]);
72 }
73 }
74
75 let out_idx = b * channels * out_length + c * out_length + ol;
76 output_data[out_idx] = max_val;
77 }
78 }
79 }
80
81 let output = Tensor::from_vec(output_data, &[batch, channels, out_length]).unwrap();
82 Variable::new(output, input.requires_grad())
83 }
84
85 fn name(&self) -> &'static str {
86 "MaxPool1d"
87 }
88}
89
90pub struct MaxPool2d {
100 kernel_size: (usize, usize),
101 stride: (usize, usize),
102 padding: (usize, usize),
103}
104
105impl MaxPool2d {
106 pub fn new(kernel_size: usize) -> Self {
108 Self {
109 kernel_size: (kernel_size, kernel_size),
110 stride: (kernel_size, kernel_size),
111 padding: (0, 0),
112 }
113 }
114
115 pub fn with_options(
117 kernel_size: (usize, usize),
118 stride: (usize, usize),
119 padding: (usize, usize),
120 ) -> Self {
121 Self {
122 kernel_size,
123 stride,
124 padding,
125 }
126 }
127}
128
129impl Module for MaxPool2d {
130 fn forward(&self, input: &Variable) -> Variable {
131 let shape = input.shape();
132 let batch = shape[0];
133 let channels = shape[1];
134 let height = shape[2];
135 let width = shape[3];
136
137 let (kh, kw) = self.kernel_size;
138 let (sh, sw) = self.stride;
139 let (ph, pw) = self.padding;
140
141 let out_h = (height + 2 * ph - kh) / sh + 1;
142 let out_w = (width + 2 * pw - kw) / sw + 1;
143
144 let input_vec = input.data().to_vec();
145 let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_h * out_w];
146
147 for b in 0..batch {
148 for c in 0..channels {
149 for oh in 0..out_h {
150 for ow in 0..out_w {
151 let mut max_val = f32::NEG_INFINITY;
152
153 for ki in 0..kh {
154 for kj in 0..kw {
155 let ih = oh * sh + ki;
156 let iw = ow * sw + kj;
157
158 if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
159 let actual_ih = ih - ph;
160 let actual_iw = iw - pw;
161 let idx = b * channels * height * width
162 + c * height * width
163 + actual_ih * width
164 + actual_iw;
165 max_val = max_val.max(input_vec[idx]);
166 }
167 }
168 }
169
170 let out_idx =
171 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
172 output_data[out_idx] = max_val;
173 }
174 }
175 }
176 }
177
178 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
179 Variable::new(output, input.requires_grad())
180 }
181
182 fn name(&self) -> &'static str {
183 "MaxPool2d"
184 }
185}
186
187pub struct AvgPool1d {
193 kernel_size: usize,
194 stride: usize,
195 padding: usize,
196}
197
198impl AvgPool1d {
199 pub fn new(kernel_size: usize) -> Self {
201 Self {
202 kernel_size,
203 stride: kernel_size,
204 padding: 0,
205 }
206 }
207
208 pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
210 Self {
211 kernel_size,
212 stride,
213 padding,
214 }
215 }
216}
217
218impl Module for AvgPool1d {
219 fn forward(&self, input: &Variable) -> Variable {
220 let shape = input.shape();
221 let batch = shape[0];
222 let channels = shape[1];
223 let length = shape[2];
224
225 let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
226
227 let input_vec = input.data().to_vec();
228 let mut output_data = vec![0.0f32; batch * channels * out_length];
229
230 for b in 0..batch {
231 for c in 0..channels {
232 for ol in 0..out_length {
233 let in_start = ol * self.stride;
234 let mut sum = 0.0f32;
235 let mut count = 0;
236
237 for k in 0..self.kernel_size {
238 let il = in_start + k;
239 if il >= self.padding && il < length + self.padding {
240 let actual_il = il - self.padding;
241 let idx = b * channels * length + c * length + actual_il;
242 sum += input_vec[idx];
243 count += 1;
244 }
245 }
246
247 let out_idx = b * channels * out_length + c * out_length + ol;
248 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
249 }
250 }
251 }
252
253 let output = Tensor::from_vec(output_data, &[batch, channels, out_length]).unwrap();
254 Variable::new(output, input.requires_grad())
255 }
256
257 fn name(&self) -> &'static str {
258 "AvgPool1d"
259 }
260}
261
262pub struct AvgPool2d {
268 kernel_size: (usize, usize),
269 stride: (usize, usize),
270 padding: (usize, usize),
271}
272
273impl AvgPool2d {
274 pub fn new(kernel_size: usize) -> Self {
276 Self {
277 kernel_size: (kernel_size, kernel_size),
278 stride: (kernel_size, kernel_size),
279 padding: (0, 0),
280 }
281 }
282
283 pub fn with_options(
285 kernel_size: (usize, usize),
286 stride: (usize, usize),
287 padding: (usize, usize),
288 ) -> Self {
289 Self {
290 kernel_size,
291 stride,
292 padding,
293 }
294 }
295}
296
297impl Module for AvgPool2d {
298 fn forward(&self, input: &Variable) -> Variable {
299 let shape = input.shape();
300 let batch = shape[0];
301 let channels = shape[1];
302 let height = shape[2];
303 let width = shape[3];
304
305 let (kh, kw) = self.kernel_size;
306 let (sh, sw) = self.stride;
307 let (ph, pw) = self.padding;
308
309 let out_h = (height + 2 * ph - kh) / sh + 1;
310 let out_w = (width + 2 * pw - kw) / sw + 1;
311
312 let input_vec = input.data().to_vec();
313 let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
314
315 for b in 0..batch {
316 for c in 0..channels {
317 for oh in 0..out_h {
318 for ow in 0..out_w {
319 let mut sum = 0.0f32;
320 let mut count = 0;
321
322 for ki in 0..kh {
323 for kj in 0..kw {
324 let ih = oh * sh + ki;
325 let iw = ow * sw + kj;
326
327 if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
328 let actual_ih = ih - ph;
329 let actual_iw = iw - pw;
330 let idx = b * channels * height * width
331 + c * height * width
332 + actual_ih * width
333 + actual_iw;
334 sum += input_vec[idx];
335 count += 1;
336 }
337 }
338 }
339
340 let out_idx =
341 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
342 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
343 }
344 }
345 }
346 }
347
348 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
349 Variable::new(output, input.requires_grad())
350 }
351
352 fn name(&self) -> &'static str {
353 "AvgPool2d"
354 }
355}
356
357pub struct AdaptiveAvgPool2d {
366 output_size: (usize, usize),
367}
368
369impl AdaptiveAvgPool2d {
370 pub fn new(output_size: (usize, usize)) -> Self {
372 Self { output_size }
373 }
374
375 pub fn square(size: usize) -> Self {
377 Self {
378 output_size: (size, size),
379 }
380 }
381}
382
383impl Module for AdaptiveAvgPool2d {
384 fn forward(&self, input: &Variable) -> Variable {
385 let shape = input.shape();
386 let batch = shape[0];
387 let channels = shape[1];
388 let in_h = shape[2];
389 let in_w = shape[3];
390
391 let (out_h, out_w) = self.output_size;
392 let input_vec = input.data().to_vec();
393 let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
394
395 for b in 0..batch {
396 for c in 0..channels {
397 for oh in 0..out_h {
398 for ow in 0..out_w {
399 let ih_start = (oh * in_h) / out_h;
401 let ih_end = ((oh + 1) * in_h) / out_h;
402 let iw_start = (ow * in_w) / out_w;
403 let iw_end = ((ow + 1) * in_w) / out_w;
404
405 let mut sum = 0.0f32;
406 let mut count = 0;
407
408 for ih in ih_start..ih_end {
409 for iw in iw_start..iw_end {
410 let idx =
411 b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
412 sum += input_vec[idx];
413 count += 1;
414 }
415 }
416
417 let out_idx =
418 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
419 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
420 }
421 }
422 }
423 }
424
425 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
426 Variable::new(output, input.requires_grad())
427 }
428
429 fn name(&self) -> &'static str {
430 "AdaptiveAvgPool2d"
431 }
432}
433
434#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_maxpool2d() {
444 let pool = MaxPool2d::new(2);
445 let input = Variable::new(
446 Tensor::from_vec(
447 vec![
448 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
449 15.0, 16.0,
450 ],
451 &[1, 1, 4, 4],
452 )
453 .unwrap(),
454 false,
455 );
456 let output = pool.forward(&input);
457 assert_eq!(output.shape(), vec![1, 1, 2, 2]);
458 assert_eq!(output.data().to_vec(), vec![6.0, 8.0, 14.0, 16.0]);
460 }
461
462 #[test]
463 fn test_avgpool2d() {
464 let pool = AvgPool2d::new(2);
465 let input = Variable::new(
466 Tensor::from_vec(
467 vec![
468 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
469 15.0, 16.0,
470 ],
471 &[1, 1, 4, 4],
472 )
473 .unwrap(),
474 false,
475 );
476 let output = pool.forward(&input);
477 assert_eq!(output.shape(), vec![1, 1, 2, 2]);
478 assert_eq!(output.data().to_vec(), vec![3.5, 5.5, 11.5, 13.5]);
480 }
481
482 #[test]
483 fn test_adaptive_avgpool2d() {
484 let pool = AdaptiveAvgPool2d::new((1, 1));
485 let input = Variable::new(
486 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap(),
487 false,
488 );
489 let output = pool.forward(&input);
490 assert_eq!(output.shape(), vec![1, 1, 1, 1]);
491 assert_eq!(output.data().to_vec(), vec![2.5]); }
493}