Skip to main content

axonml_nn/layers/
pooling.rs

1//! Pooling Layers - Max and Average Pooling
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/pooling.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_autograd::Variable;
18use axonml_autograd::functions::{
19    AdaptiveAvgPool2dBackward, AvgPool1dBackward, AvgPool2dBackward, MaxPool1dBackward,
20    MaxPool2dBackward,
21};
22use axonml_autograd::grad_fn::GradFn;
23use axonml_autograd::no_grad::is_grad_enabled;
24use axonml_tensor::Tensor;
25
26use crate::module::Module;
27
28// =============================================================================
29// MaxPool1d
30// =============================================================================
31
32/// Applies max pooling over a 1D signal.
33///
34/// # Shape
35/// - Input: (N, C, L)
36/// - Output: (N, C, L_out)
37pub struct MaxPool1d {
38    kernel_size: usize,
39    stride: usize,
40    padding: usize,
41}
42
43impl MaxPool1d {
44    /// Creates a new MaxPool1d layer.
45    pub fn new(kernel_size: usize) -> Self {
46        Self {
47            kernel_size,
48            stride: kernel_size,
49            padding: 0,
50        }
51    }
52
53    /// Creates a MaxPool1d with custom stride and padding.
54    pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
55        Self {
56            kernel_size,
57            stride,
58            padding,
59        }
60    }
61}
62
63impl Module for MaxPool1d {
64    fn forward(&self, input: &Variable) -> Variable {
65        let shape = input.shape();
66        let batch = shape[0];
67        let channels = shape[1];
68        let length = shape[2];
69
70        let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
71
72        let input_vec = input.data().to_vec();
73        let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_length];
74        let mut max_indices = vec![0usize; batch * channels * out_length];
75
76        for b in 0..batch {
77            for c in 0..channels {
78                for ol in 0..out_length {
79                    let in_start = ol * self.stride;
80                    let mut max_val = f32::NEG_INFINITY;
81                    let mut max_idx = 0;
82
83                    for k in 0..self.kernel_size {
84                        let il = in_start + k;
85                        if il >= self.padding && il < length + self.padding {
86                            let actual_il = il - self.padding;
87                            let idx = b * channels * length + c * length + actual_il;
88                            if input_vec[idx] > max_val {
89                                max_val = input_vec[idx];
90                                max_idx = idx;
91                            }
92                        }
93                    }
94
95                    let out_idx = b * channels * out_length + c * out_length + ol;
96                    output_data[out_idx] = max_val;
97                    max_indices[out_idx] = max_idx;
98                }
99            }
100        }
101
102        let output = Tensor::from_vec(output_data, &[batch, channels, out_length])
103            .expect("tensor creation failed");
104
105        let requires_grad = input.requires_grad() && is_grad_enabled();
106        if requires_grad {
107            let grad_fn = GradFn::new(MaxPool1dBackward::new(
108                input.grad_fn().cloned(),
109                shape,
110                max_indices,
111            ));
112            Variable::from_operation(output, grad_fn, true)
113        } else {
114            Variable::new(output, false)
115        }
116    }
117
118    fn name(&self) -> &'static str {
119        "MaxPool1d"
120    }
121}
122
123// =============================================================================
124// MaxPool2d
125// =============================================================================
126
127/// Applies max pooling over a 2D signal (image).
128///
129/// # Shape
130/// - Input: (N, C, H, W)
131/// - Output: (N, C, H_out, W_out)
132pub struct MaxPool2d {
133    kernel_size: (usize, usize),
134    stride: (usize, usize),
135    padding: (usize, usize),
136}
137
138impl MaxPool2d {
139    /// Creates a new MaxPool2d layer with square kernel.
140    pub fn new(kernel_size: usize) -> Self {
141        Self {
142            kernel_size: (kernel_size, kernel_size),
143            stride: (kernel_size, kernel_size),
144            padding: (0, 0),
145        }
146    }
147
148    /// Creates a MaxPool2d with all options.
149    pub fn with_options(
150        kernel_size: (usize, usize),
151        stride: (usize, usize),
152        padding: (usize, usize),
153    ) -> Self {
154        Self {
155            kernel_size,
156            stride,
157            padding,
158        }
159    }
160}
161
162impl Module for MaxPool2d {
163    fn forward(&self, input: &Variable) -> Variable {
164        let shape = input.shape();
165        let batch = shape[0];
166        let channels = shape[1];
167        let height = shape[2];
168        let width = shape[3];
169
170        let (kh, kw) = self.kernel_size;
171        let (sh, sw) = self.stride;
172        let (ph, pw) = self.padding;
173
174        let out_h = (height + 2 * ph - kh) / sh + 1;
175        let out_w = (width + 2 * pw - kw) / sw + 1;
176
177        // Try GPU path
178        #[cfg(feature = "cuda")]
179        {
180            if let Some((gpu_output, gpu_indices)) =
181                input
182                    .data()
183                    .maxpool2d_cuda(self.kernel_size, self.stride, self.padding)
184            {
185                let max_indices: Vec<usize> = gpu_indices.iter().map(|&i| i as usize).collect();
186
187                let requires_grad = input.requires_grad() && is_grad_enabled();
188                if requires_grad {
189                    let grad_fn = GradFn::new(MaxPool2dBackward::new(
190                        input.grad_fn().cloned(),
191                        shape,
192                        max_indices,
193                        self.kernel_size,
194                        self.stride,
195                        self.padding,
196                    ));
197                    return Variable::from_operation(gpu_output, grad_fn, true);
198                } else {
199                    return Variable::new(gpu_output, false);
200                }
201            }
202        }
203
204        // CPU path
205        let input_vec = input.data().to_vec();
206        let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_h * out_w];
207        let mut max_indices = vec![0usize; batch * channels * out_h * out_w];
208
209        for b in 0..batch {
210            for c in 0..channels {
211                for oh in 0..out_h {
212                    for ow in 0..out_w {
213                        let mut max_val = f32::NEG_INFINITY;
214                        let mut max_idx = 0;
215
216                        for ki in 0..kh {
217                            for kj in 0..kw {
218                                let ih = oh * sh + ki;
219                                let iw = ow * sw + kj;
220
221                                if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
222                                    let actual_ih = ih - ph;
223                                    let actual_iw = iw - pw;
224                                    let idx = b * channels * height * width
225                                        + c * height * width
226                                        + actual_ih * width
227                                        + actual_iw;
228                                    if input_vec[idx] > max_val {
229                                        max_val = input_vec[idx];
230                                        max_idx = idx;
231                                    }
232                                }
233                            }
234                        }
235
236                        let out_idx =
237                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
238                        output_data[out_idx] = max_val;
239                        max_indices[out_idx] = max_idx;
240                    }
241                }
242            }
243        }
244
245        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
246            .expect("tensor creation failed");
247
248        let requires_grad = input.requires_grad() && is_grad_enabled();
249        if requires_grad {
250            let grad_fn = GradFn::new(MaxPool2dBackward::new(
251                input.grad_fn().cloned(),
252                shape,
253                max_indices,
254                self.kernel_size,
255                self.stride,
256                self.padding,
257            ));
258            Variable::from_operation(output, grad_fn, true)
259        } else {
260            Variable::new(output, false)
261        }
262    }
263
264    fn name(&self) -> &'static str {
265        "MaxPool2d"
266    }
267}
268
269// =============================================================================
270// AvgPool1d
271// =============================================================================
272
273/// Applies average pooling over a 1D signal.
274pub struct AvgPool1d {
275    kernel_size: usize,
276    stride: usize,
277    padding: usize,
278}
279
280impl AvgPool1d {
281    /// Creates a new AvgPool1d layer.
282    pub fn new(kernel_size: usize) -> Self {
283        Self {
284            kernel_size,
285            stride: kernel_size,
286            padding: 0,
287        }
288    }
289
290    /// Creates an AvgPool1d with custom stride and padding.
291    pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
292        Self {
293            kernel_size,
294            stride,
295            padding,
296        }
297    }
298}
299
300impl Module for AvgPool1d {
301    fn forward(&self, input: &Variable) -> Variable {
302        let shape = input.shape();
303        let batch = shape[0];
304        let channels = shape[1];
305        let length = shape[2];
306
307        let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
308
309        let input_vec = input.data().to_vec();
310        let mut output_data = vec![0.0f32; batch * channels * out_length];
311
312        for b in 0..batch {
313            for c in 0..channels {
314                for ol in 0..out_length {
315                    let in_start = ol * self.stride;
316                    let mut sum = 0.0f32;
317                    let mut count = 0;
318
319                    for k in 0..self.kernel_size {
320                        let il = in_start + k;
321                        if il >= self.padding && il < length + self.padding {
322                            let actual_il = il - self.padding;
323                            let idx = b * channels * length + c * length + actual_il;
324                            sum += input_vec[idx];
325                            count += 1;
326                        }
327                    }
328
329                    let out_idx = b * channels * out_length + c * out_length + ol;
330                    output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
331                }
332            }
333        }
334
335        let output = Tensor::from_vec(output_data, &[batch, channels, out_length])
336            .expect("tensor creation failed");
337
338        let requires_grad = input.requires_grad() && is_grad_enabled();
339        if requires_grad {
340            let grad_fn = GradFn::new(AvgPool1dBackward::new(
341                input.grad_fn().cloned(),
342                shape,
343                self.kernel_size,
344                self.stride,
345                self.padding,
346            ));
347            Variable::from_operation(output, grad_fn, true)
348        } else {
349            Variable::new(output, false)
350        }
351    }
352
353    fn name(&self) -> &'static str {
354        "AvgPool1d"
355    }
356}
357
358// =============================================================================
359// AvgPool2d
360// =============================================================================
361
362/// Applies average pooling over a 2D signal (image).
363pub struct AvgPool2d {
364    kernel_size: (usize, usize),
365    stride: (usize, usize),
366    padding: (usize, usize),
367}
368
369impl AvgPool2d {
370    /// Creates a new AvgPool2d layer with square kernel.
371    pub fn new(kernel_size: usize) -> Self {
372        Self {
373            kernel_size: (kernel_size, kernel_size),
374            stride: (kernel_size, kernel_size),
375            padding: (0, 0),
376        }
377    }
378
379    /// Creates an AvgPool2d with all options.
380    pub fn with_options(
381        kernel_size: (usize, usize),
382        stride: (usize, usize),
383        padding: (usize, usize),
384    ) -> Self {
385        Self {
386            kernel_size,
387            stride,
388            padding,
389        }
390    }
391}
392
393impl Module for AvgPool2d {
394    fn forward(&self, input: &Variable) -> Variable {
395        let shape = input.shape();
396        let batch = shape[0];
397        let channels = shape[1];
398        let height = shape[2];
399        let width = shape[3];
400
401        let (kh, kw) = self.kernel_size;
402        let (sh, sw) = self.stride;
403        let (ph, pw) = self.padding;
404
405        let out_h = (height + 2 * ph - kh) / sh + 1;
406        let out_w = (width + 2 * pw - kw) / sw + 1;
407
408        // Try GPU path
409        #[cfg(feature = "cuda")]
410        {
411            if let Some(gpu_output) = input.data().avgpool2d_cuda(
412                self.kernel_size,
413                self.stride,
414                self.padding,
415                false, // count_include_pad=false matches CPU behavior
416            ) {
417                let requires_grad = input.requires_grad() && is_grad_enabled();
418                if requires_grad {
419                    let grad_fn = GradFn::new(AvgPool2dBackward::new(
420                        input.grad_fn().cloned(),
421                        shape,
422                        self.kernel_size,
423                        self.stride,
424                        self.padding,
425                    ));
426                    return Variable::from_operation(gpu_output, grad_fn, true);
427                } else {
428                    return Variable::new(gpu_output, false);
429                }
430            }
431        }
432
433        // CPU path
434        let input_vec = input.data().to_vec();
435        let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
436
437        for b in 0..batch {
438            for c in 0..channels {
439                for oh in 0..out_h {
440                    for ow in 0..out_w {
441                        let mut sum = 0.0f32;
442                        let mut count = 0;
443
444                        for ki in 0..kh {
445                            for kj in 0..kw {
446                                let ih = oh * sh + ki;
447                                let iw = ow * sw + kj;
448
449                                if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
450                                    let actual_ih = ih - ph;
451                                    let actual_iw = iw - pw;
452                                    let idx = b * channels * height * width
453                                        + c * height * width
454                                        + actual_ih * width
455                                        + actual_iw;
456                                    sum += input_vec[idx];
457                                    count += 1;
458                                }
459                            }
460                        }
461
462                        let out_idx =
463                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
464                        output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
465                    }
466                }
467            }
468        }
469
470        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
471            .expect("tensor creation failed");
472
473        let requires_grad = input.requires_grad() && is_grad_enabled();
474        if requires_grad {
475            let grad_fn = GradFn::new(AvgPool2dBackward::new(
476                input.grad_fn().cloned(),
477                shape,
478                self.kernel_size,
479                self.stride,
480                self.padding,
481            ));
482            Variable::from_operation(output, grad_fn, true)
483        } else {
484            Variable::new(output, false)
485        }
486    }
487
488    fn name(&self) -> &'static str {
489        "AvgPool2d"
490    }
491}
492
493// =============================================================================
494// AdaptiveAvgPool2d
495// =============================================================================
496
497/// Applies adaptive average pooling to produce specified output size.
498pub struct AdaptiveAvgPool2d {
499    output_size: (usize, usize),
500}
501
502impl AdaptiveAvgPool2d {
503    /// Creates a new AdaptiveAvgPool2d.
504    pub fn new(output_size: (usize, usize)) -> Self {
505        Self { output_size }
506    }
507
508    /// Creates an AdaptiveAvgPool2d with square output.
509    pub fn square(size: usize) -> Self {
510        Self {
511            output_size: (size, size),
512        }
513    }
514}
515
516impl Module for AdaptiveAvgPool2d {
517    fn forward(&self, input: &Variable) -> Variable {
518        let shape = input.shape();
519        let batch = shape[0];
520        let channels = shape[1];
521        let in_h = shape[2];
522        let in_w = shape[3];
523
524        let (out_h, out_w) = self.output_size;
525        let input_vec = input.data().to_vec();
526        let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
527
528        for b in 0..batch {
529            for c in 0..channels {
530                for oh in 0..out_h {
531                    for ow in 0..out_w {
532                        let ih_start = (oh * in_h) / out_h;
533                        let ih_end = ((oh + 1) * in_h) / out_h;
534                        let iw_start = (ow * in_w) / out_w;
535                        let iw_end = ((ow + 1) * in_w) / out_w;
536
537                        let mut sum = 0.0f32;
538                        let mut count = 0;
539
540                        for ih in ih_start..ih_end {
541                            for iw in iw_start..iw_end {
542                                let idx =
543                                    b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
544                                sum += input_vec[idx];
545                                count += 1;
546                            }
547                        }
548
549                        let out_idx =
550                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
551                        output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
552                    }
553                }
554            }
555        }
556
557        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
558            .expect("tensor creation failed");
559
560        let requires_grad = input.requires_grad() && is_grad_enabled();
561        if requires_grad {
562            let grad_fn = GradFn::new(AdaptiveAvgPool2dBackward::new(
563                input.grad_fn().cloned(),
564                shape,
565                self.output_size,
566            ));
567            Variable::from_operation(output, grad_fn, true)
568        } else {
569            Variable::new(output, false)
570        }
571    }
572
573    fn name(&self) -> &'static str {
574        "AdaptiveAvgPool2d"
575    }
576}
577
578// =============================================================================
579// Tests
580// =============================================================================
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585
586    #[test]
587    fn test_maxpool2d() {
588        let pool = MaxPool2d::new(2);
589        let input = Variable::new(
590            Tensor::from_vec(
591                vec![
592                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
593                    15.0, 16.0,
594                ],
595                &[1, 1, 4, 4],
596            )
597            .unwrap(),
598            false,
599        );
600        let output = pool.forward(&input);
601        assert_eq!(output.shape(), vec![1, 1, 2, 2]);
602        assert_eq!(output.data().to_vec(), vec![6.0, 8.0, 14.0, 16.0]);
603    }
604
605    #[test]
606    fn test_maxpool2d_backward() {
607        let pool = MaxPool2d::new(2);
608        let input = Variable::new(
609            Tensor::from_vec(
610                vec![
611                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
612                    15.0, 16.0,
613                ],
614                &[1, 1, 4, 4],
615            )
616            .unwrap(),
617            true,
618        );
619        let output = pool.forward(&input);
620        let loss = output.sum();
621        loss.backward();
622
623        assert!(input.grad().is_some(), "MaxPool2d: gradient should flow");
624        let grad = input.grad().unwrap();
625        assert_eq!(grad.shape(), &[1, 1, 4, 4]);
626        let grad_vec = grad.to_vec();
627        // Only max positions (6,8,14,16) at indices [5,7,13,15] should have gradient
628        assert_eq!(grad_vec[5], 1.0);
629        assert_eq!(grad_vec[7], 1.0);
630        assert_eq!(grad_vec[13], 1.0);
631        assert_eq!(grad_vec[15], 1.0);
632        assert_eq!(grad_vec[0], 0.0);
633    }
634
635    #[test]
636    fn test_avgpool2d() {
637        let pool = AvgPool2d::new(2);
638        let input = Variable::new(
639            Tensor::from_vec(
640                vec![
641                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
642                    15.0, 16.0,
643                ],
644                &[1, 1, 4, 4],
645            )
646            .unwrap(),
647            false,
648        );
649        let output = pool.forward(&input);
650        assert_eq!(output.shape(), vec![1, 1, 2, 2]);
651        assert_eq!(output.data().to_vec(), vec![3.5, 5.5, 11.5, 13.5]);
652    }
653
654    #[test]
655    fn test_avgpool2d_backward() {
656        let pool = AvgPool2d::new(2);
657        let input = Variable::new(
658            Tensor::from_vec(vec![1.0; 16], &[1, 1, 4, 4]).expect("tensor creation failed"),
659            true,
660        );
661        let output = pool.forward(&input);
662        let loss = output.sum();
663        loss.backward();
664
665        assert!(input.grad().is_some(), "AvgPool2d: gradient should flow");
666        let grad = input.grad().unwrap();
667        // Each input element contributes to exactly one pool window, gets 1/4 of the gradient
668        for &v in &grad.to_vec() {
669            assert!((v - 0.25).abs() < 1e-6);
670        }
671    }
672
673    #[test]
674    fn test_adaptive_avgpool2d() {
675        let pool = AdaptiveAvgPool2d::new((1, 1));
676        let input = Variable::new(
677            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2])
678                .expect("tensor creation failed"),
679            false,
680        );
681        let output = pool.forward(&input);
682        assert_eq!(output.shape(), vec![1, 1, 1, 1]);
683        assert_eq!(output.data().to_vec(), vec![2.5]);
684    }
685
686    #[test]
687    fn test_adaptive_avgpool2d_backward() {
688        let pool = AdaptiveAvgPool2d::new((1, 1));
689        let input = Variable::new(
690            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2])
691                .expect("tensor creation failed"),
692            true,
693        );
694        let output = pool.forward(&input);
695        let loss = output.sum();
696        loss.backward();
697
698        assert!(
699            input.grad().is_some(),
700            "AdaptiveAvgPool2d: gradient should flow"
701        );
702        let grad = input.grad().unwrap();
703        for &v in &grad.to_vec() {
704            assert!((v - 0.25).abs() < 1e-6);
705        }
706    }
707}