Skip to main content

trueno/matrix/ops/ml_ops/
pooling.rs

1//! Pooling operations (max pool, average pool) for Matrix
2
3use crate::TruenoError;
4
5use super::super::super::Matrix;
6
7impl Matrix<f32> {
8    /// 2D Max Pooling operation for CNN downsampling
9    ///
10    /// Applies max pooling over a 2D input tensor with specified kernel size and stride.
11    ///
12    /// # Arguments
13    /// * `kernel` - (kernel_height, kernel_width) pooling window size
14    /// * `stride` - (stride_height, stride_width) step size
15    ///
16    /// # Examples
17    /// ```
18    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
19    /// use trueno::matrix::Matrix;
20    /// let input = Matrix::from_vec(4, 4, vec![
21    ///     1.0, 2.0, 3.0, 4.0,
22    ///     5.0, 6.0, 7.0, 8.0,
23    ///     9.0, 10.0, 11.0, 12.0,
24    ///     13.0, 14.0, 15.0, 16.0,
25    /// ])?;
26    /// let pooled = input.max_pool2d((2, 2), (2, 2))?;
27    /// assert_eq!(pooled.shape(), (2, 2));
28    /// assert_eq!(pooled.get(0, 0), Some(&6.0));  // max of [1,2,5,6]
29    /// assert_eq!(pooled.get(1, 1), Some(&16.0)); // max of [11,12,15,16]
30    /// # Ok(())
31    /// # }
32    /// ```
33    pub fn max_pool2d(
34        &self,
35        kernel: (usize, usize),
36        stride: (usize, usize),
37    ) -> Result<Matrix<f32>, TruenoError> {
38        let (kh, kw) = kernel;
39        let (sh, sw) = stride;
40
41        if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
42            return Err(TruenoError::InvalidInput(
43                "Kernel and stride dimensions must be positive".into(),
44            ));
45        }
46
47        if kh > self.rows || kw > self.cols {
48            return Err(TruenoError::InvalidInput(format!(
49                "Kernel size ({}, {}) larger than input ({}, {})",
50                kh, kw, self.rows, self.cols
51            )));
52        }
53
54        let out_h = (self.rows - kh) / sh + 1;
55        let out_w = (self.cols - kw) / sw + 1;
56        let mut result = Matrix::new(out_h, out_w);
57
58        for i in 0..out_h {
59            for j in 0..out_w {
60                let mut max_val = f32::NEG_INFINITY;
61                for ki in 0..kh {
62                    for kj in 0..kw {
63                        let val = self.data[(i * sh + ki) * self.cols + (j * sw + kj)];
64                        max_val = max_val.max(val);
65                    }
66                }
67                result.data[i * out_w + j] = max_val;
68            }
69        }
70
71        Ok(result)
72    }
73
74    /// 2D Average Pooling operation for CNN downsampling
75    ///
76    /// Applies average pooling over a 2D input tensor with specified kernel size and stride.
77    ///
78    /// # Arguments
79    /// * `kernel` - (kernel_height, kernel_width) pooling window size
80    /// * `stride` - (stride_height, stride_width) step size
81    ///
82    /// # Examples
83    /// ```
84    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
85    /// use trueno::matrix::Matrix;
86    /// let input = Matrix::from_vec(4, 4, vec![
87    ///     1.0, 2.0, 3.0, 4.0,
88    ///     5.0, 6.0, 7.0, 8.0,
89    ///     9.0, 10.0, 11.0, 12.0,
90    ///     13.0, 14.0, 15.0, 16.0,
91    /// ])?;
92    /// let pooled = input.avg_pool2d((2, 2), (2, 2))?;
93    /// assert_eq!(pooled.shape(), (2, 2));
94    /// assert!((pooled.get(0, 0).unwrap_or(&0.0) - 3.5).abs() < 1e-5);  // avg of [1,2,5,6]
95    /// # Ok(())
96    /// # }
97    /// ```
98    pub fn avg_pool2d(
99        &self,
100        kernel: (usize, usize),
101        stride: (usize, usize),
102    ) -> Result<Matrix<f32>, TruenoError> {
103        let (kh, kw) = kernel;
104        let (sh, sw) = stride;
105
106        if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
107            return Err(TruenoError::InvalidInput(
108                "Kernel and stride dimensions must be positive".into(),
109            ));
110        }
111
112        if kh > self.rows || kw > self.cols {
113            return Err(TruenoError::InvalidInput(format!(
114                "Kernel size ({}, {}) larger than input ({}, {})",
115                kh, kw, self.rows, self.cols
116            )));
117        }
118
119        let out_h = (self.rows - kh) / sh + 1;
120        let out_w = (self.cols - kw) / sw + 1;
121        let kernel_size = (kh * kw) as f32;
122        let mut result = Matrix::new(out_h, out_w);
123
124        for i in 0..out_h {
125            for j in 0..out_w {
126                let mut sum = 0.0;
127                for ki in 0..kh {
128                    for kj in 0..kw {
129                        sum += self.data[(i * sh + ki) * self.cols + (j * sw + kj)];
130                    }
131                }
132                result.data[i * out_w + j] = sum / kernel_size;
133            }
134        }
135
136        Ok(result)
137    }
138}