1use axonml_autograd::Variable;
24use axonml_autograd::functions::{
25 AdaptiveAvgPool2dBackward, AvgPool1dBackward, AvgPool2dBackward, MaxPool1dBackward,
26 MaxPool2dBackward,
27};
28use axonml_autograd::grad_fn::GradFn;
29use axonml_autograd::no_grad::is_grad_enabled;
30use axonml_tensor::Tensor;
31
32use crate::module::Module;
33
34pub struct MaxPool1d {
44 kernel_size: usize,
45 stride: usize,
46 padding: usize,
47}
48
49impl MaxPool1d {
50 pub fn new(kernel_size: usize) -> Self {
52 Self {
53 kernel_size,
54 stride: kernel_size,
55 padding: 0,
56 }
57 }
58
59 pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
61 Self {
62 kernel_size,
63 stride,
64 padding,
65 }
66 }
67}
68
69impl Module for MaxPool1d {
70 fn forward(&self, input: &Variable) -> Variable {
71 let shape = input.shape();
72 let batch = shape[0];
73 let channels = shape[1];
74 let length = shape[2];
75
76 let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
77
78 let input_vec = input.data().to_vec();
79 let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_length];
80 let mut max_indices = vec![0usize; batch * channels * out_length];
81
82 for b in 0..batch {
83 for c in 0..channels {
84 for ol in 0..out_length {
85 let in_start = ol * self.stride;
86 let mut max_val = f32::NEG_INFINITY;
87 let mut max_idx = 0;
88
89 for k in 0..self.kernel_size {
90 let il = in_start + k;
91 if il >= self.padding && il < length + self.padding {
92 let actual_il = il - self.padding;
93 let idx = b * channels * length + c * length + actual_il;
94 if input_vec[idx] > max_val {
95 max_val = input_vec[idx];
96 max_idx = idx;
97 }
98 }
99 }
100
101 let out_idx = b * channels * out_length + c * out_length + ol;
102 output_data[out_idx] = max_val;
103 max_indices[out_idx] = max_idx;
104 }
105 }
106 }
107
108 let output = Tensor::from_vec(output_data, &[batch, channels, out_length])
109 .expect("tensor creation failed");
110
111 let requires_grad = input.requires_grad() && is_grad_enabled();
112 if requires_grad {
113 let grad_fn = GradFn::new(MaxPool1dBackward::new(
114 input.grad_fn().cloned(),
115 shape,
116 max_indices,
117 ));
118 Variable::from_operation(output, grad_fn, true)
119 } else {
120 Variable::new(output, false)
121 }
122 }
123
124 fn name(&self) -> &'static str {
125 "MaxPool1d"
126 }
127}
128
129pub struct MaxPool2d {
139 kernel_size: (usize, usize),
140 stride: (usize, usize),
141 padding: (usize, usize),
142}
143
144impl MaxPool2d {
145 pub fn new(kernel_size: usize) -> Self {
147 Self {
148 kernel_size: (kernel_size, kernel_size),
149 stride: (kernel_size, kernel_size),
150 padding: (0, 0),
151 }
152 }
153
154 pub fn with_options(
156 kernel_size: (usize, usize),
157 stride: (usize, usize),
158 padding: (usize, usize),
159 ) -> Self {
160 Self {
161 kernel_size,
162 stride,
163 padding,
164 }
165 }
166}
167
168impl Module for MaxPool2d {
169 fn forward(&self, input: &Variable) -> Variable {
170 let shape = input.shape();
171 let batch = shape[0];
172 let channels = shape[1];
173 let height = shape[2];
174 let width = shape[3];
175
176 let (kh, kw) = self.kernel_size;
177 let (sh, sw) = self.stride;
178 let (ph, pw) = self.padding;
179
180 let out_h = (height + 2 * ph - kh) / sh + 1;
181 let out_w = (width + 2 * pw - kw) / sw + 1;
182
183 #[cfg(feature = "cuda")]
185 {
186 if let Some((gpu_output, gpu_indices)) =
187 input
188 .data()
189 .maxpool2d_cuda(self.kernel_size, self.stride, self.padding)
190 {
191 let max_indices: Vec<usize> = gpu_indices.iter().map(|&i| i as usize).collect();
192
193 let requires_grad = input.requires_grad() && is_grad_enabled();
194 if requires_grad {
195 let grad_fn = GradFn::new(MaxPool2dBackward::new(
196 input.grad_fn().cloned(),
197 shape,
198 max_indices,
199 self.kernel_size,
200 self.stride,
201 self.padding,
202 ));
203 return Variable::from_operation(gpu_output, grad_fn, true);
204 } else {
205 return Variable::new(gpu_output, false);
206 }
207 }
208 }
209
210 let input_vec = input.data().to_vec();
212 let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_h * out_w];
213 let mut max_indices = vec![0usize; batch * channels * out_h * out_w];
214
215 for b in 0..batch {
216 for c in 0..channels {
217 for oh in 0..out_h {
218 for ow in 0..out_w {
219 let mut max_val = f32::NEG_INFINITY;
220 let mut max_idx = 0;
221
222 for ki in 0..kh {
223 for kj in 0..kw {
224 let ih = oh * sh + ki;
225 let iw = ow * sw + kj;
226
227 if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
228 let actual_ih = ih - ph;
229 let actual_iw = iw - pw;
230 let idx = b * channels * height * width
231 + c * height * width
232 + actual_ih * width
233 + actual_iw;
234 if input_vec[idx] > max_val {
235 max_val = input_vec[idx];
236 max_idx = idx;
237 }
238 }
239 }
240 }
241
242 let out_idx =
243 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
244 output_data[out_idx] = max_val;
245 max_indices[out_idx] = max_idx;
246 }
247 }
248 }
249 }
250
251 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
252 .expect("tensor creation failed");
253
254 let requires_grad = input.requires_grad() && is_grad_enabled();
255 if requires_grad {
256 let grad_fn = GradFn::new(MaxPool2dBackward::new(
257 input.grad_fn().cloned(),
258 shape,
259 max_indices,
260 self.kernel_size,
261 self.stride,
262 self.padding,
263 ));
264 Variable::from_operation(output, grad_fn, true)
265 } else {
266 Variable::new(output, false)
267 }
268 }
269
270 fn name(&self) -> &'static str {
271 "MaxPool2d"
272 }
273}
274
275pub struct AvgPool1d {
281 kernel_size: usize,
282 stride: usize,
283 padding: usize,
284}
285
286impl AvgPool1d {
287 pub fn new(kernel_size: usize) -> Self {
289 Self {
290 kernel_size,
291 stride: kernel_size,
292 padding: 0,
293 }
294 }
295
296 pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
298 Self {
299 kernel_size,
300 stride,
301 padding,
302 }
303 }
304}
305
306impl Module for AvgPool1d {
307 fn forward(&self, input: &Variable) -> Variable {
308 let shape = input.shape();
309 let batch = shape[0];
310 let channels = shape[1];
311 let length = shape[2];
312
313 let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
314
315 let input_vec = input.data().to_vec();
316 let mut output_data = vec![0.0f32; batch * channels * out_length];
317
318 for b in 0..batch {
319 for c in 0..channels {
320 for ol in 0..out_length {
321 let in_start = ol * self.stride;
322 let mut sum = 0.0f32;
323 let mut count = 0;
324
325 for k in 0..self.kernel_size {
326 let il = in_start + k;
327 if il >= self.padding && il < length + self.padding {
328 let actual_il = il - self.padding;
329 let idx = b * channels * length + c * length + actual_il;
330 sum += input_vec[idx];
331 count += 1;
332 }
333 }
334
335 let out_idx = b * channels * out_length + c * out_length + ol;
336 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
337 }
338 }
339 }
340
341 let output = Tensor::from_vec(output_data, &[batch, channels, out_length])
342 .expect("tensor creation failed");
343
344 let requires_grad = input.requires_grad() && is_grad_enabled();
345 if requires_grad {
346 let grad_fn = GradFn::new(AvgPool1dBackward::new(
347 input.grad_fn().cloned(),
348 shape,
349 self.kernel_size,
350 self.stride,
351 self.padding,
352 ));
353 Variable::from_operation(output, grad_fn, true)
354 } else {
355 Variable::new(output, false)
356 }
357 }
358
359 fn name(&self) -> &'static str {
360 "AvgPool1d"
361 }
362}
363
364pub struct AvgPool2d {
370 kernel_size: (usize, usize),
371 stride: (usize, usize),
372 padding: (usize, usize),
373}
374
375impl AvgPool2d {
376 pub fn new(kernel_size: usize) -> Self {
378 Self {
379 kernel_size: (kernel_size, kernel_size),
380 stride: (kernel_size, kernel_size),
381 padding: (0, 0),
382 }
383 }
384
385 pub fn with_options(
387 kernel_size: (usize, usize),
388 stride: (usize, usize),
389 padding: (usize, usize),
390 ) -> Self {
391 Self {
392 kernel_size,
393 stride,
394 padding,
395 }
396 }
397}
398
399impl Module for AvgPool2d {
400 fn forward(&self, input: &Variable) -> Variable {
401 let shape = input.shape();
402 let batch = shape[0];
403 let channels = shape[1];
404 let height = shape[2];
405 let width = shape[3];
406
407 let (kh, kw) = self.kernel_size;
408 let (sh, sw) = self.stride;
409 let (ph, pw) = self.padding;
410
411 let out_h = (height + 2 * ph - kh) / sh + 1;
412 let out_w = (width + 2 * pw - kw) / sw + 1;
413
414 #[cfg(feature = "cuda")]
416 {
417 if let Some(gpu_output) = input.data().avgpool2d_cuda(
418 self.kernel_size,
419 self.stride,
420 self.padding,
421 false, ) {
423 let requires_grad = input.requires_grad() && is_grad_enabled();
424 if requires_grad {
425 let grad_fn = GradFn::new(AvgPool2dBackward::new(
426 input.grad_fn().cloned(),
427 shape,
428 self.kernel_size,
429 self.stride,
430 self.padding,
431 ));
432 return Variable::from_operation(gpu_output, grad_fn, true);
433 } else {
434 return Variable::new(gpu_output, false);
435 }
436 }
437 }
438
439 let input_vec = input.data().to_vec();
441 let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
442
443 for b in 0..batch {
444 for c in 0..channels {
445 for oh in 0..out_h {
446 for ow in 0..out_w {
447 let mut sum = 0.0f32;
448 let mut count = 0;
449
450 for ki in 0..kh {
451 for kj in 0..kw {
452 let ih = oh * sh + ki;
453 let iw = ow * sw + kj;
454
455 if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
456 let actual_ih = ih - ph;
457 let actual_iw = iw - pw;
458 let idx = b * channels * height * width
459 + c * height * width
460 + actual_ih * width
461 + actual_iw;
462 sum += input_vec[idx];
463 count += 1;
464 }
465 }
466 }
467
468 let out_idx =
469 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
470 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
471 }
472 }
473 }
474 }
475
476 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
477 .expect("tensor creation failed");
478
479 let requires_grad = input.requires_grad() && is_grad_enabled();
480 if requires_grad {
481 let grad_fn = GradFn::new(AvgPool2dBackward::new(
482 input.grad_fn().cloned(),
483 shape,
484 self.kernel_size,
485 self.stride,
486 self.padding,
487 ));
488 Variable::from_operation(output, grad_fn, true)
489 } else {
490 Variable::new(output, false)
491 }
492 }
493
494 fn name(&self) -> &'static str {
495 "AvgPool2d"
496 }
497}
498
499pub struct AdaptiveAvgPool2d {
505 output_size: (usize, usize),
506}
507
508impl AdaptiveAvgPool2d {
509 pub fn new(output_size: (usize, usize)) -> Self {
511 Self { output_size }
512 }
513
514 pub fn square(size: usize) -> Self {
516 Self {
517 output_size: (size, size),
518 }
519 }
520}
521
522impl Module for AdaptiveAvgPool2d {
523 fn forward(&self, input: &Variable) -> Variable {
524 let shape = input.shape();
525 let batch = shape[0];
526 let channels = shape[1];
527 let in_h = shape[2];
528 let in_w = shape[3];
529
530 let (out_h, out_w) = self.output_size;
531 let input_vec = input.data().to_vec();
532 let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
533
534 for b in 0..batch {
535 for c in 0..channels {
536 for oh in 0..out_h {
537 for ow in 0..out_w {
538 let ih_start = (oh * in_h) / out_h;
539 let ih_end = ((oh + 1) * in_h) / out_h;
540 let iw_start = (ow * in_w) / out_w;
541 let iw_end = ((ow + 1) * in_w) / out_w;
542
543 let mut sum = 0.0f32;
544 let mut count = 0;
545
546 for ih in ih_start..ih_end {
547 for iw in iw_start..iw_end {
548 let idx =
549 b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
550 sum += input_vec[idx];
551 count += 1;
552 }
553 }
554
555 let out_idx =
556 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
557 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
558 }
559 }
560 }
561 }
562
563 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
564 .expect("tensor creation failed");
565
566 let requires_grad = input.requires_grad() && is_grad_enabled();
567 if requires_grad {
568 let grad_fn = GradFn::new(AdaptiveAvgPool2dBackward::new(
569 input.grad_fn().cloned(),
570 shape,
571 self.output_size,
572 ));
573 Variable::from_operation(output, grad_fn, true)
574 } else {
575 Variable::new(output, false)
576 }
577 }
578
579 fn name(&self) -> &'static str {
580 "AdaptiveAvgPool2d"
581 }
582}
583
584#[cfg(test)]
589mod tests {
590 use super::*;
591
592 #[test]
593 fn test_maxpool2d() {
594 let pool = MaxPool2d::new(2);
595 let input = Variable::new(
596 Tensor::from_vec(
597 vec![
598 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
599 15.0, 16.0,
600 ],
601 &[1, 1, 4, 4],
602 )
603 .unwrap(),
604 false,
605 );
606 let output = pool.forward(&input);
607 assert_eq!(output.shape(), vec![1, 1, 2, 2]);
608 assert_eq!(output.data().to_vec(), vec![6.0, 8.0, 14.0, 16.0]);
609 }
610
611 #[test]
612 fn test_maxpool2d_backward() {
613 let pool = MaxPool2d::new(2);
614 let input = Variable::new(
615 Tensor::from_vec(
616 vec![
617 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
618 15.0, 16.0,
619 ],
620 &[1, 1, 4, 4],
621 )
622 .unwrap(),
623 true,
624 );
625 let output = pool.forward(&input);
626 let loss = output.sum();
627 loss.backward();
628
629 assert!(input.grad().is_some(), "MaxPool2d: gradient should flow");
630 let grad = input.grad().unwrap();
631 assert_eq!(grad.shape(), &[1, 1, 4, 4]);
632 let grad_vec = grad.to_vec();
633 assert_eq!(grad_vec[5], 1.0);
635 assert_eq!(grad_vec[7], 1.0);
636 assert_eq!(grad_vec[13], 1.0);
637 assert_eq!(grad_vec[15], 1.0);
638 assert_eq!(grad_vec[0], 0.0);
639 }
640
641 #[test]
642 fn test_avgpool2d() {
643 let pool = AvgPool2d::new(2);
644 let input = Variable::new(
645 Tensor::from_vec(
646 vec![
647 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
648 15.0, 16.0,
649 ],
650 &[1, 1, 4, 4],
651 )
652 .unwrap(),
653 false,
654 );
655 let output = pool.forward(&input);
656 assert_eq!(output.shape(), vec![1, 1, 2, 2]);
657 assert_eq!(output.data().to_vec(), vec![3.5, 5.5, 11.5, 13.5]);
658 }
659
660 #[test]
661 fn test_avgpool2d_backward() {
662 let pool = AvgPool2d::new(2);
663 let input = Variable::new(
664 Tensor::from_vec(vec![1.0; 16], &[1, 1, 4, 4]).expect("tensor creation failed"),
665 true,
666 );
667 let output = pool.forward(&input);
668 let loss = output.sum();
669 loss.backward();
670
671 assert!(input.grad().is_some(), "AvgPool2d: gradient should flow");
672 let grad = input.grad().unwrap();
673 for &v in &grad.to_vec() {
675 assert!((v - 0.25).abs() < 1e-6);
676 }
677 }
678
679 #[test]
680 fn test_adaptive_avgpool2d() {
681 let pool = AdaptiveAvgPool2d::new((1, 1));
682 let input = Variable::new(
683 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2])
684 .expect("tensor creation failed"),
685 false,
686 );
687 let output = pool.forward(&input);
688 assert_eq!(output.shape(), vec![1, 1, 1, 1]);
689 assert_eq!(output.data().to_vec(), vec![2.5]);
690 }
691
692 #[test]
693 fn test_adaptive_avgpool2d_backward() {
694 let pool = AdaptiveAvgPool2d::new((1, 1));
695 let input = Variable::new(
696 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2])
697 .expect("tensor creation failed"),
698 true,
699 );
700 let output = pool.forward(&input);
701 let loss = output.sum();
702 loss.backward();
703
704 assert!(
705 input.grad().is_some(),
706 "AdaptiveAvgPool2d: gradient should flow"
707 );
708 let grad = input.grad().unwrap();
709 for &v in &grad.to_vec() {
710 assert!((v - 0.25).abs() < 1e-6);
711 }
712 }
713}