Skip to main content

axonml_nn/layers/
pooling.rs

1//! Pooling layers — Max, Average, and AdaptiveAvgPool for 1D and 2D.
2//!
3//! 707 lines. `MaxPool1d` / `MaxPool2d` (strided window max with padding),
4//! `AvgPool1d` / `AvgPool2d` (strided window mean), `AdaptiveAvgPool2d`
5//! (output-size-driven pooling that auto-computes kernel/stride). All
6//! implement `Module` (stateless — no parameters, no train/eval difference).
7//!
8//! # File
9//! `crates/axonml-nn/src/layers/pooling.rs`
10//!
11//! # Author
12//! Andrew Jewell Sr. — AutomataNexus LLC
13//! ORCID: 0009-0005-2158-7060
14//!
15//! # Updated
16//! April 14, 2026 11:15 PM EST
17//!
18//! # Disclaimer
19//! Use at own risk. This software is provided "as is", without warranty of any
20//! kind, express or implied. The author and AutomataNexus shall not be held
21//! liable for any damages arising from the use of this software.
22
23use axonml_autograd::Variable;
24use axonml_autograd::functions::{
25    AdaptiveAvgPool2dBackward, AvgPool1dBackward, AvgPool2dBackward, MaxPool1dBackward,
26    MaxPool2dBackward,
27};
28use axonml_autograd::grad_fn::GradFn;
29use axonml_autograd::no_grad::is_grad_enabled;
30use axonml_tensor::Tensor;
31
32use crate::module::Module;
33
34// =============================================================================
35// MaxPool1d
36// =============================================================================
37
38/// Applies max pooling over a 1D signal.
39///
40/// # Shape
41/// - Input: (N, C, L)
42/// - Output: (N, C, L_out)
43pub struct MaxPool1d {
44    kernel_size: usize,
45    stride: usize,
46    padding: usize,
47}
48
49impl MaxPool1d {
50    /// Creates a new MaxPool1d layer.
51    pub fn new(kernel_size: usize) -> Self {
52        Self {
53            kernel_size,
54            stride: kernel_size,
55            padding: 0,
56        }
57    }
58
59    /// Creates a MaxPool1d with custom stride and padding.
60    pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
61        Self {
62            kernel_size,
63            stride,
64            padding,
65        }
66    }
67}
68
69impl Module for MaxPool1d {
70    fn forward(&self, input: &Variable) -> Variable {
71        let shape = input.shape();
72        let batch = shape[0];
73        let channels = shape[1];
74        let length = shape[2];
75
76        let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
77
78        let input_vec = input.data().to_vec();
79        let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_length];
80        let mut max_indices = vec![0usize; batch * channels * out_length];
81
82        for b in 0..batch {
83            for c in 0..channels {
84                for ol in 0..out_length {
85                    let in_start = ol * self.stride;
86                    let mut max_val = f32::NEG_INFINITY;
87                    let mut max_idx = 0;
88
89                    for k in 0..self.kernel_size {
90                        let il = in_start + k;
91                        if il >= self.padding && il < length + self.padding {
92                            let actual_il = il - self.padding;
93                            let idx = b * channels * length + c * length + actual_il;
94                            if input_vec[idx] > max_val {
95                                max_val = input_vec[idx];
96                                max_idx = idx;
97                            }
98                        }
99                    }
100
101                    let out_idx = b * channels * out_length + c * out_length + ol;
102                    output_data[out_idx] = max_val;
103                    max_indices[out_idx] = max_idx;
104                }
105            }
106        }
107
108        let output = Tensor::from_vec(output_data, &[batch, channels, out_length])
109            .expect("tensor creation failed");
110
111        let requires_grad = input.requires_grad() && is_grad_enabled();
112        if requires_grad {
113            let grad_fn = GradFn::new(MaxPool1dBackward::new(
114                input.grad_fn().cloned(),
115                shape,
116                max_indices,
117            ));
118            Variable::from_operation(output, grad_fn, true)
119        } else {
120            Variable::new(output, false)
121        }
122    }
123
124    fn name(&self) -> &'static str {
125        "MaxPool1d"
126    }
127}
128
129// =============================================================================
130// MaxPool2d
131// =============================================================================
132
133/// Applies max pooling over a 2D signal (image).
134///
135/// # Shape
136/// - Input: (N, C, H, W)
137/// - Output: (N, C, H_out, W_out)
138pub struct MaxPool2d {
139    kernel_size: (usize, usize),
140    stride: (usize, usize),
141    padding: (usize, usize),
142}
143
144impl MaxPool2d {
145    /// Creates a new MaxPool2d layer with square kernel.
146    pub fn new(kernel_size: usize) -> Self {
147        Self {
148            kernel_size: (kernel_size, kernel_size),
149            stride: (kernel_size, kernel_size),
150            padding: (0, 0),
151        }
152    }
153
154    /// Creates a MaxPool2d with all options.
155    pub fn with_options(
156        kernel_size: (usize, usize),
157        stride: (usize, usize),
158        padding: (usize, usize),
159    ) -> Self {
160        Self {
161            kernel_size,
162            stride,
163            padding,
164        }
165    }
166}
167
168impl Module for MaxPool2d {
169    fn forward(&self, input: &Variable) -> Variable {
170        let shape = input.shape();
171        let batch = shape[0];
172        let channels = shape[1];
173        let height = shape[2];
174        let width = shape[3];
175
176        let (kh, kw) = self.kernel_size;
177        let (sh, sw) = self.stride;
178        let (ph, pw) = self.padding;
179
180        let out_h = (height + 2 * ph - kh) / sh + 1;
181        let out_w = (width + 2 * pw - kw) / sw + 1;
182
183        // Try GPU path
184        #[cfg(feature = "cuda")]
185        {
186            if let Some((gpu_output, gpu_indices)) =
187                input
188                    .data()
189                    .maxpool2d_cuda(self.kernel_size, self.stride, self.padding)
190            {
191                let max_indices: Vec<usize> = gpu_indices.iter().map(|&i| i as usize).collect();
192
193                let requires_grad = input.requires_grad() && is_grad_enabled();
194                if requires_grad {
195                    let grad_fn = GradFn::new(MaxPool2dBackward::new(
196                        input.grad_fn().cloned(),
197                        shape,
198                        max_indices,
199                        self.kernel_size,
200                        self.stride,
201                        self.padding,
202                    ));
203                    return Variable::from_operation(gpu_output, grad_fn, true);
204                } else {
205                    return Variable::new(gpu_output, false);
206                }
207            }
208        }
209
210        // CPU path
211        let input_vec = input.data().to_vec();
212        let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_h * out_w];
213        let mut max_indices = vec![0usize; batch * channels * out_h * out_w];
214
215        for b in 0..batch {
216            for c in 0..channels {
217                for oh in 0..out_h {
218                    for ow in 0..out_w {
219                        let mut max_val = f32::NEG_INFINITY;
220                        let mut max_idx = 0;
221
222                        for ki in 0..kh {
223                            for kj in 0..kw {
224                                let ih = oh * sh + ki;
225                                let iw = ow * sw + kj;
226
227                                if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
228                                    let actual_ih = ih - ph;
229                                    let actual_iw = iw - pw;
230                                    let idx = b * channels * height * width
231                                        + c * height * width
232                                        + actual_ih * width
233                                        + actual_iw;
234                                    if input_vec[idx] > max_val {
235                                        max_val = input_vec[idx];
236                                        max_idx = idx;
237                                    }
238                                }
239                            }
240                        }
241
242                        let out_idx =
243                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
244                        output_data[out_idx] = max_val;
245                        max_indices[out_idx] = max_idx;
246                    }
247                }
248            }
249        }
250
251        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
252            .expect("tensor creation failed");
253
254        let requires_grad = input.requires_grad() && is_grad_enabled();
255        if requires_grad {
256            let grad_fn = GradFn::new(MaxPool2dBackward::new(
257                input.grad_fn().cloned(),
258                shape,
259                max_indices,
260                self.kernel_size,
261                self.stride,
262                self.padding,
263            ));
264            Variable::from_operation(output, grad_fn, true)
265        } else {
266            Variable::new(output, false)
267        }
268    }
269
270    fn name(&self) -> &'static str {
271        "MaxPool2d"
272    }
273}
274
275// =============================================================================
276// AvgPool1d
277// =============================================================================
278
279/// Applies average pooling over a 1D signal.
280pub struct AvgPool1d {
281    kernel_size: usize,
282    stride: usize,
283    padding: usize,
284}
285
286impl AvgPool1d {
287    /// Creates a new AvgPool1d layer.
288    pub fn new(kernel_size: usize) -> Self {
289        Self {
290            kernel_size,
291            stride: kernel_size,
292            padding: 0,
293        }
294    }
295
296    /// Creates an AvgPool1d with custom stride and padding.
297    pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
298        Self {
299            kernel_size,
300            stride,
301            padding,
302        }
303    }
304}
305
306impl Module for AvgPool1d {
307    fn forward(&self, input: &Variable) -> Variable {
308        let shape = input.shape();
309        let batch = shape[0];
310        let channels = shape[1];
311        let length = shape[2];
312
313        let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
314
315        let input_vec = input.data().to_vec();
316        let mut output_data = vec![0.0f32; batch * channels * out_length];
317
318        for b in 0..batch {
319            for c in 0..channels {
320                for ol in 0..out_length {
321                    let in_start = ol * self.stride;
322                    let mut sum = 0.0f32;
323                    let mut count = 0;
324
325                    for k in 0..self.kernel_size {
326                        let il = in_start + k;
327                        if il >= self.padding && il < length + self.padding {
328                            let actual_il = il - self.padding;
329                            let idx = b * channels * length + c * length + actual_il;
330                            sum += input_vec[idx];
331                            count += 1;
332                        }
333                    }
334
335                    let out_idx = b * channels * out_length + c * out_length + ol;
336                    output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
337                }
338            }
339        }
340
341        let output = Tensor::from_vec(output_data, &[batch, channels, out_length])
342            .expect("tensor creation failed");
343
344        let requires_grad = input.requires_grad() && is_grad_enabled();
345        if requires_grad {
346            let grad_fn = GradFn::new(AvgPool1dBackward::new(
347                input.grad_fn().cloned(),
348                shape,
349                self.kernel_size,
350                self.stride,
351                self.padding,
352            ));
353            Variable::from_operation(output, grad_fn, true)
354        } else {
355            Variable::new(output, false)
356        }
357    }
358
359    fn name(&self) -> &'static str {
360        "AvgPool1d"
361    }
362}
363
364// =============================================================================
365// AvgPool2d
366// =============================================================================
367
368/// Applies average pooling over a 2D signal (image).
369pub struct AvgPool2d {
370    kernel_size: (usize, usize),
371    stride: (usize, usize),
372    padding: (usize, usize),
373}
374
375impl AvgPool2d {
376    /// Creates a new AvgPool2d layer with square kernel.
377    pub fn new(kernel_size: usize) -> Self {
378        Self {
379            kernel_size: (kernel_size, kernel_size),
380            stride: (kernel_size, kernel_size),
381            padding: (0, 0),
382        }
383    }
384
385    /// Creates an AvgPool2d with all options.
386    pub fn with_options(
387        kernel_size: (usize, usize),
388        stride: (usize, usize),
389        padding: (usize, usize),
390    ) -> Self {
391        Self {
392            kernel_size,
393            stride,
394            padding,
395        }
396    }
397}
398
399impl Module for AvgPool2d {
400    fn forward(&self, input: &Variable) -> Variable {
401        let shape = input.shape();
402        let batch = shape[0];
403        let channels = shape[1];
404        let height = shape[2];
405        let width = shape[3];
406
407        let (kh, kw) = self.kernel_size;
408        let (sh, sw) = self.stride;
409        let (ph, pw) = self.padding;
410
411        let out_h = (height + 2 * ph - kh) / sh + 1;
412        let out_w = (width + 2 * pw - kw) / sw + 1;
413
414        // Try GPU path
415        #[cfg(feature = "cuda")]
416        {
417            if let Some(gpu_output) = input.data().avgpool2d_cuda(
418                self.kernel_size,
419                self.stride,
420                self.padding,
421                false, // count_include_pad=false matches CPU behavior
422            ) {
423                let requires_grad = input.requires_grad() && is_grad_enabled();
424                if requires_grad {
425                    let grad_fn = GradFn::new(AvgPool2dBackward::new(
426                        input.grad_fn().cloned(),
427                        shape,
428                        self.kernel_size,
429                        self.stride,
430                        self.padding,
431                    ));
432                    return Variable::from_operation(gpu_output, grad_fn, true);
433                } else {
434                    return Variable::new(gpu_output, false);
435                }
436            }
437        }
438
439        // CPU path
440        let input_vec = input.data().to_vec();
441        let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
442
443        for b in 0..batch {
444            for c in 0..channels {
445                for oh in 0..out_h {
446                    for ow in 0..out_w {
447                        let mut sum = 0.0f32;
448                        let mut count = 0;
449
450                        for ki in 0..kh {
451                            for kj in 0..kw {
452                                let ih = oh * sh + ki;
453                                let iw = ow * sw + kj;
454
455                                if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
456                                    let actual_ih = ih - ph;
457                                    let actual_iw = iw - pw;
458                                    let idx = b * channels * height * width
459                                        + c * height * width
460                                        + actual_ih * width
461                                        + actual_iw;
462                                    sum += input_vec[idx];
463                                    count += 1;
464                                }
465                            }
466                        }
467
468                        let out_idx =
469                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
470                        output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
471                    }
472                }
473            }
474        }
475
476        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
477            .expect("tensor creation failed");
478
479        let requires_grad = input.requires_grad() && is_grad_enabled();
480        if requires_grad {
481            let grad_fn = GradFn::new(AvgPool2dBackward::new(
482                input.grad_fn().cloned(),
483                shape,
484                self.kernel_size,
485                self.stride,
486                self.padding,
487            ));
488            Variable::from_operation(output, grad_fn, true)
489        } else {
490            Variable::new(output, false)
491        }
492    }
493
494    fn name(&self) -> &'static str {
495        "AvgPool2d"
496    }
497}
498
499// =============================================================================
500// AdaptiveAvgPool2d
501// =============================================================================
502
503/// Applies adaptive average pooling to produce specified output size.
504pub struct AdaptiveAvgPool2d {
505    output_size: (usize, usize),
506}
507
508impl AdaptiveAvgPool2d {
509    /// Creates a new AdaptiveAvgPool2d.
510    pub fn new(output_size: (usize, usize)) -> Self {
511        Self { output_size }
512    }
513
514    /// Creates an AdaptiveAvgPool2d with square output.
515    pub fn square(size: usize) -> Self {
516        Self {
517            output_size: (size, size),
518        }
519    }
520}
521
522impl Module for AdaptiveAvgPool2d {
523    fn forward(&self, input: &Variable) -> Variable {
524        let shape = input.shape();
525        let batch = shape[0];
526        let channels = shape[1];
527        let in_h = shape[2];
528        let in_w = shape[3];
529
530        let (out_h, out_w) = self.output_size;
531        let input_vec = input.data().to_vec();
532        let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
533
534        for b in 0..batch {
535            for c in 0..channels {
536                for oh in 0..out_h {
537                    for ow in 0..out_w {
538                        let ih_start = (oh * in_h) / out_h;
539                        let ih_end = ((oh + 1) * in_h) / out_h;
540                        let iw_start = (ow * in_w) / out_w;
541                        let iw_end = ((ow + 1) * in_w) / out_w;
542
543                        let mut sum = 0.0f32;
544                        let mut count = 0;
545
546                        for ih in ih_start..ih_end {
547                            for iw in iw_start..iw_end {
548                                let idx =
549                                    b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
550                                sum += input_vec[idx];
551                                count += 1;
552                            }
553                        }
554
555                        let out_idx =
556                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
557                        output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
558                    }
559                }
560            }
561        }
562
563        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w])
564            .expect("tensor creation failed");
565
566        let requires_grad = input.requires_grad() && is_grad_enabled();
567        if requires_grad {
568            let grad_fn = GradFn::new(AdaptiveAvgPool2dBackward::new(
569                input.grad_fn().cloned(),
570                shape,
571                self.output_size,
572            ));
573            Variable::from_operation(output, grad_fn, true)
574        } else {
575            Variable::new(output, false)
576        }
577    }
578
579    fn name(&self) -> &'static str {
580        "AdaptiveAvgPool2d"
581    }
582}
583
584// =============================================================================
585// Tests
586// =============================================================================
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591
592    #[test]
593    fn test_maxpool2d() {
594        let pool = MaxPool2d::new(2);
595        let input = Variable::new(
596            Tensor::from_vec(
597                vec![
598                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
599                    15.0, 16.0,
600                ],
601                &[1, 1, 4, 4],
602            )
603            .unwrap(),
604            false,
605        );
606        let output = pool.forward(&input);
607        assert_eq!(output.shape(), vec![1, 1, 2, 2]);
608        assert_eq!(output.data().to_vec(), vec![6.0, 8.0, 14.0, 16.0]);
609    }
610
611    #[test]
612    fn test_maxpool2d_backward() {
613        let pool = MaxPool2d::new(2);
614        let input = Variable::new(
615            Tensor::from_vec(
616                vec![
617                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
618                    15.0, 16.0,
619                ],
620                &[1, 1, 4, 4],
621            )
622            .unwrap(),
623            true,
624        );
625        let output = pool.forward(&input);
626        let loss = output.sum();
627        loss.backward();
628
629        assert!(input.grad().is_some(), "MaxPool2d: gradient should flow");
630        let grad = input.grad().unwrap();
631        assert_eq!(grad.shape(), &[1, 1, 4, 4]);
632        let grad_vec = grad.to_vec();
633        // Only max positions (6,8,14,16) at indices [5,7,13,15] should have gradient
634        assert_eq!(grad_vec[5], 1.0);
635        assert_eq!(grad_vec[7], 1.0);
636        assert_eq!(grad_vec[13], 1.0);
637        assert_eq!(grad_vec[15], 1.0);
638        assert_eq!(grad_vec[0], 0.0);
639    }
640
641    #[test]
642    fn test_avgpool2d() {
643        let pool = AvgPool2d::new(2);
644        let input = Variable::new(
645            Tensor::from_vec(
646                vec![
647                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
648                    15.0, 16.0,
649                ],
650                &[1, 1, 4, 4],
651            )
652            .unwrap(),
653            false,
654        );
655        let output = pool.forward(&input);
656        assert_eq!(output.shape(), vec![1, 1, 2, 2]);
657        assert_eq!(output.data().to_vec(), vec![3.5, 5.5, 11.5, 13.5]);
658    }
659
660    #[test]
661    fn test_avgpool2d_backward() {
662        let pool = AvgPool2d::new(2);
663        let input = Variable::new(
664            Tensor::from_vec(vec![1.0; 16], &[1, 1, 4, 4]).expect("tensor creation failed"),
665            true,
666        );
667        let output = pool.forward(&input);
668        let loss = output.sum();
669        loss.backward();
670
671        assert!(input.grad().is_some(), "AvgPool2d: gradient should flow");
672        let grad = input.grad().unwrap();
673        // Each input element contributes to exactly one pool window, gets 1/4 of the gradient
674        for &v in &grad.to_vec() {
675            assert!((v - 0.25).abs() < 1e-6);
676        }
677    }
678
679    #[test]
680    fn test_adaptive_avgpool2d() {
681        let pool = AdaptiveAvgPool2d::new((1, 1));
682        let input = Variable::new(
683            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2])
684                .expect("tensor creation failed"),
685            false,
686        );
687        let output = pool.forward(&input);
688        assert_eq!(output.shape(), vec![1, 1, 1, 1]);
689        assert_eq!(output.data().to_vec(), vec![2.5]);
690    }
691
692    #[test]
693    fn test_adaptive_avgpool2d_backward() {
694        let pool = AdaptiveAvgPool2d::new((1, 1));
695        let input = Variable::new(
696            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2])
697                .expect("tensor creation failed"),
698            true,
699        );
700        let output = pool.forward(&input);
701        let loss = output.sum();
702        loss.backward();
703
704        assert!(
705            input.grad().is_some(),
706            "AdaptiveAvgPool2d: gradient should flow"
707        );
708        let grad = input.grad().unwrap();
709        for &v in &grad.to_vec() {
710            assert!((v - 0.25).abs() < 1e-6);
711        }
712    }
713}