numrs/backend/cpu/
batchnorm.rs1
2use crate::array::Array;
3use anyhow::{Result, anyhow};
4
5pub fn batch_norm_1d_training(
11 input: &Array,
12 running_mean: &mut Array,
13 running_var: &mut Array,
14 weight: &Array,
15 bias: &Array,
16 momentum: f32,
17 eps: f32
18) -> Result<Array> {
19 if input.shape.len() != 3 {
20 return Err(anyhow!("BatchNorm1D input must be 3D [Batch, Channels, Length]"));
21 }
22
23 let batch_size = input.shape[0];
24 let channels = input.shape[1];
25 let length = input.shape[2];
26
27 if weight.shape[0] != channels || bias.shape[0] != channels {
29 return Err(anyhow!("Weight/Bias size mismatch channels"));
30 }
31
32 let mut batch_mean = vec![0.0; channels];
35 let mut batch_var = vec![0.0; channels];
36 let num_elements = (batch_size * length) as f32;
37
38 let input_data = &input.data;
39
40 for c in 0..channels {
42 let mut sum = 0.0;
43 for b in 0..batch_size {
44 for l in 0..length {
45 let idx = b * (channels * length) + c * length + l;
46 sum += input_data[idx];
47 }
48 }
49 batch_mean[c] = sum / num_elements;
50 }
51
52 for c in 0..channels {
54 let mut sum_sq_diff = 0.0;
55 let mean = batch_mean[c];
56 for b in 0..batch_size {
57 for l in 0..length {
58 let idx = b * (channels * length) + c * length + l;
59 let diff = input_data[idx] - mean;
60 sum_sq_diff += diff * diff;
61 }
62 }
63 batch_var[c] = sum_sq_diff / num_elements;
64 }
65
66 for c in 0..channels {
71 running_mean.data[c] = (1.0 - momentum) * running_mean.data[c] + momentum * batch_mean[c];
72 let unbiased_var = batch_var[c] * num_elements / (num_elements - 1.0);
75 running_var.data[c] = (1.0 - momentum) * running_var.data[c] + momentum * unbiased_var;
76 }
77
78 let mut output_data = vec![0.0; input.data.len()];
81
82 let weight_data = &weight.data;
83 let bias_data = &bias.data;
84
85 for c in 0..channels {
86 let mean = batch_mean[c];
87 let var = batch_var[c];
88 let inv_std = 1.0 / (var + eps).sqrt();
89 let w = weight_data[c];
90 let b = bias_data[c];
91
92 for i in 0..batch_size {
93 for l in 0..length {
94 let idx = i * (channels * length) + c * length + l;
95 let val = input_data[idx];
96 let normalized = (val - mean) * inv_std;
97 output_data[idx] = normalized * w + b;
98 }
99 }
100 }
101
102 Ok(Array::new(input.shape.clone(), output_data))
103}
104
105pub fn batch_norm_1d_inference(
108 input: &Array,
109 running_mean: &Array,
110 running_var: &Array,
111 weight: &Array,
112 bias: &Array,
113 eps: f32
114) -> Result<Array> {
115 if input.shape.len() != 3 {
116 return Err(anyhow!("BatchNorm1D input must be 3D"));
117 }
118 let channels = input.shape[1];
119
120 let mut output_data = vec![0.0; input.data.len()];
121
122 let input_data = &input.data;
123 let mean_data = &running_mean.data;
124 let var_data = &running_var.data;
125 let weight_data = &weight.data;
126 let bias_data = &bias.data;
127
128 let batch_size = input.shape[0];
129 let length = input.shape[2];
130
131 for c in 0..channels {
132 let mean = mean_data[c];
133 let var = var_data[c];
134 let inv_std = 1.0 / (var + eps).sqrt();
135 let w = weight_data[c];
136 let b = bias_data[c];
137
138 for i in 0..batch_size {
139 for l in 0..length {
140 let idx = i * (channels * length) + c * length + l;
141 let val = input_data[idx];
142 let normalized = (val - mean) * inv_std;
143 output_data[idx] = normalized * w + b;
144 }
145 }
146 }
147
148 Ok(Array::new(input.shape.clone(), output_data))
149}
150
151pub fn batchnorm_backward_naive(
154 grad_output: &Array,
155 input: &Array,
156 weight: &Array,
157 bias: &Array,
158 eps: f32,
159) -> Result<(Array, Array, Array)> {
160 if input.shape.len() != 3 {
161 return Err(anyhow!("BatchNorm1D input must be 3D"));
162 }
163
164 let batch_size = input.shape[0];
165 let channels = input.shape[1];
166 let length = input.shape[2];
167 let num_elements = (batch_size * length) as f32;
168
169 let mut grad_input_data = vec![0.0; input.data.len()];
171 let mut grad_weight_data = vec![0.0; channels];
172 let mut grad_bias_data = vec![0.0; channels];
173
174 let grad_out_data = &grad_output.data;
175 let input_data = &input.data;
176 let weight_data = &weight.data; for c in 0..channels {
180 let mut sum = 0.0;
182 for b in 0..batch_size {
183 for l in 0..length {
184 let idx = b * (channels * length) + c * length + l;
185 sum += input_data[idx];
186 }
187 }
188 let mean = sum / num_elements;
189
190 let mut sum_sq_diff = 0.0;
191 for b in 0..batch_size {
192 for l in 0..length {
193 let idx = b * (channels * length) + c * length + l;
194 let diff = input_data[idx] - mean;
195 sum_sq_diff += diff * diff;
196 }
197 }
198 let var = sum_sq_diff / num_elements;
199 let std = (var + eps).sqrt();
200 let inv_std = 1.0 / std;
201
202 let mut sum_grad_out = 0.0;
204 let mut sum_grad_out_x_hat = 0.0;
205
206 for b in 0..batch_size {
207 for l in 0..length {
208 let idx = b * (channels * length) + c * length + l;
209 let val = input_data[idx];
210 let go = grad_out_data[idx];
211 let x_hat = (val - mean) * inv_std;
212
213 sum_grad_out += go;
214 sum_grad_out_x_hat += go * x_hat;
215 }
216 }
217
218 grad_bias_data[c] = sum_grad_out;
220 grad_weight_data[c] = sum_grad_out_x_hat; let gamma = weight_data[c];
237 let factor = gamma / (num_elements * std);
238
239 for b in 0..batch_size {
240 for l in 0..length {
241 let idx = b * (channels * length) + c * length + l;
242 let val = input_data[idx];
243 let go = grad_out_data[idx];
244 let x_hat = (val - mean) * inv_std;
245
246 let num = num_elements * go - sum_grad_out - x_hat * sum_grad_out_x_hat;
247 grad_input_data[idx] = factor * num;
248 }
249 }
250 }
251
252 Ok((
253 Array::new(input.shape.clone(), grad_input_data),
254 Array::new(weight.shape.clone(), grad_weight_data),
255 Array::new(bias.shape.clone(), grad_bias_data)
256 ))
257}