1use axonml_autograd::Variable;
18use axonml_autograd::functions::{
19 AdaptiveAvgPool2dBackward, AvgPool1dBackward, AvgPool2dBackward, MaxPool1dBackward,
20 MaxPool2dBackward,
21};
22use axonml_autograd::grad_fn::GradFn;
23use axonml_autograd::no_grad::is_grad_enabled;
24use axonml_tensor::Tensor;
25
26use crate::module::Module;
27
28pub struct MaxPool1d {
38 kernel_size: usize,
39 stride: usize,
40 padding: usize,
41}
42
43impl MaxPool1d {
44 pub fn new(kernel_size: usize) -> Self {
46 Self {
47 kernel_size,
48 stride: kernel_size,
49 padding: 0,
50 }
51 }
52
53 pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
55 Self {
56 kernel_size,
57 stride,
58 padding,
59 }
60 }
61}
62
63impl Module for MaxPool1d {
64 fn forward(&self, input: &Variable) -> Variable {
65 let shape = input.shape();
66 let batch = shape[0];
67 let channels = shape[1];
68 let length = shape[2];
69
70 let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
71
72 let input_vec = input.data().to_vec();
73 let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_length];
74 let mut max_indices = vec![0usize; batch * channels * out_length];
75
76 for b in 0..batch {
77 for c in 0..channels {
78 for ol in 0..out_length {
79 let in_start = ol * self.stride;
80 let mut max_val = f32::NEG_INFINITY;
81 let mut max_idx = 0;
82
83 for k in 0..self.kernel_size {
84 let il = in_start + k;
85 if il >= self.padding && il < length + self.padding {
86 let actual_il = il - self.padding;
87 let idx = b * channels * length + c * length + actual_il;
88 if input_vec[idx] > max_val {
89 max_val = input_vec[idx];
90 max_idx = idx;
91 }
92 }
93 }
94
95 let out_idx = b * channels * out_length + c * out_length + ol;
96 output_data[out_idx] = max_val;
97 max_indices[out_idx] = max_idx;
98 }
99 }
100 }
101
102 let output = Tensor::from_vec(output_data, &[batch, channels, out_length]).unwrap();
103
104 let requires_grad = input.requires_grad() && is_grad_enabled();
105 if requires_grad {
106 let grad_fn = GradFn::new(MaxPool1dBackward::new(
107 input.grad_fn().cloned(),
108 shape,
109 max_indices,
110 ));
111 Variable::from_operation(output, grad_fn, true)
112 } else {
113 Variable::new(output, false)
114 }
115 }
116
117 fn name(&self) -> &'static str {
118 "MaxPool1d"
119 }
120}
121
122pub struct MaxPool2d {
132 kernel_size: (usize, usize),
133 stride: (usize, usize),
134 padding: (usize, usize),
135}
136
137impl MaxPool2d {
138 pub fn new(kernel_size: usize) -> Self {
140 Self {
141 kernel_size: (kernel_size, kernel_size),
142 stride: (kernel_size, kernel_size),
143 padding: (0, 0),
144 }
145 }
146
147 pub fn with_options(
149 kernel_size: (usize, usize),
150 stride: (usize, usize),
151 padding: (usize, usize),
152 ) -> Self {
153 Self {
154 kernel_size,
155 stride,
156 padding,
157 }
158 }
159}
160
161impl Module for MaxPool2d {
162 fn forward(&self, input: &Variable) -> Variable {
163 let shape = input.shape();
164 let batch = shape[0];
165 let channels = shape[1];
166 let height = shape[2];
167 let width = shape[3];
168
169 let (kh, kw) = self.kernel_size;
170 let (sh, sw) = self.stride;
171 let (ph, pw) = self.padding;
172
173 let out_h = (height + 2 * ph - kh) / sh + 1;
174 let out_w = (width + 2 * pw - kw) / sw + 1;
175
176 #[cfg(feature = "cuda")]
178 {
179 if let Some((gpu_output, gpu_indices)) =
180 input
181 .data()
182 .maxpool2d_cuda(self.kernel_size, self.stride, self.padding)
183 {
184 let max_indices: Vec<usize> = gpu_indices.iter().map(|&i| i as usize).collect();
185
186 let requires_grad = input.requires_grad() && is_grad_enabled();
187 if requires_grad {
188 let grad_fn = GradFn::new(MaxPool2dBackward::new(
189 input.grad_fn().cloned(),
190 shape,
191 max_indices,
192 self.kernel_size,
193 self.stride,
194 self.padding,
195 ));
196 return Variable::from_operation(gpu_output, grad_fn, true);
197 } else {
198 return Variable::new(gpu_output, false);
199 }
200 }
201 }
202
203 let input_vec = input.data().to_vec();
205 let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_h * out_w];
206 let mut max_indices = vec![0usize; batch * channels * out_h * out_w];
207
208 for b in 0..batch {
209 for c in 0..channels {
210 for oh in 0..out_h {
211 for ow in 0..out_w {
212 let mut max_val = f32::NEG_INFINITY;
213 let mut max_idx = 0;
214
215 for ki in 0..kh {
216 for kj in 0..kw {
217 let ih = oh * sh + ki;
218 let iw = ow * sw + kj;
219
220 if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
221 let actual_ih = ih - ph;
222 let actual_iw = iw - pw;
223 let idx = b * channels * height * width
224 + c * height * width
225 + actual_ih * width
226 + actual_iw;
227 if input_vec[idx] > max_val {
228 max_val = input_vec[idx];
229 max_idx = idx;
230 }
231 }
232 }
233 }
234
235 let out_idx =
236 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
237 output_data[out_idx] = max_val;
238 max_indices[out_idx] = max_idx;
239 }
240 }
241 }
242 }
243
244 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
245
246 let requires_grad = input.requires_grad() && is_grad_enabled();
247 if requires_grad {
248 let grad_fn = GradFn::new(MaxPool2dBackward::new(
249 input.grad_fn().cloned(),
250 shape,
251 max_indices,
252 self.kernel_size,
253 self.stride,
254 self.padding,
255 ));
256 Variable::from_operation(output, grad_fn, true)
257 } else {
258 Variable::new(output, false)
259 }
260 }
261
262 fn name(&self) -> &'static str {
263 "MaxPool2d"
264 }
265}
266
267pub struct AvgPool1d {
273 kernel_size: usize,
274 stride: usize,
275 padding: usize,
276}
277
278impl AvgPool1d {
279 pub fn new(kernel_size: usize) -> Self {
281 Self {
282 kernel_size,
283 stride: kernel_size,
284 padding: 0,
285 }
286 }
287
288 pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
290 Self {
291 kernel_size,
292 stride,
293 padding,
294 }
295 }
296}
297
298impl Module for AvgPool1d {
299 fn forward(&self, input: &Variable) -> Variable {
300 let shape = input.shape();
301 let batch = shape[0];
302 let channels = shape[1];
303 let length = shape[2];
304
305 let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
306
307 let input_vec = input.data().to_vec();
308 let mut output_data = vec![0.0f32; batch * channels * out_length];
309
310 for b in 0..batch {
311 for c in 0..channels {
312 for ol in 0..out_length {
313 let in_start = ol * self.stride;
314 let mut sum = 0.0f32;
315 let mut count = 0;
316
317 for k in 0..self.kernel_size {
318 let il = in_start + k;
319 if il >= self.padding && il < length + self.padding {
320 let actual_il = il - self.padding;
321 let idx = b * channels * length + c * length + actual_il;
322 sum += input_vec[idx];
323 count += 1;
324 }
325 }
326
327 let out_idx = b * channels * out_length + c * out_length + ol;
328 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
329 }
330 }
331 }
332
333 let output = Tensor::from_vec(output_data, &[batch, channels, out_length]).unwrap();
334
335 let requires_grad = input.requires_grad() && is_grad_enabled();
336 if requires_grad {
337 let grad_fn = GradFn::new(AvgPool1dBackward::new(
338 input.grad_fn().cloned(),
339 shape,
340 self.kernel_size,
341 self.stride,
342 self.padding,
343 ));
344 Variable::from_operation(output, grad_fn, true)
345 } else {
346 Variable::new(output, false)
347 }
348 }
349
350 fn name(&self) -> &'static str {
351 "AvgPool1d"
352 }
353}
354
355pub struct AvgPool2d {
361 kernel_size: (usize, usize),
362 stride: (usize, usize),
363 padding: (usize, usize),
364}
365
366impl AvgPool2d {
367 pub fn new(kernel_size: usize) -> Self {
369 Self {
370 kernel_size: (kernel_size, kernel_size),
371 stride: (kernel_size, kernel_size),
372 padding: (0, 0),
373 }
374 }
375
376 pub fn with_options(
378 kernel_size: (usize, usize),
379 stride: (usize, usize),
380 padding: (usize, usize),
381 ) -> Self {
382 Self {
383 kernel_size,
384 stride,
385 padding,
386 }
387 }
388}
389
390impl Module for AvgPool2d {
391 fn forward(&self, input: &Variable) -> Variable {
392 let shape = input.shape();
393 let batch = shape[0];
394 let channels = shape[1];
395 let height = shape[2];
396 let width = shape[3];
397
398 let (kh, kw) = self.kernel_size;
399 let (sh, sw) = self.stride;
400 let (ph, pw) = self.padding;
401
402 let out_h = (height + 2 * ph - kh) / sh + 1;
403 let out_w = (width + 2 * pw - kw) / sw + 1;
404
405 #[cfg(feature = "cuda")]
407 {
408 if let Some(gpu_output) = input.data().avgpool2d_cuda(
409 self.kernel_size,
410 self.stride,
411 self.padding,
412 false, ) {
414 let requires_grad = input.requires_grad() && is_grad_enabled();
415 if requires_grad {
416 let grad_fn = GradFn::new(AvgPool2dBackward::new(
417 input.grad_fn().cloned(),
418 shape,
419 self.kernel_size,
420 self.stride,
421 self.padding,
422 ));
423 return Variable::from_operation(gpu_output, grad_fn, true);
424 } else {
425 return Variable::new(gpu_output, false);
426 }
427 }
428 }
429
430 let input_vec = input.data().to_vec();
432 let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
433
434 for b in 0..batch {
435 for c in 0..channels {
436 for oh in 0..out_h {
437 for ow in 0..out_w {
438 let mut sum = 0.0f32;
439 let mut count = 0;
440
441 for ki in 0..kh {
442 for kj in 0..kw {
443 let ih = oh * sh + ki;
444 let iw = ow * sw + kj;
445
446 if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
447 let actual_ih = ih - ph;
448 let actual_iw = iw - pw;
449 let idx = b * channels * height * width
450 + c * height * width
451 + actual_ih * width
452 + actual_iw;
453 sum += input_vec[idx];
454 count += 1;
455 }
456 }
457 }
458
459 let out_idx =
460 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
461 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
462 }
463 }
464 }
465 }
466
467 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
468
469 let requires_grad = input.requires_grad() && is_grad_enabled();
470 if requires_grad {
471 let grad_fn = GradFn::new(AvgPool2dBackward::new(
472 input.grad_fn().cloned(),
473 shape,
474 self.kernel_size,
475 self.stride,
476 self.padding,
477 ));
478 Variable::from_operation(output, grad_fn, true)
479 } else {
480 Variable::new(output, false)
481 }
482 }
483
484 fn name(&self) -> &'static str {
485 "AvgPool2d"
486 }
487}
488
489pub struct AdaptiveAvgPool2d {
495 output_size: (usize, usize),
496}
497
498impl AdaptiveAvgPool2d {
499 pub fn new(output_size: (usize, usize)) -> Self {
501 Self { output_size }
502 }
503
504 pub fn square(size: usize) -> Self {
506 Self {
507 output_size: (size, size),
508 }
509 }
510}
511
512impl Module for AdaptiveAvgPool2d {
513 fn forward(&self, input: &Variable) -> Variable {
514 let shape = input.shape();
515 let batch = shape[0];
516 let channels = shape[1];
517 let in_h = shape[2];
518 let in_w = shape[3];
519
520 let (out_h, out_w) = self.output_size;
521 let input_vec = input.data().to_vec();
522 let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
523
524 for b in 0..batch {
525 for c in 0..channels {
526 for oh in 0..out_h {
527 for ow in 0..out_w {
528 let ih_start = (oh * in_h) / out_h;
529 let ih_end = ((oh + 1) * in_h) / out_h;
530 let iw_start = (ow * in_w) / out_w;
531 let iw_end = ((ow + 1) * in_w) / out_w;
532
533 let mut sum = 0.0f32;
534 let mut count = 0;
535
536 for ih in ih_start..ih_end {
537 for iw in iw_start..iw_end {
538 let idx =
539 b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
540 sum += input_vec[idx];
541 count += 1;
542 }
543 }
544
545 let out_idx =
546 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
547 output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
548 }
549 }
550 }
551 }
552
553 let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
554
555 let requires_grad = input.requires_grad() && is_grad_enabled();
556 if requires_grad {
557 let grad_fn = GradFn::new(AdaptiveAvgPool2dBackward::new(
558 input.grad_fn().cloned(),
559 shape,
560 self.output_size,
561 ));
562 Variable::from_operation(output, grad_fn, true)
563 } else {
564 Variable::new(output, false)
565 }
566 }
567
568 fn name(&self) -> &'static str {
569 "AdaptiveAvgPool2d"
570 }
571}
572
573#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_maxpool2d() {
583 let pool = MaxPool2d::new(2);
584 let input = Variable::new(
585 Tensor::from_vec(
586 vec![
587 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
588 15.0, 16.0,
589 ],
590 &[1, 1, 4, 4],
591 )
592 .unwrap(),
593 false,
594 );
595 let output = pool.forward(&input);
596 assert_eq!(output.shape(), vec![1, 1, 2, 2]);
597 assert_eq!(output.data().to_vec(), vec![6.0, 8.0, 14.0, 16.0]);
598 }
599
600 #[test]
601 fn test_maxpool2d_backward() {
602 let pool = MaxPool2d::new(2);
603 let input = Variable::new(
604 Tensor::from_vec(
605 vec![
606 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
607 15.0, 16.0,
608 ],
609 &[1, 1, 4, 4],
610 )
611 .unwrap(),
612 true,
613 );
614 let output = pool.forward(&input);
615 let loss = output.sum();
616 loss.backward();
617
618 assert!(input.grad().is_some(), "MaxPool2d: gradient should flow");
619 let grad = input.grad().unwrap();
620 assert_eq!(grad.shape(), &[1, 1, 4, 4]);
621 let grad_vec = grad.to_vec();
622 assert_eq!(grad_vec[5], 1.0);
624 assert_eq!(grad_vec[7], 1.0);
625 assert_eq!(grad_vec[13], 1.0);
626 assert_eq!(grad_vec[15], 1.0);
627 assert_eq!(grad_vec[0], 0.0);
628 }
629
630 #[test]
631 fn test_avgpool2d() {
632 let pool = AvgPool2d::new(2);
633 let input = Variable::new(
634 Tensor::from_vec(
635 vec![
636 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
637 15.0, 16.0,
638 ],
639 &[1, 1, 4, 4],
640 )
641 .unwrap(),
642 false,
643 );
644 let output = pool.forward(&input);
645 assert_eq!(output.shape(), vec![1, 1, 2, 2]);
646 assert_eq!(output.data().to_vec(), vec![3.5, 5.5, 11.5, 13.5]);
647 }
648
649 #[test]
650 fn test_avgpool2d_backward() {
651 let pool = AvgPool2d::new(2);
652 let input = Variable::new(
653 Tensor::from_vec(vec![1.0; 16], &[1, 1, 4, 4]).unwrap(),
654 true,
655 );
656 let output = pool.forward(&input);
657 let loss = output.sum();
658 loss.backward();
659
660 assert!(input.grad().is_some(), "AvgPool2d: gradient should flow");
661 let grad = input.grad().unwrap();
662 for &v in &grad.to_vec() {
664 assert!((v - 0.25).abs() < 1e-6);
665 }
666 }
667
668 #[test]
669 fn test_adaptive_avgpool2d() {
670 let pool = AdaptiveAvgPool2d::new((1, 1));
671 let input = Variable::new(
672 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap(),
673 false,
674 );
675 let output = pool.forward(&input);
676 assert_eq!(output.shape(), vec![1, 1, 1, 1]);
677 assert_eq!(output.data().to_vec(), vec![2.5]);
678 }
679
680 #[test]
681 fn test_adaptive_avgpool2d_backward() {
682 let pool = AdaptiveAvgPool2d::new((1, 1));
683 let input = Variable::new(
684 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap(),
685 true,
686 );
687 let output = pool.forward(&input);
688 let loss = output.sum();
689 loss.backward();
690
691 assert!(
692 input.grad().is_some(),
693 "AdaptiveAvgPool2d: gradient should flow"
694 );
695 let grad = input.grad().unwrap();
696 for &v in &grad.to_vec() {
697 assert!((v - 0.25).abs() < 1e-6);
698 }
699 }
700}