use crate::bitstream::BitReader;
use crate::error::FlacError;
pub const FLAC_MARKER: &[u8; 4] = b"fLaC";
const STREAMINFO_LEN: usize = 34;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StreamInfo {
pub min_block_size: u16,
pub max_block_size: u16,
pub min_frame_size: u32,
pub max_frame_size: u32,
pub sample_rate: u32,
pub channels: u8,
pub bits_per_sample: u8,
pub total_samples: u64,
pub md5: [u8; 16],
}
#[derive(Debug)]
pub struct Header {
pub stream_info: StreamInfo,
pub frame_start: usize,
}
pub fn read_header(bytes: &[u8]) -> Result<Header, FlacError> {
if bytes.len() < 4 || &bytes[0..4] != FLAC_MARKER {
return Err(FlacError::NotFlac);
}
let mut pos = 4;
let mut stream_info: Option<StreamInfo> = None;
loop {
if pos + 4 > bytes.len() {
return Err(FlacError::Truncated);
}
let header = bytes[pos];
let is_last = header & 0x80 != 0;
let block_type = header & 0x7F;
let length = ((bytes[pos + 1] as usize) << 16)
| ((bytes[pos + 2] as usize) << 8)
| bytes[pos + 3] as usize;
pos += 4;
let end = pos.checked_add(length).ok_or(FlacError::Truncated)?;
if end > bytes.len() {
return Err(FlacError::Truncated);
}
if block_type == 0 {
if length != STREAMINFO_LEN {
return Err(FlacError::CorruptStream(format!(
"STREAMINFO length is {length}, expected {STREAMINFO_LEN}"
)));
}
if stream_info.is_some() {
return Err(FlacError::CorruptStream(
"more than one STREAMINFO block".into(),
));
}
stream_info = Some(parse_stream_info(&bytes[pos..pos + length])?);
} else if block_type == 127 {
return Err(FlacError::CorruptStream(
"invalid metadata block type 127".into(),
));
}
pos += length;
if is_last {
break;
}
}
let stream_info =
stream_info.ok_or_else(|| FlacError::CorruptStream("no STREAMINFO block found".into()))?;
Ok(Header {
stream_info,
frame_start: pos,
})
}
fn parse_stream_info(body: &[u8]) -> Result<StreamInfo, FlacError> {
let mut r = BitReader::new(body);
let min_block_size = r.read_u32(16)? as u16;
let max_block_size = r.read_u32(16)? as u16;
let min_frame_size = r.read_u32(24)?;
let max_frame_size = r.read_u32(24)?;
let sample_rate = r.read_u32(20)?;
let channels = r.read_u32(3)? as u8 + 1;
let bits_per_sample = r.read_u32(5)? as u8 + 1;
let total_samples = r.read_u64(36)?;
let mut md5 = [0u8; 16];
for b in md5.iter_mut() {
*b = r.read_u32(8)? as u8;
}
if sample_rate == 0 {
return Err(FlacError::CorruptStream(
"STREAMINFO sample rate is zero".into(),
));
}
Ok(StreamInfo {
min_block_size,
max_block_size,
min_frame_size,
max_frame_size,
sample_rate,
channels,
bits_per_sample,
total_samples,
md5,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn synthetic_header() -> Vec<u8> {
let mut v = Vec::new();
v.extend_from_slice(FLAC_MARKER);
v.push(0x80);
v.extend_from_slice(&[0x00, 0x00, 0x22]);
v.extend_from_slice(&[0x10, 0x00, 0x10, 0x00]);
v.extend_from_slice(&[0, 0, 0, 0, 0, 0]);
let mut bits = Vec::new();
push_bits(&mut bits, 44100, 20);
push_bits(&mut bits, 1, 3); push_bits(&mut bits, 15, 5); push_bits(&mut bits, 88200, 36); let packed = pack(&bits);
v.extend_from_slice(&packed);
v.extend_from_slice(&[0xAB; 16]);
v
}
fn push_bits(out: &mut Vec<u8>, value: u64, n: u32) {
for i in (0..n).rev() {
out.push(((value >> i) & 1) as u8);
}
}
fn pack(bits: &[u8]) -> Vec<u8> {
let mut out = vec![0u8; bits.len().div_ceil(8)];
for (i, &bit) in bits.iter().enumerate() {
if bit != 0 {
out[i / 8] |= 1 << (7 - (i % 8));
}
}
out
}
#[test]
fn rejects_non_flac() {
assert_eq!(read_header(b"RIFFxxxx").unwrap_err(), FlacError::NotFlac);
assert_eq!(read_header(b"fL").unwrap_err(), FlacError::NotFlac);
}
#[test]
fn parses_synthetic_streaminfo() {
let h = read_header(&synthetic_header()).unwrap();
let si = &h.stream_info;
assert_eq!(si.min_block_size, 4096);
assert_eq!(si.max_block_size, 4096);
assert_eq!(si.sample_rate, 44100);
assert_eq!(si.channels, 2);
assert_eq!(si.bits_per_sample, 16);
assert_eq!(si.total_samples, 88200);
assert_eq!(si.md5, [0xAB; 16]);
assert_eq!(h.frame_start, synthetic_header().len());
}
#[test]
fn truncated_header_errors() {
let full = synthetic_header();
assert_eq!(read_header(&full[..10]).unwrap_err(), FlacError::Truncated);
}
#[test]
fn rejects_zero_sample_rate() {
let mut h = synthetic_header();
h[18] = 0;
h[19] = 0;
h[20] &= 0x0F; assert!(matches!(read_header(&h), Err(FlacError::CorruptStream(_))));
}
}