use crate::celt_header::CeltPostFilter;
use crate::celt_mdct_window::{self, MdctWindowError};
pub const POST_FILTER_MIN_PERIOD: u16 = 15;
pub const POST_FILTER_MAX_PERIOD: u16 = 1022;
pub const POST_FILTER_TAPSET_COUNT: usize = 3;
pub const POST_FILTER_GAIN_NUMERATOR: u32 = 3;
pub const POST_FILTER_GAIN_DENOMINATOR: u32 = 32;
pub const POST_FILTER_GAIN_INDEX_MAX: u8 = 7;
pub const POST_FILTER_TAPS: [[f64; 3]; POST_FILTER_TAPSET_COUNT] = [
[0.306_640_625_0, 0.217_041_015_6, 0.129_638_671_9],
[0.463_867_187_5, 0.268_066_406_2, 0.0],
[0.799_804_687_5, 0.100_097_656_2, 0.0],
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PostFilterError {
TapsetOutOfRange {
tapset: u8,
},
PeriodOutOfRange {
period: u16,
},
GainIndexOutOfRange {
gain_index: u8,
},
OutputBufferTooSmall {
input_len: usize,
output_len: usize,
},
TransitionLengthMismatch {
expected: usize,
provided: usize,
},
Window(MdctWindowError),
}
impl core::fmt::Display for PostFilterError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
PostFilterError::TapsetOutOfRange { tapset } => {
write!(f, "post-filter tapset out of range: {tapset} (max 2)")
}
PostFilterError::PeriodOutOfRange { period } => write!(
f,
"post-filter period out of range: {period} (must be 15..=1022)"
),
PostFilterError::GainIndexOutOfRange { gain_index } => {
write!(
f,
"post-filter gain index out of range: {gain_index} (max 7)"
)
}
PostFilterError::OutputBufferTooSmall {
input_len,
output_len,
} => write!(
f,
"post-filter output buffer too small: need {input_len} samples, got {output_len}"
),
PostFilterError::TransitionLengthMismatch { expected, provided } => write!(
f,
"post-filter transition length mismatch: expected {expected}, got {provided}"
),
PostFilterError::Window(e) => write!(f, "post-filter window error: {e}"),
}
}
}
impl std::error::Error for PostFilterError {}
impl From<MdctWindowError> for PostFilterError {
fn from(e: MdctWindowError) -> Self {
PostFilterError::Window(e)
}
}
pub fn tapset_coefficients(tapset: u8) -> Result<[f64; 3], PostFilterError> {
POST_FILTER_TAPS
.get(tapset as usize)
.copied()
.ok_or(PostFilterError::TapsetOutOfRange { tapset })
}
pub fn post_filter_gain(gain_index: u8) -> Result<f64, PostFilterError> {
if gain_index > POST_FILTER_GAIN_INDEX_MAX {
return Err(PostFilterError::GainIndexOutOfRange { gain_index });
}
let g = f64::from(POST_FILTER_GAIN_NUMERATOR * (u32::from(gain_index) + 1))
/ f64::from(POST_FILTER_GAIN_DENOMINATOR);
Ok(g)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct PostFilterCoeffs {
pub period: u16,
pub a0: f64,
pub a1: f64,
pub a2: f64,
}
impl PostFilterCoeffs {
pub fn new(period: u16, gain_index: u8, tapset: u8) -> Result<Self, PostFilterError> {
if !(POST_FILTER_MIN_PERIOD..=POST_FILTER_MAX_PERIOD).contains(&period) {
return Err(PostFilterError::PeriodOutOfRange { period });
}
let g = post_filter_gain(gain_index)?;
let [g0, g1, g2] = tapset_coefficients(tapset)?;
Ok(Self {
period,
a0: g * g0,
a1: g * g1,
a2: g * g2,
})
}
pub fn from_header(pf: &CeltPostFilter) -> Result<Self, PostFilterError> {
Self::new(pf.period, pf.gain_index, pf.tapset)
}
#[must_use]
pub fn history_len(&self) -> usize {
usize::from(self.period) + 2
}
}
#[derive(Debug, Clone)]
pub struct PostFilter {
hist: Vec<f64>,
head: usize,
}
impl Default for PostFilter {
fn default() -> Self {
Self::new()
}
}
impl PostFilter {
const HIST_CAP: usize = POST_FILTER_MAX_PERIOD as usize + 2;
#[must_use]
pub fn new() -> Self {
Self {
hist: vec![0.0; Self::HIST_CAP],
head: 0,
}
}
pub fn reset(&mut self) {
for v in self.hist.iter_mut() {
*v = 0.0;
}
self.head = 0;
}
fn past(&self, delay: usize) -> f64 {
debug_assert!((1..=Self::HIST_CAP).contains(&delay));
let idx = (self.head + Self::HIST_CAP - delay) % Self::HIST_CAP;
self.hist[idx]
}
fn push(&mut self, y: f64) {
self.hist[self.head] = y;
self.head = (self.head + 1) % Self::HIST_CAP;
}
#[must_use]
pub fn step(&mut self, x: f64, c: &PostFilterCoeffs) -> f64 {
let t = usize::from(c.period);
let y = x
+ c.a0 * self.past(t)
+ c.a1 * (self.past(t - 1) + self.past(t + 1))
+ c.a2 * (self.past(t - 2) + self.past(t + 2));
self.push(y);
y
}
pub fn process_in_place(&mut self, samples: &mut [f64], c: &PostFilterCoeffs) {
for x in samples.iter_mut() {
*x = self.step(*x, c);
}
}
pub fn process(
&mut self,
input: &[f64],
output: &mut [f64],
c: &PostFilterCoeffs,
) -> Result<usize, PostFilterError> {
if output.len() < input.len() {
return Err(PostFilterError::OutputBufferTooSmall {
input_len: input.len(),
output_len: output.len(),
});
}
for (i, &x) in input.iter().enumerate() {
output[i] = self.step(x, c);
}
Ok(input.len())
}
pub fn process_gain_transition(
&mut self,
input: &[f64],
output: &mut [f64],
old: &PostFilterCoeffs,
new: &PostFilterCoeffs,
overlap: usize,
) -> Result<usize, PostFilterError> {
if input.len() != overlap {
return Err(PostFilterError::TransitionLengthMismatch {
expected: overlap,
provided: input.len(),
});
}
if output.len() < overlap {
return Err(PostFilterError::OutputBufferTooSmall {
input_len: overlap,
output_len: output.len(),
});
}
for (n, &x) in input.iter().enumerate() {
let y_old = self.comb_response(x, old);
let y_new = self.comb_response(x, new);
let w2 = {
let w = celt_mdct_window::window_tap(n, overlap)?;
w * w
};
let y = (1.0 - w2) * y_old + w2 * y_new;
self.push(y);
output[n] = y;
}
Ok(overlap)
}
fn comb_response(&self, x: f64, c: &PostFilterCoeffs) -> f64 {
let t = usize::from(c.period);
x + c.a0 * self.past(t)
+ c.a1 * (self.past(t - 1) + self.past(t + 1))
+ c.a2 * (self.past(t - 2) + self.past(t + 2))
}
}
pub fn crossfade_transition(
old_out: &[f64],
new_out: &[f64],
out: &mut [f64],
overlap: usize,
) -> Result<(), PostFilterError> {
if old_out.len() != overlap || new_out.len() != overlap {
return Err(PostFilterError::TransitionLengthMismatch {
expected: overlap,
provided: old_out.len().min(new_out.len()),
});
}
if out.len() < overlap {
return Err(PostFilterError::OutputBufferTooSmall {
input_len: overlap,
output_len: out.len(),
});
}
for n in 0..overlap {
let w = celt_mdct_window::window_tap(n, overlap)?;
let w2 = w * w;
out[n] = (1.0 - w2) * old_out[n] + w2 * new_out[n];
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tapsets_match_rfc() {
assert_eq!(
tapset_coefficients(0).unwrap(),
[0.306_640_625_0, 0.217_041_015_6, 0.129_638_671_9]
);
assert_eq!(
tapset_coefficients(1).unwrap(),
[0.463_867_187_5, 0.268_066_406_2, 0.0]
);
assert_eq!(
tapset_coefficients(2).unwrap(),
[0.799_804_687_5, 0.100_097_656_2, 0.0]
);
}
#[test]
fn tapsets_one_and_two_drop_g2() {
assert_eq!(tapset_coefficients(1).unwrap()[2], 0.0);
assert_eq!(tapset_coefficients(2).unwrap()[2], 0.0);
}
#[test]
fn tapset_out_of_range_rejected() {
assert_eq!(
tapset_coefficients(3),
Err(PostFilterError::TapsetOutOfRange { tapset: 3 })
);
}
#[test]
fn gain_formula() {
assert_eq!(post_filter_gain(0).unwrap(), 3.0 / 32.0);
assert_eq!(post_filter_gain(7).unwrap(), 24.0 / 32.0);
assert_eq!(post_filter_gain(3).unwrap(), 12.0 / 32.0);
}
#[test]
fn gain_monotone() {
let mut prev = -1.0;
for gi in 0..=POST_FILTER_GAIN_INDEX_MAX {
let g = post_filter_gain(gi).unwrap();
assert!(g > prev, "gain not monotone at {gi}");
prev = g;
}
assert_eq!(post_filter_gain(POST_FILTER_GAIN_INDEX_MAX).unwrap(), 0.75);
}
#[test]
fn gain_index_out_of_range_rejected() {
assert_eq!(
post_filter_gain(8),
Err(PostFilterError::GainIndexOutOfRange { gain_index: 8 })
);
}
#[test]
fn coeffs_fold_gain() {
let c = PostFilterCoeffs::new(100, 7, 0).unwrap();
let g = 0.75;
assert_eq!(c.a0, g * 0.306_640_625_0);
assert_eq!(c.a1, g * 0.217_041_015_6);
assert_eq!(c.a2, g * 0.129_638_671_9);
assert_eq!(c.period, 100);
assert_eq!(c.history_len(), 102);
}
#[test]
fn period_bounds_enforced() {
assert!(PostFilterCoeffs::new(15, 0, 0).is_ok());
assert!(PostFilterCoeffs::new(1022, 0, 0).is_ok());
assert_eq!(
PostFilterCoeffs::new(14, 0, 0),
Err(PostFilterError::PeriodOutOfRange { period: 14 })
);
assert_eq!(
PostFilterCoeffs::new(1023, 0, 0),
Err(PostFilterError::PeriodOutOfRange { period: 1023 })
);
}
#[test]
fn from_header_round_trips() {
let pf = CeltPostFilter {
octave: 2,
period: 64,
gain_index: 4,
tapset: 1,
};
let c = PostFilterCoeffs::from_header(&pf).unwrap();
assert_eq!(c, PostFilterCoeffs::new(64, 4, 1).unwrap());
}
#[test]
fn fresh_filter_passes_through_until_period() {
let c = PostFilterCoeffs::new(15, 7, 0).unwrap();
let mut f = PostFilter::new();
for k in 0..13 {
let x = (k as f64) + 1.0;
assert_eq!(f.step(x, &c), x, "sample {k} should pass through");
}
}
#[test]
fn comb_recurrence_matches_hand_expansion() {
let c = PostFilterCoeffs::new(15, 7, 0).unwrap();
let mut f = PostFilter::new();
let mut ys = Vec::new();
ys.push(f.step(1.0, &c));
for _ in 0..20 {
ys.push(f.step(0.0, &c));
}
assert_eq!(ys[0], 1.0);
for (n, y) in ys.iter().enumerate().take(13).skip(1) {
assert_eq!(*y, 0.0, "y({n}) should be zero before feedback");
}
assert_eq!(ys[13], c.a2 * 1.0);
assert_eq!(ys[14], c.a1 * 1.0);
assert_eq!(ys[15], c.a0 * 1.0);
assert_eq!(ys[16], c.a1 * 1.0);
assert_eq!(ys[17], c.a2 * 1.0);
assert_eq!(ys[18], 0.0);
}
#[test]
fn impulse_response_is_symmetric_about_period() {
let c = PostFilterCoeffs::new(20, 5, 0).unwrap();
let mut f = PostFilter::new();
let mut ys = vec![f.step(1.0, &c)];
for _ in 0..30 {
ys.push(f.step(0.0, &c));
}
assert_eq!(ys[19], ys[21], "g1 pair not symmetric");
assert_eq!(ys[18], ys[22], "g2 pair not symmetric");
}
#[test]
fn history_carries_across_blocks() {
let c = PostFilterCoeffs::new(15, 6, 0).unwrap();
let stream: Vec<f64> = (0..40).map(|k| ((k * 7) % 11) as f64 - 5.0).collect();
let mut whole = PostFilter::new();
let mut ref_out = stream.clone();
whole.process_in_place(&mut ref_out, &c);
let mut split = PostFilter::new();
let mut a = stream[..18].to_vec();
let mut b = stream[18..].to_vec();
split.process_in_place(&mut a, &c);
split.process_in_place(&mut b, &c);
for (i, v) in a.iter().chain(b.iter()).enumerate() {
assert_eq!(*v, ref_out[i], "block-split mismatch at {i}");
}
}
#[test]
fn process_writes_output_buffer() {
let c = PostFilterCoeffs::new(15, 3, 1).unwrap();
let input: Vec<f64> = (0..30).map(|k| (k as f64).sin()).collect();
let mut out = vec![0.0; 30];
let mut f = PostFilter::new();
let n = f.process(&input, &mut out, &c).unwrap();
assert_eq!(n, 30);
let mut g = PostFilter::new();
let mut ref_out = input.clone();
g.process_in_place(&mut ref_out, &c);
assert_eq!(out, ref_out);
}
#[test]
fn process_accepts_longer_output() {
let c = PostFilterCoeffs::new(15, 0, 0).unwrap();
let input = [1.0_f64, 2.0, 3.0];
let mut out = [9.0_f64; 6];
let mut f = PostFilter::new();
let n = f.process(&input, &mut out, &c).unwrap();
assert_eq!(n, 3);
assert_eq!(&out[..3], &[1.0, 2.0, 3.0]);
assert_eq!(&out[3..], &[9.0, 9.0, 9.0]);
}
#[test]
fn process_rejects_short_output() {
let c = PostFilterCoeffs::new(15, 0, 0).unwrap();
let input = [1.0_f64, 2.0, 3.0];
let mut out = [0.0_f64; 2];
let mut f = PostFilter::new();
let err = f.process(&input, &mut out, &c).unwrap_err();
assert_eq!(
err,
PostFilterError::OutputBufferTooSmall {
input_len: 3,
output_len: 2,
}
);
let mut g = PostFilter::new();
assert_eq!(g.step(1.0, &c), f.step(1.0, &c));
}
#[test]
fn reset_zeroes_history() {
let c = PostFilterCoeffs::new(15, 7, 0).unwrap();
let mut f = PostFilter::new();
for k in 0..40 {
let _ = f.step((k as f64) + 1.0, &c);
}
f.reset();
assert_eq!(f.step(2.5, &c), 2.5);
}
#[test]
fn gain_transition_endpoints() {
let overlap = 16usize;
let old = PostFilterCoeffs::new(15, 1, 0).unwrap();
let new = PostFilterCoeffs::new(15, 7, 0).unwrap();
let mut prime = PostFilter::new();
for k in 0..40 {
let _ = prime.step(((k % 5) as f64) - 2.0, &old);
}
let input: Vec<f64> = (0..overlap).map(|k| (k as f64).cos()).collect();
let mut f_old = prime.clone();
let mut old_out = input.clone();
f_old.process_in_place(&mut old_out, &old);
let mut f_new = prime.clone();
let mut new_out = input.clone();
f_new.process_in_place(&mut new_out, &new);
let mut f = prime.clone();
let mut out = vec![0.0; overlap];
f.process_gain_transition(&input, &mut out, &old, &new, overlap)
.unwrap();
let w0 = {
let w = celt_mdct_window::window_tap(0, overlap).unwrap();
w * w
};
let expect0 = (1.0 - w0) * old_out[0] + w0 * new_out[0];
assert!((out[0] - expect0).abs() < 1e-12, "n=0 mix wrong");
let wlast = {
let w = celt_mdct_window::window_tap(overlap - 1, overlap).unwrap();
w * w
};
assert!(wlast > 0.99, "window square should approach 1 at the end");
}
#[test]
fn gain_transition_identity_when_same_params() {
let overlap = 16usize;
let c = PostFilterCoeffs::new(15, 4, 0).unwrap();
let input: Vec<f64> = (0..overlap).map(|k| (k as f64) * 0.1 - 0.5).collect();
let mut a = PostFilter::new();
let mut out_a = vec![0.0; overlap];
a.process_gain_transition(&input, &mut out_a, &c, &c, overlap)
.unwrap();
let mut b = PostFilter::new();
let mut out_b = input.clone();
b.process_in_place(&mut out_b, &c);
for (i, (x, y)) in out_a.iter().zip(out_b.iter()).enumerate() {
assert!((x - y).abs() < 1e-12, "transition!=plain at {i}");
}
}
#[test]
fn gain_transition_length_mismatch_rejected() {
let c = PostFilterCoeffs::new(15, 0, 0).unwrap();
let mut f = PostFilter::new();
let input = [1.0; 8];
let mut out = [0.0; 8];
let err = f
.process_gain_transition(&input, &mut out, &c, &c, 16)
.unwrap_err();
assert_eq!(
err,
PostFilterError::TransitionLengthMismatch {
expected: 16,
provided: 8,
}
);
}
#[test]
fn crossfade_helper_mixes() {
let overlap = 8usize;
let old_out = vec![1.0; overlap];
let new_out = vec![3.0; overlap];
let mut out = vec![0.0; overlap];
crossfade_transition(&old_out, &new_out, &mut out, overlap).unwrap();
for (n, v) in out.iter().enumerate() {
let w = celt_mdct_window::window_tap(n, overlap).unwrap();
let w2 = w * w;
let expect = (1.0 - w2) * 1.0 + w2 * 3.0;
assert!((v - expect).abs() < 1e-12, "mix wrong at {n}");
assert!(*v >= 1.0 - 1e-12 && *v <= 3.0 + 1e-12);
}
}
#[test]
fn crossfade_helper_rejects_mismatch() {
let mut out = vec![0.0; 8];
let err = crossfade_transition(&[1.0; 8], &[2.0; 4], &mut out, 8).unwrap_err();
assert_eq!(
err,
PostFilterError::TransitionLengthMismatch {
expected: 8,
provided: 4,
}
);
}
#[test]
fn crossfade_helper_rejects_short_output() {
let mut out = vec![0.0; 4];
let err = crossfade_transition(&[1.0; 8], &[2.0; 8], &mut out, 8).unwrap_err();
assert_eq!(
err,
PostFilterError::OutputBufferTooSmall {
input_len: 8,
output_len: 4,
}
);
}
#[test]
fn constants_match_rfc() {
assert_eq!(POST_FILTER_MIN_PERIOD, 15);
assert_eq!(POST_FILTER_MAX_PERIOD, 1022);
assert_eq!(POST_FILTER_TAPSET_COUNT, 3);
assert_eq!(POST_FILTER_GAIN_NUMERATOR, 3);
assert_eq!(POST_FILTER_GAIN_DENOMINATOR, 32);
assert_eq!(POST_FILTER_GAIN_INDEX_MAX, 7);
}
#[test]
fn error_display() {
assert!(PostFilterError::TapsetOutOfRange { tapset: 9 }
.to_string()
.contains("tapset"));
assert!(PostFilterError::PeriodOutOfRange { period: 5 }
.to_string()
.contains("period"));
assert!(PostFilterError::GainIndexOutOfRange { gain_index: 9 }
.to_string()
.contains("gain"));
assert!(PostFilterError::OutputBufferTooSmall {
input_len: 4,
output_len: 1
}
.to_string()
.contains("need 4"));
assert!(PostFilterError::TransitionLengthMismatch {
expected: 16,
provided: 8
}
.to_string()
.contains("expected 16"));
let w = PostFilterError::Window(MdctWindowError::ZeroLength);
assert!(w.to_string().contains("window"));
}
#[test]
fn max_reach_fits_history() {
let c = PostFilterCoeffs::new(POST_FILTER_MAX_PERIOD, 7, 0).unwrap();
assert_eq!(c.history_len(), PostFilter::HIST_CAP);
let mut f = PostFilter::new();
let _ = f.step(1.0, &c);
}
}