1use std::f32::consts::PI;
9
10pub const CELT_BANDS: [usize; 22] = [
15 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100,
16];
17
18const NUM_BANDS: usize = 21;
20
21#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct CeltFrameConfig {
24 pub frame_size: usize,
26 pub channels: u8,
28 pub start_band: u8,
30 pub end_band: u8,
32}
33
34impl Default for CeltFrameConfig {
35 fn default() -> Self {
37 Self {
38 frame_size: 960,
39 channels: 2,
40 start_band: 0,
41 end_band: NUM_BANDS as u8,
42 }
43 }
44}
45
46impl CeltFrameConfig {
47 pub fn with_frame_size(mut self, frame_size: usize) -> Self {
49 self.frame_size = frame_size;
50 self
51 }
52
53 pub fn with_channels(mut self, channels: u8) -> Self {
55 self.channels = channels;
56 self
57 }
58
59 pub fn with_start_band(mut self, start_band: u8) -> Self {
61 self.start_band = start_band;
62 self
63 }
64
65 pub fn with_end_band(mut self, end_band: u8) -> Self {
67 self.end_band = end_band;
68 self
69 }
70}
71
72#[derive(Debug, Clone)]
77pub struct CeltEnergy {
78 pub bands: [f32; NUM_BANDS],
80}
81
82impl CeltEnergy {
83 pub fn new() -> Self {
85 Self {
86 bands: [0.0f32; NUM_BANDS],
87 }
88 }
89
90 pub fn energy(&self, band: usize) -> f32 {
94 if band < NUM_BANDS {
95 self.bands[band]
96 } else {
97 0.0
98 }
99 }
100
101 pub fn set_energy(&mut self, band: usize, val: f32) {
105 if band < NUM_BANDS {
106 self.bands[band] = val;
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct CeltFrame {
114 pub config: CeltFrameConfig,
116 pub energy: CeltEnergy,
118 pub collapsed_mask: u32,
120 pub samples: Vec<f32>,
122}
123
124impl CeltFrame {
125 pub fn new(config: CeltFrameConfig) -> Self {
127 let sample_count = config.frame_size;
128 Self {
129 config,
130 energy: CeltEnergy::new(),
131 collapsed_mask: 0,
132 samples: vec![0.0f32; sample_count],
133 }
134 }
135
136 pub fn sample_count(&self) -> usize {
138 self.config.frame_size
139 }
140}
141
142#[derive(Debug)]
149pub struct CeltDecoder {
150 pub config: CeltFrameConfig,
152 pub prev_energy: CeltEnergy,
154}
155
156impl CeltDecoder {
157 pub fn new(config: CeltFrameConfig) -> Self {
159 Self {
160 config,
161 prev_energy: CeltEnergy::new(),
162 }
163 }
164
165 pub fn decode_frame(&mut self, data: &[u8]) -> Result<CeltFrame, String> {
176 let mut frame = CeltFrame::new(self.config.clone());
177
178 let start = self.config.start_band as usize;
179 let end = self.config.end_band as usize;
180
181 for band in start..end {
182 let band_idx = band - start;
183 let energy = if band_idx < data.len() {
184 let raw = data[band_idx] as i8;
185 raw as f32 / 16.0
186 } else {
187 0.0
188 };
189 let predicted = self.prev_energy.energy(band);
191 let new_energy = predicted + energy;
192 frame.energy.set_energy(band, new_energy);
193 self.prev_energy.set_energy(band, new_energy);
194 }
195
196 Ok(frame)
197 }
198
199 pub fn apply_mdct_inverse(&self, coeffs: &[f32]) -> Vec<f32> {
212 let n = coeffs.len();
213 if n == 0 {
214 return Vec::new();
215 }
216
217 let scale = (2.0f32 / n as f32).sqrt();
218 let mut output = vec![0.0f32; n];
219
220 for (idx, out) in output.iter_mut().enumerate() {
221 let n_f = n as f32;
222 let nn = idx as f32 + 0.5 + n_f / 2.0;
223 let mut acc = 0.0f32;
224 for (k, &coeff) in coeffs.iter().enumerate() {
225 let kk = k as f32 + 0.5;
226 acc += coeff * (PI / n_f * nn * kk).cos();
227 }
228 *out = scale * acc;
229 }
230
231 output
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_celt_bands_has_22_elements() {
241 assert_eq!(CELT_BANDS.len(), 22, "21 bands require 22 edges");
242 }
243
244 #[test]
245 fn test_celt_bands_starts_at_zero() {
246 assert_eq!(CELT_BANDS[0], 0);
247 }
248
249 #[test]
250 fn test_celt_bands_monotonically_increasing() {
251 for i in 1..CELT_BANDS.len() {
252 assert!(
253 CELT_BANDS[i] > CELT_BANDS[i - 1],
254 "CELT_BANDS[{}] = {} should be > CELT_BANDS[{}] = {}",
255 i,
256 CELT_BANDS[i],
257 i - 1,
258 CELT_BANDS[i - 1]
259 );
260 }
261 }
262
263 #[test]
264 fn test_celt_frame_config_default() {
265 let cfg = CeltFrameConfig::default();
266 assert_eq!(cfg.frame_size, 960);
267 assert_eq!(cfg.channels, 2);
268 assert_eq!(cfg.start_band, 0);
269 assert_eq!(cfg.end_band, 21);
270 }
271
272 #[test]
273 fn test_celt_frame_config_builder() {
274 let cfg = CeltFrameConfig::default()
275 .with_frame_size(480)
276 .with_channels(1)
277 .with_start_band(2)
278 .with_end_band(18);
279 assert_eq!(cfg.frame_size, 480);
280 assert_eq!(cfg.channels, 1);
281 assert_eq!(cfg.start_band, 2);
282 assert_eq!(cfg.end_band, 18);
283 }
284
285 #[test]
286 fn test_celt_energy_new_all_zero() {
287 let energy = CeltEnergy::new();
288 for band in 0..NUM_BANDS {
289 assert_eq!(energy.energy(band), 0.0);
290 }
291 }
292
293 #[test]
294 fn test_celt_energy_set_and_get() {
295 let mut energy = CeltEnergy::new();
296 energy.set_energy(5, 3.14);
297 assert!((energy.energy(5) - 3.14).abs() < 1e-6);
298 }
299
300 #[test]
301 fn test_celt_energy_out_of_range() {
302 let mut energy = CeltEnergy::new();
303 energy.set_energy(100, 99.0);
305 assert_eq!(energy.energy(100), 0.0);
307 }
308
309 #[test]
310 fn test_celt_frame_sample_count() {
311 let cfg = CeltFrameConfig::default().with_frame_size(480);
312 let frame = CeltFrame::new(cfg);
313 assert_eq!(frame.sample_count(), 480);
314 }
315
316 #[test]
317 fn test_celt_frame_sample_count_960() {
318 let cfg = CeltFrameConfig::default();
319 let frame = CeltFrame::new(cfg);
320 assert_eq!(frame.sample_count(), 960);
321 }
322
323 #[test]
324 fn test_celt_decoder_new() {
325 let cfg = CeltFrameConfig::default();
326 let dec = CeltDecoder::new(cfg.clone());
327 assert_eq!(dec.config, cfg);
328 }
329
330 #[test]
331 fn test_celt_decoder_decode_frame_returns_correct_size() {
332 let cfg = CeltFrameConfig::default().with_frame_size(480);
333 let mut dec = CeltDecoder::new(cfg);
334 let data = vec![0u8; 21];
335 let frame = dec.decode_frame(&data).expect("should succeed");
336 assert_eq!(frame.sample_count(), 480);
337 }
338
339 #[test]
340 fn test_celt_decoder_decode_frame_zero_data_zero_energy() {
341 let cfg = CeltFrameConfig::default();
342 let mut dec = CeltDecoder::new(cfg);
343 let data = vec![0u8; 21];
344 let frame = dec.decode_frame(&data).expect("should succeed");
345 for band in 0..NUM_BANDS {
346 assert_eq!(frame.energy.energy(band), 0.0);
347 }
348 }
349
350 #[test]
351 fn test_celt_decoder_apply_mdct_inverse_all_zero_input() {
352 let cfg = CeltFrameConfig::default();
353 let dec = CeltDecoder::new(cfg);
354 let coeffs = vec![0.0f32; 16];
356 let output = dec.apply_mdct_inverse(&coeffs);
357 assert_eq!(output.len(), 16);
358 for &sample in &output {
359 assert!(sample.abs() < 1e-6, "expected zero, got {}", sample);
360 }
361 }
362
363 #[test]
364 fn test_celt_decoder_apply_mdct_inverse_empty_input() {
365 let cfg = CeltFrameConfig::default();
366 let dec = CeltDecoder::new(cfg);
367 let output = dec.apply_mdct_inverse(&[]);
368 assert!(output.is_empty());
369 }
370
371 #[test]
372 fn test_celt_decoder_apply_mdct_inverse_nonzero() {
373 let cfg = CeltFrameConfig::default();
374 let dec = CeltDecoder::new(cfg);
375 let mut coeffs = vec![0.0f32; 8];
377 coeffs[0] = 1.0;
378 let output = dec.apply_mdct_inverse(&coeffs);
379 assert_eq!(output.len(), 8);
380 let any_nonzero = output.iter().any(|&s| s.abs() > 1e-6);
381 assert!(any_nonzero, "IMDCT of non-zero input must not be all-zero");
382 }
383}