1use crate::spec::{PoolConfig, PoolKernel};
8
9pub fn pool_2d(
18 input: &[f32],
19 input_mask: &[u8],
20 input_shape: &[usize],
21 config: &PoolConfig,
22) -> (Vec<f32>, Vec<u8>, Vec<usize>) {
23 assert_eq!(input_shape.len(), 2, "pool_2d requires 2D input shape");
24 let h = input_shape[0];
25 let w = input_shape[1];
26 assert_eq!(input.len(), h * w);
27 assert_eq!(input_mask.len(), h * w);
28
29 let ks = config.kernel_size;
30 let stride = config.stride;
31 assert!(ks > 0, "kernel_size must be > 0");
32 assert!(stride > 0, "stride must be > 0");
33
34 let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
35 let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
36 let out_len = out_h * out_w;
37
38 let mut output = vec![0.0f32; out_len];
39 let mut output_mask = vec![0u8; out_len];
40 let output_shape = vec![out_h, out_w];
41
42 for oh in 0..out_h {
43 for ow in 0..out_w {
44 let r0 = oh * stride;
45 let c0 = ow * stride;
46 let out_idx = oh * out_w + ow;
47
48 let mut valid_count = 0u32;
49 let mut accum = match config.kernel {
50 PoolKernel::Max => f32::NEG_INFINITY,
51 PoolKernel::Min => f32::INFINITY,
52 PoolKernel::Mean | PoolKernel::Sum => 0.0,
53 };
54
55 for kr in 0..ks {
56 for kc in 0..ks {
57 let r = r0 + kr;
58 let c = c0 + kc;
59 let idx = r * w + c;
60 if input_mask[idx] == 1 {
61 let val = input[idx];
62 valid_count += 1;
63 match config.kernel {
64 PoolKernel::Mean | PoolKernel::Sum => accum += val,
65 PoolKernel::Max => {
66 if val > accum {
67 accum = val;
68 }
69 }
70 PoolKernel::Min => {
71 if val < accum {
72 accum = val;
73 }
74 }
75 }
76 }
77 }
78 }
79
80 if valid_count > 0 {
81 output_mask[out_idx] = 1;
82 output[out_idx] = match config.kernel {
83 PoolKernel::Mean => accum / valid_count as f32,
84 PoolKernel::Max | PoolKernel::Min | PoolKernel::Sum => accum,
85 };
86 }
87 }
88 }
89
90 (output, output_mask, output_shape)
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 fn pool_cfg(kernel: PoolKernel, kernel_size: usize, stride: usize) -> PoolConfig {
98 PoolConfig {
99 kernel,
100 kernel_size,
101 stride,
102 }
103 }
104
105 #[test]
106 fn mean_pool_2x2_stride2_on_4x4() {
107 let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
109 let mask = vec![1u8; 16];
110 let cfg = pool_cfg(PoolKernel::Mean, 2, 2);
111
112 let (output, out_mask, out_shape) = pool_2d(&input, &mask, &[4, 4], &cfg);
113 assert_eq!(out_shape, vec![2, 2]);
114 assert_eq!(out_mask, vec![1, 1, 1, 1]);
115
116 assert!((output[0] - 3.5).abs() < 1e-6);
118 assert!((output[1] - 5.5).abs() < 1e-6);
120 assert!((output[2] - 11.5).abs() < 1e-6);
122 assert!((output[3] - 13.5).abs() < 1e-6);
124 }
125
126 #[test]
127 fn max_pool_2x2_stride2() {
128 let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
129 let mask = vec![1u8; 16];
130 let cfg = pool_cfg(PoolKernel::Max, 2, 2);
131
132 let (output, _, out_shape) = pool_2d(&input, &mask, &[4, 4], &cfg);
133 assert_eq!(out_shape, vec![2, 2]);
134 assert_eq!(output, vec![6.0, 8.0, 14.0, 16.0]);
135 }
136
137 #[test]
138 fn min_pool_2x2_stride2() {
139 let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
140 let mask = vec![1u8; 16];
141 let cfg = pool_cfg(PoolKernel::Min, 2, 2);
142
143 let (output, _, _) = pool_2d(&input, &mask, &[4, 4], &cfg);
144 assert_eq!(output, vec![1.0, 3.0, 9.0, 11.0]);
145 }
146
147 #[test]
148 fn sum_pool_2x2_stride2() {
149 let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
150 let mask = vec![1u8; 16];
151 let cfg = pool_cfg(PoolKernel::Sum, 2, 2);
152
153 let (output, _, _) = pool_2d(&input, &mask, &[4, 4], &cfg);
154 assert_eq!(output, vec![14.0, 22.0, 46.0, 54.0]);
156 }
157
158 #[test]
159 fn partial_valid_mask_mean() {
160 let input = vec![
162 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
163 ];
164 let mask = vec![1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1];
165 let cfg = pool_cfg(PoolKernel::Mean, 2, 2);
166
167 let (output, out_mask, _) = pool_2d(&input, &mask, &[4, 4], &cfg);
168 assert!((output[0] - 4.0).abs() < 1e-6);
170 assert_eq!(out_mask[0], 1);
171 assert!((output[1] - 5.0).abs() < 1e-6);
173 }
174
175 #[test]
176 fn all_masked_window_gives_zero() {
177 let input = vec![1.0, 2.0, 3.0, 4.0];
178 let mask = vec![0, 0, 0, 0];
179 let cfg = pool_cfg(PoolKernel::Mean, 2, 2);
180
181 let (output, out_mask, out_shape) = pool_2d(&input, &mask, &[2, 2], &cfg);
182 assert_eq!(out_shape, vec![1, 1]);
183 assert_eq!(output, vec![0.0]);
184 assert_eq!(out_mask, vec![0]);
185 }
186
187 #[test]
188 fn stride_1_produces_larger_output() {
189 let input: Vec<f32> = (1..=9).map(|x| x as f32).collect();
190 let mask = vec![1u8; 9];
191 let cfg = pool_cfg(PoolKernel::Mean, 2, 1);
192
193 let (_, _, out_shape) = pool_2d(&input, &mask, &[3, 3], &cfg);
194 assert_eq!(out_shape, vec![2, 2]);
196 }
197
198 #[test]
199 fn kernel_larger_than_input_gives_empty() {
200 let input = vec![1.0, 2.0, 3.0, 4.0];
201 let mask = vec![1u8; 4];
202 let cfg = pool_cfg(PoolKernel::Mean, 3, 1);
203
204 let (output, _, out_shape) = pool_2d(&input, &mask, &[2, 2], &cfg);
205 assert_eq!(out_shape, vec![0, 0]);
206 assert!(output.is_empty());
207 }
208}