ghostflow_nn/
pooling.rs

1//! Pooling layers
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5
6/// Max Pooling 2D
7pub struct MaxPool2d {
8    kernel_size: (usize, usize),
9    stride: (usize, usize),
10    padding: (usize, usize),
11}
12
13impl MaxPool2d {
14    pub fn new(kernel_size: usize) -> Self {
15        Self::with_params(kernel_size, kernel_size, 0)
16    }
17
18    pub fn with_params(kernel_size: usize, stride: usize, padding: usize) -> Self {
19        MaxPool2d {
20            kernel_size: (kernel_size, kernel_size),
21            stride: (stride, stride),
22            padding: (padding, padding),
23        }
24    }
25}
26
27impl Module for MaxPool2d {
28    fn forward(&self, input: &Tensor) -> Tensor {
29        let dims = input.dims();
30        let batch = dims[0];
31        let channels = dims[1];
32        let in_h = dims[2];
33        let in_w = dims[3];
34        
35        let out_h = (in_h + 2 * self.padding.0 - self.kernel_size.0) / self.stride.0 + 1;
36        let out_w = (in_w + 2 * self.padding.1 - self.kernel_size.1) / self.stride.1 + 1;
37        
38        let data = input.data_f32();
39        let mut output = vec![f32::NEG_INFINITY; batch * channels * out_h * out_w];
40        
41        for b in 0..batch {
42            for c in 0..channels {
43                for oh in 0..out_h {
44                    for ow in 0..out_w {
45                        let mut max_val = f32::NEG_INFINITY;
46                        
47                        for kh in 0..self.kernel_size.0 {
48                            for kw in 0..self.kernel_size.1 {
49                                let ih = (oh * self.stride.0 + kh) as i32 - self.padding.0 as i32;
50                                let iw = (ow * self.stride.1 + kw) as i32 - self.padding.1 as i32;
51                                
52                                if ih >= 0 && (ih as usize) < in_h && iw >= 0 && (iw as usize) < in_w {
53                                    let idx = b * channels * in_h * in_w 
54                                        + c * in_h * in_w 
55                                        + (ih as usize) * in_w 
56                                        + iw as usize;
57                                    max_val = max_val.max(data[idx]);
58                                }
59                            }
60                        }
61                        
62                        let out_idx = b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
63                        output[out_idx] = max_val;
64                    }
65                }
66            }
67        }
68        
69        Tensor::from_slice(&output, &[batch, channels, out_h, out_w]).unwrap()
70    }
71
72    fn parameters(&self) -> Vec<Tensor> { vec![] }
73    fn train(&mut self) {}
74    fn eval(&mut self) {}
75    fn is_training(&self) -> bool { false }
76}
77
78/// Average Pooling 2D
79pub struct AvgPool2d {
80    kernel_size: (usize, usize),
81    stride: (usize, usize),
82    padding: (usize, usize),
83    count_include_pad: bool,
84}
85
86impl AvgPool2d {
87    pub fn new(kernel_size: usize) -> Self {
88        Self::with_params(kernel_size, kernel_size, 0)
89    }
90
91    pub fn with_params(kernel_size: usize, stride: usize, padding: usize) -> Self {
92        AvgPool2d {
93            kernel_size: (kernel_size, kernel_size),
94            stride: (stride, stride),
95            padding: (padding, padding),
96            count_include_pad: true,
97        }
98    }
99}
100
101impl Module for AvgPool2d {
102    fn forward(&self, input: &Tensor) -> Tensor {
103        let dims = input.dims();
104        let batch = dims[0];
105        let channels = dims[1];
106        let in_h = dims[2];
107        let in_w = dims[3];
108        
109        let out_h = (in_h + 2 * self.padding.0 - self.kernel_size.0) / self.stride.0 + 1;
110        let out_w = (in_w + 2 * self.padding.1 - self.kernel_size.1) / self.stride.1 + 1;
111        
112        let data = input.data_f32();
113        let mut output = vec![0.0f32; batch * channels * out_h * out_w];
114        
115        for b in 0..batch {
116            for c in 0..channels {
117                for oh in 0..out_h {
118                    for ow in 0..out_w {
119                        let mut sum = 0.0f32;
120                        let mut count = 0;
121                        
122                        for kh in 0..self.kernel_size.0 {
123                            for kw in 0..self.kernel_size.1 {
124                                let ih = (oh * self.stride.0 + kh) as i32 - self.padding.0 as i32;
125                                let iw = (ow * self.stride.1 + kw) as i32 - self.padding.1 as i32;
126                                
127                                if ih >= 0 && (ih as usize) < in_h && iw >= 0 && (iw as usize) < in_w {
128                                    let idx = b * channels * in_h * in_w 
129                                        + c * in_h * in_w 
130                                        + (ih as usize) * in_w 
131                                        + iw as usize;
132                                    sum += data[idx];
133                                    count += 1;
134                                } else if self.count_include_pad {
135                                    count += 1;
136                                }
137                            }
138                        }
139                        
140                        let out_idx = b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
141                        output[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
142                    }
143                }
144            }
145        }
146        
147        Tensor::from_slice(&output, &[batch, channels, out_h, out_w]).unwrap()
148    }
149
150    fn parameters(&self) -> Vec<Tensor> { vec![] }
151    fn train(&mut self) {}
152    fn eval(&mut self) {}
153    fn is_training(&self) -> bool { false }
154}
155
156/// Global Average Pooling 2D - reduces spatial dimensions to 1x1
157pub struct GlobalAvgPool2d;
158
159impl GlobalAvgPool2d {
160    pub fn new() -> Self {
161        GlobalAvgPool2d
162    }
163}
164
165impl Default for GlobalAvgPool2d {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl Module for GlobalAvgPool2d {
172    fn forward(&self, input: &Tensor) -> Tensor {
173        let dims = input.dims();
174        let batch = dims[0];
175        let channels = dims[1];
176        let spatial_size = dims[2] * dims[3];
177        
178        let data = input.data_f32();
179        let mut output = vec![0.0f32; batch * channels];
180        
181        for b in 0..batch {
182            for c in 0..channels {
183                let start = b * channels * spatial_size + c * spatial_size;
184                let sum: f32 = data[start..start + spatial_size].iter().sum();
185                output[b * channels + c] = sum / spatial_size as f32;
186            }
187        }
188        
189        Tensor::from_slice(&output, &[batch, channels, 1, 1]).unwrap()
190    }
191
192    fn parameters(&self) -> Vec<Tensor> { vec![] }
193    fn train(&mut self) {}
194    fn eval(&mut self) {}
195    fn is_training(&self) -> bool { false }
196}
197
198/// Global Max Pooling 2D
199pub struct GlobalMaxPool2d;
200
201impl GlobalMaxPool2d {
202    pub fn new() -> Self {
203        GlobalMaxPool2d
204    }
205}
206
207impl Default for GlobalMaxPool2d {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213impl Module for GlobalMaxPool2d {
214    fn forward(&self, input: &Tensor) -> Tensor {
215        let dims = input.dims();
216        let batch = dims[0];
217        let channels = dims[1];
218        let spatial_size = dims[2] * dims[3];
219        
220        let data = input.data_f32();
221        let mut output = vec![f32::NEG_INFINITY; batch * channels];
222        
223        for b in 0..batch {
224            for c in 0..channels {
225                let start = b * channels * spatial_size + c * spatial_size;
226                let max_val = data[start..start + spatial_size]
227                    .iter()
228                    .cloned()
229                    .fold(f32::NEG_INFINITY, f32::max);
230                output[b * channels + c] = max_val;
231            }
232        }
233        
234        Tensor::from_slice(&output, &[batch, channels, 1, 1]).unwrap()
235    }
236
237    fn parameters(&self) -> Vec<Tensor> { vec![] }
238    fn train(&mut self) {}
239    fn eval(&mut self) {}
240    fn is_training(&self) -> bool { false }
241}
242
243/// Adaptive Average Pooling 2D - pools to target output size
244pub struct AdaptiveAvgPool2d {
245    output_size: (usize, usize),
246}
247
248impl AdaptiveAvgPool2d {
249    pub fn new(output_size: (usize, usize)) -> Self {
250        AdaptiveAvgPool2d { output_size }
251    }
252
253    pub fn square(size: usize) -> Self {
254        Self::new((size, size))
255    }
256}
257
258impl Module for AdaptiveAvgPool2d {
259    fn forward(&self, input: &Tensor) -> Tensor {
260        let dims = input.dims();
261        let batch = dims[0];
262        let channels = dims[1];
263        let in_h = dims[2];
264        let in_w = dims[3];
265        let (out_h, out_w) = self.output_size;
266        
267        let data = input.data_f32();
268        let mut output = vec![0.0f32; batch * channels * out_h * out_w];
269        
270        for b in 0..batch {
271            for c in 0..channels {
272                for oh in 0..out_h {
273                    for ow in 0..out_w {
274                        // Compute input region for this output
275                        let ih_start = (oh * in_h) / out_h;
276                        let ih_end = ((oh + 1) * in_h) / out_h;
277                        let iw_start = (ow * in_w) / out_w;
278                        let iw_end = ((ow + 1) * in_w) / out_w;
279                        
280                        let mut sum = 0.0f32;
281                        let mut count = 0;
282                        
283                        for ih in ih_start..ih_end {
284                            for iw in iw_start..iw_end {
285                                let idx = b * channels * in_h * in_w + c * in_h * in_w + ih * in_w + iw;
286                                sum += data[idx];
287                                count += 1;
288                            }
289                        }
290                        
291                        let out_idx = b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
292                        output[out_idx] = if count > 0 { sum / count as f32 } else { 0.0 };
293                    }
294                }
295            }
296        }
297        
298        Tensor::from_slice(&output, &[batch, channels, out_h, out_w]).unwrap()
299    }
300
301    fn parameters(&self) -> Vec<Tensor> { vec![] }
302    fn train(&mut self) {}
303    fn eval(&mut self) {}
304    fn is_training(&self) -> bool { false }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_max_pool2d() {
313        let pool = MaxPool2d::new(2);
314        let input = Tensor::randn(&[1, 3, 8, 8]);
315        let output = pool.forward(&input);
316        
317        assert_eq!(output.dims(), &[1, 3, 4, 4]);
318    }
319
320    #[test]
321    fn test_avg_pool2d() {
322        let pool = AvgPool2d::new(2);
323        let input = Tensor::randn(&[1, 3, 8, 8]);
324        let output = pool.forward(&input);
325        
326        assert_eq!(output.dims(), &[1, 3, 4, 4]);
327    }
328
329    #[test]
330    fn test_global_avg_pool() {
331        let pool = GlobalAvgPool2d::new();
332        let input = Tensor::randn(&[2, 64, 7, 7]);
333        let output = pool.forward(&input);
334        
335        assert_eq!(output.dims(), &[2, 64, 1, 1]);
336    }
337
338    #[test]
339    fn test_adaptive_avg_pool() {
340        let pool = AdaptiveAvgPool2d::new((1, 1));
341        let input = Tensor::randn(&[2, 64, 7, 7]);
342        let output = pool.forward(&input);
343        
344        assert_eq!(output.dims(), &[2, 64, 1, 1]);
345    }
346}