use crate::error::VadError;
pub(crate) struct CmvnStats {
pub means: Vec<f32>,
pub inv_stds: Vec<f32>,
}
impl CmvnStats {
pub fn from_kaldi_binary(data: &[u8]) -> Result<Self, VadError> {
if data.len() < 15 {
return Err(VadError::BackendError("CMVN file too small".into()));
}
let b_pos = data
.iter()
.position(|&b| b == b'B')
.ok_or_else(|| VadError::BackendError("CMVN: no binary marker 'B' found".into()))?;
let mut pos = b_pos + 1;
if pos >= data.len() {
return Err(VadError::BackendError("CMVN: truncated after 'B'".into()));
}
let type_marker = data[pos];
pos += 1;
let is_double = match type_marker {
b'F' => false,
b'D' => true,
_ => {
return Err(VadError::BackendError(format!(
"CMVN: unexpected type marker '{}'",
type_marker as char
)))
}
};
if pos >= data.len() || data[pos] != b'M' {
return Err(VadError::BackendError("CMVN: expected 'M' marker".into()));
}
pos += 1;
if pos < data.len() && data[pos] == b' ' {
pos += 1;
}
if pos + 5 > data.len() {
return Err(VadError::BackendError("CMVN: truncated at rows".into()));
}
pos += 1; let rows = i32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
if rows != 2 {
return Err(VadError::BackendError(format!(
"CMVN: expected 2 rows, got {rows}"
)));
}
if pos + 5 > data.len() {
return Err(VadError::BackendError("CMVN: truncated at cols".into()));
}
pos += 1; let cols = i32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
if cols < 2 {
return Err(VadError::BackendError(format!(
"CMVN: expected cols >= 2, got {cols}"
)));
}
let dim = (cols - 1) as usize;
let elem_size = if is_double { 8 } else { 4 };
let total_elems = 2 * cols as usize;
let data_size = total_elems * elem_size;
if pos + data_size > data.len() {
return Err(VadError::BackendError(format!(
"CMVN: file too small for {total_elems} elements"
)));
}
let read_f64 = |offset: usize| -> f64 {
if is_double {
f64::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
])
} else {
f32::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]) as f64
}
};
let row0_start = pos;
let row1_start = pos + cols as usize * elem_size;
let count = read_f64(row0_start + dim * elem_size);
if count < 1.0 {
return Err(VadError::BackendError(format!(
"CMVN: count must be >= 1, got {count}"
)));
}
let floor: f64 = 1e-20;
let mut means = Vec::with_capacity(dim);
let mut inv_stds = Vec::with_capacity(dim);
for d in 0..dim {
let sum = read_f64(row0_start + d * elem_size);
let sum_sq = read_f64(row1_start + d * elem_size);
let mean = sum / count;
let variance = (sum_sq / count) - mean * mean;
let variance = if variance < floor { floor } else { variance };
let istd = 1.0 / variance.sqrt();
means.push(mean as f32);
inv_stds.push(istd as f32);
}
Ok(Self { means, inv_stds })
}
#[inline]
pub fn normalize(&self, features: &mut [f32]) {
for (i, feat) in features.iter_mut().enumerate() {
*feat = (*feat - self.means[i]) * self.inv_stds[i];
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const CMVN_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/firered_cmvn.ark"));
#[test]
fn parse_cmvn_dimensions() {
let stats = CmvnStats::from_kaldi_binary(CMVN_BYTES).unwrap();
assert_eq!(stats.means.len(), 80);
assert_eq!(stats.inv_stds.len(), 80);
}
#[test]
fn parse_cmvn_values_match_python() {
let stats = CmvnStats::from_kaldi_binary(CMVN_BYTES).unwrap();
let ref_means: [f64; 5] = [
10.42295174919564,
10.862097411631494,
11.764544378124809,
12.490164701573908,
13.25983008289003,
];
let ref_inv_stds: [f64; 5] = [
0.2494980879825924,
0.23563235243542163,
0.23145152525802104,
0.2332233926481505,
0.23182660283718737,
];
for i in 0..5 {
let mean_diff = (stats.means[i] as f64 - ref_means[i]).abs();
assert!(
mean_diff < 1e-4,
"mean[{i}]: rust={}, python={}, diff={mean_diff}",
stats.means[i],
ref_means[i]
);
let istd_diff = (stats.inv_stds[i] as f64 - ref_inv_stds[i]).abs();
assert!(
istd_diff < 1e-4,
"inv_std[{i}]: rust={}, python={}, diff={istd_diff}",
stats.inv_stds[i],
ref_inv_stds[i]
);
}
}
#[test]
fn normalize_applies_correctly() {
let stats = CmvnStats {
means: vec![1.0, 2.0],
inv_stds: vec![0.5, 0.25],
};
let mut features = vec![3.0, 6.0];
stats.normalize(&mut features);
assert!((features[0] - 1.0).abs() < 1e-6);
assert!((features[1] - 1.0).abs() < 1e-6);
}
#[test]
fn parse_invalid_data() {
assert!(CmvnStats::from_kaldi_binary(b"").is_err());
assert!(CmvnStats::from_kaldi_binary(b"too short").is_err());
}
}