#[derive(Clone, Debug)]
pub struct KWeightedFilter {
b0_pre: f32,
b1_pre: f32,
b2_pre: f32,
a1_pre: f32,
a2_pre: f32,
b0_hp: f32,
b1_hp: f32,
b2_hp: f32,
a1_hp: f32,
a2_hp: f32,
z1_pre: f32,
z2_pre: f32,
z1_hp: f32,
z2_hp: f32,
}
impl KWeightedFilter {
pub fn new(sample_rate: u32) -> Self {
let (b0_pre, b1_pre, b2_pre, a1_pre, a2_pre, b0_hp, b1_hp, b2_hp, a1_hp, a2_hp) =
compute_k_weight_coeffs(f64::from(sample_rate));
Self {
b0_pre,
b1_pre,
b2_pre,
a1_pre,
a2_pre,
b0_hp,
b1_hp,
b2_hp,
a1_hp,
a2_hp,
z1_pre: 0.0,
z2_pre: 0.0,
z1_hp: 0.0,
z2_hp: 0.0,
}
}
#[inline]
pub fn process_sample(&mut self, x: f32) -> f32 {
let y1 = self.b0_pre * x + self.z1_pre;
self.z1_pre = self.b1_pre * x - self.a1_pre * y1 + self.z2_pre;
self.z2_pre = self.b2_pre * x - self.a2_pre * y1;
let y2 = self.b0_hp * y1 + self.z1_hp;
self.z1_hp = self.b1_hp * y1 - self.a1_hp * y2 + self.z2_hp;
self.z2_hp = self.b2_hp * y1 - self.a2_hp * y2;
y2
}
pub fn process_block(&mut self, input: &[f32], output: &mut [f32]) {
assert_eq!(
input.len(),
output.len(),
"KWeightedFilter::process_block: input and output length mismatch"
);
for (x, y) in input.iter().zip(output.iter_mut()) {
*y = self.process_sample(*x);
}
}
pub fn reset(&mut self) {
self.z1_pre = 0.0;
self.z2_pre = 0.0;
self.z1_hp = 0.0;
self.z2_hp = 0.0;
}
}
fn compute_k_weight_coeffs(fs: f64) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
let db_gain = 3.999_843_853_973_347_f64;
let f0_shelf = 1_681.974_450_955_533_f64;
let q_shelf = 0.707_213_195_806_047_6_f64;
let k_s = (std::f64::consts::PI * f0_shelf / fs).tan();
let vh = 10_f64.powf(db_gain / 20.0);
let vb = vh.powf(0.5);
let denom_s = 1.0 + k_s / q_shelf + k_s * k_s;
let b0_pre = ((vh + vb * k_s / q_shelf + k_s * k_s) / denom_s) as f32;
let b1_pre = (2.0 * (k_s * k_s - vh) / denom_s) as f32;
let b2_pre = ((vh - vb * k_s / q_shelf + k_s * k_s) / denom_s) as f32;
let a1_pre = (2.0 * (k_s * k_s - 1.0) / denom_s) as f32;
let a2_pre = ((1.0 - k_s / q_shelf + k_s * k_s) / denom_s) as f32;
let f1_hp = 38.134_566_580_756_27_f64;
let q_hp = 0.500_316_983_843_589_1_f64;
let k_h = (std::f64::consts::PI * f1_hp / fs).tan();
let denom_h = 1.0 + k_h / q_hp + k_h * k_h;
let b0_hp = (1.0 / denom_h) as f32;
let b1_hp = (-2.0 / denom_h) as f32;
let b2_hp = (1.0 / denom_h) as f32;
let a1_hp = (2.0 * (k_h * k_h - 1.0) / denom_h) as f32;
let a2_hp = ((1.0 - k_h / q_hp + k_h * k_h) / denom_h) as f32;
(
b0_pre, b1_pre, b2_pre, a1_pre, a2_pre, b0_hp, b1_hp, b2_hp, a1_hp, a2_hp,
)
}
pub fn k_weight_4ch_simd(
channels: [&[f32]; 4],
filters: &mut [KWeightedFilter; 4],
) -> [Vec<f32>; 4] {
let len = channels[0].len();
assert_eq!(channels[1].len(), len, "channel 1 length mismatch");
assert_eq!(channels[2].len(), len, "channel 2 length mismatch");
assert_eq!(channels[3].len(), len, "channel 3 length mismatch");
let mut out0 = vec![0.0f32; len];
let mut out1 = vec![0.0f32; len];
let mut out2 = vec![0.0f32; len];
let mut out3 = vec![0.0f32; len];
let b0p = filters[0].b0_pre;
let b1p = filters[0].b1_pre;
let b2p = filters[0].b2_pre;
let a1p = filters[0].a1_pre;
let a2p = filters[0].a2_pre;
let b0h = filters[0].b0_hp;
let b1h = filters[0].b1_hp;
let b2h = filters[0].b2_hp;
let a1h = filters[0].a1_hp;
let a2h = filters[0].a2_hp;
let mut z1p = [
filters[0].z1_pre,
filters[1].z1_pre,
filters[2].z1_pre,
filters[3].z1_pre,
];
let mut z2p = [
filters[0].z2_pre,
filters[1].z2_pre,
filters[2].z2_pre,
filters[3].z2_pre,
];
let mut z1h = [
filters[0].z1_hp,
filters[1].z1_hp,
filters[2].z1_hp,
filters[3].z1_hp,
];
let mut z2h = [
filters[0].z2_hp,
filters[1].z2_hp,
filters[2].z2_hp,
filters[3].z2_hp,
];
for i in 0..len {
let x = [
channels[0][i],
channels[1][i],
channels[2][i],
channels[3][i],
];
let y1_0 = b0p * x[0] + z1p[0];
let y1_1 = b0p * x[1] + z1p[1];
let y1_2 = b0p * x[2] + z1p[2];
let y1_3 = b0p * x[3] + z1p[3];
z1p[0] = b1p * x[0] - a1p * y1_0 + z2p[0];
z1p[1] = b1p * x[1] - a1p * y1_1 + z2p[1];
z1p[2] = b1p * x[2] - a1p * y1_2 + z2p[2];
z1p[3] = b1p * x[3] - a1p * y1_3 + z2p[3];
z2p[0] = b2p * x[0] - a2p * y1_0;
z2p[1] = b2p * x[1] - a2p * y1_1;
z2p[2] = b2p * x[2] - a2p * y1_2;
z2p[3] = b2p * x[3] - a2p * y1_3;
let y2_0 = b0h * y1_0 + z1h[0];
let y2_1 = b0h * y1_1 + z1h[1];
let y2_2 = b0h * y1_2 + z1h[2];
let y2_3 = b0h * y1_3 + z1h[3];
z1h[0] = b1h * y1_0 - a1h * y2_0 + z2h[0];
z1h[1] = b1h * y1_1 - a1h * y2_1 + z2h[1];
z1h[2] = b1h * y1_2 - a1h * y2_2 + z2h[2];
z1h[3] = b1h * y1_3 - a1h * y2_3 + z2h[3];
z2h[0] = b2h * y1_0 - a2h * y2_0;
z2h[1] = b2h * y1_1 - a2h * y2_1;
z2h[2] = b2h * y1_2 - a2h * y2_2;
z2h[3] = b2h * y1_3 - a2h * y2_3;
out0[i] = y2_0;
out1[i] = y2_1;
out2[i] = y2_2;
out3[i] = y2_3;
}
for (ch, filt) in filters.iter_mut().enumerate() {
filt.z1_pre = z1p[ch];
filt.z2_pre = z2p[ch];
filt.z1_hp = z1h[ch];
filt.z2_hp = z2h[ch];
}
[out0, out1, out2, out3]
}
#[cfg(test)]
mod tests {
use super::*;
fn make_filter_48k() -> KWeightedFilter {
KWeightedFilter::new(48_000)
}
#[test]
fn test_k_weighted_filter_dc() {
let mut f = make_filter_48k();
let n = 20_000;
let mut last = 0.0_f32;
for _ in 0..n {
last = f.process_sample(1.0);
}
assert!(
last.abs() < 0.01,
"DC input not attenuated: settled output = {last}"
);
}
#[test]
fn test_k_weighted_filter_roundtrip() {
let sr = 48_000u32;
let mut f = KWeightedFilter::new(sr);
let freq = 1_000.0_f32;
let n = 4 * sr as usize;
let mut energy_out = 0.0_f64;
let mut energy_in = 0.0_f64;
for i in 0..n {
let x = (2.0 * std::f32::consts::PI * freq * i as f32 / sr as f32).sin();
let y = f.process_sample(x);
if i >= n / 2 {
energy_in += (x * x) as f64;
energy_out += (y * y) as f64;
}
}
let ratio = energy_out / energy_in;
assert!(
ratio > 0.5 && ratio < 8.0,
"1 kHz energy ratio {ratio:.3} outside [0.5, 8.0] — filter may be broken"
);
}
#[test]
fn test_k_weighted_4ch_matches_scalar() {
let sr = 48_000u32;
let n = 2048;
let freq = 1_000.0_f32;
let make_input = |phase: f32| -> Vec<f32> {
(0..n)
.map(|i| (2.0 * std::f32::consts::PI * freq * i as f32 / sr as f32 + phase).sin())
.collect()
};
let inputs: [Vec<f32>; 4] = [
make_input(0.0),
make_input(0.5),
make_input(1.0),
make_input(1.5),
];
let scalar_out: [Vec<f32>; 4] = {
let mut results: [Vec<f32>; 4] =
[vec![0.0; n], vec![0.0; n], vec![0.0; n], vec![0.0; n]];
for ch in 0..4 {
let mut f = KWeightedFilter::new(sr);
f.process_block(&inputs[ch], &mut results[ch]);
}
results
};
let mut simd_filters: [KWeightedFilter; 4] = [
KWeightedFilter::new(sr),
KWeightedFilter::new(sr),
KWeightedFilter::new(sr),
KWeightedFilter::new(sr),
];
let simd_out = k_weight_4ch_simd(
[&inputs[0], &inputs[1], &inputs[2], &inputs[3]],
&mut simd_filters,
);
for ch in 0..4 {
for (i, (&s, &r)) in simd_out[ch].iter().zip(scalar_out[ch].iter()).enumerate() {
assert!(
(s - r).abs() < 1e-5,
"ch={ch} i={i}: simd={s} scalar={r} diff={}",
(s - r).abs()
);
}
}
}
#[test]
fn test_k_weighted_block_vs_sample() {
let sr = 44_100u32;
let n = 512;
let input: Vec<f32> = (0..n).map(|i| (i as f32 * 0.01).sin()).collect();
let mut out_block = vec![0.0f32; n];
let mut f1 = KWeightedFilter::new(sr);
f1.process_block(&input, &mut out_block);
let mut f2 = KWeightedFilter::new(sr);
let out_sample: Vec<f32> = input.iter().map(|&x| f2.process_sample(x)).collect();
for (i, (&b, &s)) in out_block.iter().zip(out_sample.iter()).enumerate() {
assert!(
(b - s).abs() < 1e-7,
"index {i}: block={b} sample={s} diff={}",
(b - s).abs()
);
}
}
}