axonml_nn/layers/
pooling.rs

1//! Pooling Layers - Max and Average Pooling
2//!
3//! Reduces spatial dimensions through pooling operations.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use axonml_autograd::Variable;
9use axonml_tensor::Tensor;
10
11use crate::module::Module;
12
13// =============================================================================
14// MaxPool1d
15// =============================================================================
16
17/// Applies max pooling over a 1D signal.
18///
19/// # Shape
20/// - Input: (N, C, L)
21/// - Output: (N, C, L_out)
22pub struct MaxPool1d {
23    kernel_size: usize,
24    stride: usize,
25    padding: usize,
26}
27
28impl MaxPool1d {
29    /// Creates a new MaxPool1d layer.
30    pub fn new(kernel_size: usize) -> Self {
31        Self {
32            kernel_size,
33            stride: kernel_size, // Default stride equals kernel size
34            padding: 0,
35        }
36    }
37
38    /// Creates a MaxPool1d with custom stride and padding.
39    pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
40        Self {
41            kernel_size,
42            stride,
43            padding,
44        }
45    }
46}
47
48impl Module for MaxPool1d {
49    fn forward(&self, input: &Variable) -> Variable {
50        let shape = input.shape();
51        let batch = shape[0];
52        let channels = shape[1];
53        let length = shape[2];
54
55        let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
56
57        let input_vec = input.data().to_vec();
58        let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_length];
59
60        for b in 0..batch {
61            for c in 0..channels {
62                for ol in 0..out_length {
63                    let in_start = ol * self.stride;
64                    let mut max_val = f32::NEG_INFINITY;
65
66                    for k in 0..self.kernel_size {
67                        let il = in_start + k;
68                        if il >= self.padding && il < length + self.padding {
69                            let actual_il = il - self.padding;
70                            let idx = b * channels * length + c * length + actual_il;
71                            max_val = max_val.max(input_vec[idx]);
72                        }
73                    }
74
75                    let out_idx = b * channels * out_length + c * out_length + ol;
76                    output_data[out_idx] = max_val;
77                }
78            }
79        }
80
81        let output = Tensor::from_vec(output_data, &[batch, channels, out_length]).unwrap();
82        Variable::new(output, input.requires_grad())
83    }
84
85    fn name(&self) -> &'static str {
86        "MaxPool1d"
87    }
88}
89
90// =============================================================================
91// MaxPool2d
92// =============================================================================
93
94/// Applies max pooling over a 2D signal (image).
95///
96/// # Shape
97/// - Input: (N, C, H, W)
98/// - Output: (N, C, H_out, W_out)
99pub struct MaxPool2d {
100    kernel_size: (usize, usize),
101    stride: (usize, usize),
102    padding: (usize, usize),
103}
104
105impl MaxPool2d {
106    /// Creates a new MaxPool2d layer with square kernel.
107    pub fn new(kernel_size: usize) -> Self {
108        Self {
109            kernel_size: (kernel_size, kernel_size),
110            stride: (kernel_size, kernel_size),
111            padding: (0, 0),
112        }
113    }
114
115    /// Creates a MaxPool2d with all options.
116    pub fn with_options(
117        kernel_size: (usize, usize),
118        stride: (usize, usize),
119        padding: (usize, usize),
120    ) -> Self {
121        Self {
122            kernel_size,
123            stride,
124            padding,
125        }
126    }
127}
128
129impl Module for MaxPool2d {
130    fn forward(&self, input: &Variable) -> Variable {
131        let shape = input.shape();
132        let batch = shape[0];
133        let channels = shape[1];
134        let height = shape[2];
135        let width = shape[3];
136
137        let (kh, kw) = self.kernel_size;
138        let (sh, sw) = self.stride;
139        let (ph, pw) = self.padding;
140
141        let out_h = (height + 2 * ph - kh) / sh + 1;
142        let out_w = (width + 2 * pw - kw) / sw + 1;
143
144        let input_vec = input.data().to_vec();
145        let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_h * out_w];
146
147        for b in 0..batch {
148            for c in 0..channels {
149                for oh in 0..out_h {
150                    for ow in 0..out_w {
151                        let mut max_val = f32::NEG_INFINITY;
152
153                        for ki in 0..kh {
154                            for kj in 0..kw {
155                                let ih = oh * sh + ki;
156                                let iw = ow * sw + kj;
157
158                                if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
159                                    let actual_ih = ih - ph;
160                                    let actual_iw = iw - pw;
161                                    let idx = b * channels * height * width
162                                        + c * height * width
163                                        + actual_ih * width
164                                        + actual_iw;
165                                    max_val = max_val.max(input_vec[idx]);
166                                }
167                            }
168                        }
169
170                        let out_idx =
171                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
172                        output_data[out_idx] = max_val;
173                    }
174                }
175            }
176        }
177
178        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
179        Variable::new(output, input.requires_grad())
180    }
181
182    fn name(&self) -> &'static str {
183        "MaxPool2d"
184    }
185}
186
187// =============================================================================
188// AvgPool1d
189// =============================================================================
190
191/// Applies average pooling over a 1D signal.
192pub struct AvgPool1d {
193    kernel_size: usize,
194    stride: usize,
195    padding: usize,
196}
197
198impl AvgPool1d {
199    /// Creates a new AvgPool1d layer.
200    pub fn new(kernel_size: usize) -> Self {
201        Self {
202            kernel_size,
203            stride: kernel_size,
204            padding: 0,
205        }
206    }
207
208    /// Creates an AvgPool1d with custom stride and padding.
209    pub fn with_options(kernel_size: usize, stride: usize, padding: usize) -> Self {
210        Self {
211            kernel_size,
212            stride,
213            padding,
214        }
215    }
216}
217
218impl Module for AvgPool1d {
219    fn forward(&self, input: &Variable) -> Variable {
220        let shape = input.shape();
221        let batch = shape[0];
222        let channels = shape[1];
223        let length = shape[2];
224
225        let out_length = (length + 2 * self.padding - self.kernel_size) / self.stride + 1;
226
227        let input_vec = input.data().to_vec();
228        let mut output_data = vec![0.0f32; batch * channels * out_length];
229
230        for b in 0..batch {
231            for c in 0..channels {
232                for ol in 0..out_length {
233                    let in_start = ol * self.stride;
234                    let mut sum = 0.0f32;
235                    let mut count = 0;
236
237                    for k in 0..self.kernel_size {
238                        let il = in_start + k;
239                        if il >= self.padding && il < length + self.padding {
240                            let actual_il = il - self.padding;
241                            let idx = b * channels * length + c * length + actual_il;
242                            sum += input_vec[idx];
243                            count += 1;
244                        }
245                    }
246
247                    let out_idx = b * channels * out_length + c * out_length + ol;
248                    output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
249                }
250            }
251        }
252
253        let output = Tensor::from_vec(output_data, &[batch, channels, out_length]).unwrap();
254        Variable::new(output, input.requires_grad())
255    }
256
257    fn name(&self) -> &'static str {
258        "AvgPool1d"
259    }
260}
261
262// =============================================================================
263// AvgPool2d
264// =============================================================================
265
266/// Applies average pooling over a 2D signal (image).
267pub struct AvgPool2d {
268    kernel_size: (usize, usize),
269    stride: (usize, usize),
270    padding: (usize, usize),
271}
272
273impl AvgPool2d {
274    /// Creates a new AvgPool2d layer with square kernel.
275    pub fn new(kernel_size: usize) -> Self {
276        Self {
277            kernel_size: (kernel_size, kernel_size),
278            stride: (kernel_size, kernel_size),
279            padding: (0, 0),
280        }
281    }
282
283    /// Creates an AvgPool2d with all options.
284    pub fn with_options(
285        kernel_size: (usize, usize),
286        stride: (usize, usize),
287        padding: (usize, usize),
288    ) -> Self {
289        Self {
290            kernel_size,
291            stride,
292            padding,
293        }
294    }
295}
296
297impl Module for AvgPool2d {
298    fn forward(&self, input: &Variable) -> Variable {
299        let shape = input.shape();
300        let batch = shape[0];
301        let channels = shape[1];
302        let height = shape[2];
303        let width = shape[3];
304
305        let (kh, kw) = self.kernel_size;
306        let (sh, sw) = self.stride;
307        let (ph, pw) = self.padding;
308
309        let out_h = (height + 2 * ph - kh) / sh + 1;
310        let out_w = (width + 2 * pw - kw) / sw + 1;
311
312        let input_vec = input.data().to_vec();
313        let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
314
315        for b in 0..batch {
316            for c in 0..channels {
317                for oh in 0..out_h {
318                    for ow in 0..out_w {
319                        let mut sum = 0.0f32;
320                        let mut count = 0;
321
322                        for ki in 0..kh {
323                            for kj in 0..kw {
324                                let ih = oh * sh + ki;
325                                let iw = ow * sw + kj;
326
327                                if ih >= ph && ih < height + ph && iw >= pw && iw < width + pw {
328                                    let actual_ih = ih - ph;
329                                    let actual_iw = iw - pw;
330                                    let idx = b * channels * height * width
331                                        + c * height * width
332                                        + actual_ih * width
333                                        + actual_iw;
334                                    sum += input_vec[idx];
335                                    count += 1;
336                                }
337                            }
338                        }
339
340                        let out_idx =
341                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
342                        output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
343                    }
344                }
345            }
346        }
347
348        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
349        Variable::new(output, input.requires_grad())
350    }
351
352    fn name(&self) -> &'static str {
353        "AvgPool2d"
354    }
355}
356
357// =============================================================================
358// AdaptiveAvgPool2d
359// =============================================================================
360
361/// Applies adaptive average pooling to produce specified output size.
362///
363/// This automatically determines the kernel size and stride to achieve
364/// the desired output dimensions.
365pub struct AdaptiveAvgPool2d {
366    output_size: (usize, usize),
367}
368
369impl AdaptiveAvgPool2d {
370    /// Creates a new AdaptiveAvgPool2d.
371    pub fn new(output_size: (usize, usize)) -> Self {
372        Self { output_size }
373    }
374
375    /// Creates an AdaptiveAvgPool2d with square output.
376    pub fn square(size: usize) -> Self {
377        Self {
378            output_size: (size, size),
379        }
380    }
381}
382
383impl Module for AdaptiveAvgPool2d {
384    fn forward(&self, input: &Variable) -> Variable {
385        let shape = input.shape();
386        let batch = shape[0];
387        let channels = shape[1];
388        let in_h = shape[2];
389        let in_w = shape[3];
390
391        let (out_h, out_w) = self.output_size;
392        let input_vec = input.data().to_vec();
393        let mut output_data = vec![0.0f32; batch * channels * out_h * out_w];
394
395        for b in 0..batch {
396            for c in 0..channels {
397                for oh in 0..out_h {
398                    for ow in 0..out_w {
399                        // Calculate input region for this output pixel
400                        let ih_start = (oh * in_h) / out_h;
401                        let ih_end = ((oh + 1) * in_h) / out_h;
402                        let iw_start = (ow * in_w) / out_w;
403                        let iw_end = ((ow + 1) * in_w) / out_w;
404
405                        let mut sum = 0.0f32;
406                        let mut count = 0;
407
408                        for ih in ih_start..ih_end {
409                            for iw in iw_start..iw_end {
410                                let idx =
411                                    b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
412                                sum += input_vec[idx];
413                                count += 1;
414                            }
415                        }
416
417                        let out_idx =
418                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
419                        output_data[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
420                    }
421                }
422            }
423        }
424
425        let output = Tensor::from_vec(output_data, &[batch, channels, out_h, out_w]).unwrap();
426        Variable::new(output, input.requires_grad())
427    }
428
429    fn name(&self) -> &'static str {
430        "AdaptiveAvgPool2d"
431    }
432}
433
434// =============================================================================
435// Tests
436// =============================================================================
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_maxpool2d() {
444        let pool = MaxPool2d::new(2);
445        let input = Variable::new(
446            Tensor::from_vec(
447                vec![
448                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
449                    15.0, 16.0,
450                ],
451                &[1, 1, 4, 4],
452            )
453            .unwrap(),
454            false,
455        );
456        let output = pool.forward(&input);
457        assert_eq!(output.shape(), vec![1, 1, 2, 2]);
458        // Max of each 2x2 region
459        assert_eq!(output.data().to_vec(), vec![6.0, 8.0, 14.0, 16.0]);
460    }
461
462    #[test]
463    fn test_avgpool2d() {
464        let pool = AvgPool2d::new(2);
465        let input = Variable::new(
466            Tensor::from_vec(
467                vec![
468                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
469                    15.0, 16.0,
470                ],
471                &[1, 1, 4, 4],
472            )
473            .unwrap(),
474            false,
475        );
476        let output = pool.forward(&input);
477        assert_eq!(output.shape(), vec![1, 1, 2, 2]);
478        // Avg of each 2x2 region
479        assert_eq!(output.data().to_vec(), vec![3.5, 5.5, 11.5, 13.5]);
480    }
481
482    #[test]
483    fn test_adaptive_avgpool2d() {
484        let pool = AdaptiveAvgPool2d::new((1, 1));
485        let input = Variable::new(
486            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap(),
487            false,
488        );
489        let output = pool.forward(&input);
490        assert_eq!(output.shape(), vec![1, 1, 1, 1]);
491        assert_eq!(output.data().to_vec(), vec![2.5]); // Average of all
492    }
493}