numrs/backend/cpu/
batchnorm.rs

1
2use crate::array::Array;
3use anyhow::{Result, anyhow};
4
5/// BatchNorm1D Training Implementation (CPU Naive)
6/// 
7/// Input: [Batch, Channels, Length]
8/// Weight/Bias/RunningStats: [Channels]
9/// Output: same as Input
10pub fn batch_norm_1d_training(
11    input: &Array,
12    running_mean: &mut Array,
13    running_var: &mut Array,
14    weight: &Array,
15    bias: &Array,
16    momentum: f32,
17    eps: f32
18) -> Result<Array> {
19    if input.shape.len() != 3 {
20        return Err(anyhow!("BatchNorm1D input must be 3D [Batch, Channels, Length]"));
21    }
22    
23    let batch_size = input.shape[0];
24    let channels = input.shape[1];
25    let length = input.shape[2];
26    
27    // Check shapes
28    if weight.shape[0] != channels || bias.shape[0] != channels {
29        return Err(anyhow!("Weight/Bias size mismatch channels"));
30    }
31    
32    // 1. Calculate Batch Statistics [Channels]
33    // Mean(c) = sum(x[b,c,l]) / (B*L)
34    let mut batch_mean = vec![0.0; channels];
35    let mut batch_var = vec![0.0; channels];
36    let num_elements = (batch_size * length) as f32;
37    
38    let input_data = &input.data;
39    
40    // Pass 1: Mean
41    for c in 0..channels {
42        let mut sum = 0.0;
43        for b in 0..batch_size {
44            for l in 0..length {
45                let idx = b * (channels * length) + c * length + l;
46                sum += input_data[idx];
47            }
48        }
49        batch_mean[c] = sum / num_elements;
50    }
51    
52    // Pass 2: Variance
53    for c in 0..channels {
54        let mut sum_sq_diff = 0.0;
55        let mean = batch_mean[c];
56        for b in 0..batch_size {
57            for l in 0..length {
58                let idx = b * (channels * length) + c * length + l;
59                let diff = input_data[idx] - mean;
60                sum_sq_diff += diff * diff;
61            }
62        }
63        batch_var[c] = sum_sq_diff / num_elements;
64    }
65    
66    // 2. Update Running Stats (Momentum)
67    // running = (1 - momentum) * running + momentum * batch
68    // Note: PyTorch default momentum is 0.1, meaning 0.1 of new data.
69    // Here we assume 'momentum' arg adheres to library convention.
70    for c in 0..channels {
71        running_mean.data[c] = (1.0 - momentum) * running_mean.data[c] + momentum * batch_mean[c];
72        // Unbiased var for running stats typically, but let's stick to simple biased update for now
73        // or check if Bessel correction is needed for running_var update (usually yes)
74        let unbiased_var = batch_var[c] * num_elements / (num_elements - 1.0);  
75        running_var.data[c] = (1.0 - momentum) * running_var.data[c] + momentum * unbiased_var;
76    }
77    
78    // 3. Normalize and Scale/Shift
79    // y = (x - mean) / sqrt(var + eps) * weight + bias
80    let mut output_data = vec![0.0; input.data.len()];
81    
82    let weight_data = &weight.data;
83    let bias_data = &bias.data;
84    
85    for c in 0..channels {
86        let mean = batch_mean[c];
87        let var = batch_var[c];
88        let inv_std = 1.0 / (var + eps).sqrt();
89        let w = weight_data[c];
90        let b = bias_data[c];
91        
92        for i in 0..batch_size {
93            for l in 0..length {
94                let idx = i * (channels * length) + c * length + l;
95                let val = input_data[idx];
96                let normalized = (val - mean) * inv_std;
97                output_data[idx] = normalized * w + b;
98            }
99        }
100    }
101    
102    Ok(Array::new(input.shape.clone(), output_data))
103}
104
105/// BatchNorm1D Inference Implementation (CPU Naive)
106/// Uses running stats instead of batch stats.
107pub fn batch_norm_1d_inference(
108    input: &Array,
109    running_mean: &Array,
110    running_var: &Array,
111    weight: &Array,
112    bias: &Array,
113    eps: f32
114) -> Result<Array> {
115    if input.shape.len() != 3 {
116        return Err(anyhow!("BatchNorm1D input must be 3D"));
117    }
118    let channels = input.shape[1];
119    
120    let mut output_data = vec![0.0; input.data.len()];
121    
122    let input_data = &input.data;
123    let mean_data = &running_mean.data;
124    let var_data = &running_var.data;
125    let weight_data = &weight.data;
126    let bias_data = &bias.data;
127    
128    let batch_size = input.shape[0];
129    let length = input.shape[2];
130
131    for c in 0..channels {
132        let mean = mean_data[c];
133        let var = var_data[c];
134        let inv_std = 1.0 / (var + eps).sqrt();
135        let w = weight_data[c];
136        let b = bias_data[c];
137        
138        for i in 0..batch_size {
139            for l in 0..length {
140                let idx = i * (channels * length) + c * length + l;
141                let val = input_data[idx];
142                let normalized = (val - mean) * inv_std;
143                output_data[idx] = normalized * w + b;
144            }
145        }
146    }
147    
148    Ok(Array::new(input.shape.clone(), output_data))
149}
150
151/// BatchNorm Backward Naive (Training Mode)
152/// Returns (grad_input, grad_weight, grad_bias)
153pub fn batchnorm_backward_naive(
154    grad_output: &Array,
155    input: &Array,
156    weight: &Array,
157    bias: &Array,
158    eps: f32,
159) -> Result<(Array, Array, Array)> {
160     if input.shape.len() != 3 {
161        return Err(anyhow!("BatchNorm1D input must be 3D"));
162    }
163    
164    let batch_size = input.shape[0];
165    let channels = input.shape[1];
166    let length = input.shape[2];
167    let num_elements = (batch_size * length) as f32;
168    
169    // Grads
170    let mut grad_input_data = vec![0.0; input.data.len()];
171    let mut grad_weight_data = vec![0.0; channels];
172    let mut grad_bias_data = vec![0.0; channels];
173    
174    let grad_out_data = &grad_output.data;
175    let input_data = &input.data;
176    let weight_data = &weight.data; // Gamma
177    
178    // Process per channel
179    for c in 0..channels {
180        // 1. Calculate Mean and Variance (of Input)
181        let mut sum = 0.0;
182        for b in 0..batch_size {
183            for l in 0..length {
184                let idx = b * (channels * length) + c * length + l;
185                sum += input_data[idx];
186            }
187        }
188        let mean = sum / num_elements;
189        
190        let mut sum_sq_diff = 0.0;
191        for b in 0..batch_size {
192            for l in 0..length {
193                 let idx = b * (channels * length) + c * length + l;
194                 let diff = input_data[idx] - mean;
195                 sum_sq_diff += diff * diff;
196            }
197        }
198        let var = sum_sq_diff / num_elements;
199        let std = (var + eps).sqrt();
200        let inv_std = 1.0 / std;
201        
202        // 2. Calculate intermediates
203        let mut sum_grad_out = 0.0;
204        let mut sum_grad_out_x_hat = 0.0;
205        
206        for b in 0..batch_size {
207            for l in 0..length {
208                let idx = b * (channels * length) + c * length + l;
209                let val = input_data[idx];
210                let go = grad_out_data[idx];
211                let x_hat = (val - mean) * inv_std;
212                
213                sum_grad_out += go;
214                sum_grad_out_x_hat += go * x_hat;
215            }
216        }
217        
218        // 3. Gradients for Params
219        grad_bias_data[c] = sum_grad_out;
220        grad_weight_data[c] = sum_grad_out_x_hat; // Before multiplying by gamma? No, dL/dGamma = sum(dL/dy * x_hat) = sum(grad_out * x_hat)
221        
222        // 4. Gradient for Input
223        // dL/dx = (gamma / (N * std)) * (N * dL/dx_hat - sum(dL/dx_hat) - x_hat * sum(dL/dx_hat * x_hat))
224        // But dL/dx_hat = grad_out * gamma? No dL/dy = grad_out. y = gamma * x_hat + beta.
225        // So dL/dx_hat = grad_out * gamma.
226        
227        // Let's use clean formula:
228        // dx_hat = grad_out * gamma
229        // dvar = sum(dx_hat * (x - mean) * -0.5 * (var + eps)^-1.5)
230        // dmean = sum(dx_hat * -inv_std) + dvar * sum(-2 * (x - mean)) / N
231        // dx = dx_hat * inv_std + dvar * 2 * (x - mean) / N + dmean / N
232        
233        // Optimized formula:
234        // dx = (gamma / (N * std)) * (N * grad_out - sum_grad_out - x_hat * sum_grad_out_x_hat)
235        
236        let gamma = weight_data[c];
237        let factor = gamma / (num_elements * std);
238        
239        for b in 0..batch_size {
240            for l in 0..length {
241                 let idx = b * (channels * length) + c * length + l;
242                 let val = input_data[idx];
243                 let go = grad_out_data[idx];
244                 let x_hat = (val - mean) * inv_std;
245                 
246                 let num = num_elements * go - sum_grad_out - x_hat * sum_grad_out_x_hat;
247                 grad_input_data[idx] = factor * num;
248            }
249        }
250    }
251    
252    Ok((
253        Array::new(input.shape.clone(), grad_input_data),
254        Array::new(weight.shape.clone(), grad_weight_data),
255        Array::new(bias.shape.clone(), grad_bias_data)
256    ))
257}