libflo_audio/lossy/
mdct.rs

1// Full disclosure, this code is inspired by Symphonia's MDCT implementation,
2// and part's of ffmpeg's as well.
3
4use rustfft::{num_complex::Complex, FftPlanner};
5use std::f32::consts::PI;
6use std::sync::Arc;
7
8/// Window types for MDCT
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum WindowType {
11    /// Sine window - simple, good for most content
12    Sine,
13    /// Kaiser-Bessel Derived - better frequency selectivity
14    KaiserBesselDerived,
15    /// Vorbis window - optimized for audio
16    Vorbis,
17}
18
19/// MDCT block sizes
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum BlockSize {
22    /// Long block (2048 samples) - good frequency resolution for stationary signals
23    Long,
24    /// Short block (256 samples) - good time resolution for transients
25    Short,
26    /// Start block - transition from long to short
27    Start,
28    /// Stop block - transition from short to long
29    Stop,
30}
31
32impl BlockSize {
33    /// Get the number of samples for this block size
34    pub fn samples(self) -> usize {
35        match self {
36            BlockSize::Long | BlockSize::Start | BlockSize::Stop => 2048,
37            BlockSize::Short => 256,
38        }
39    }
40
41    /// Get the number of MDCT coefficients (N/2)
42    pub fn coefficients(self) -> usize {
43        self.samples() / 2
44    }
45}
46
47/// FFT-based MDCT transform for a specific window size
48struct MdctTransform {
49    /// Window size (N)
50    n: usize,
51    /// Number of coefficients (N/2)
52    n2: usize,
53    /// FFT size (N/4)
54    n4: usize,
55    /// Window function
56    window: Vec<f32>,
57    /// Forward FFT
58    fft: Arc<dyn rustfft::Fft<f32>>,
59    /// Twiddle factors: e^(i*π/n2 * (k + 1/8))
60    twiddle: Vec<Complex<f32>>,
61}
62
63impl MdctTransform {
64    fn new(window_size: usize, window_type: WindowType) -> Self {
65        let n = window_size;
66        let n2 = n / 2;
67        let n4 = n / 4;
68
69        // Create window
70        let window = match window_type {
71            WindowType::Sine => Self::sine_window(n),
72            WindowType::KaiserBesselDerived => Self::kbd_window(n, 4.0),
73            WindowType::Vorbis => Self::vorbis_window(n),
74        };
75
76        // Create FFT planner
77        let mut planner = FftPlanner::new();
78        let fft = planner.plan_fft_forward(n4);
79
80        // Pre-compute twiddle factors
81        let twiddle: Vec<Complex<f32>> = (0..n4)
82            .map(|k| {
83                let theta = PI / n2 as f32 * (k as f32 + 0.125);
84                Complex::new(theta.cos(), theta.sin())
85            })
86            .collect();
87
88        Self {
89            n,
90            n2,
91            n4,
92            window,
93            fft,
94            twiddle,
95        }
96    }
97
98    /// Sine window: w[n] = sin(π(n+0.5)/N)
99    fn sine_window(n: usize) -> Vec<f32> {
100        (0..n)
101            .map(|i| (PI * (i as f32 + 0.5) / n as f32).sin())
102            .collect()
103    }
104
105    /// Vorbis window: sin(π/2 * sin²(π(n+0.5)/N))
106    fn vorbis_window(n: usize) -> Vec<f32> {
107        (0..n)
108            .map(|i| {
109                let x = (PI * (i as f32 + 0.5) / n as f32).sin();
110                (PI / 2.0 * x * x).sin()
111            })
112            .collect()
113    }
114
115    /// Kaiser-Bessel Derived window
116    fn kbd_window(n: usize, alpha: f32) -> Vec<f32> {
117        let half = n / 2;
118
119        // Compute Kaiser window for first half
120        let kaiser: Vec<f32> = (0..=half)
121            .map(|i| {
122                Self::bessel_i0(
123                    PI * alpha * (1.0 - (2.0 * i as f32 / half as f32 - 1.0).powi(2)).sqrt(),
124                )
125            })
126            .collect();
127
128        // Cumulative sum
129        let mut cumsum = vec![0.0f32; half + 1];
130        cumsum[0] = kaiser[0];
131        for i in 1..=half {
132            cumsum[i] = cumsum[i - 1] + kaiser[i];
133        }
134        let total = cumsum[half];
135
136        // Build KBD window
137        let mut window = vec![0.0f32; n];
138        for i in 0..half {
139            window[i] = (cumsum[i] / total).sqrt();
140            window[n - 1 - i] = window[i];
141        }
142
143        window
144    }
145
146    /// Modified Bessel function I0 (for KBD window)
147    fn bessel_i0(x: f32) -> f32 {
148        let mut sum = 1.0f32;
149        let mut term = 1.0f32;
150        let x_sq = x * x / 4.0;
151
152        for k in 1..20 {
153            term *= x_sq / (k * k) as f32;
154            sum += term;
155            if term < 1e-10 {
156                break;
157            }
158        }
159
160        sum
161    }
162
163    /// Forward MDCT using FFT - O(N log N)
164    ///
165    /// Based on FFmpeg's ff_mdct_calc_c algorithm.
166    fn forward(&self, samples: &[f32]) -> Vec<f32> {
167        let n = self.n;
168        let n2 = self.n2;
169        let n4 = self.n4;
170        let n8 = n4 / 2;
171        let n3 = 3 * n4;
172
173        // Apply window
174        let x: Vec<f32> = samples
175            .iter()
176            .zip(self.window.iter())
177            .map(|(&s, &w)| s * w)
178            .collect();
179
180        // Pre-rotation: fold N windowed samples into N/4 complex FFT inputs
181        let mut z: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n4];
182
183        for i in 0..n8 {
184            // First butterfly
185            let re = -x[2 * i + n3] - x[n3 - 1 - 2 * i];
186            let im = -x[n4 + 2 * i] + x[n4 - 1 - 2 * i];
187
188            let w = &self.twiddle[i];
189            z[i] = Complex::new(-re * w.re - im * w.im, re * w.im - im * w.re);
190
191            // Second butterfly
192            let re2 = x[2 * i] - x[n2 - 1 - 2 * i];
193            let im2 = -x[n2 + 2 * i] - x[n - 1 - 2 * i];
194
195            let w2 = &self.twiddle[n8 + i];
196            z[n8 + i] = Complex::new(-re2 * w2.re - im2 * w2.im, re2 * w2.im - im2 * w2.re);
197        }
198
199        // Forward FFT
200        self.fft.process(&mut z);
201
202        // Post-rotation: extract N/2 real coefficients
203        let mut output = vec![0.0; n2];
204
205        for i in 0..n8 {
206            let idx1 = n8 - i - 1;
207            let idx2 = n8 + i;
208
209            let w1 = &self.twiddle[idx1];
210            let z1 = z[idx1];
211            let i1 = -z1.re * w1.im + z1.im * w1.re;
212            let r0 = -z1.re * w1.re - z1.im * w1.im;
213
214            let w2 = &self.twiddle[idx2];
215            let z2 = z[idx2];
216            let i0 = -z2.re * w2.im + z2.im * w2.re;
217            let r1 = -z2.re * w2.re - z2.im * w2.im;
218
219            output[2 * idx1] = r0;
220            output[2 * idx1 + 1] = i0;
221            output[2 * idx2] = r1;
222            output[2 * idx2 + 1] = i1;
223        }
224
225        output
226    }
227
228    /// Inverse MDCT using FFT - O(N log N)
229    ///
230    /// Based on Symphonia's IMDCT implementation.
231    fn inverse(&self, spec: &[f32]) -> Vec<f32> {
232        let n = self.n;
233        let n2 = self.n2;
234        let n4 = self.n4;
235        let n8 = n4 / 2;
236
237        // Pre-FFT twiddling
238        let mut z: Vec<Complex<f32>> = Vec::with_capacity(n4);
239
240        for i in 0..n4 {
241            let even = spec[i * 2];
242            let odd = -spec[n2 - 1 - i * 2];
243
244            let w = &self.twiddle[i];
245            z.push(Complex::new(
246                odd * w.im - even * w.re,
247                odd * w.re + even * w.im,
248            ));
249        }
250
251        // Apply forward FFT
252        self.fft.process(&mut z);
253
254        // Post-FFT twiddling and unfolding
255        let mut output = vec![0.0; n];
256        let scale = 2.0 / n2 as f32;
257
258        // First half of FFT output
259        for i in 0..n8 {
260            let w = &self.twiddle[i];
261            let val_re = w.re * z[i].re + w.im * z[i].im;
262            let val_im = w.im * z[i].re - w.re * z[i].im;
263
264            let fi = 2 * i;
265            let ri = n4 - 1 - 2 * i;
266
267            output[ri] = -val_im * scale * self.window[ri];
268            output[n4 + fi] = val_im * scale * self.window[n4 + fi];
269            output[n2 + ri] = val_re * scale * self.window[n2 + ri];
270            output[n2 + n4 + fi] = val_re * scale * self.window[n2 + n4 + fi];
271        }
272
273        // Second half of FFT output
274        for i in 0..n8 {
275            let idx = n8 + i;
276            let w = &self.twiddle[idx];
277            let val_re = w.re * z[idx].re + w.im * z[idx].im;
278            let val_im = w.im * z[idx].re - w.re * z[idx].im;
279
280            let fi = 2 * i;
281            let ri = n4 - 1 - 2 * i;
282
283            output[fi] = -val_re * scale * self.window[fi];
284            output[n4 + ri] = val_re * scale * self.window[n4 + ri];
285            output[n2 + fi] = val_im * scale * self.window[n2 + fi];
286            output[n2 + n4 + ri] = val_im * scale * self.window[n2 + n4 + ri];
287        }
288
289        output
290    }
291}
292
293/// MDCT processor with pre-computed windows and FFT plans
294///
295/// Provides O(N log N) MDCT/IMDCT transforms using FFT acceleration.
296pub struct Mdct {
297    /// Long block transform (2048 samples)
298    long_transform: MdctTransform,
299    /// Short block transform (256 samples)
300    short_transform: MdctTransform,
301    /// Previous frame's windowed samples for overlap-add (per channel)
302    overlap_buffer: Vec<Vec<f32>>,
303    /// Number of channels
304    channels: usize,
305}
306
307impl Mdct {
308    /// Create a new MDCT processor
309    pub fn new(channels: usize, window_type: WindowType) -> Self {
310        let long_transform = MdctTransform::new(2048, window_type);
311        let short_transform = MdctTransform::new(256, window_type);
312
313        // Initialize overlap buffers (N/2 samples per channel for long blocks)
314        let overlap_buffer = vec![vec![0.0f32; 1024]; channels];
315
316        Self {
317            long_transform,
318            short_transform,
319            overlap_buffer,
320            channels,
321        }
322    }
323
324    /// Sine window: w[n] = sin(π(n+0.5)/N)
325    pub fn sine_window(n: usize) -> Vec<f32> {
326        MdctTransform::sine_window(n)
327    }
328
329    /// Vorbis window: sin(π/2 * sin²(π(n+0.5)/N))
330    pub fn vorbis_window(n: usize) -> Vec<f32> {
331        MdctTransform::vorbis_window(n)
332    }
333
334    /// Forward MDCT: N time samples → N/2 frequency coefficients
335    ///
336    /// X[k] = Σ x[n] * w[n] * cos(π/N * (n + 0.5 + N/2) * (k + 0.5))
337    pub fn forward(&self, samples: &[f32], block_size: BlockSize) -> Vec<f32> {
338        let n = block_size.samples();
339        assert!(samples.len() >= n, "Not enough samples for MDCT");
340
341        let transform = match block_size {
342            BlockSize::Long | BlockSize::Start | BlockSize::Stop => &self.long_transform,
343            BlockSize::Short => &self.short_transform,
344        };
345
346        transform.forward(&samples[..n])
347    }
348
349    /// Inverse MDCT: N/2 frequency coefficients → N time samples
350    ///
351    /// y[n] = 2/N * Σ(k=0 to N-1) X[k] * cos(π/N * (n + 0.5 + N/2) * (k + 0.5))
352    pub fn inverse(&self, coeffs: &[f32], block_size: BlockSize) -> Vec<f32> {
353        let n2 = block_size.coefficients();
354        assert!(coeffs.len() >= n2, "Not enough coefficients for IMDCT");
355
356        let transform = match block_size {
357            BlockSize::Long | BlockSize::Start | BlockSize::Stop => &self.long_transform,
358            BlockSize::Short => &self.short_transform,
359        };
360
361        transform.inverse(&coeffs[..n2])
362    }
363
364    /// Process a frame with overlap-add for perfect reconstruction
365    /// Returns N/2 output samples (the middle half after overlap-add)
366    pub fn process_frame(
367        &mut self,
368        samples: &[f32],
369        channel: usize,
370        block_size: BlockSize,
371    ) -> (Vec<f32>, Vec<f32>) {
372        let n = block_size.samples();
373        let n2 = n / 2;
374
375        // Forward MDCT
376        let coeffs = self.forward(samples, block_size);
377
378        // Inverse MDCT (for testing/verification)
379        let reconstructed = self.inverse(&coeffs, block_size);
380
381        // Overlap-add with previous frame
382        let mut output = vec![0.0f32; n2];
383        for i in 0..n2 {
384            output[i] = reconstructed[i] + self.overlap_buffer[channel][i];
385        }
386
387        // Store second half for next frame's overlap
388        self.overlap_buffer[channel].copy_from_slice(&reconstructed[n2..n2 + n2]);
389
390        (coeffs, output)
391    }
392
393    /// Reset overlap buffers (e.g., for seeking)
394    pub fn reset(&mut self) {
395        for buf in &mut self.overlap_buffer {
396            buf.fill(0.0);
397        }
398    }
399
400    /// Encode samples to MDCT coefficients for all channels
401    /// Input: interleaved samples [L, R, L, R, ...]
402    /// Output: MDCT coefficients per channel
403    pub fn analyze(&mut self, samples: &[f32], block_size: BlockSize) -> Vec<Vec<f32>> {
404        let n = block_size.samples();
405        let samples_per_channel = samples.len() / self.channels;
406
407        // Deinterleave
408        let mut channel_data: Vec<Vec<f32>> = (0..self.channels)
409            .map(|_| Vec::with_capacity(samples_per_channel))
410            .collect();
411
412        for (i, &s) in samples.iter().enumerate() {
413            channel_data[i % self.channels].push(s);
414        }
415
416        // MDCT each channel
417        let mut all_coeffs = Vec::with_capacity(self.channels);
418        for data in &channel_data {
419            if data.len() >= n {
420                let coeffs = self.forward(data, block_size);
421                all_coeffs.push(coeffs);
422            } else {
423                // Pad with zeros if not enough samples
424                let mut padded = data.clone();
425                padded.resize(n, 0.0);
426                let coeffs = self.forward(&padded, block_size);
427                all_coeffs.push(coeffs);
428            }
429        }
430
431        all_coeffs
432    }
433
434    /// Synthesize samples from MDCT coefficients with overlap-add
435    /// Input: MDCT coefficients per channel
436    /// Output: interleaved samples
437    pub fn synthesize(&mut self, coeffs: &[Vec<f32>], block_size: BlockSize) -> Vec<f32> {
438        let n = block_size.samples();
439        let n2 = n / 2;
440
441        // IMDCT + overlap-add for each channel
442        let mut channel_outputs: Vec<Vec<f32>> = Vec::with_capacity(self.channels);
443
444        for (ch, ch_coeffs) in coeffs.iter().enumerate() {
445            let reconstructed = self.inverse(ch_coeffs, block_size);
446
447            // Overlap-add
448            let mut output = vec![0.0f32; n2];
449            for i in 0..n2 {
450                output[i] = reconstructed[i] + self.overlap_buffer[ch][i];
451            }
452
453            // Store for next frame
454            self.overlap_buffer[ch].copy_from_slice(&reconstructed[n2..n2 + n2]);
455
456            channel_outputs.push(output);
457        }
458
459        // Interleave
460        let mut output = Vec::with_capacity(n2 * self.channels);
461        for i in 0..n2 {
462            for ch in 0..self.channels {
463                output.push(channel_outputs[ch][i]);
464            }
465        }
466
467        output
468    }
469}