Skip to main content

oximedia_codec/
flac_codec.rs

1//! Simplified FLAC codec using fixed linear prediction + Rice coding.
2//!
3//! This module implements a lightweight FLAC-style encoder/decoder based on
4//! fixed predictors (order 0-4) rather than full LPC analysis. It is useful
5//! for educational purposes and lightweight lossless audio compression.
6//!
7//! ## Fixed predictor formulas
8//!
9//! | Order | Residual formula                                           |
10//! |-------|------------------------------------------------------------|
11//! | 0     | `r[i] = s[i]`                                              |
12//! | 1     | `r[i] = s[i] - s[i-1]`                                     |
13//! | 2     | `r[i] = s[i] - 2*s[i-1] + s[i-2]`                         |
14//! | 3     | `r[i] = s[i] - 3*s[i-1] + 3*s[i-2] - s[i-3]`             |
15//! | 4     | `r[i] = s[i] - 4*s[i-1] + 6*s[i-2] - 4*s[i-3] + s[i-4]` |
16//!
17//! ## Rice coding
18//!
19//! Residuals are zigzag-folded (signed → unsigned) then split into a unary
20//! quotient and binary remainder using a Rice parameter `k`.
21
22#![forbid(unsafe_code)]
23
24// =============================================================================
25// Stream info & frame header
26// =============================================================================
27
28/// FLAC stream info metadata.
29#[derive(Debug, Clone)]
30pub struct FlacStreamInfo {
31    /// Minimum block size in samples.
32    pub min_block_size: u16,
33    /// Maximum block size in samples.
34    pub max_block_size: u16,
35    /// Sample rate in Hz.
36    pub sample_rate: u32,
37    /// Number of audio channels.
38    pub channels: u8,
39    /// Bits per sample (e.g. 16).
40    pub bits_per_sample: u8,
41    /// Total samples in stream (0 if unknown).
42    pub total_samples: u64,
43}
44
45/// FLAC frame header.
46#[derive(Debug, Clone)]
47pub struct FlacFrameHeader {
48    /// Block size (samples per channel) in this frame.
49    pub block_size: u16,
50    /// Sample rate in Hz.
51    pub sample_rate: u32,
52    /// Number of channels.
53    pub channels: u8,
54    /// Bits per sample.
55    pub bits_per_sample: u8,
56    /// Frame number (sequential).
57    pub frame_number: u32,
58}
59
60// =============================================================================
61// Encoder configuration
62// =============================================================================
63
64/// FLAC encoder configuration.
65#[derive(Debug, Clone)]
66pub struct FlacEncoderConfig {
67    /// Sample rate in Hz.
68    pub sample_rate: u32,
69    /// Number of audio channels.
70    pub channels: u8,
71    /// Bits per sample (8, 16, 24).
72    pub bits_per_sample: u8,
73    /// Block size in samples per channel (default 4096).
74    pub block_size: u16,
75    /// Compression level 0-8 (higher = try more predictor orders). Default 5.
76    pub compression_level: u8,
77}
78
79impl Default for FlacEncoderConfig {
80    fn default() -> Self {
81        Self {
82            sample_rate: 44100,
83            channels: 2,
84            bits_per_sample: 16,
85            block_size: 4096,
86            compression_level: 5,
87        }
88    }
89}
90
91// =============================================================================
92// Rice coding
93// =============================================================================
94
95/// Zigzag-encode a signed value to unsigned: `n >= 0 → 2n`, `n < 0 → 2|n|-1`.
96#[inline]
97fn zigzag_encode(v: i32) -> u32 {
98    if v >= 0 {
99        (v as u32) << 1
100    } else {
101        ((-v - 1) as u32) << 1 | 1
102    }
103}
104
105/// Zigzag-decode an unsigned value back to signed.
106#[inline]
107fn zigzag_decode(u: u32) -> i32 {
108    if u & 1 == 0 {
109        (u >> 1) as i32
110    } else {
111        -((u >> 1) as i32) - 1
112    }
113}
114
115/// Rice-encode a sequence of residuals with parameter `k`.
116///
117/// Each residual is zigzag-folded, then the quotient (`zigzag >> k`) is
118/// unary-coded (q zeros followed by a 1) and the remainder is binary-coded
119/// in `k` bits. Bits are packed MSB-first into bytes.
120fn rice_encode(residuals: &[i32], rice_param: u8) -> Vec<u8> {
121    let k = rice_param;
122    let mut bits: Vec<bool> = Vec::new();
123
124    for &r in residuals {
125        let u = zigzag_encode(r);
126        let quotient = u >> k;
127        let remainder = u & ((1u32 << k).wrapping_sub(1));
128
129        // Unary: quotient zeros then a 1
130        for _ in 0..quotient {
131            bits.push(false);
132        }
133        bits.push(true);
134
135        // Binary remainder (MSB first)
136        for bit_idx in (0..k).rev() {
137            bits.push((remainder >> bit_idx) & 1 != 0);
138        }
139    }
140
141    // Pack into bytes
142    let mut out = Vec::with_capacity((bits.len() + 7) / 8);
143    let mut byte = 0u8;
144    let mut fill = 0u8;
145    for bit in bits {
146        byte = (byte << 1) | u8::from(bit);
147        fill += 1;
148        if fill == 8 {
149            out.push(byte);
150            byte = 0;
151            fill = 0;
152        }
153    }
154    if fill > 0 {
155        out.push(byte << (8 - fill));
156    }
157    out
158}
159
160/// Rice-decode `count` residuals from `data` with parameter `k`.
161fn rice_decode(data: &[u8], count: usize, rice_param: u8) -> Result<Vec<i32>, String> {
162    let k = rice_param;
163    let mut byte_pos = 0usize;
164    let mut bit_pos = 0u8;
165
166    let read_bit = |bp: &mut usize, bi: &mut u8| -> Result<bool, String> {
167        if *bp >= data.len() {
168            return Err("Rice decode: unexpected end of data".to_string());
169        }
170        let bit = (data[*bp] >> (7 - *bi)) & 1 != 0;
171        *bi += 1;
172        if *bi == 8 {
173            *bp += 1;
174            *bi = 0;
175        }
176        Ok(bit)
177    };
178
179    let mut out = Vec::with_capacity(count);
180    for _ in 0..count {
181        // Read unary quotient: count zeros until a 1
182        let mut quotient = 0u32;
183        loop {
184            let bit = read_bit(&mut byte_pos, &mut bit_pos)?;
185            if bit {
186                break;
187            }
188            quotient += 1;
189            if quotient > 1_048_576 {
190                return Err("Rice decode: quotient overflow (corrupt data)".to_string());
191            }
192        }
193
194        // Read k-bit remainder
195        let mut remainder = 0u32;
196        for _ in 0..k {
197            let bit = read_bit(&mut byte_pos, &mut bit_pos)?;
198            remainder = (remainder << 1) | u32::from(bit);
199        }
200
201        let u = (quotient << k) | remainder;
202        out.push(zigzag_decode(u));
203    }
204
205    Ok(out)
206}
207
208// =============================================================================
209// Fixed linear prediction
210// =============================================================================
211
212/// Compute the optimal fixed predictor order (0-4) for a block of samples.
213///
214/// Tries each order, picks the one yielding the smallest sum-of-absolute
215/// residuals (SAR). `compression_level` limits the maximum order tested
216/// (level 0-1 → max order 1, 2-4 → max order 2, 5-8 → max order 4).
217fn optimal_predictor_order(samples: &[i16], compression_level: u8) -> u8 {
218    if samples.is_empty() {
219        return 0;
220    }
221    let max_order = match compression_level {
222        0..=1 => 1u8,
223        2..=4 => 2,
224        _ => 4,
225    };
226    let max_order = max_order.min(samples.len().saturating_sub(1) as u8).min(4);
227
228    let mut best_order = 0u8;
229    let mut best_cost = u64::MAX;
230
231    for order in 0..=max_order {
232        let residuals = fixed_predict(samples, order);
233        let cost: u64 = residuals.iter().map(|r| r.unsigned_abs() as u64).sum();
234        if cost < best_cost {
235            best_cost = cost;
236            best_order = order;
237        }
238    }
239    best_order
240}
241
242/// Apply fixed linear prediction of the given order, returning residuals.
243///
244/// The first `order` samples are warmup and not included in residuals.
245fn fixed_predict(samples: &[i16], order: u8) -> Vec<i32> {
246    let n = samples.len();
247    let o = order as usize;
248    if n <= o {
249        return Vec::new();
250    }
251    let s: Vec<i32> = samples.iter().map(|&v| v as i32).collect();
252
253    let mut residuals = Vec::with_capacity(n - o);
254    for i in o..n {
255        let r = match order {
256            0 => s[i],
257            1 => s[i] - s[i - 1],
258            2 => s[i] - 2 * s[i - 1] + s[i - 2],
259            3 => s[i] - 3 * s[i - 1] + 3 * s[i - 2] - s[i - 3],
260            4 => s[i] - 4 * s[i - 1] + 6 * s[i - 2] - 4 * s[i - 3] + s[i - 4],
261            _ => s[i],
262        };
263        residuals.push(r);
264    }
265    residuals
266}
267
268/// Undo fixed linear prediction — reconstruct samples from residuals + warmup.
269fn fixed_restore(residuals: &[i32], order: u8, warmup: &[i16]) -> Vec<i16> {
270    let o = order as usize;
271    let mut out: Vec<i32> = warmup.iter().map(|&v| v as i32).collect();
272
273    for &r in residuals {
274        let n = out.len();
275        let sample = match order {
276            0 => r,
277            1 => r + out[n - 1],
278            2 => r + 2 * out[n - 1] - out[n - 2],
279            3 => r + 3 * out[n - 1] - 3 * out[n - 2] + out[n - 3],
280            4 => r + 4 * out[n - 1] - 6 * out[n - 2] + 4 * out[n - 3] - out[n - 4],
281            _ => r,
282        };
283        out.push(sample);
284    }
285
286    out.iter().map(|&v| v as i16).collect()
287}
288
289/// Select optimal Rice parameter for a set of residuals.
290fn optimal_rice_param(residuals: &[i32]) -> u8 {
291    if residuals.is_empty() {
292        return 0;
293    }
294    let mut best_k = 0u8;
295    let mut best_cost = u64::MAX;
296    for k in 0..=14u8 {
297        let cost: u64 = residuals
298            .iter()
299            .map(|&r| {
300                let u = zigzag_encode(r);
301                1u64 + u64::from(k) + u64::from(u >> k)
302            })
303            .sum();
304        if cost < best_cost {
305            best_cost = cost;
306            best_k = k;
307        }
308    }
309    best_k
310}
311
312// =============================================================================
313// Frame-level encode / decode
314// =============================================================================
315
316/// Encode i16 samples into a single FLAC-style frame.
317///
318/// Uses fixed linear prediction (order 0-4) + Rice coding.
319///
320/// Binary layout:
321/// ```text
322/// [1B order] [1B rice_param] [2B warmup_count]
323/// [warmup_count * 2B warmup samples (big-endian i16)]
324/// [4B residual_count (big-endian u32)]
325/// [4B rice_byte_len (big-endian u32)]
326/// [rice_byte_len bytes of Rice-coded residuals]
327/// ```
328pub fn encode_flac_frame(samples: &[i16], config: &FlacEncoderConfig) -> Vec<u8> {
329    if samples.is_empty() {
330        // Minimal empty frame: order=0, k=0, 0 warmup, 0 residuals, 0 rice bytes
331        return vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
332    }
333
334    let order = optimal_predictor_order(samples, config.compression_level);
335    let residuals = fixed_predict(samples, order);
336    let k = optimal_rice_param(&residuals);
337    let rice_bytes = rice_encode(&residuals, k);
338
339    let warmup_count = (order as usize).min(samples.len());
340    let mut out = Vec::new();
341
342    // Header
343    out.push(order);
344    out.push(k);
345    let wc = warmup_count as u16;
346    out.extend_from_slice(&wc.to_be_bytes());
347
348    // Warmup samples
349    for &s in &samples[..warmup_count] {
350        out.extend_from_slice(&s.to_be_bytes());
351    }
352
353    // Residual count
354    let rc = residuals.len() as u32;
355    out.extend_from_slice(&rc.to_be_bytes());
356
357    // Rice data length + data
358    let rl = rice_bytes.len() as u32;
359    out.extend_from_slice(&rl.to_be_bytes());
360    out.extend_from_slice(&rice_bytes);
361
362    out
363}
364
365/// Decode a FLAC-style frame back to i16 samples.
366pub fn decode_flac_frame(data: &[u8], _info: &FlacStreamInfo) -> Result<Vec<i16>, String> {
367    if data.len() < 12 {
368        return Err("FLAC frame too short".to_string());
369    }
370
371    let order = data[0];
372    if order > 4 {
373        return Err(format!("Invalid predictor order: {order}"));
374    }
375    let k = data[1];
376    if k > 30 {
377        return Err(format!("Invalid Rice parameter: {k}"));
378    }
379    let warmup_count = u16::from_be_bytes([data[2], data[3]]) as usize;
380
381    let mut pos = 4;
382
383    // Read warmup
384    if pos + warmup_count * 2 > data.len() {
385        return Err("FLAC frame: warmup overruns data".to_string());
386    }
387    let mut warmup = Vec::with_capacity(warmup_count);
388    for _ in 0..warmup_count {
389        let s = i16::from_be_bytes([data[pos], data[pos + 1]]);
390        warmup.push(s);
391        pos += 2;
392    }
393
394    // Residual count
395    if pos + 4 > data.len() {
396        return Err("FLAC frame: missing residual count".to_string());
397    }
398    let residual_count =
399        u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
400    pos += 4;
401
402    // Rice data length
403    if pos + 4 > data.len() {
404        return Err("FLAC frame: missing rice length".to_string());
405    }
406    let rice_len =
407        u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
408    pos += 4;
409
410    if pos + rice_len > data.len() {
411        return Err("FLAC frame: rice data overruns frame".to_string());
412    }
413    let rice_data = &data[pos..pos + rice_len];
414
415    // Decode residuals
416    let residuals = if residual_count == 0 {
417        Vec::new()
418    } else {
419        rice_decode(rice_data, residual_count, k)?
420    };
421
422    // Restore signal
423    let samples = fixed_restore(&residuals, order, &warmup);
424    Ok(samples)
425}
426
427// =============================================================================
428// Tests
429// =============================================================================
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_fixed_predict_order0() {
437        let samples: Vec<i16> = vec![10, 20, 30, 40, 50];
438        let residuals = fixed_predict(&samples, 0);
439        // Order 0: residual = sample itself
440        let expected: Vec<i32> = vec![10, 20, 30, 40, 50];
441        assert_eq!(residuals, expected, "Order 0 should be identity");
442    }
443
444    #[test]
445    fn test_fixed_predict_order1_constant() {
446        let samples: Vec<i16> = vec![100, 100, 100, 100];
447        let residuals = fixed_predict(&samples, 1);
448        // s[i] - s[i-1] = 0 for constant
449        assert!(
450            residuals.iter().all(|&r| r == 0),
451            "Constant signal should produce all-zero order-1 residuals: {:?}",
452            residuals
453        );
454    }
455
456    #[test]
457    fn test_fixed_predict_restore_roundtrip() {
458        let samples: Vec<i16> = vec![10, -5, 300, -200, 0, 127, -128, 500];
459        for order in 0..=4u8 {
460            if samples.len() <= order as usize {
461                continue;
462            }
463            let residuals = fixed_predict(&samples, order);
464            let warmup = &samples[..order as usize];
465            let restored = fixed_restore(&residuals, order, warmup);
466            assert_eq!(
467                restored, samples,
468                "Order {order} predict-restore roundtrip must be lossless"
469            );
470        }
471    }
472
473    #[test]
474    fn test_rice_encode_decode_roundtrip() {
475        let residuals = vec![0i32, 1, -1, 5, -5, 100, -100, 0];
476        for k in 0..=6u8 {
477            let encoded = rice_encode(&residuals, k);
478            let decoded =
479                rice_decode(&encoded, residuals.len(), k).expect("rice decode should succeed");
480            assert_eq!(decoded, residuals, "Rice roundtrip failed for k={k}");
481        }
482    }
483
484    #[test]
485    fn test_rice_encode_zeros() {
486        let zeros = vec![0i32; 64];
487        let k = optimal_rice_param(&zeros);
488        assert_eq!(k, 0, "All zeros should use k=0");
489        let encoded = rice_encode(&zeros, k);
490        // k=0: each zero is zigzag(0)=0, unary 0 = just a '1' bit → 1 bit each
491        // 64 bits = 8 bytes
492        assert_eq!(
493            encoded.len(),
494            8,
495            "64 zero residuals at k=0 should be 8 bytes"
496        );
497    }
498
499    #[test]
500    fn test_optimal_predictor_silence() {
501        let silence: Vec<i16> = vec![0; 128];
502        let order = optimal_predictor_order(&silence, 5);
503        assert_eq!(order, 0, "Silence should pick order 0");
504    }
505
506    #[test]
507    fn test_optimal_predictor_linear_ramp() {
508        let ramp: Vec<i16> = (0..128).map(|i| i as i16).collect();
509        let order = optimal_predictor_order(&ramp, 5);
510        // A linear ramp has zero residuals at order >= 1; order 1 should win
511        assert!(
512            order >= 1,
513            "Linear ramp should pick order >= 1, got {order}"
514        );
515    }
516
517    #[test]
518    fn test_encode_decode_frame_roundtrip() {
519        let config = FlacEncoderConfig::default();
520        let info = FlacStreamInfo {
521            min_block_size: 4096,
522            max_block_size: 4096,
523            sample_rate: 44100,
524            channels: 2,
525            bits_per_sample: 16,
526            total_samples: 0,
527        };
528
529        // Generate a test signal: ramp + sine
530        let samples: Vec<i16> = (0..512)
531            .map(|i| {
532                let ramp = (i as f64 / 512.0 * 1000.0) as i16;
533                let sine = (100.0 * (i as f64 * 0.1).sin()) as i16;
534                ramp.saturating_add(sine)
535            })
536            .collect();
537
538        let encoded = encode_flac_frame(&samples, &config);
539        let decoded = decode_flac_frame(&encoded, &info).expect("decode should succeed");
540        assert_eq!(
541            decoded, samples,
542            "Frame encode-decode roundtrip must be lossless"
543        );
544    }
545
546    #[test]
547    fn test_flac_config_default() {
548        let config = FlacEncoderConfig::default();
549        assert_eq!(config.sample_rate, 44100);
550        assert_eq!(config.channels, 2);
551        assert_eq!(config.bits_per_sample, 16);
552        assert_eq!(config.block_size, 4096);
553        assert_eq!(config.compression_level, 5);
554    }
555
556    #[test]
557    fn test_encode_empty_block() {
558        let config = FlacEncoderConfig::default();
559        let info = FlacStreamInfo {
560            min_block_size: 0,
561            max_block_size: 0,
562            sample_rate: 44100,
563            channels: 1,
564            bits_per_sample: 16,
565            total_samples: 0,
566        };
567
568        let encoded = encode_flac_frame(&[], &config);
569        assert!(
570            !encoded.is_empty(),
571            "Empty input should still produce a frame header"
572        );
573        let decoded = decode_flac_frame(&encoded, &info).expect("decode empty should succeed");
574        assert!(
575            decoded.is_empty(),
576            "Decoded empty frame should have no samples"
577        );
578    }
579
580    #[test]
581    fn test_zigzag_roundtrip() {
582        for v in [-1000i32, -1, 0, 1, 1000, i16::MIN as i32, i16::MAX as i32] {
583            let u = zigzag_encode(v);
584            let back = zigzag_decode(u);
585            assert_eq!(back, v, "zigzag roundtrip failed for {v}");
586        }
587    }
588
589    #[test]
590    fn test_fixed_predict_order2_quadratic() {
591        // Quadratic: s[i] = i^2 → second differences are constant
592        let samples: Vec<i16> = (0..20).map(|i: i16| i * i).collect();
593        let residuals = fixed_predict(&samples, 2);
594        // After warmup of 2, all second-order residuals for a quadratic should be constant (= 2)
595        let all_two = residuals.iter().all(|&r| r == 2);
596        assert!(
597            all_two,
598            "Quadratic signal order-2 residuals should all be 2: {:?}",
599            residuals
600        );
601    }
602
603    #[test]
604    fn test_encode_decode_large_block() {
605        let config = FlacEncoderConfig {
606            compression_level: 8,
607            ..FlacEncoderConfig::default()
608        };
609        let info = FlacStreamInfo {
610            min_block_size: 4096,
611            max_block_size: 4096,
612            sample_rate: 44100,
613            channels: 1,
614            bits_per_sample: 16,
615            total_samples: 4096,
616        };
617        let samples: Vec<i16> = (0..4096)
618            .map(|i| (1000.0 * (i as f64 * 0.05).sin()) as i16)
619            .collect();
620        let encoded = encode_flac_frame(&samples, &config);
621        // Compressed should be smaller than raw (4096 * 2 = 8192 bytes)
622        assert!(
623            encoded.len() < 8192,
624            "Compressed frame ({} bytes) should be smaller than raw (8192)",
625            encoded.len()
626        );
627        let decoded = decode_flac_frame(&encoded, &info).expect("decode");
628        assert_eq!(decoded, samples);
629    }
630}