Skip to main content

axonml_nn/layers/
pooling.rs

1//! Pooling Layers - Max and Average Pooling
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/pooling.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_autograd::Variable;
18use axonml_autograd::functions::{
19    AdaptiveAvgPool2dBackward, AvgPool1dBackward, AvgPool2dBackward, MaxPool1dBackward,
20    MaxPool2dBackward,
21};
22use axonml_autograd::grad_fn::GradFn;
23use axonml_autograd::no_grad::is_grad_enabled;
24use axonml_tensor::Tensor;
25
26use crate::module::Module;
27
28// =============================================================================
29// MaxPool1d
30// =============================================================================
31
32/// Applies max pooling over a 1D signal.
33///
34/// # Shape
35/// - Input: (N, C, L)
36/// - Output: (N, C, L_out)
37pub struct MaxPool1d {
38    kernel_size: usize,
39    stride: usize,
40    padding: usize,
41}
42
43impl MaxPool1d {
44    /// Creates a new MaxPool1d layer.
45    pub fn new(kernel_size: usize) -> Self {
46        Self {
47            kernel_size,
48            stride: kernel_size,
49            padding: 0,
50        }
51    }
52
53    /// Creates a MaxPool1d with custom stride and padding.
54    pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
55        Self {
56            kernel_size,
57            stride,
58            padding,
59        }
60    }
61}
62
63impl Module for MaxPool1d {
64    fn forward(&self, input: &Variable) -> Variable {
65        let shape = input.shape();
66        let batch = shape[0];
67        let channels = shape[1];
68        let length = shape[2];
69
70        let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
71
72        let input_vec = input.data().to_vec();
73        let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_length];
74        let mut max_indices = vec![0usize; batch * channels * out_length];
75
76        for b in 0..batch {
77            for c in 0..channels {
78                for ol in 0..out_length {
79                    let in_start = ol * self.stride;
80                    let mut max_val = f32::NEG_INFINITY;
81                    let mut max_idx = 0;
82
83                    for k in 0..self.kernel_size {
84                        let il = in_start + k;
85                        if il >= self.padding && il < length + self.padding {
86                            let actual_il = il - self.padding;
87                            let idx = b * channels * length + c * length + actual_il;
88                            if input_vec[idx] > max_val {
89                                max_val = input_vec[idx];
90                                max_idx = idx;
91                            }
92                        }
93                    }
94
95                    let out_idx = b * channels * out_length + c * out_length + ol;
96                    output_data[out_idx] = max_val;
97                    max_indices[out_idx] = max_idx;
98                }
99            }
100        }
101
102        let output = Tensor::from_vec(output_data, &[batch, channels, out_length]).unwrap();
103
104        let requires_grad = input.requires_grad() && is_grad_enabled();
105        if requires_grad {
106            let grad_fn = GradFn::new(MaxPool1dBackward::new(
107                input.grad_fn().cloned(),
108                shape,
109                max_indices,
110            ));
111            Variable::from_operation(output, grad_fn, true)
112        } else {
113            Variable::new(output, false)
114        }
115    }
116
117    fn name(&self) -> &'static str {
118        "MaxPool1d"
119    }
120}
121
122// =============================================================================
123// MaxPool2d
124// =============================================================================
125
126/// Applies max pooling over a 2D signal (image).
127///
128/// # Shape
129/// - Input: (N, C, H, W)
130/// - Output: (N, C, H_out, W_out)
131pub struct MaxPool2d {
132    kernel_size: (usize, usize),
133    stride: (usize, usize),
134    padding: (usize, usize),
135}
136
137impl MaxPool2d {
138    /// Creates a new MaxPool2d layer with square kernel.
139    pub fn new(kernel_size: usize) -> Self {
140        Self {
141            kernel_size: (kernel_size, kernel_size),
142            stride: (kernel_size, kernel_size),
143            padding: (0, 0),
144        }
145    }
146
147    /// Creates a MaxPool2d with all options.
148    pub fn with_options(
149        kernel_size: (usize, usize),
150        stride: (usize, usize),
151        padding: (usize, usize),
152    ) -> Self {
153        Self {
154            kernel_size,
155            stride,
156            padding,
157        }
158    }
159}
160
161impl Module for MaxPool2d {
162    fn forward(&self, input: &Variable) -> Variable {
163        let shape = input.shape();
164        let batch = shape[0];
165        let channels = shape[1];
166        let height = shape[2];
167        let width = shape[3];
168
169        let (kh, kw) = self.kernel_size;
170        let (sh, sw) = self.stride;
171        let (ph, pw) = self.padding;
172
173        let out_h = (height + 2 * ph - kh) / sh + 1;
174        let out_w = (width + 2 * pw - kw) / sw + 1;
175
176        // Try GPU path
177        #[cfg(feature = "cuda")]
178        {
179            if let Some((gpu_output, gpu_indices)) =
180                input
181                    .data()
182                    .maxpool2d_cuda(self.kernel_size, self.stride, self.padding)
183            {
184                let max_indices: Vec<usize> = gpu_indices.iter().map(|&i| i as usize).collect();
185
186                let requires_grad = input.requires_grad() && is_grad_enabled();
187                if requires_grad {
188                    let grad_fn = GradFn::new(MaxPool2dBackward::new(
189                        input.grad_fn().cloned(),
190                        shape,
191                        max_indices,
192                        self.kernel_size,
193                        self.stride,
194                        self.padding,
195                    ));
196                    return Variable::from_operation(gpu_output, grad_fn, true);
197                } else {
198                    return Variable::new(gpu_output, false);
199                }
200            }
201        }
202
203        // CPU path
204        let input_vec = input.data().to_vec();
205        let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_h * out_w];
206        let mut max_indices = vec![0usize; batch * channels * out_h * out_w];
207
208        for b in 0..batch {
209            for c in 0..channels {
210                for oh in 0..out_h {
211                    for ow in 0..out_w {
212                        let mut max_val = f32::NEG_INFINITY;
213                        let mut max_idx = 0;
214
215                        for ki in 0..kh {
216                            for kj in 0..kw {
217                                let ih = oh * sh + ki;
218                                let iw = ow * sw + kj;
219
220                                if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
221                                    let actual_ih = ih - ph;
222                                    let actual_iw = iw - pw;
223                                    let idx = b * channels * height * width
224                                        + c * height * width
225                                        + actual_ih * width
226                                        + actual_iw;
227                                    if input_vec[idx] > max_val {
228                                        max_val = input_vec[idx];
229                                        max_idx = idx;
230                                    }
231                                }
232                            }
233                        }
234
235                        let out_idx =
236                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
237                        output_data[out_idx] = max_val;
238                        max_indices[out_idx] = max_idx;
239                    }
240                }
241            }
242        }
243
244        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
245
246        let requires_grad = input.requires_grad() && is_grad_enabled();
247        if requires_grad {
248            let grad_fn = GradFn::new(MaxPool2dBackward::new(
249                input.grad_fn().cloned(),
250                shape,
251                max_indices,
252                self.kernel_size,
253                self.stride,
254                self.padding,
255            ));
256            Variable::from_operation(output, grad_fn, true)
257        } else {
258            Variable::new(output, false)
259        }
260    }
261
262    fn name(&self) -> &'static str {
263        "MaxPool2d"
264    }
265}
266
267// =============================================================================
268// AvgPool1d
269// =============================================================================
270
271/// Applies average pooling over a 1D signal.
272pub struct AvgPool1d {
273    kernel_size: usize,
274    stride: usize,
275    padding: usize,
276}
277
278impl AvgPool1d {
279    /// Creates a new AvgPool1d layer.
280    pub fn new(kernel_size: usize) -> Self {
281        Self {
282            kernel_size,
283            stride: kernel_size,
284            padding: 0,
285        }
286    }
287
288    /// Creates an AvgPool1d with custom stride and padding.
289    pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
290        Self {
291            kernel_size,
292            stride,
293            padding,
294        }
295    }
296}
297
298impl Module for AvgPool1d {
299    fn forward(&self, input: &Variable) -> Variable {
300        let shape = input.shape();
301        let batch = shape[0];
302        let channels = shape[1];
303        let length = shape[2];
304
305        let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
306
307        let input_vec = input.data().to_vec();
308        let mut output_data = vec![0.0f32; batch * channels * out_length];
309
310        for b in 0..batch {
311            for c in 0..channels {
312                for ol in 0..out_length {
313                    let in_start = ol * self.stride;
314                    let mut sum = 0.0f32;
315                    let mut count = 0;
316
317                    for k in 0..self.kernel_size {
318                        let il = in_start + k;
319                        if il >= self.padding && il < length + self.padding {
320                            let actual_il = il - self.padding;
321                            let idx = b * channels * length + c * length + actual_il;
322                            sum += input_vec[idx];
323                            count += 1;
324                        }
325                    }
326
327                    let out_idx = b * channels * out_length + c * out_length + ol;
328                    output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
329                }
330            }
331        }
332
333        let output = Tensor::from_vec(output_data, &[batch, channels, out_length]).unwrap();
334
335        let requires_grad = input.requires_grad() && is_grad_enabled();
336        if requires_grad {
337            let grad_fn = GradFn::new(AvgPool1dBackward::new(
338                input.grad_fn().cloned(),
339                shape,
340                self.kernel_size,
341                self.stride,
342                self.padding,
343            ));
344            Variable::from_operation(output, grad_fn, true)
345        } else {
346            Variable::new(output, false)
347        }
348    }
349
350    fn name(&self) -> &'static str {
351        "AvgPool1d"
352    }
353}
354
355// =============================================================================
356// AvgPool2d
357// =============================================================================
358
359/// Applies average pooling over a 2D signal (image).
360pub struct AvgPool2d {
361    kernel_size: (usize, usize),
362    stride: (usize, usize),
363    padding: (usize, usize),
364}
365
366impl AvgPool2d {
367    /// Creates a new AvgPool2d layer with square kernel.
368    pub fn new(kernel_size: usize) -> Self {
369        Self {
370            kernel_size: (kernel_size, kernel_size),
371            stride: (kernel_size, kernel_size),
372            padding: (0, 0),
373        }
374    }
375
376    /// Creates an AvgPool2d with all options.
377    pub fn with_options(
378        kernel_size: (usize, usize),
379        stride: (usize, usize),
380        padding: (usize, usize),
381    ) -> Self {
382        Self {
383            kernel_size,
384            stride,
385            padding,
386        }
387    }
388}
389
390impl Module for AvgPool2d {
391    fn forward(&self, input: &Variable) -> Variable {
392        let shape = input.shape();
393        let batch = shape[0];
394        let channels = shape[1];
395        let height = shape[2];
396        let width = shape[3];
397
398        let (kh, kw) = self.kernel_size;
399        let (sh, sw) = self.stride;
400        let (ph, pw) = self.padding;
401
402        let out_h = (height + 2 * ph - kh) / sh + 1;
403        let out_w = (width + 2 * pw - kw) / sw + 1;
404
405        // Try GPU path
406        #[cfg(feature = "cuda")]
407        {
408            if let Some(gpu_output) = input.data().avgpool2d_cuda(
409                self.kernel_size,
410                self.stride,
411                self.padding,
412                false, // count_include_pad=false matches CPU behavior
413            ) {
414                let requires_grad = input.requires_grad() && is_grad_enabled();
415                if requires_grad {
416                    let grad_fn = GradFn::new(AvgPool2dBackward::new(
417                        input.grad_fn().cloned(),
418                        shape,
419                        self.kernel_size,
420                        self.stride,
421                        self.padding,
422                    ));
423                    return Variable::from_operation(gpu_output, grad_fn, true);
424                } else {
425                    return Variable::new(gpu_output, false);
426                }
427            }
428        }
429
430        // CPU path
431        let input_vec = input.data().to_vec();
432        let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
433
434        for b in 0..batch {
435            for c in 0..channels {
436                for oh in 0..out_h {
437                    for ow in 0..out_w {
438                        let mut sum = 0.0f32;
439                        let mut count = 0;
440
441                        for ki in 0..kh {
442                            for kj in 0..kw {
443                                let ih = oh * sh + ki;
444                                let iw = ow * sw + kj;
445
446                                if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
447                                    let actual_ih = ih - ph;
448                                    let actual_iw = iw - pw;
449                                    let idx = b * channels * height * width
450                                        + c * height * width
451                                        + actual_ih * width
452                                        + actual_iw;
453                                    sum += input_vec[idx];
454                                    count += 1;
455                                }
456                            }
457                        }
458
459                        let out_idx =
460                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
461                        output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
462                    }
463                }
464            }
465        }
466
467        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
468
469        let requires_grad = input.requires_grad() && is_grad_enabled();
470        if requires_grad {
471            let grad_fn = GradFn::new(AvgPool2dBackward::new(
472                input.grad_fn().cloned(),
473                shape,
474                self.kernel_size,
475                self.stride,
476                self.padding,
477            ));
478            Variable::from_operation(output, grad_fn, true)
479        } else {
480            Variable::new(output, false)
481        }
482    }
483
484    fn name(&self) -> &'static str {
485        "AvgPool2d"
486    }
487}
488
489// =============================================================================
490// AdaptiveAvgPool2d
491// =============================================================================
492
493/// Applies adaptive average pooling to produce specified output size.
494pub struct AdaptiveAvgPool2d {
495    output_size: (usize, usize),
496}
497
498impl AdaptiveAvgPool2d {
499    /// Creates a new AdaptiveAvgPool2d.
500    pub fn new(output_size: (usize, usize)) -> Self {
501        Self { output_size }
502    }
503
504    /// Creates an AdaptiveAvgPool2d with square output.
505    pub fn square(size: usize) -> Self {
506        Self {
507            output_size: (size, size),
508        }
509    }
510}
511
512impl Module for AdaptiveAvgPool2d {
513    fn forward(&self, input: &Variable) -> Variable {
514        let shape = input.shape();
515        let batch = shape[0];
516        let channels = shape[1];
517        let in_h = shape[2];
518        let in_w = shape[3];
519
520        let (out_h, out_w) = self.output_size;
521        let input_vec = input.data().to_vec();
522        let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
523
524        for b in 0..batch {
525            for c in 0..channels {
526                for oh in 0..out_h {
527                    for ow in 0..out_w {
528                        let ih_start = (oh * in_h) / out_h;
529                        let ih_end = ((oh + 1) * in_h) / out_h;
530                        let iw_start = (ow * in_w) / out_w;
531                        let iw_end = ((ow + 1) * in_w) / out_w;
532
533                        let mut sum = 0.0f32;
534                        let mut count = 0;
535
536                        for ih in ih_start..ih_end {
537                            for iw in iw_start..iw_end {
538                                let idx =
539                                    b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
540                                sum += input_vec[idx];
541                                count += 1;
542                            }
543                        }
544
545                        let out_idx =
546                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
547                        output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
548                    }
549                }
550            }
551        }
552
553        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
554
555        let requires_grad = input.requires_grad() && is_grad_enabled();
556        if requires_grad {
557            let grad_fn = GradFn::new(AdaptiveAvgPool2dBackward::new(
558                input.grad_fn().cloned(),
559                shape,
560                self.output_size,
561            ));
562            Variable::from_operation(output, grad_fn, true)
563        } else {
564            Variable::new(output, false)
565        }
566    }
567
568    fn name(&self) -> &'static str {
569        "AdaptiveAvgPool2d"
570    }
571}
572
573// =============================================================================
574// Tests
575// =============================================================================
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn test_maxpool2d() {
583        let pool = MaxPool2d::new(2);
584        let input = Variable::new(
585            Tensor::from_vec(
586                vec![
587                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
588                    15.0, 16.0,
589                ],
590                &[1, 1, 4, 4],
591            )
592            .unwrap(),
593            false,
594        );
595        let output = pool.forward(&input);
596        assert_eq!(output.shape(), vec![1, 1, 2, 2]);
597        assert_eq!(output.data().to_vec(), vec![6.0, 8.0, 14.0, 16.0]);
598    }
599
600    #[test]
601    fn test_maxpool2d_backward() {
602        let pool = MaxPool2d::new(2);
603        let input = Variable::new(
604            Tensor::from_vec(
605                vec![
606                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
607                    15.0, 16.0,
608                ],
609                &[1, 1, 4, 4],
610            )
611            .unwrap(),
612            true,
613        );
614        let output = pool.forward(&input);
615        let loss = output.sum();
616        loss.backward();
617
618        assert!(input.grad().is_some(), "MaxPool2d: gradient should flow");
619        let grad = input.grad().unwrap();
620        assert_eq!(grad.shape(), &[1, 1, 4, 4]);
621        let grad_vec = grad.to_vec();
622        // Only max positions (6,8,14,16) at indices [5,7,13,15] should have gradient
623        assert_eq!(grad_vec[5], 1.0);
624        assert_eq!(grad_vec[7], 1.0);
625        assert_eq!(grad_vec[13], 1.0);
626        assert_eq!(grad_vec[15], 1.0);
627        assert_eq!(grad_vec[0], 0.0);
628    }
629
630    #[test]
631    fn test_avgpool2d() {
632        let pool = AvgPool2d::new(2);
633        let input = Variable::new(
634            Tensor::from_vec(
635                vec![
636                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
637                    15.0, 16.0,
638                ],
639                &[1, 1, 4, 4],
640            )
641            .unwrap(),
642            false,
643        );
644        let output = pool.forward(&input);
645        assert_eq!(output.shape(), vec![1, 1, 2, 2]);
646        assert_eq!(output.data().to_vec(), vec![3.5, 5.5, 11.5, 13.5]);
647    }
648
649    #[test]
650    fn test_avgpool2d_backward() {
651        let pool = AvgPool2d::new(2);
652        let input = Variable::new(
653            Tensor::from_vec(vec![1.0; 16], &[1, 1, 4, 4]).unwrap(),
654            true,
655        );
656        let output = pool.forward(&input);
657        let loss = output.sum();
658        loss.backward();
659
660        assert!(input.grad().is_some(), "AvgPool2d: gradient should flow");
661        let grad = input.grad().unwrap();
662        // Each input element contributes to exactly one pool window, gets 1/4 of the gradient
663        for &v in &grad.to_vec() {
664            assert!((v - 0.25).abs() < 1e-6);
665        }
666    }
667
668    #[test]
669    fn test_adaptive_avgpool2d() {
670        let pool = AdaptiveAvgPool2d::new((1, 1));
671        let input = Variable::new(
672            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap(),
673            false,
674        );
675        let output = pool.forward(&input);
676        assert_eq!(output.shape(), vec![1, 1, 1, 1]);
677        assert_eq!(output.data().to_vec(), vec![2.5]);
678    }
679
680    #[test]
681    fn test_adaptive_avgpool2d_backward() {
682        let pool = AdaptiveAvgPool2d::new((1, 1));
683        let input = Variable::new(
684            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap(),
685            true,
686        );
687        let output = pool.forward(&input);
688        let loss = output.sum();
689        loss.backward();
690
691        assert!(
692            input.grad().is_some(),
693            "AdaptiveAvgPool2d: gradient should flow"
694        );
695        let grad = input.grad().unwrap();
696        for &v in &grad.to_vec() {
697            assert!((v - 0.25).abs() < 1e-6);
698        }
699    }
700}