1use axonml_autograd::Variable;
18use axonml_autograd::functions::{
19 AdaptiveAvgPool2dBackward, AvgPool1dBackward, AvgPool2dBackward, MaxPool1dBackward,
20 MaxPool2dBackward,
21};
22use axonml_autograd::grad_fn::GradFn;
23use axonml_autograd::no_grad::is_grad_enabled;
24use axonml_tensor::Tensor;
25
26use crate::module::Module;
27
28pub struct MaxPool1d {
38 kernel_size: usize,
39 stride: usize,
40 padding: usize,
41}
42
43impl MaxPool1d {
44 pub fn new(kernel_size: usize) -> Self {
46 Self {
47 kernel_size,
48 stride: kernel_size,
49 padding: 0,
50 }
51 }
52
53 pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
55 Self {
56 kernel_size,
57 stride,
58 padding,
59 }
60 }
61}
62
63impl Module for MaxPool1d {
64 fn forward(&self, input: &Variable) -> Variable {
65 let shape = input.shape();
66 let batch = shape[0];
67 let channels = shape[1];
68 let length = shape[2];
69
70 let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
71
72 let input_vec = input.data().to_vec();
73 let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_length];
74 let mut max_indices = vec![0usize; batch * channels * out_length];
75
76 for b in 0..batch {
77 for c in 0..channels {
78 for ol in 0..out_length {
79 let in_start = ol * self.stride;
80 let mut max_val = f32::NEG_INFINITY;
81 let mut max_idx = 0;
82
83 for k in 0..self.kernel_size {
84 let il = in_start + k;
85 if il >= self.padding && il < length + self.padding {
86 let actual_il = il - self.padding;
87 let idx = b * channels * length + c * length + actual_il;
88 if input_vec[idx] > max_val {
89 max_val = input_vec[idx];
90 max_idx = idx;
91 }
92 }
93 }
94
95 let out_idx = b * channels * out_length + c * out_length + ol;
96 output_data[out_idx] = max_val;
97 max_indices[out_idx] = max_idx;
98 }
99 }
100 }
101
102 let output = Tensor::from_vec(output_data, &[batch, channels, out_length])
103 .expect("tensor creation failed");
104
105 let requires_grad = input.requires_grad() && is_grad_enabled();
106 if requires_grad {
107 let grad_fn = GradFn::new(MaxPool1dBackward::new(
108 input.grad_fn().cloned(),
109 shape,
110 max_indices,
111 ));
112 Variable::from_operation(output, grad_fn, true)
113 } else {
114 Variable::new(output, false)
115 }
116 }
117
118 fn name(&self) -> &'static str {
119 "MaxPool1d"
120 }
121}
122
123pub struct MaxPool2d {
133 kernel_size: (usize, usize),
134 stride: (usize, usize),
135 padding: (usize, usize),
136}
137
138impl MaxPool2d {
139 pub fn new(kernel_size: usize) -> Self {
141 Self {
142 kernel_size: (kernel_size, kernel_size),
143 stride: (kernel_size, kernel_size),
144 padding: (0, 0),
145 }
146 }
147
148 pub fn with_options(
150 kernel_size: (usize, usize),
151 stride: (usize, usize),
152 padding: (usize, usize),
153 ) -> Self {
154 Self {
155 kernel_size,
156 stride,
157 padding,
158 }
159 }
160}
161
162impl Module for MaxPool2d {
163 fn forward(&self, input: &Variable) -> Variable {
164 let shape = input.shape();
165 let batch = shape[0];
166 let channels = shape[1];
167 let height = shape[2];
168 let width = shape[3];
169
170 let (kh, kw) = self.kernel_size;
171 let (sh, sw) = self.stride;
172 let (ph, pw) = self.padding;
173
174 let out_h = (height + 2 * ph - kh) / sh + 1;
175 let out_w = (width + 2 * pw - kw) / sw + 1;
176
177 #[cfg(feature = "cuda")]
179 {
180 if let Some((gpu_output, gpu_indices)) =
181 input
182 .data()
183 .maxpool2d_cuda(self.kernel_size, self.stride, self.padding)
184 {
185 let max_indices: Vec<usize> = gpu_indices.iter().map(|&i| i as usize).collect();
186
187 let requires_grad = input.requires_grad() && is_grad_enabled();
188 if requires_grad {
189 let grad_fn = GradFn::new(MaxPool2dBackward::new(
190 input.grad_fn().cloned(),
191 shape,
192 max_indices,
193 self.kernel_size,
194 self.stride,
195 self.padding,
196 ));
197 return Variable::from_operation(gpu_output, grad_fn, true);
198 } else {
199 return Variable::new(gpu_output, false);
200 }
201 }
202 }
203
204 let input_vec = input.data().to_vec();
206 let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_h * out_w];
207 let mut max_indices = vec![0usize; batch * channels * out_h * out_w];
208
209 for b in 0..batch {
210 for c in 0..channels {
211 for oh in 0..out_h {
212 for ow in 0..out_w {
213 let mut max_val = f32::NEG_INFINITY;
214 let mut max_idx = 0;
215
216 for ki in 0..kh {
217 for kj in 0..kw {
218 let ih = oh * sh + ki;
219 let iw = ow * sw + kj;
220
221 if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
222 let actual_ih = ih - ph;
223 let actual_iw = iw - pw;
224 let idx = b * channels * height * width
225 + c * height * width
226 + actual_ih * width
227 + actual_iw;
228 if input_vec[idx] > max_val {
229 max_val = input_vec[idx];
230 max_idx = idx;
231 }
232 }
233 }
234 }
235
236 let out_idx =
237 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
238 output_data[out_idx] = max_val;
239 max_indices[out_idx] = max_idx;
240 }
241 }
242 }
243 }
244
245 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
246 .expect("tensor creation failed");
247
248 let requires_grad = input.requires_grad() && is_grad_enabled();
249 if requires_grad {
250 let grad_fn = GradFn::new(MaxPool2dBackward::new(
251 input.grad_fn().cloned(),
252 shape,
253 max_indices,
254 self.kernel_size,
255 self.stride,
256 self.padding,
257 ));
258 Variable::from_operation(output, grad_fn, true)
259 } else {
260 Variable::new(output, false)
261 }
262 }
263
264 fn name(&self) -> &'static str {
265 "MaxPool2d"
266 }
267}
268
269pub struct AvgPool1d {
275 kernel_size: usize,
276 stride: usize,
277 padding: usize,
278}
279
280impl AvgPool1d {
281 pub fn new(kernel_size: usize) -> Self {
283 Self {
284 kernel_size,
285 stride: kernel_size,
286 padding: 0,
287 }
288 }
289
290 pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
292 Self {
293 kernel_size,
294 stride,
295 padding,
296 }
297 }
298}
299
300impl Module for AvgPool1d {
301 fn forward(&self, input: &Variable) -> Variable {
302 let shape = input.shape();
303 let batch = shape[0];
304 let channels = shape[1];
305 let length = shape[2];
306
307 let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
308
309 let input_vec = input.data().to_vec();
310 let mut output_data = vec![0.0f32; batch * channels * out_length];
311
312 for b in 0..batch {
313 for c in 0..channels {
314 for ol in 0..out_length {
315 let in_start = ol * self.stride;
316 let mut sum = 0.0f32;
317 let mut count = 0;
318
319 for k in 0..self.kernel_size {
320 let il = in_start + k;
321 if il >= self.padding && il < length + self.padding {
322 let actual_il = il - self.padding;
323 let idx = b * channels * length + c * length + actual_il;
324 sum += input_vec[idx];
325 count += 1;
326 }
327 }
328
329 let out_idx = b * channels * out_length + c * out_length + ol;
330 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
331 }
332 }
333 }
334
335 let output = Tensor::from_vec(output_data, &[batch, channels, out_length])
336 .expect("tensor creation failed");
337
338 let requires_grad = input.requires_grad() && is_grad_enabled();
339 if requires_grad {
340 let grad_fn = GradFn::new(AvgPool1dBackward::new(
341 input.grad_fn().cloned(),
342 shape,
343 self.kernel_size,
344 self.stride,
345 self.padding,
346 ));
347 Variable::from_operation(output, grad_fn, true)
348 } else {
349 Variable::new(output, false)
350 }
351 }
352
353 fn name(&self) -> &'static str {
354 "AvgPool1d"
355 }
356}
357
358pub struct AvgPool2d {
364 kernel_size: (usize, usize),
365 stride: (usize, usize),
366 padding: (usize, usize),
367}
368
369impl AvgPool2d {
370 pub fn new(kernel_size: usize) -> Self {
372 Self {
373 kernel_size: (kernel_size, kernel_size),
374 stride: (kernel_size, kernel_size),
375 padding: (0, 0),
376 }
377 }
378
379 pub fn with_options(
381 kernel_size: (usize, usize),
382 stride: (usize, usize),
383 padding: (usize, usize),
384 ) -> Self {
385 Self {
386 kernel_size,
387 stride,
388 padding,
389 }
390 }
391}
392
393impl Module for AvgPool2d {
394 fn forward(&self, input: &Variable) -> Variable {
395 let shape = input.shape();
396 let batch = shape[0];
397 let channels = shape[1];
398 let height = shape[2];
399 let width = shape[3];
400
401 let (kh, kw) = self.kernel_size;
402 let (sh, sw) = self.stride;
403 let (ph, pw) = self.padding;
404
405 let out_h = (height + 2 * ph - kh) / sh + 1;
406 let out_w = (width + 2 * pw - kw) / sw + 1;
407
408 #[cfg(feature = "cuda")]
410 {
411 if let Some(gpu_output) = input.data().avgpool2d_cuda(
412 self.kernel_size,
413 self.stride,
414 self.padding,
415 false, ) {
417 let requires_grad = input.requires_grad() && is_grad_enabled();
418 if requires_grad {
419 let grad_fn = GradFn::new(AvgPool2dBackward::new(
420 input.grad_fn().cloned(),
421 shape,
422 self.kernel_size,
423 self.stride,
424 self.padding,
425 ));
426 return Variable::from_operation(gpu_output, grad_fn, true);
427 } else {
428 return Variable::new(gpu_output, false);
429 }
430 }
431 }
432
433 let input_vec = input.data().to_vec();
435 let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
436
437 for b in 0..batch {
438 for c in 0..channels {
439 for oh in 0..out_h {
440 for ow in 0..out_w {
441 let mut sum = 0.0f32;
442 let mut count = 0;
443
444 for ki in 0..kh {
445 for kj in 0..kw {
446 let ih = oh * sh + ki;
447 let iw = ow * sw + kj;
448
449 if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
450 let actual_ih = ih - ph;
451 let actual_iw = iw - pw;
452 let idx = b * channels * height * width
453 + c * height * width
454 + actual_ih * width
455 + actual_iw;
456 sum += input_vec[idx];
457 count += 1;
458 }
459 }
460 }
461
462 let out_idx =
463 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
464 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
465 }
466 }
467 }
468 }
469
470 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
471 .expect("tensor creation failed");
472
473 let requires_grad = input.requires_grad() && is_grad_enabled();
474 if requires_grad {
475 let grad_fn = GradFn::new(AvgPool2dBackward::new(
476 input.grad_fn().cloned(),
477 shape,
478 self.kernel_size,
479 self.stride,
480 self.padding,
481 ));
482 Variable::from_operation(output, grad_fn, true)
483 } else {
484 Variable::new(output, false)
485 }
486 }
487
488 fn name(&self) -> &'static str {
489 "AvgPool2d"
490 }
491}
492
493pub struct AdaptiveAvgPool2d {
499 output_size: (usize, usize),
500}
501
502impl AdaptiveAvgPool2d {
503 pub fn new(output_size: (usize, usize)) -> Self {
505 Self { output_size }
506 }
507
508 pub fn square(size: usize) -> Self {
510 Self {
511 output_size: (size, size),
512 }
513 }
514}
515
516impl Module for AdaptiveAvgPool2d {
517 fn forward(&self, input: &Variable) -> Variable {
518 let shape = input.shape();
519 let batch = shape[0];
520 let channels = shape[1];
521 let in_h = shape[2];
522 let in_w = shape[3];
523
524 let (out_h, out_w) = self.output_size;
525 let input_vec = input.data().to_vec();
526 let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
527
528 for b in 0..batch {
529 for c in 0..channels {
530 for oh in 0..out_h {
531 for ow in 0..out_w {
532 let ih_start = (oh * in_h) / out_h;
533 let ih_end = ((oh + 1) * in_h) / out_h;
534 let iw_start = (ow * in_w) / out_w;
535 let iw_end = ((ow + 1) * in_w) / out_w;
536
537 let mut sum = 0.0f32;
538 let mut count = 0;
539
540 for ih in ih_start..ih_end {
541 for iw in iw_start..iw_end {
542 let idx =
543 b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
544 sum += input_vec[idx];
545 count += 1;
546 }
547 }
548
549 let out_idx =
550 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
551 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
552 }
553 }
554 }
555 }
556
557 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
558 .expect("tensor creation failed");
559
560 let requires_grad = input.requires_grad() && is_grad_enabled();
561 if requires_grad {
562 let grad_fn = GradFn::new(AdaptiveAvgPool2dBackward::new(
563 input.grad_fn().cloned(),
564 shape,
565 self.output_size,
566 ));
567 Variable::from_operation(output, grad_fn, true)
568 } else {
569 Variable::new(output, false)
570 }
571 }
572
573 fn name(&self) -> &'static str {
574 "AdaptiveAvgPool2d"
575 }
576}
577
578#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[test]
587 fn test_maxpool2d() {
588 let pool = MaxPool2d::new(2);
589 let input = Variable::new(
590 Tensor::from_vec(
591 vec![
592 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
593 15.0, 16.0,
594 ],
595 &[1, 1, 4, 4],
596 )
597 .unwrap(),
598 false,
599 );
600 let output = pool.forward(&input);
601 assert_eq!(output.shape(), vec![1, 1, 2, 2]);
602 assert_eq!(output.data().to_vec(), vec![6.0, 8.0, 14.0, 16.0]);
603 }
604
605 #[test]
606 fn test_maxpool2d_backward() {
607 let pool = MaxPool2d::new(2);
608 let input = Variable::new(
609 Tensor::from_vec(
610 vec![
611 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
612 15.0, 16.0,
613 ],
614 &[1, 1, 4, 4],
615 )
616 .unwrap(),
617 true,
618 );
619 let output = pool.forward(&input);
620 let loss = output.sum();
621 loss.backward();
622
623 assert!(input.grad().is_some(), "MaxPool2d: gradient should flow");
624 let grad = input.grad().unwrap();
625 assert_eq!(grad.shape(), &[1, 1, 4, 4]);
626 let grad_vec = grad.to_vec();
627 assert_eq!(grad_vec[5], 1.0);
629 assert_eq!(grad_vec[7], 1.0);
630 assert_eq!(grad_vec[13], 1.0);
631 assert_eq!(grad_vec[15], 1.0);
632 assert_eq!(grad_vec[0], 0.0);
633 }
634
635 #[test]
636 fn test_avgpool2d() {
637 let pool = AvgPool2d::new(2);
638 let input = Variable::new(
639 Tensor::from_vec(
640 vec![
641 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
642 15.0, 16.0,
643 ],
644 &[1, 1, 4, 4],
645 )
646 .unwrap(),
647 false,
648 );
649 let output = pool.forward(&input);
650 assert_eq!(output.shape(), vec![1, 1, 2, 2]);
651 assert_eq!(output.data().to_vec(), vec![3.5, 5.5, 11.5, 13.5]);
652 }
653
654 #[test]
655 fn test_avgpool2d_backward() {
656 let pool = AvgPool2d::new(2);
657 let input = Variable::new(
658 Tensor::from_vec(vec![1.0; 16], &[1, 1, 4, 4]).expect("tensor creation failed"),
659 true,
660 );
661 let output = pool.forward(&input);
662 let loss = output.sum();
663 loss.backward();
664
665 assert!(input.grad().is_some(), "AvgPool2d: gradient should flow");
666 let grad = input.grad().unwrap();
667 for &v in &grad.to_vec() {
669 assert!((v - 0.25).abs() < 1e-6);
670 }
671 }
672
673 #[test]
674 fn test_adaptive_avgpool2d() {
675 let pool = AdaptiveAvgPool2d::new((1, 1));
676 let input = Variable::new(
677 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2])
678 .expect("tensor creation failed"),
679 false,
680 );
681 let output = pool.forward(&input);
682 assert_eq!(output.shape(), vec![1, 1, 1, 1]);
683 assert_eq!(output.data().to_vec(), vec![2.5]);
684 }
685
686 #[test]
687 fn test_adaptive_avgpool2d_backward() {
688 let pool = AdaptiveAvgPool2d::new((1, 1));
689 let input = Variable::new(
690 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2])
691 .expect("tensor creation failed"),
692 true,
693 );
694 let output = pool.forward(&input);
695 let loss = output.sum();
696 loss.backward();
697
698 assert!(
699 input.grad().is_some(),
700 "AdaptiveAvgPool2d: gradient should flow"
701 );
702 let grad = input.grad().unwrap();
703 for &v in &grad.to_vec() {
704 assert!((v - 0.25).abs() < 1e-6);
705 }
706 }
707}