use svod_tensor::Tensor;
use crate::state::StateDict;
use super::error::{Error, Result};
const BN_EPS: f32 = 1e-5;
pub fn fold_batchnorm(mut sd: StateDict) -> Result<StateDict> {
sd.retain(|k, _| !k.ends_with("num_batches_tracked"));
let var_keys: Vec<String> = sd.keys().filter(|k| k.ends_with("running_var")).cloned().collect();
for key in var_keys {
let var = sd.remove(&key).expect("key just enumerated");
let var_f32 = var.as_vec::<f32>().map_err(|e| Error::Tensor { source: Box::new(e) })?;
let invstd: Vec<f32> = var_f32.iter().map(|&v| 1.0 / (v + BN_EPS).sqrt()).collect();
sd.insert(key, Tensor::from_slice(&invstd));
}
Ok(sd)
}