numrs/ops/
batchnorm.rs

1//! Batch Normalization Operations
2
3use crate::array::Array;
4use anyhow::Result;
5
6/// Backward compatibility alias
7pub use batch_norm as batch_norm_1d;
8
9/// Batch Normalization 1D
10///
11/// Applies Batch Normalization over a 3D input (Batch, Channels, Length).
12/// 
13/// # Arguments
14/// * `input` - Input tensor [B, C, L]
15/// * `running_mean` - Running mean stats [C] (In-place update during training)
16/// * `running_var` - Running variance stats [C] (In-place update during training)
17/// * `weight` - Learnable gamma [C]
18/// * `bias` - Learnable beta [C]
19/// * `training` - If true, uses batch stats and updates running stats. If false, uses running stats.
20/// * `momentum` - Momentum for running stats check (default use 0.1)
21/// * `eps` - Epsilon for stability
22pub fn batch_norm(
23    input: &Array,
24    running_mean: &mut Array,
25    running_var: &mut Array,
26    weight: &Array,
27    bias: &Array,
28    training: bool,
29    momentum: f32,
30    eps: f32,
31) -> Result<Array> {
32    
33    // Dispatch
34    if training {
35        // TODO: Dispatch to GPU/SIMD if available
36        #[cfg(numrs_kernel_batchnorm_gpu)]
37        {
38             if crate::backend::webgpu::is_available_cached() {
39                 return crate::backend::webgpu::batchnorm::batch_norm_1d_training_webgpu(
40                     input, running_mean, running_var, weight, bias, momentum, eps
41                 );
42             }
43        }
44        
45        crate::backend::cpu::batchnorm::batch_norm_1d_training(
46            input, running_mean, running_var, weight, bias, momentum, eps
47        )
48    } else {
49        // GPU dispatch for inference
50        #[cfg(numrs_kernel_batchnorm_gpu)]
51        {
52             if crate::backend::webgpu::is_available_cached() {
53                 return crate::backend::webgpu::batchnorm::batch_norm_1d_inference_webgpu(
54                     input, running_mean, running_var, weight, bias, eps
55                 );
56             }
57        }
58        
59        crate::backend::cpu::batchnorm::batch_norm_1d_inference(
60            input, running_mean, running_var, weight, bias, eps
61        )
62    }
63}