trueno/matrix/ops/ml_ops/pooling.rs
1//! Pooling operations (max pool, average pool) for Matrix
2
3use crate::TruenoError;
4
5use super::super::super::Matrix;
6
7impl Matrix<f32> {
8 /// 2D Max Pooling operation for CNN downsampling
9 ///
10 /// Applies max pooling over a 2D input tensor with specified kernel size and stride.
11 ///
12 /// # Arguments
13 /// * `kernel` - (kernel_height, kernel_width) pooling window size
14 /// * `stride` - (stride_height, stride_width) step size
15 ///
16 /// # Examples
17 /// ```
18 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
19 /// use trueno::matrix::Matrix;
20 /// let input = Matrix::from_vec(4, 4, vec![
21 /// 1.0, 2.0, 3.0, 4.0,
22 /// 5.0, 6.0, 7.0, 8.0,
23 /// 9.0, 10.0, 11.0, 12.0,
24 /// 13.0, 14.0, 15.0, 16.0,
25 /// ])?;
26 /// let pooled = input.max_pool2d((2, 2), (2, 2))?;
27 /// assert_eq!(pooled.shape(), (2, 2));
28 /// assert_eq!(pooled.get(0, 0), Some(&6.0)); // max of [1,2,5,6]
29 /// assert_eq!(pooled.get(1, 1), Some(&16.0)); // max of [11,12,15,16]
30 /// # Ok(())
31 /// # }
32 /// ```
33 pub fn max_pool2d(
34 &self,
35 kernel: (usize, usize),
36 stride: (usize, usize),
37 ) -> Result<Matrix<f32>, TruenoError> {
38 let (kh, kw) = kernel;
39 let (sh, sw) = stride;
40
41 if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
42 return Err(TruenoError::InvalidInput(
43 "Kernel and stride dimensions must be positive".into(),
44 ));
45 }
46
47 if kh > self.rows || kw > self.cols {
48 return Err(TruenoError::InvalidInput(format!(
49 "Kernel size ({}, {}) larger than input ({}, {})",
50 kh, kw, self.rows, self.cols
51 )));
52 }
53
54 let out_h = (self.rows - kh) / sh + 1;
55 let out_w = (self.cols - kw) / sw + 1;
56 let mut result = Matrix::new(out_h, out_w);
57
58 for i in 0..out_h {
59 for j in 0..out_w {
60 let mut max_val = f32::NEG_INFINITY;
61 for ki in 0..kh {
62 for kj in 0..kw {
63 let val = self.data[(i * sh + ki) * self.cols + (j * sw + kj)];
64 max_val = max_val.max(val);
65 }
66 }
67 result.data[i * out_w + j] = max_val;
68 }
69 }
70
71 Ok(result)
72 }
73
74 /// 2D Average Pooling operation for CNN downsampling
75 ///
76 /// Applies average pooling over a 2D input tensor with specified kernel size and stride.
77 ///
78 /// # Arguments
79 /// * `kernel` - (kernel_height, kernel_width) pooling window size
80 /// * `stride` - (stride_height, stride_width) step size
81 ///
82 /// # Examples
83 /// ```
84 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
85 /// use trueno::matrix::Matrix;
86 /// let input = Matrix::from_vec(4, 4, vec![
87 /// 1.0, 2.0, 3.0, 4.0,
88 /// 5.0, 6.0, 7.0, 8.0,
89 /// 9.0, 10.0, 11.0, 12.0,
90 /// 13.0, 14.0, 15.0, 16.0,
91 /// ])?;
92 /// let pooled = input.avg_pool2d((2, 2), (2, 2))?;
93 /// assert_eq!(pooled.shape(), (2, 2));
94 /// assert!((pooled.get(0, 0).unwrap_or(&0.0) - 3.5).abs() < 1e-5); // avg of [1,2,5,6]
95 /// # Ok(())
96 /// # }
97 /// ```
98 pub fn avg_pool2d(
99 &self,
100 kernel: (usize, usize),
101 stride: (usize, usize),
102 ) -> Result<Matrix<f32>, TruenoError> {
103 let (kh, kw) = kernel;
104 let (sh, sw) = stride;
105
106 if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
107 return Err(TruenoError::InvalidInput(
108 "Kernel and stride dimensions must be positive".into(),
109 ));
110 }
111
112 if kh > self.rows || kw > self.cols {
113 return Err(TruenoError::InvalidInput(format!(
114 "Kernel size ({}, {}) larger than input ({}, {})",
115 kh, kw, self.rows, self.cols
116 )));
117 }
118
119 let out_h = (self.rows - kh) / sh + 1;
120 let out_w = (self.cols - kw) / sw + 1;
121 let kernel_size = (kh * kw) as f32;
122 let mut result = Matrix::new(out_h, out_w);
123
124 for i in 0..out_h {
125 for j in 0..out_w {
126 let mut sum = 0.0;
127 for ki in 0..kh {
128 for kj in 0..kw {
129 sum += self.data[(i * sh + ki) * self.cols + (j * sw + kj)];
130 }
131 }
132 result.data[i * out_w + j] = sum / kernel_size;
133 }
134 }
135
136 Ok(result)
137 }
138}