Skip to main content

oximedia_codec/
celt.rs

1//! Standalone CELT frame decoding types and scaffolding.
2//!
3//! CELT (Constrained Energy Lapped Transform) is the music/wideband codec
4//! used within Opus. This module provides lightweight frame-level types for
5//! energy bookkeeping and the MDCT-IV inverse transform, independent of the
6//! full Opus decoder pipeline found in `crate::opus::celt`.
7
8use std::f32::consts::PI;
9
10/// CELT band-edge positions in MDCT bins for a 48 kHz, 20 ms (960-sample) frame.
11///
12/// There are 21 bands bounded by 22 edges. The first edge is bin 0 and the
13/// last is bin 100 (equivalent to 10 kHz at 48 kHz).
14pub const CELT_BANDS: [usize; 22] = [
15    0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100,
16];
17
18/// Number of CELT frequency bands.
19const NUM_BANDS: usize = 21;
20
21/// Configuration for a CELT frame decoder.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct CeltFrameConfig {
24    /// Frame size in samples (120, 240, 480, 960, or 1920).
25    pub frame_size: usize,
26    /// Number of audio channels (1 or 2).
27    pub channels: u8,
28    /// First band index (for hybrid Opus mode).
29    pub start_band: u8,
30    /// One-past-last band index.
31    pub end_band: u8,
32}
33
34impl Default for CeltFrameConfig {
35    /// Default configuration: 960-sample stereo frame covering all 21 bands.
36    fn default() -> Self {
37        Self {
38            frame_size: 960,
39            channels: 2,
40            start_band: 0,
41            end_band: NUM_BANDS as u8,
42        }
43    }
44}
45
46impl CeltFrameConfig {
47    /// Sets the frame size and returns `self` for chaining.
48    pub fn with_frame_size(mut self, frame_size: usize) -> Self {
49        self.frame_size = frame_size;
50        self
51    }
52
53    /// Sets the channel count and returns `self` for chaining.
54    pub fn with_channels(mut self, channels: u8) -> Self {
55        self.channels = channels;
56        self
57    }
58
59    /// Sets the start band and returns `self` for chaining.
60    pub fn with_start_band(mut self, start_band: u8) -> Self {
61        self.start_band = start_band;
62        self
63    }
64
65    /// Sets the end band and returns `self` for chaining.
66    pub fn with_end_band(mut self, end_band: u8) -> Self {
67        self.end_band = end_band;
68        self
69    }
70}
71
72/// Per-band log-domain energy for a CELT frame.
73///
74/// All 21 CELT bands are tracked. Energy values are in the log domain
75/// (natural logarithm of linear energy) as used internally by CELT.
76#[derive(Debug, Clone)]
77pub struct CeltEnergy {
78    /// Per-band energy values (log domain), indexed 0..21.
79    pub bands: [f32; NUM_BANDS],
80}
81
82impl CeltEnergy {
83    /// Creates a `CeltEnergy` with all bands initialised to zero.
84    pub fn new() -> Self {
85        Self {
86            bands: [0.0f32; NUM_BANDS],
87        }
88    }
89
90    /// Returns the energy for the given band index.
91    ///
92    /// Returns `0.0` for out-of-range indices.
93    pub fn energy(&self, band: usize) -> f32 {
94        if band < NUM_BANDS {
95            self.bands[band]
96        } else {
97            0.0
98        }
99    }
100
101    /// Sets the energy for the given band index.
102    ///
103    /// Out-of-range indices are silently ignored.
104    pub fn set_energy(&mut self, band: usize, val: f32) {
105        if band < NUM_BANDS {
106            self.bands[band] = val;
107        }
108    }
109}
110
111/// A decoded CELT frame.
112#[derive(Debug, Clone)]
113pub struct CeltFrame {
114    /// Frame configuration.
115    pub config: CeltFrameConfig,
116    /// Per-band energy decoded from the bitstream.
117    pub energy: CeltEnergy,
118    /// Bitmask of collapsed (zeroed) bands.
119    pub collapsed_mask: u32,
120    /// Decoded output samples.
121    pub samples: Vec<f32>,
122}
123
124impl CeltFrame {
125    /// Creates a new zeroed `CeltFrame` for the given configuration.
126    pub fn new(config: CeltFrameConfig) -> Self {
127        let sample_count = config.frame_size;
128        Self {
129            config,
130            energy: CeltEnergy::new(),
131            collapsed_mask: 0,
132            samples: vec![0.0f32; sample_count],
133        }
134    }
135
136    /// Returns the number of samples in this frame (per channel).
137    pub fn sample_count(&self) -> usize {
138        self.config.frame_size
139    }
140}
141
142/// CELT frame decoder scaffold.
143///
144/// Provides energy decoding and the MDCT-IV inverse transform. Full
145/// entropy-coded CELT decoding (PVQ, fine energy, band prediction …) is
146/// provided by `crate::opus::celt`; this type is intentionally lightweight
147/// and intended for testing and scaffolding.
148#[derive(Debug)]
149pub struct CeltDecoder {
150    /// Frame configuration.
151    pub config: CeltFrameConfig,
152    /// Per-band energy state carried across frames.
153    pub prev_energy: CeltEnergy,
154}
155
156impl CeltDecoder {
157    /// Creates a new `CeltDecoder` for the given configuration.
158    pub fn new(config: CeltFrameConfig) -> Self {
159        Self {
160            config,
161            prev_energy: CeltEnergy::new(),
162        }
163    }
164
165    /// Parses energy values from the first bytes of `data` and returns a
166    /// `CeltFrame` with zeroed samples.
167    ///
168    /// Each active band contributes one byte to the energy encoding: the byte
169    /// is interpreted as a signed i8 and scaled by 1/16 to produce a
170    /// log-domain energy value. Bands beyond the end of `data` default to
171    /// `0.0`.
172    ///
173    /// Full PVQ coefficient decoding is not performed; this method is a
174    /// scaffold.
175    pub fn decode_frame(&mut self, data: &[u8]) -> Result<CeltFrame, String> {
176        let mut frame = CeltFrame::new(self.config.clone());
177
178        let start = self.config.start_band as usize;
179        let end = self.config.end_band as usize;
180
181        for band in start..end {
182            let band_idx = band - start;
183            let energy = if band_idx < data.len() {
184                let raw = data[band_idx] as i8;
185                raw as f32 / 16.0
186            } else {
187                0.0
188            };
189            // Delta from previous frame (simple inter-frame prediction).
190            let predicted = self.prev_energy.energy(band);
191            let new_energy = predicted + energy;
192            frame.energy.set_energy(band, new_energy);
193            self.prev_energy.set_energy(band, new_energy);
194        }
195
196        Ok(frame)
197    }
198
199    /// Computes the Type-IV MDCT inverse transform (IMDCT-IV).
200    ///
201    /// Given `N = coeffs.len()` spectral coefficients `X[k]`, the output
202    /// time-domain samples are:
203    ///
204    /// ```text
205    /// x[n] = sqrt(2/N) * sum_{k=0}^{N-1} X[k] * cos(π/N * (n + 0.5 + N/2) * (k + 0.5))
206    /// ```
207    ///
208    /// for `n = 0 .. N-1`.
209    ///
210    /// Returns an empty vector if `coeffs` is empty.
211    pub fn apply_mdct_inverse(&self, coeffs: &[f32]) -> Vec<f32> {
212        let n = coeffs.len();
213        if n == 0 {
214            return Vec::new();
215        }
216
217        let scale = (2.0f32 / n as f32).sqrt();
218        let mut output = vec![0.0f32; n];
219
220        for (idx, out) in output.iter_mut().enumerate() {
221            let n_f = n as f32;
222            let nn = idx as f32 + 0.5 + n_f / 2.0;
223            let mut acc = 0.0f32;
224            for (k, &coeff) in coeffs.iter().enumerate() {
225                let kk = k as f32 + 0.5;
226                acc += coeff * (PI / n_f * nn * kk).cos();
227            }
228            *out = scale * acc;
229        }
230
231        output
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_celt_bands_has_22_elements() {
241        assert_eq!(CELT_BANDS.len(), 22, "21 bands require 22 edges");
242    }
243
244    #[test]
245    fn test_celt_bands_starts_at_zero() {
246        assert_eq!(CELT_BANDS[0], 0);
247    }
248
249    #[test]
250    fn test_celt_bands_monotonically_increasing() {
251        for i in 1..CELT_BANDS.len() {
252            assert!(
253                CELT_BANDS[i] > CELT_BANDS[i - 1],
254                "CELT_BANDS[{}] = {} should be > CELT_BANDS[{}] = {}",
255                i,
256                CELT_BANDS[i],
257                i - 1,
258                CELT_BANDS[i - 1]
259            );
260        }
261    }
262
263    #[test]
264    fn test_celt_frame_config_default() {
265        let cfg = CeltFrameConfig::default();
266        assert_eq!(cfg.frame_size, 960);
267        assert_eq!(cfg.channels, 2);
268        assert_eq!(cfg.start_band, 0);
269        assert_eq!(cfg.end_band, 21);
270    }
271
272    #[test]
273    fn test_celt_frame_config_builder() {
274        let cfg = CeltFrameConfig::default()
275            .with_frame_size(480)
276            .with_channels(1)
277            .with_start_band(2)
278            .with_end_band(18);
279        assert_eq!(cfg.frame_size, 480);
280        assert_eq!(cfg.channels, 1);
281        assert_eq!(cfg.start_band, 2);
282        assert_eq!(cfg.end_band, 18);
283    }
284
285    #[test]
286    fn test_celt_energy_new_all_zero() {
287        let energy = CeltEnergy::new();
288        for band in 0..NUM_BANDS {
289            assert_eq!(energy.energy(band), 0.0);
290        }
291    }
292
293    #[test]
294    fn test_celt_energy_set_and_get() {
295        let mut energy = CeltEnergy::new();
296        energy.set_energy(5, 3.14);
297        assert!((energy.energy(5) - 3.14).abs() < 1e-6);
298    }
299
300    #[test]
301    fn test_celt_energy_out_of_range() {
302        let mut energy = CeltEnergy::new();
303        // Out-of-range set should not panic.
304        energy.set_energy(100, 99.0);
305        // Out-of-range get should return 0.
306        assert_eq!(energy.energy(100), 0.0);
307    }
308
309    #[test]
310    fn test_celt_frame_sample_count() {
311        let cfg = CeltFrameConfig::default().with_frame_size(480);
312        let frame = CeltFrame::new(cfg);
313        assert_eq!(frame.sample_count(), 480);
314    }
315
316    #[test]
317    fn test_celt_frame_sample_count_960() {
318        let cfg = CeltFrameConfig::default();
319        let frame = CeltFrame::new(cfg);
320        assert_eq!(frame.sample_count(), 960);
321    }
322
323    #[test]
324    fn test_celt_decoder_new() {
325        let cfg = CeltFrameConfig::default();
326        let dec = CeltDecoder::new(cfg.clone());
327        assert_eq!(dec.config, cfg);
328    }
329
330    #[test]
331    fn test_celt_decoder_decode_frame_returns_correct_size() {
332        let cfg = CeltFrameConfig::default().with_frame_size(480);
333        let mut dec = CeltDecoder::new(cfg);
334        let data = vec![0u8; 21];
335        let frame = dec.decode_frame(&data).expect("should succeed");
336        assert_eq!(frame.sample_count(), 480);
337    }
338
339    #[test]
340    fn test_celt_decoder_decode_frame_zero_data_zero_energy() {
341        let cfg = CeltFrameConfig::default();
342        let mut dec = CeltDecoder::new(cfg);
343        let data = vec![0u8; 21];
344        let frame = dec.decode_frame(&data).expect("should succeed");
345        for band in 0..NUM_BANDS {
346            assert_eq!(frame.energy.energy(band), 0.0);
347        }
348    }
349
350    #[test]
351    fn test_celt_decoder_apply_mdct_inverse_all_zero_input() {
352        let cfg = CeltFrameConfig::default();
353        let dec = CeltDecoder::new(cfg);
354        // All-zero coefficients must produce all-zero output.
355        let coeffs = vec![0.0f32; 16];
356        let output = dec.apply_mdct_inverse(&coeffs);
357        assert_eq!(output.len(), 16);
358        for &sample in &output {
359            assert!(sample.abs() < 1e-6, "expected zero, got {}", sample);
360        }
361    }
362
363    #[test]
364    fn test_celt_decoder_apply_mdct_inverse_empty_input() {
365        let cfg = CeltFrameConfig::default();
366        let dec = CeltDecoder::new(cfg);
367        let output = dec.apply_mdct_inverse(&[]);
368        assert!(output.is_empty());
369    }
370
371    #[test]
372    fn test_celt_decoder_apply_mdct_inverse_nonzero() {
373        let cfg = CeltFrameConfig::default();
374        let dec = CeltDecoder::new(cfg);
375        // Single non-zero DC coefficient should yield non-zero output.
376        let mut coeffs = vec![0.0f32; 8];
377        coeffs[0] = 1.0;
378        let output = dec.apply_mdct_inverse(&coeffs);
379        assert_eq!(output.len(), 8);
380        let any_nonzero = output.iter().any(|&s| s.abs() > 1e-6);
381        assert!(any_nonzero, "IMDCT of non-zero input must not be all-zero");
382    }
383}