Skip to main content

oxideav_ac4/
encoder_mdct.rs

1//! Forward MDCT analysis for the AC-4 IMS encoder (round 48).
2//!
3//! Per ETSI TS 103 190-1 §5.5.2 (Pseudocodes 60-64) the AC-4 transform is
4//! a standard MDCT — same Princen-Bradley TDAC convention, same KBD
5//! windowing, same `2N`-input / `N`-output symmetry — only the spec uses
6//! an FFT-factored unfolding for speed. The decoder ([`crate::mdct`])
7//! ships the inverse direction (IMDCT). This module supplies the
8//! complementary forward direction in the simplest possible form: the
9//! direct-summation cosine matrix evaluated naively in O(N²). Correctness
10//! first; speed last (encoder isn't on a hot path).
11//!
12//! The forward formula matched against the decoder's IMDCT is the
13//! standard MDCT pair:
14//!
15//!   X[k] = sum_{n=0..2N-1} x[n] * cos(pi/N * (n + 0.5 + N/2) * (k + 0.5))
16//!   y[n] = (2/N) * sum_{k=0..N-1} X[k] * cos(pi/N * (n + 0.5 + N/2) * (k + 0.5))
17//!
18//! An end-to-end Princen-Bradley round-trip test (forward MDCT → IMDCT →
19//! KBD overlap-add over 3 consecutive frames of a constant signal) lives
20//! in `tests/` to lock the sign convention to the decoder.
21
22use crate::mdct::kbd_window;
23use core::f64::consts::PI;
24
25/// Apply the KBD analysis window of length `2N` to `2N` time-domain
26/// samples. Identical to the synthesis window — KBD is symmetric.
27pub fn apply_kbd_window(samples: &mut [f32], window: &[f32]) {
28    debug_assert_eq!(samples.len(), window.len());
29    for (s, w) in samples.iter_mut().zip(window.iter()) {
30        *s *= *w;
31    }
32}
33
34/// Naive forward MDCT: transforms `2N` time-domain samples (already
35/// windowed) into `N` spectral coefficients via the direct-summation
36/// cosine basis. O(N²); intended for correctness, not throughput.
37///
38/// The scaling matches the decoder's IMDCT in [`crate::mdct::imdct`]
39/// (which itself divides post-IFFT by `N`): forward = `2 * sum_n x*cos`,
40/// inverse = `(1/N) * sum_k X*cos`. The `2x` prefactor here is the
41/// canonical AC-4 / AAC normalisation that makes the windowed
42/// Princen-Bradley round-trip recover the original signal exactly in
43/// steady-state frames.
44///
45/// Sign convention matches [`crate::mdct::imdct`] so the round-trip
46/// `imdct(mdct(x)) ≈ x` (with TDAC overlap-add and KBD windowing).
47pub fn mdct_naive(x: &[f32]) -> Vec<f32> {
48    let two_n = x.len();
49    let n = two_n / 2;
50    let mut out = vec![0.0_f32; n];
51    let n_f = n as f64;
52    for (k, out_k) in out.iter_mut().enumerate() {
53        let kf = k as f64;
54        let mut acc = 0.0_f64;
55        for (nn, &xv) in x.iter().enumerate().take(two_n) {
56            let nf = nn as f64;
57            let theta = PI / n_f * (nf + 0.5 + n_f * 0.5) * (kf + 0.5);
58            acc += xv as f64 * theta.cos();
59        }
60        // Forward-direction scaling complementary to the decoder's
61        // IMDCT post-IFFT divide-by-N.
62        *out_k = (2.0 * acc) as f32;
63    }
64    out
65}
66
67/// Per-channel forward-MDCT analysis state — carries the last `N` input
68/// samples so the next frame can window across the 50% TDAC boundary.
69#[derive(Debug, Clone)]
70pub struct EncoderMdctState {
71    /// Transform length (samples per output spectrum). The MDCT consumes
72    /// `2N` samples per call (last N from history + this frame's N new).
73    pub n: u32,
74    /// Last `N` input PCM samples for the next-frame overlap.
75    pub history: Vec<f32>,
76}
77
78impl EncoderMdctState {
79    /// Fresh encoder state with N-sample zero history (TDAC priming —
80    /// the very first encoded frame will lose half a window of energy at
81    /// the leading edge, identical to the decoder's first-frame behaviour).
82    pub fn new(n: u32) -> Self {
83        Self {
84            n,
85            history: vec![0.0_f32; n as usize],
86        }
87    }
88
89    /// Push a fresh `N`-sample frame of input PCM through the windowed
90    /// forward MDCT and return the `N` spectral coefficients. The caller
91    /// owns the input buffer — this method consumes it logically (copies
92    /// into the state).
93    pub fn analyse_frame(&mut self, frame: &[f32]) -> Vec<f32> {
94        let n = self.n as usize;
95        assert_eq!(frame.len(), n, "encoder: frame must be exactly N samples");
96        // Build the 2N-sample MDCT input: previous N + new N.
97        let mut buf = Vec::with_capacity(2 * n);
98        buf.extend_from_slice(&self.history);
99        buf.extend_from_slice(frame);
100        // KBD window per §5.5.3 — same window as synthesis (Princen-Bradley).
101        let window = kbd_window(self.n);
102        apply_kbd_window(&mut buf, &window);
103        // Forward MDCT.
104        let spec = mdct_naive(&buf);
105        // Slide history forward — drop the old N, keep this frame's N
106        // for the next call's leading half.
107        self.history.clear();
108        self.history.extend_from_slice(frame);
109        spec
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::mdct::{imdct, imdct_olap_symmetric, kbd_window};
117
118    /// Simplest TDAC contract: forward MDCT of an all-zero windowed
119    /// signal is all-zero spectrum.
120    #[test]
121    fn forward_mdct_zero_in_zero_out() {
122        let n = 8;
123        let x = vec![0.0_f32; 2 * n];
124        let spec = mdct_naive(&x);
125        assert_eq!(spec.len(), n);
126        for &s in &spec {
127            assert!(s.abs() < 1e-6, "got non-zero from zero input: {s}");
128        }
129    }
130
131    /// MDCT followed by IMDCT (no windowing) should give back exactly
132    /// `2 * x[n] - aliasing` per the MDCT identity. The straightforward
133    /// check: forward then inverse, for a signal that's odd-symmetric
134    /// about n = N/2 the time-domain aliasing cancels and we get x back
135    /// exactly. The reverse is harder to construct; for round-trip we
136    /// rely on the windowed Princen-Bradley test below.
137    #[test]
138    fn forward_mdct_is_linear() {
139        let n = 8;
140        let mut x = vec![0.0_f32; 2 * n];
141        x[3] = 0.7;
142        x[10] = -0.3;
143        let s = mdct_naive(&x);
144        let scaled: Vec<f32> = x.iter().map(|v| v * 2.5).collect();
145        let s2 = mdct_naive(&scaled);
146        for (a, b) in s.iter().zip(s2.iter()) {
147            assert!(((*a) * 2.5 - *b).abs() < 1e-4, "linearity broken");
148        }
149    }
150
151    /// Princen-Bradley TDAC: forward MDCT → IMDCT → overlap-add over
152    /// 3 consecutive frames of a constant signal recovers the constant
153    /// in the steady-state middle frame (within floating-point error).
154    /// This is the headline correctness test that locks the sign
155    /// convention between mdct_naive and the decoder's imdct.
156    #[test]
157    fn princen_bradley_constant_signal_recovers_in_middle_frame() {
158        let n = 16u32;
159        let nsz = n as usize;
160        // 4 frames of constant 0.1 amplitude.
161        let frames: Vec<Vec<f32>> = (0..4).map(|_| vec![0.1_f32; nsz]).collect();
162
163        // Encoder side.
164        let mut enc = EncoderMdctState::new(n);
165        let specs: Vec<Vec<f32>> = frames.iter().map(|f| enc.analyse_frame(f)).collect();
166
167        // Decoder side: imdct + KBD overlap-add.
168        let window = kbd_window(n);
169        let mut overlap = vec![0.0_f32; nsz];
170        let mut pcm_out: Vec<Vec<f32>> = Vec::new();
171        for spec in &specs {
172            let y = imdct(spec);
173            let pcm = imdct_olap_symmetric(&y, &window, &mut overlap);
174            pcm_out.push(pcm);
175        }
176        // The middle frames (index 1, 2) should be ~0.1 throughout.
177        // First and last lose half a window to the zero history.
178        for &v in &pcm_out[2] {
179            assert!(
180                (v - 0.1).abs() < 0.01,
181                "Princen-Bradley failed: got {v}, expected 0.1"
182            );
183        }
184    }
185
186    /// EncoderMdctState round-trips a sine-wave through 4 consecutive
187    /// frames; the middle two frames must reconstruct within a small
188    /// SNR margin.
189    #[test]
190    fn encoder_state_sinewave_round_trips() {
191        let n = 32u32;
192        let nsz = n as usize;
193        let freq_bin = 4.0_f32; // bin index k: x[t] = sin(2*pi * k * t / (2N))
194        let make_frame = |start: usize| -> Vec<f32> {
195            (0..nsz)
196                .map(|i| {
197                    let t = (start + i) as f32;
198                    (2.0 * std::f32::consts::PI * freq_bin * t / (2.0 * n as f32)).sin()
199                })
200                .collect()
201        };
202        let frames: Vec<Vec<f32>> = (0..4).map(|i| make_frame(i * nsz)).collect();
203        let mut enc = EncoderMdctState::new(n);
204        let specs: Vec<Vec<f32>> = frames.iter().map(|f| enc.analyse_frame(f)).collect();
205
206        let window = kbd_window(n);
207        let mut overlap = vec![0.0_f32; nsz];
208        let mut pcm_out: Vec<Vec<f32>> = Vec::new();
209        for spec in &specs {
210            let y = imdct(spec);
211            let pcm = imdct_olap_symmetric(&y, &window, &mut overlap);
212            pcm_out.push(pcm);
213        }
214        // Frame 2 is steady-state middle; compare to original.
215        let orig = &frames[2];
216        let recon = &pcm_out[2];
217        let mut sig_e = 0.0_f64;
218        let mut err_e = 0.0_f64;
219        for (o, r) in orig.iter().zip(recon.iter()) {
220            sig_e += (*o as f64).powi(2);
221            err_e += (*o as f64 - *r as f64).powi(2);
222        }
223        let snr_db = 10.0 * (sig_e / err_e.max(1e-30)).log10();
224        assert!(snr_db > 40.0, "round-trip SNR too low: {snr_db:.1} dB");
225    }
226}