ghostflow_nn/
norm.rs

1//! Normalization layers
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5
6/// Batch Normalization for 1D inputs (N, C) or (N, C, L)
7pub struct BatchNorm1d {
8    #[allow(dead_code)]
9    num_features: usize,
10    gamma: Tensor,  // scale
11    beta: Tensor,   // shift
12    running_mean: Tensor,
13    running_var: Tensor,
14    eps: f32,
15    #[allow(dead_code)]
16    momentum: f32,
17    training: bool,
18}
19
20impl BatchNorm1d {
21    pub fn new(num_features: usize) -> Self {
22        Self::with_params(num_features, 1e-5, 0.1)
23    }
24
25    pub fn with_params(num_features: usize, eps: f32, momentum: f32) -> Self {
26        BatchNorm1d {
27            num_features,
28            gamma: Tensor::ones(&[num_features]),
29            beta: Tensor::zeros(&[num_features]),
30            running_mean: Tensor::zeros(&[num_features]),
31            running_var: Tensor::ones(&[num_features]),
32            eps,
33            momentum,
34            training: true,
35        }
36    }
37}
38
39impl Module for BatchNorm1d {
40    fn forward(&self, input: &Tensor) -> Tensor {
41        let dims = input.dims();
42        let data = input.data_f32();
43        let gamma = self.gamma.data_f32();
44        let beta = self.beta.data_f32();
45        
46        let batch_size = dims[0];
47        let channels = dims[1];
48        let spatial_size = if dims.len() > 2 { dims[2] } else { 1 };
49        
50        let (mean, var) = if self.training {
51            // Compute batch statistics
52            let mut mean = vec![0.0f32; channels];
53            let mut var = vec![0.0f32; channels];
54            let n = (batch_size * spatial_size) as f32;
55            
56            for c in 0..channels {
57                let mut sum = 0.0f32;
58                for b in 0..batch_size {
59                    for s in 0..spatial_size {
60                        let idx = b * channels * spatial_size + c * spatial_size + s;
61                        sum += data[idx];
62                    }
63                }
64                mean[c] = sum / n;
65                
66                let mut var_sum = 0.0f32;
67                for b in 0..batch_size {
68                    for s in 0..spatial_size {
69                        let idx = b * channels * spatial_size + c * spatial_size + s;
70                        var_sum += (data[idx] - mean[c]).powi(2);
71                    }
72                }
73                var[c] = var_sum / n;
74            }
75            
76            (mean, var)
77        } else {
78            (self.running_mean.data_f32(), self.running_var.data_f32())
79        };
80        
81        // Normalize
82        let mut output = vec![0.0f32; data.len()];
83        
84        for b in 0..batch_size {
85            for c in 0..channels {
86                let std = (var[c] + self.eps).sqrt();
87                for s in 0..spatial_size {
88                    let idx = b * channels * spatial_size + c * spatial_size + s;
89                    output[idx] = gamma[c] * (data[idx] - mean[c]) / std + beta[c];
90                }
91            }
92        }
93        
94        Tensor::from_slice(&output, dims).unwrap()
95    }
96
97    fn parameters(&self) -> Vec<Tensor> {
98        vec![self.gamma.clone(), self.beta.clone()]
99    }
100
101    fn train(&mut self) { self.training = true; }
102    fn eval(&mut self) { self.training = false; }
103    fn is_training(&self) -> bool { self.training }
104}
105
106/// Batch Normalization for 2D inputs (N, C, H, W)
107pub struct BatchNorm2d {
108    #[allow(dead_code)]
109    num_features: usize,
110    gamma: Tensor,
111    beta: Tensor,
112    running_mean: Tensor,
113    running_var: Tensor,
114    eps: f32,
115    #[allow(dead_code)]
116    momentum: f32,
117    training: bool,
118}
119
120impl BatchNorm2d {
121    pub fn new(num_features: usize) -> Self {
122        Self::with_params(num_features, 1e-5, 0.1)
123    }
124
125    pub fn with_params(num_features: usize, eps: f32, momentum: f32) -> Self {
126        BatchNorm2d {
127            num_features,
128            gamma: Tensor::ones(&[num_features]),
129            beta: Tensor::zeros(&[num_features]),
130            running_mean: Tensor::zeros(&[num_features]),
131            running_var: Tensor::ones(&[num_features]),
132            eps,
133            momentum,
134            training: true,
135        }
136    }
137}
138
139impl Module for BatchNorm2d {
140    fn forward(&self, input: &Tensor) -> Tensor {
141        let dims = input.dims();
142        let data = input.data_f32();
143        let gamma = self.gamma.data_f32();
144        let beta = self.beta.data_f32();
145        
146        let batch_size = dims[0];
147        let channels = dims[1];
148        let height = dims[2];
149        let width = dims[3];
150        let spatial_size = height * width;
151        
152        let (mean, var) = if self.training {
153            let mut mean = vec![0.0f32; channels];
154            let mut var = vec![0.0f32; channels];
155            let n = (batch_size * spatial_size) as f32;
156            
157            for c in 0..channels {
158                let mut sum = 0.0f32;
159                for b in 0..batch_size {
160                    for h in 0..height {
161                        for w in 0..width {
162                            let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
163                            sum += data[idx];
164                        }
165                    }
166                }
167                mean[c] = sum / n;
168                
169                let mut var_sum = 0.0f32;
170                for b in 0..batch_size {
171                    for h in 0..height {
172                        for w in 0..width {
173                            let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
174                            var_sum += (data[idx] - mean[c]).powi(2);
175                        }
176                    }
177                }
178                var[c] = var_sum / n;
179            }
180            
181            (mean, var)
182        } else {
183            (self.running_mean.data_f32(), self.running_var.data_f32())
184        };
185        
186        let mut output = vec![0.0f32; data.len()];
187        
188        for b in 0..batch_size {
189            for c in 0..channels {
190                let std = (var[c] + self.eps).sqrt();
191                for h in 0..height {
192                    for w in 0..width {
193                        let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
194                        output[idx] = gamma[c] * (data[idx] - mean[c]) / std + beta[c];
195                    }
196                }
197            }
198        }
199        
200        Tensor::from_slice(&output, dims).unwrap()
201    }
202
203    fn parameters(&self) -> Vec<Tensor> {
204        vec![self.gamma.clone(), self.beta.clone()]
205    }
206
207    fn train(&mut self) { self.training = true; }
208    fn eval(&mut self) { self.training = false; }
209    fn is_training(&self) -> bool { self.training }
210}
211
212/// Layer Normalization
213pub struct LayerNorm {
214    normalized_shape: Vec<usize>,
215    gamma: Tensor,
216    beta: Tensor,
217    eps: f32,
218    training: bool,
219}
220
221impl LayerNorm {
222    pub fn new(normalized_shape: &[usize]) -> Self {
223        Self::with_eps(normalized_shape, 1e-5)
224    }
225
226    pub fn with_eps(normalized_shape: &[usize], eps: f32) -> Self {
227        let numel: usize = normalized_shape.iter().product();
228        
229        LayerNorm {
230            normalized_shape: normalized_shape.to_vec(),
231            gamma: Tensor::ones(&[numel]),
232            beta: Tensor::zeros(&[numel]),
233            eps,
234            training: true,
235        }
236    }
237}
238
239impl Module for LayerNorm {
240    fn forward(&self, input: &Tensor) -> Tensor {
241        let dims = input.dims();
242        let data = input.data_f32();
243        let gamma = self.gamma.data_f32();
244        let beta = self.beta.data_f32();
245        
246        let norm_size: usize = self.normalized_shape.iter().product();
247        let batch_size = data.len() / norm_size;
248        
249        let mut output = vec![0.0f32; data.len()];
250        
251        for b in 0..batch_size {
252            let start = b * norm_size;
253            let end = start + norm_size;
254            let slice = &data[start..end];
255            
256            // Compute mean
257            let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
258            
259            // Compute variance
260            let var: f32 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
261            let std = (var + self.eps).sqrt();
262            
263            // Normalize
264            for i in 0..norm_size {
265                output[start + i] = gamma[i] * (slice[i] - mean) / std + beta[i];
266            }
267        }
268        
269        Tensor::from_slice(&output, dims).unwrap()
270    }
271
272    fn parameters(&self) -> Vec<Tensor> {
273        vec![self.gamma.clone(), self.beta.clone()]
274    }
275
276    fn train(&mut self) { self.training = true; }
277    fn eval(&mut self) { self.training = false; }
278    fn is_training(&self) -> bool { self.training }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_batchnorm2d() {
287        let bn = BatchNorm2d::new(16);
288        let input = Tensor::randn(&[2, 16, 8, 8]);
289        let output = bn.forward(&input);
290        
291        assert_eq!(output.dims(), input.dims());
292    }
293
294    #[test]
295    fn test_layernorm() {
296        let ln = LayerNorm::new(&[64]);
297        let input = Tensor::randn(&[2, 10, 64]);
298        let output = ln.forward(&input);
299        
300        assert_eq!(output.dims(), input.dims());
301    }
302}