libflo_audio/lossy/
mdct.rs1use rustfft::{num_complex::Complex, FftPlanner};
5use std::f32::consts::PI;
6use std::sync::Arc;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum WindowType {
11 Sine,
13 KaiserBesselDerived,
15 Vorbis,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum BlockSize {
22 Long,
24 Short,
26 Start,
28 Stop,
30}
31
32impl BlockSize {
33 pub fn samples(self) -> usize {
35 match self {
36 BlockSize::Long | BlockSize::Start | BlockSize::Stop => 2048,
37 BlockSize::Short => 256,
38 }
39 }
40
41 pub fn coefficients(self) -> usize {
43 self.samples() / 2
44 }
45}
46
47struct MdctTransform {
49 n: usize,
51 n2: usize,
53 n4: usize,
55 window: Vec<f32>,
57 fft: Arc<dyn rustfft::Fft<f32>>,
59 twiddle: Vec<Complex<f32>>,
61}
62
63impl MdctTransform {
64 fn new(window_size: usize, window_type: WindowType) -> Self {
65 let n = window_size;
66 let n2 = n / 2;
67 let n4 = n / 4;
68
69 let window = match window_type {
71 WindowType::Sine => Self::sine_window(n),
72 WindowType::KaiserBesselDerived => Self::kbd_window(n, 4.0),
73 WindowType::Vorbis => Self::vorbis_window(n),
74 };
75
76 let mut planner = FftPlanner::new();
78 let fft = planner.plan_fft_forward(n4);
79
80 let twiddle: Vec<Complex<f32>> = (0..n4)
82 .map(|k| {
83 let theta = PI / n2 as f32 * (k as f32 + 0.125);
84 Complex::new(theta.cos(), theta.sin())
85 })
86 .collect();
87
88 Self {
89 n,
90 n2,
91 n4,
92 window,
93 fft,
94 twiddle,
95 }
96 }
97
98 fn sine_window(n: usize) -> Vec<f32> {
100 (0..n)
101 .map(|i| (PI * (i as f32 + 0.5) / n as f32).sin())
102 .collect()
103 }
104
105 fn vorbis_window(n: usize) -> Vec<f32> {
107 (0..n)
108 .map(|i| {
109 let x = (PI * (i as f32 + 0.5) / n as f32).sin();
110 (PI / 2.0 * x * x).sin()
111 })
112 .collect()
113 }
114
115 fn kbd_window(n: usize, alpha: f32) -> Vec<f32> {
117 let half = n / 2;
118
119 let kaiser: Vec<f32> = (0..=half)
121 .map(|i| {
122 Self::bessel_i0(
123 PI * alpha * (1.0 - (2.0 * i as f32 / half as f32 - 1.0).powi(2)).sqrt(),
124 )
125 })
126 .collect();
127
128 let mut cumsum = vec![0.0f32; half + 1];
130 cumsum[0] = kaiser[0];
131 for i in 1..=half {
132 cumsum[i] = cumsum[i - 1] + kaiser[i];
133 }
134 let total = cumsum[half];
135
136 let mut window = vec![0.0f32; n];
138 for i in 0..half {
139 window[i] = (cumsum[i] / total).sqrt();
140 window[n - 1 - i] = window[i];
141 }
142
143 window
144 }
145
146 fn bessel_i0(x: f32) -> f32 {
148 let mut sum = 1.0f32;
149 let mut term = 1.0f32;
150 let x_sq = x * x / 4.0;
151
152 for k in 1..20 {
153 term *= x_sq / (k * k) as f32;
154 sum += term;
155 if term < 1e-10 {
156 break;
157 }
158 }
159
160 sum
161 }
162
163 fn forward(&self, samples: &[f32]) -> Vec<f32> {
167 let n = self.n;
168 let n2 = self.n2;
169 let n4 = self.n4;
170 let n8 = n4 / 2;
171 let n3 = 3 * n4;
172
173 let x: Vec<f32> = samples
175 .iter()
176 .zip(self.window.iter())
177 .map(|(&s, &w)| s * w)
178 .collect();
179
180 let mut z: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n4];
182
183 for i in 0..n8 {
184 let re = -x[2 * i + n3] - x[n3 - 1 - 2 * i];
186 let im = -x[n4 + 2 * i] + x[n4 - 1 - 2 * i];
187
188 let w = &self.twiddle[i];
189 z[i] = Complex::new(-re * w.re - im * w.im, re * w.im - im * w.re);
190
191 let re2 = x[2 * i] - x[n2 - 1 - 2 * i];
193 let im2 = -x[n2 + 2 * i] - x[n - 1 - 2 * i];
194
195 let w2 = &self.twiddle[n8 + i];
196 z[n8 + i] = Complex::new(-re2 * w2.re - im2 * w2.im, re2 * w2.im - im2 * w2.re);
197 }
198
199 self.fft.process(&mut z);
201
202 let mut output = vec![0.0; n2];
204
205 for i in 0..n8 {
206 let idx1 = n8 - i - 1;
207 let idx2 = n8 + i;
208
209 let w1 = &self.twiddle[idx1];
210 let z1 = z[idx1];
211 let i1 = -z1.re * w1.im + z1.im * w1.re;
212 let r0 = -z1.re * w1.re - z1.im * w1.im;
213
214 let w2 = &self.twiddle[idx2];
215 let z2 = z[idx2];
216 let i0 = -z2.re * w2.im + z2.im * w2.re;
217 let r1 = -z2.re * w2.re - z2.im * w2.im;
218
219 output[2 * idx1] = r0;
220 output[2 * idx1 + 1] = i0;
221 output[2 * idx2] = r1;
222 output[2 * idx2 + 1] = i1;
223 }
224
225 output
226 }
227
228 fn inverse(&self, spec: &[f32]) -> Vec<f32> {
232 let n = self.n;
233 let n2 = self.n2;
234 let n4 = self.n4;
235 let n8 = n4 / 2;
236
237 let mut z: Vec<Complex<f32>> = Vec::with_capacity(n4);
239
240 for i in 0..n4 {
241 let even = spec[i * 2];
242 let odd = -spec[n2 - 1 - i * 2];
243
244 let w = &self.twiddle[i];
245 z.push(Complex::new(
246 odd * w.im - even * w.re,
247 odd * w.re + even * w.im,
248 ));
249 }
250
251 self.fft.process(&mut z);
253
254 let mut output = vec![0.0; n];
256 let scale = 2.0 / n2 as f32;
257
258 for i in 0..n8 {
260 let w = &self.twiddle[i];
261 let val_re = w.re * z[i].re + w.im * z[i].im;
262 let val_im = w.im * z[i].re - w.re * z[i].im;
263
264 let fi = 2 * i;
265 let ri = n4 - 1 - 2 * i;
266
267 output[ri] = -val_im * scale * self.window[ri];
268 output[n4 + fi] = val_im * scale * self.window[n4 + fi];
269 output[n2 + ri] = val_re * scale * self.window[n2 + ri];
270 output[n2 + n4 + fi] = val_re * scale * self.window[n2 + n4 + fi];
271 }
272
273 for i in 0..n8 {
275 let idx = n8 + i;
276 let w = &self.twiddle[idx];
277 let val_re = w.re * z[idx].re + w.im * z[idx].im;
278 let val_im = w.im * z[idx].re - w.re * z[idx].im;
279
280 let fi = 2 * i;
281 let ri = n4 - 1 - 2 * i;
282
283 output[fi] = -val_re * scale * self.window[fi];
284 output[n4 + ri] = val_re * scale * self.window[n4 + ri];
285 output[n2 + fi] = val_im * scale * self.window[n2 + fi];
286 output[n2 + n4 + ri] = val_im * scale * self.window[n2 + n4 + ri];
287 }
288
289 output
290 }
291}
292
293pub struct Mdct {
297 long_transform: MdctTransform,
299 short_transform: MdctTransform,
301 overlap_buffer: Vec<Vec<f32>>,
303 channels: usize,
305}
306
307impl Mdct {
308 pub fn new(channels: usize, window_type: WindowType) -> Self {
310 let long_transform = MdctTransform::new(2048, window_type);
311 let short_transform = MdctTransform::new(256, window_type);
312
313 let overlap_buffer = vec![vec![0.0f32; 1024]; channels];
315
316 Self {
317 long_transform,
318 short_transform,
319 overlap_buffer,
320 channels,
321 }
322 }
323
324 pub fn sine_window(n: usize) -> Vec<f32> {
326 MdctTransform::sine_window(n)
327 }
328
329 pub fn vorbis_window(n: usize) -> Vec<f32> {
331 MdctTransform::vorbis_window(n)
332 }
333
334 pub fn forward(&self, samples: &[f32], block_size: BlockSize) -> Vec<f32> {
338 let n = block_size.samples();
339 assert!(samples.len() >= n, "Not enough samples for MDCT");
340
341 let transform = match block_size {
342 BlockSize::Long | BlockSize::Start | BlockSize::Stop => &self.long_transform,
343 BlockSize::Short => &self.short_transform,
344 };
345
346 transform.forward(&samples[..n])
347 }
348
349 pub fn inverse(&self, coeffs: &[f32], block_size: BlockSize) -> Vec<f32> {
353 let n2 = block_size.coefficients();
354 assert!(coeffs.len() >= n2, "Not enough coefficients for IMDCT");
355
356 let transform = match block_size {
357 BlockSize::Long | BlockSize::Start | BlockSize::Stop => &self.long_transform,
358 BlockSize::Short => &self.short_transform,
359 };
360
361 transform.inverse(&coeffs[..n2])
362 }
363
364 pub fn process_frame(
367 &mut self,
368 samples: &[f32],
369 channel: usize,
370 block_size: BlockSize,
371 ) -> (Vec<f32>, Vec<f32>) {
372 let n = block_size.samples();
373 let n2 = n / 2;
374
375 let coeffs = self.forward(samples, block_size);
377
378 let reconstructed = self.inverse(&coeffs, block_size);
380
381 let mut output = vec![0.0f32; n2];
383 for i in 0..n2 {
384 output[i] = reconstructed[i] + self.overlap_buffer[channel][i];
385 }
386
387 self.overlap_buffer[channel].copy_from_slice(&reconstructed[n2..n2 + n2]);
389
390 (coeffs, output)
391 }
392
393 pub fn reset(&mut self) {
395 for buf in &mut self.overlap_buffer {
396 buf.fill(0.0);
397 }
398 }
399
400 pub fn analyze(&mut self, samples: &[f32], block_size: BlockSize) -> Vec<Vec<f32>> {
404 let n = block_size.samples();
405 let samples_per_channel = samples.len() / self.channels;
406
407 let mut channel_data: Vec<Vec<f32>> = (0..self.channels)
409 .map(|_| Vec::with_capacity(samples_per_channel))
410 .collect();
411
412 for (i, &s) in samples.iter().enumerate() {
413 channel_data[i % self.channels].push(s);
414 }
415
416 let mut all_coeffs = Vec::with_capacity(self.channels);
418 for data in &channel_data {
419 if data.len() >= n {
420 let coeffs = self.forward(data, block_size);
421 all_coeffs.push(coeffs);
422 } else {
423 let mut padded = data.clone();
425 padded.resize(n, 0.0);
426 let coeffs = self.forward(&padded, block_size);
427 all_coeffs.push(coeffs);
428 }
429 }
430
431 all_coeffs
432 }
433
434 pub fn synthesize(&mut self, coeffs: &[Vec<f32>], block_size: BlockSize) -> Vec<f32> {
438 let n = block_size.samples();
439 let n2 = n / 2;
440
441 let mut channel_outputs: Vec<Vec<f32>> = Vec::with_capacity(self.channels);
443
444 for (ch, ch_coeffs) in coeffs.iter().enumerate() {
445 let reconstructed = self.inverse(ch_coeffs, block_size);
446
447 let mut output = vec![0.0f32; n2];
449 for i in 0..n2 {
450 output[i] = reconstructed[i] + self.overlap_buffer[ch][i];
451 }
452
453 self.overlap_buffer[ch].copy_from_slice(&reconstructed[n2..n2 + n2]);
455
456 channel_outputs.push(output);
457 }
458
459 let mut output = Vec::with_capacity(n2 * self.channels);
461 for i in 0..n2 {
462 for ch in 0..self.channels {
463 output.push(channel_outputs[ch][i]);
464 }
465 }
466
467 output
468 }
469}