numrs/ops/batchnorm.rs
1//! Batch Normalization Operations
2
3use crate::array::Array;
4use anyhow::Result;
5
6/// Backward compatibility alias
7pub use batch_norm as batch_norm_1d;
8
9/// Batch Normalization 1D
10///
11/// Applies Batch Normalization over a 3D input (Batch, Channels, Length).
12///
13/// # Arguments
14/// * `input` - Input tensor [B, C, L]
15/// * `running_mean` - Running mean stats [C] (In-place update during training)
16/// * `running_var` - Running variance stats [C] (In-place update during training)
17/// * `weight` - Learnable gamma [C]
18/// * `bias` - Learnable beta [C]
19/// * `training` - If true, uses batch stats and updates running stats. If false, uses running stats.
20/// * `momentum` - Momentum for running stats check (default use 0.1)
21/// * `eps` - Epsilon for stability
22pub fn batch_norm(
23 input: &Array,
24 running_mean: &mut Array,
25 running_var: &mut Array,
26 weight: &Array,
27 bias: &Array,
28 training: bool,
29 momentum: f32,
30 eps: f32,
31) -> Result<Array> {
32
33 // Dispatch
34 if training {
35 // TODO: Dispatch to GPU/SIMD if available
36 #[cfg(numrs_kernel_batchnorm_gpu)]
37 {
38 if crate::backend::webgpu::is_available_cached() {
39 return crate::backend::webgpu::batchnorm::batch_norm_1d_training_webgpu(
40 input, running_mean, running_var, weight, bias, momentum, eps
41 );
42 }
43 }
44
45 crate::backend::cpu::batchnorm::batch_norm_1d_training(
46 input, running_mean, running_var, weight, bias, momentum, eps
47 )
48 } else {
49 // GPU dispatch for inference
50 #[cfg(numrs_kernel_batchnorm_gpu)]
51 {
52 if crate::backend::webgpu::is_available_cached() {
53 return crate::backend::webgpu::batchnorm::batch_norm_1d_inference_webgpu(
54 input, running_mean, running_var, weight, bias, eps
55 );
56 }
57 }
58
59 crate::backend::cpu::batchnorm::batch_norm_1d_inference(
60 input, running_mean, running_var, weight, bias, eps
61 )
62 }
63}