1#![forbid(unsafe_code)]
23
24#[derive(Debug, Clone)]
30pub struct FlacStreamInfo {
31 pub min_block_size: u16,
33 pub max_block_size: u16,
35 pub sample_rate: u32,
37 pub channels: u8,
39 pub bits_per_sample: u8,
41 pub total_samples: u64,
43}
44
45#[derive(Debug, Clone)]
47pub struct FlacFrameHeader {
48 pub block_size: u16,
50 pub sample_rate: u32,
52 pub channels: u8,
54 pub bits_per_sample: u8,
56 pub frame_number: u32,
58}
59
60#[derive(Debug, Clone)]
66pub struct FlacEncoderConfig {
67 pub sample_rate: u32,
69 pub channels: u8,
71 pub bits_per_sample: u8,
73 pub block_size: u16,
75 pub compression_level: u8,
77}
78
79impl Default for FlacEncoderConfig {
80 fn default() -> Self {
81 Self {
82 sample_rate: 44100,
83 channels: 2,
84 bits_per_sample: 16,
85 block_size: 4096,
86 compression_level: 5,
87 }
88 }
89}
90
91#[inline]
97fn zigzag_encode(v: i32) -> u32 {
98 if v >= 0 {
99 (v as u32) << 1
100 } else {
101 ((-v - 1) as u32) << 1 | 1
102 }
103}
104
105#[inline]
107fn zigzag_decode(u: u32) -> i32 {
108 if u & 1 == 0 {
109 (u >> 1) as i32
110 } else {
111 -((u >> 1) as i32) - 1
112 }
113}
114
115fn rice_encode(residuals: &[i32], rice_param: u8) -> Vec<u8> {
121 let k = rice_param;
122 let mut bits: Vec<bool> = Vec::new();
123
124 for &r in residuals {
125 let u = zigzag_encode(r);
126 let quotient = u >> k;
127 let remainder = u & ((1u32 << k).wrapping_sub(1));
128
129 for _ in 0..quotient {
131 bits.push(false);
132 }
133 bits.push(true);
134
135 for bit_idx in (0..k).rev() {
137 bits.push((remainder >> bit_idx) & 1 != 0);
138 }
139 }
140
141 let mut out = Vec::with_capacity((bits.len() + 7) / 8);
143 let mut byte = 0u8;
144 let mut fill = 0u8;
145 for bit in bits {
146 byte = (byte << 1) | u8::from(bit);
147 fill += 1;
148 if fill == 8 {
149 out.push(byte);
150 byte = 0;
151 fill = 0;
152 }
153 }
154 if fill > 0 {
155 out.push(byte << (8 - fill));
156 }
157 out
158}
159
160fn rice_decode(data: &[u8], count: usize, rice_param: u8) -> Result<Vec<i32>, String> {
162 let k = rice_param;
163 let mut byte_pos = 0usize;
164 let mut bit_pos = 0u8;
165
166 let read_bit = |bp: &mut usize, bi: &mut u8| -> Result<bool, String> {
167 if *bp >= data.len() {
168 return Err("Rice decode: unexpected end of data".to_string());
169 }
170 let bit = (data[*bp] >> (7 - *bi)) & 1 != 0;
171 *bi += 1;
172 if *bi == 8 {
173 *bp += 1;
174 *bi = 0;
175 }
176 Ok(bit)
177 };
178
179 let mut out = Vec::with_capacity(count);
180 for _ in 0..count {
181 let mut quotient = 0u32;
183 loop {
184 let bit = read_bit(&mut byte_pos, &mut bit_pos)?;
185 if bit {
186 break;
187 }
188 quotient += 1;
189 if quotient > 1_048_576 {
190 return Err("Rice decode: quotient overflow (corrupt data)".to_string());
191 }
192 }
193
194 let mut remainder = 0u32;
196 for _ in 0..k {
197 let bit = read_bit(&mut byte_pos, &mut bit_pos)?;
198 remainder = (remainder << 1) | u32::from(bit);
199 }
200
201 let u = (quotient << k) | remainder;
202 out.push(zigzag_decode(u));
203 }
204
205 Ok(out)
206}
207
208fn optimal_predictor_order(samples: &[i16], compression_level: u8) -> u8 {
218 if samples.is_empty() {
219 return 0;
220 }
221 let max_order = match compression_level {
222 0..=1 => 1u8,
223 2..=4 => 2,
224 _ => 4,
225 };
226 let max_order = max_order.min(samples.len().saturating_sub(1) as u8).min(4);
227
228 let mut best_order = 0u8;
229 let mut best_cost = u64::MAX;
230
231 for order in 0..=max_order {
232 let residuals = fixed_predict(samples, order);
233 let cost: u64 = residuals.iter().map(|r| r.unsigned_abs() as u64).sum();
234 if cost < best_cost {
235 best_cost = cost;
236 best_order = order;
237 }
238 }
239 best_order
240}
241
242fn fixed_predict(samples: &[i16], order: u8) -> Vec<i32> {
246 let n = samples.len();
247 let o = order as usize;
248 if n <= o {
249 return Vec::new();
250 }
251 let s: Vec<i32> = samples.iter().map(|&v| v as i32).collect();
252
253 let mut residuals = Vec::with_capacity(n - o);
254 for i in o..n {
255 let r = match order {
256 0 => s[i],
257 1 => s[i] - s[i - 1],
258 2 => s[i] - 2 * s[i - 1] + s[i - 2],
259 3 => s[i] - 3 * s[i - 1] + 3 * s[i - 2] - s[i - 3],
260 4 => s[i] - 4 * s[i - 1] + 6 * s[i - 2] - 4 * s[i - 3] + s[i - 4],
261 _ => s[i],
262 };
263 residuals.push(r);
264 }
265 residuals
266}
267
268fn fixed_restore(residuals: &[i32], order: u8, warmup: &[i16]) -> Vec<i16> {
270 let o = order as usize;
271 let mut out: Vec<i32> = warmup.iter().map(|&v| v as i32).collect();
272
273 for &r in residuals {
274 let n = out.len();
275 let sample = match order {
276 0 => r,
277 1 => r + out[n - 1],
278 2 => r + 2 * out[n - 1] - out[n - 2],
279 3 => r + 3 * out[n - 1] - 3 * out[n - 2] + out[n - 3],
280 4 => r + 4 * out[n - 1] - 6 * out[n - 2] + 4 * out[n - 3] - out[n - 4],
281 _ => r,
282 };
283 out.push(sample);
284 }
285
286 out.iter().map(|&v| v as i16).collect()
287}
288
289fn optimal_rice_param(residuals: &[i32]) -> u8 {
291 if residuals.is_empty() {
292 return 0;
293 }
294 let mut best_k = 0u8;
295 let mut best_cost = u64::MAX;
296 for k in 0..=14u8 {
297 let cost: u64 = residuals
298 .iter()
299 .map(|&r| {
300 let u = zigzag_encode(r);
301 1u64 + u64::from(k) + u64::from(u >> k)
302 })
303 .sum();
304 if cost < best_cost {
305 best_cost = cost;
306 best_k = k;
307 }
308 }
309 best_k
310}
311
312pub fn encode_flac_frame(samples: &[i16], config: &FlacEncoderConfig) -> Vec<u8> {
329 if samples.is_empty() {
330 return vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
332 }
333
334 let order = optimal_predictor_order(samples, config.compression_level);
335 let residuals = fixed_predict(samples, order);
336 let k = optimal_rice_param(&residuals);
337 let rice_bytes = rice_encode(&residuals, k);
338
339 let warmup_count = (order as usize).min(samples.len());
340 let mut out = Vec::new();
341
342 out.push(order);
344 out.push(k);
345 let wc = warmup_count as u16;
346 out.extend_from_slice(&wc.to_be_bytes());
347
348 for &s in &samples[..warmup_count] {
350 out.extend_from_slice(&s.to_be_bytes());
351 }
352
353 let rc = residuals.len() as u32;
355 out.extend_from_slice(&rc.to_be_bytes());
356
357 let rl = rice_bytes.len() as u32;
359 out.extend_from_slice(&rl.to_be_bytes());
360 out.extend_from_slice(&rice_bytes);
361
362 out
363}
364
365pub fn decode_flac_frame(data: &[u8], _info: &FlacStreamInfo) -> Result<Vec<i16>, String> {
367 if data.len() < 12 {
368 return Err("FLAC frame too short".to_string());
369 }
370
371 let order = data[0];
372 if order > 4 {
373 return Err(format!("Invalid predictor order: {order}"));
374 }
375 let k = data[1];
376 if k > 30 {
377 return Err(format!("Invalid Rice parameter: {k}"));
378 }
379 let warmup_count = u16::from_be_bytes([data[2], data[3]]) as usize;
380
381 let mut pos = 4;
382
383 if pos + warmup_count * 2 > data.len() {
385 return Err("FLAC frame: warmup overruns data".to_string());
386 }
387 let mut warmup = Vec::with_capacity(warmup_count);
388 for _ in 0..warmup_count {
389 let s = i16::from_be_bytes([data[pos], data[pos + 1]]);
390 warmup.push(s);
391 pos += 2;
392 }
393
394 if pos + 4 > data.len() {
396 return Err("FLAC frame: missing residual count".to_string());
397 }
398 let residual_count =
399 u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
400 pos += 4;
401
402 if pos + 4 > data.len() {
404 return Err("FLAC frame: missing rice length".to_string());
405 }
406 let rice_len =
407 u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
408 pos += 4;
409
410 if pos + rice_len > data.len() {
411 return Err("FLAC frame: rice data overruns frame".to_string());
412 }
413 let rice_data = &data[pos..pos + rice_len];
414
415 let residuals = if residual_count == 0 {
417 Vec::new()
418 } else {
419 rice_decode(rice_data, residual_count, k)?
420 };
421
422 let samples = fixed_restore(&residuals, order, &warmup);
424 Ok(samples)
425}
426
427#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_fixed_predict_order0() {
437 let samples: Vec<i16> = vec![10, 20, 30, 40, 50];
438 let residuals = fixed_predict(&samples, 0);
439 let expected: Vec<i32> = vec![10, 20, 30, 40, 50];
441 assert_eq!(residuals, expected, "Order 0 should be identity");
442 }
443
444 #[test]
445 fn test_fixed_predict_order1_constant() {
446 let samples: Vec<i16> = vec![100, 100, 100, 100];
447 let residuals = fixed_predict(&samples, 1);
448 assert!(
450 residuals.iter().all(|&r| r == 0),
451 "Constant signal should produce all-zero order-1 residuals: {:?}",
452 residuals
453 );
454 }
455
456 #[test]
457 fn test_fixed_predict_restore_roundtrip() {
458 let samples: Vec<i16> = vec![10, -5, 300, -200, 0, 127, -128, 500];
459 for order in 0..=4u8 {
460 if samples.len() <= order as usize {
461 continue;
462 }
463 let residuals = fixed_predict(&samples, order);
464 let warmup = &samples[..order as usize];
465 let restored = fixed_restore(&residuals, order, warmup);
466 assert_eq!(
467 restored, samples,
468 "Order {order} predict-restore roundtrip must be lossless"
469 );
470 }
471 }
472
473 #[test]
474 fn test_rice_encode_decode_roundtrip() {
475 let residuals = vec![0i32, 1, -1, 5, -5, 100, -100, 0];
476 for k in 0..=6u8 {
477 let encoded = rice_encode(&residuals, k);
478 let decoded =
479 rice_decode(&encoded, residuals.len(), k).expect("rice decode should succeed");
480 assert_eq!(decoded, residuals, "Rice roundtrip failed for k={k}");
481 }
482 }
483
484 #[test]
485 fn test_rice_encode_zeros() {
486 let zeros = vec![0i32; 64];
487 let k = optimal_rice_param(&zeros);
488 assert_eq!(k, 0, "All zeros should use k=0");
489 let encoded = rice_encode(&zeros, k);
490 assert_eq!(
493 encoded.len(),
494 8,
495 "64 zero residuals at k=0 should be 8 bytes"
496 );
497 }
498
499 #[test]
500 fn test_optimal_predictor_silence() {
501 let silence: Vec<i16> = vec![0; 128];
502 let order = optimal_predictor_order(&silence, 5);
503 assert_eq!(order, 0, "Silence should pick order 0");
504 }
505
506 #[test]
507 fn test_optimal_predictor_linear_ramp() {
508 let ramp: Vec<i16> = (0..128).map(|i| i as i16).collect();
509 let order = optimal_predictor_order(&ramp, 5);
510 assert!(
512 order >= 1,
513 "Linear ramp should pick order >= 1, got {order}"
514 );
515 }
516
517 #[test]
518 fn test_encode_decode_frame_roundtrip() {
519 let config = FlacEncoderConfig::default();
520 let info = FlacStreamInfo {
521 min_block_size: 4096,
522 max_block_size: 4096,
523 sample_rate: 44100,
524 channels: 2,
525 bits_per_sample: 16,
526 total_samples: 0,
527 };
528
529 let samples: Vec<i16> = (0..512)
531 .map(|i| {
532 let ramp = (i as f64 / 512.0 * 1000.0) as i16;
533 let sine = (100.0 * (i as f64 * 0.1).sin()) as i16;
534 ramp.saturating_add(sine)
535 })
536 .collect();
537
538 let encoded = encode_flac_frame(&samples, &config);
539 let decoded = decode_flac_frame(&encoded, &info).expect("decode should succeed");
540 assert_eq!(
541 decoded, samples,
542 "Frame encode-decode roundtrip must be lossless"
543 );
544 }
545
546 #[test]
547 fn test_flac_config_default() {
548 let config = FlacEncoderConfig::default();
549 assert_eq!(config.sample_rate, 44100);
550 assert_eq!(config.channels, 2);
551 assert_eq!(config.bits_per_sample, 16);
552 assert_eq!(config.block_size, 4096);
553 assert_eq!(config.compression_level, 5);
554 }
555
556 #[test]
557 fn test_encode_empty_block() {
558 let config = FlacEncoderConfig::default();
559 let info = FlacStreamInfo {
560 min_block_size: 0,
561 max_block_size: 0,
562 sample_rate: 44100,
563 channels: 1,
564 bits_per_sample: 16,
565 total_samples: 0,
566 };
567
568 let encoded = encode_flac_frame(&[], &config);
569 assert!(
570 !encoded.is_empty(),
571 "Empty input should still produce a frame header"
572 );
573 let decoded = decode_flac_frame(&encoded, &info).expect("decode empty should succeed");
574 assert!(
575 decoded.is_empty(),
576 "Decoded empty frame should have no samples"
577 );
578 }
579
580 #[test]
581 fn test_zigzag_roundtrip() {
582 for v in [-1000i32, -1, 0, 1, 1000, i16::MIN as i32, i16::MAX as i32] {
583 let u = zigzag_encode(v);
584 let back = zigzag_decode(u);
585 assert_eq!(back, v, "zigzag roundtrip failed for {v}");
586 }
587 }
588
589 #[test]
590 fn test_fixed_predict_order2_quadratic() {
591 let samples: Vec<i16> = (0..20).map(|i: i16| i * i).collect();
593 let residuals = fixed_predict(&samples, 2);
594 let all_two = residuals.iter().all(|&r| r == 2);
596 assert!(
597 all_two,
598 "Quadratic signal order-2 residuals should all be 2: {:?}",
599 residuals
600 );
601 }
602
603 #[test]
604 fn test_encode_decode_large_block() {
605 let config = FlacEncoderConfig {
606 compression_level: 8,
607 ..FlacEncoderConfig::default()
608 };
609 let info = FlacStreamInfo {
610 min_block_size: 4096,
611 max_block_size: 4096,
612 sample_rate: 44100,
613 channels: 1,
614 bits_per_sample: 16,
615 total_samples: 4096,
616 };
617 let samples: Vec<i16> = (0..4096)
618 .map(|i| (1000.0 * (i as f64 * 0.05).sin()) as i16)
619 .collect();
620 let encoded = encode_flac_frame(&samples, &config);
621 assert!(
623 encoded.len() < 8192,
624 "Compressed frame ({} bytes) should be smaller than raw (8192)",
625 encoded.len()
626 );
627 let decoded = decode_flac_frame(&encoded, &info).expect("decode");
628 assert_eq!(decoded, samples);
629 }
630}