Skip to main content

murk_obs/
pool.rs

1//! Spatial pooling operations for observation extraction.
2//!
3//! Pooling reduces the spatial resolution of a gathered observation
4//! buffer by sliding a window and computing an aggregate (mean, max,
5//! min, or sum) over valid cells within each window.
6
7use crate::spec::{PoolConfig, PoolKernel};
8
9/// Apply 2D spatial pooling to a gathered observation buffer.
10///
11/// `input` is the flat gather buffer in row-major order.
12/// `input_mask` has 1 for valid cells, 0 for padding.
13/// `input_shape` is `[H, W]`.
14///
15/// Returns `(output, output_mask, output_shape)` where output_shape
16/// is `[(H - kernel_size) / stride + 1, (W - kernel_size) / stride + 1]`.
17pub fn pool_2d(
18    input: &[f32],
19    input_mask: &[u8],
20    input_shape: &[usize],
21    config: &PoolConfig,
22) -> (Vec<f32>, Vec<u8>, Vec<usize>) {
23    assert_eq!(input_shape.len(), 2, "pool_2d requires 2D input shape");
24    let h = input_shape[0];
25    let w = input_shape[1];
26    assert_eq!(input.len(), h * w);
27    assert_eq!(input_mask.len(), h * w);
28
29    let ks = config.kernel_size;
30    let stride = config.stride;
31    assert!(ks > 0, "kernel_size must be > 0");
32    assert!(stride > 0, "stride must be > 0");
33
34    let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
35    let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
36    let out_len = out_h * out_w;
37
38    let mut output = vec![0.0f32; out_len];
39    let mut output_mask = vec![0u8; out_len];
40    let output_shape = vec![out_h, out_w];
41
42    for oh in 0..out_h {
43        for ow in 0..out_w {
44            let r0 = oh * stride;
45            let c0 = ow * stride;
46            let out_idx = oh * out_w + ow;
47
48            let mut valid_count = 0u32;
49            let mut accum = match config.kernel {
50                PoolKernel::Max => f32::NEG_INFINITY,
51                PoolKernel::Min => f32::INFINITY,
52                PoolKernel::Mean | PoolKernel::Sum => 0.0,
53            };
54
55            for kr in 0..ks {
56                for kc in 0..ks {
57                    let r = r0 + kr;
58                    let c = c0 + kc;
59                    let idx = r * w + c;
60                    if input_mask[idx] == 1 {
61                        let val = input[idx];
62                        valid_count += 1;
63                        match config.kernel {
64                            PoolKernel::Mean | PoolKernel::Sum => accum += val,
65                            PoolKernel::Max => {
66                                if val > accum {
67                                    accum = val;
68                                }
69                            }
70                            PoolKernel::Min => {
71                                if val < accum {
72                                    accum = val;
73                                }
74                            }
75                        }
76                    }
77                }
78            }
79
80            if valid_count > 0 {
81                output_mask[out_idx] = 1;
82                output[out_idx] = match config.kernel {
83                    PoolKernel::Mean => accum / valid_count as f32,
84                    PoolKernel::Max | PoolKernel::Min | PoolKernel::Sum => accum,
85                };
86            }
87        }
88    }
89
90    (output, output_mask, output_shape)
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    fn pool_cfg(kernel: PoolKernel, kernel_size: usize, stride: usize) -> PoolConfig {
98        PoolConfig {
99            kernel,
100            kernel_size,
101            stride,
102        }
103    }
104
105    #[test]
106    fn mean_pool_2x2_stride2_on_4x4() {
107        // 4x4 input, values 1..16
108        let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
109        let mask = vec![1u8; 16];
110        let cfg = pool_cfg(PoolKernel::Mean, 2, 2);
111
112        let (output, out_mask, out_shape) = pool_2d(&input, &mask, &[4, 4], &cfg);
113        assert_eq!(out_shape, vec![2, 2]);
114        assert_eq!(out_mask, vec![1, 1, 1, 1]);
115
116        // Top-left: (1+2+5+6)/4 = 3.5
117        assert!((output[0] - 3.5).abs() < 1e-6);
118        // Top-right: (3+4+7+8)/4 = 5.5
119        assert!((output[1] - 5.5).abs() < 1e-6);
120        // Bottom-left: (9+10+13+14)/4 = 11.5
121        assert!((output[2] - 11.5).abs() < 1e-6);
122        // Bottom-right: (11+12+15+16)/4 = 13.5
123        assert!((output[3] - 13.5).abs() < 1e-6);
124    }
125
126    #[test]
127    fn max_pool_2x2_stride2() {
128        let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
129        let mask = vec![1u8; 16];
130        let cfg = pool_cfg(PoolKernel::Max, 2, 2);
131
132        let (output, _, out_shape) = pool_2d(&input, &mask, &[4, 4], &cfg);
133        assert_eq!(out_shape, vec![2, 2]);
134        assert_eq!(output, vec![6.0, 8.0, 14.0, 16.0]);
135    }
136
137    #[test]
138    fn min_pool_2x2_stride2() {
139        let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
140        let mask = vec![1u8; 16];
141        let cfg = pool_cfg(PoolKernel::Min, 2, 2);
142
143        let (output, _, _) = pool_2d(&input, &mask, &[4, 4], &cfg);
144        assert_eq!(output, vec![1.0, 3.0, 9.0, 11.0]);
145    }
146
147    #[test]
148    fn sum_pool_2x2_stride2() {
149        let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
150        let mask = vec![1u8; 16];
151        let cfg = pool_cfg(PoolKernel::Sum, 2, 2);
152
153        let (output, _, _) = pool_2d(&input, &mask, &[4, 4], &cfg);
154        // Top-left: 1+2+5+6=14
155        assert_eq!(output, vec![14.0, 22.0, 46.0, 54.0]);
156    }
157
158    #[test]
159    fn partial_valid_mask_mean() {
160        // 4x4 input, but some cells masked out
161        let input = vec![
162            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
163        ];
164        let mask = vec![1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1];
165        let cfg = pool_cfg(PoolKernel::Mean, 2, 2);
166
167        let (output, out_mask, _) = pool_2d(&input, &mask, &[4, 4], &cfg);
168        // Top-left: (1+5+6)/3 = 4.0 (cell [0,1]=2 masked out)
169        assert!((output[0] - 4.0).abs() < 1e-6);
170        assert_eq!(out_mask[0], 1);
171        // Top-right: (3+4+8)/3 = 5.0 (cell [1,2]=7 masked out)
172        assert!((output[1] - 5.0).abs() < 1e-6);
173    }
174
175    #[test]
176    fn all_masked_window_gives_zero() {
177        let input = vec![1.0, 2.0, 3.0, 4.0];
178        let mask = vec![0, 0, 0, 0];
179        let cfg = pool_cfg(PoolKernel::Mean, 2, 2);
180
181        let (output, out_mask, out_shape) = pool_2d(&input, &mask, &[2, 2], &cfg);
182        assert_eq!(out_shape, vec![1, 1]);
183        assert_eq!(output, vec![0.0]);
184        assert_eq!(out_mask, vec![0]);
185    }
186
187    #[test]
188    fn stride_1_produces_larger_output() {
189        let input: Vec<f32> = (1..=9).map(|x| x as f32).collect();
190        let mask = vec![1u8; 9];
191        let cfg = pool_cfg(PoolKernel::Mean, 2, 1);
192
193        let (_, _, out_shape) = pool_2d(&input, &mask, &[3, 3], &cfg);
194        // (3-2)/1+1=2 in each dim
195        assert_eq!(out_shape, vec![2, 2]);
196    }
197
198    #[test]
199    fn kernel_larger_than_input_gives_empty() {
200        let input = vec![1.0, 2.0, 3.0, 4.0];
201        let mask = vec![1u8; 4];
202        let cfg = pool_cfg(PoolKernel::Mean, 3, 1);
203
204        let (output, _, out_shape) = pool_2d(&input, &mask, &[2, 2], &cfg);
205        assert_eq!(out_shape, vec![0, 0]);
206        assert!(output.is_empty());
207    }
208}