use anyhow::{Result, bail};
use snapcast_proto::SampleFormat;
use snapcast_proto::message::codec_header::CodecHeader;
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::{CODEC_TYPE_VORBIS, CodecParameters, DecoderOptions};
use symphonia::core::formats::Packet;
use crate::decoder::Decoder;
fn parse_vorbis_header(payload: &[u8]) -> Result<(SampleFormat, Vec<u8>)> {
if payload.len() < 28 || &payload[0..4] != b"OggS" {
bail!("not an Ogg bitstream");
}
let num_segments = payload[26] as usize;
let header_size = 27 + num_segments;
if payload.len() < header_size {
bail!("Ogg page header truncated");
}
let packet_start = header_size;
let remaining = &payload[packet_start..];
if remaining.len() < 16 {
bail!("Vorbis identification header too small");
}
if remaining[0] != 1 || &remaining[1..7] != b"vorbis" {
bail!("not a Vorbis identification header");
}
let channels = remaining[11] as u16;
let sample_rate = u32::from_le_bytes(remaining[12..16].try_into().unwrap());
if sample_rate == 0 || channels == 0 {
bail!("invalid Vorbis header: rate={sample_rate}, channels={channels}");
}
let sf = SampleFormat::new(sample_rate, 16, channels);
Ok((sf, payload.to_vec()))
}
pub struct VorbisDecoder {
decoder: Box<dyn symphonia::core::codecs::Decoder>,
sample_format: SampleFormat,
packet_id: u64,
}
impl VorbisDecoder {
fn from_header(header: &CodecHeader) -> Result<Self> {
let (sf, extra_data) = parse_vorbis_header(&header.payload)?;
let mut params = CodecParameters::new();
params
.for_codec(CODEC_TYPE_VORBIS)
.with_sample_rate(sf.rate())
.with_channels(
symphonia::core::audio::Channels::from_bits(((1u64 << sf.channels()) - 1) as u32)
.unwrap_or(symphonia::core::audio::Channels::FRONT_LEFT),
)
.with_extra_data(extra_data.into_boxed_slice());
let decoder = symphonia::default::get_codecs()
.make(¶ms, &DecoderOptions::default())
.map_err(|e| anyhow::anyhow!("failed to create Vorbis decoder: {e}"))?;
Ok(Self {
decoder,
sample_format: sf,
packet_id: 0,
})
}
}
impl Decoder for VorbisDecoder {
fn set_header(&mut self, header: &CodecHeader) -> Result<SampleFormat> {
*self = Self::from_header(header)?;
Ok(self.sample_format)
}
fn decode(&mut self, data: &mut Vec<u8>) -> Result<bool> {
if data.is_empty() {
return Ok(false);
}
tracing::trace!(
codec = "vorbis",
input_bytes = data.len(),
packet_id = self.packet_id,
"decode"
);
let packet = Packet::new_from_slice(0, self.packet_id, 0, data);
self.packet_id += 1;
let decoded = match self.decoder.decode(&packet) {
Ok(buf) => buf,
Err(e) => {
tracing::warn!(codec = "vorbis", error = %e, "decode failed");
return Ok(false);
}
};
let spec = *decoded.spec();
let frames = decoded.frames() as u64;
let mut sample_buf = SampleBuffer::<i16>::new(frames, spec);
sample_buf.copy_interleaved_ref(decoded);
let mut out = Vec::with_capacity(sample_buf.samples().len() * 2);
for &s in sample_buf.samples() {
out.extend_from_slice(&s.to_le_bytes());
}
*data = out;
Ok(true)
}
}
pub fn create(header: &CodecHeader) -> Result<VorbisDecoder> {
VorbisDecoder::from_header(header)
}
#[cfg(test)]
mod tests {
use super::*;
fn ogg_vorbis_header_44100_2() -> Vec<u8> {
let mut page = Vec::new();
let mut vorbis_id = Vec::new();
vorbis_id.push(1u8); vorbis_id.extend_from_slice(b"vorbis");
vorbis_id.extend_from_slice(&0u32.to_le_bytes()); vorbis_id.push(2); vorbis_id.extend_from_slice(&44100u32.to_le_bytes()); vorbis_id.extend_from_slice(&0i32.to_le_bytes()); vorbis_id.extend_from_slice(&128000i32.to_le_bytes()); vorbis_id.extend_from_slice(&0i32.to_le_bytes()); vorbis_id.push(0xb8); vorbis_id.push(1);
let packet_len = vorbis_id.len();
page.extend_from_slice(b"OggS"); page.push(0); page.push(0x02); page.extend_from_slice(&0u64.to_le_bytes()); page.extend_from_slice(&1u32.to_le_bytes()); page.extend_from_slice(&0u32.to_le_bytes()); page.extend_from_slice(&0u32.to_le_bytes()); page.push(1); page.push(packet_len as u8);
page.extend_from_slice(&vorbis_id);
page
}
#[test]
fn parse_header_44100_2() {
let payload = ogg_vorbis_header_44100_2();
let (sf, _) = parse_vorbis_header(&payload).unwrap();
assert_eq!(sf.rate(), 44100);
assert_eq!(sf.channels(), 2);
assert_eq!(sf.bits(), 16); }
#[test]
fn parse_header_48000_6() {
let mut page = Vec::new();
let mut vorbis_id = Vec::new();
vorbis_id.push(1u8);
vorbis_id.extend_from_slice(b"vorbis");
vorbis_id.extend_from_slice(&0u32.to_le_bytes());
vorbis_id.push(6); vorbis_id.extend_from_slice(&48000u32.to_le_bytes());
vorbis_id.resize(30, 0);
let packet_len = vorbis_id.len();
page.extend_from_slice(b"OggS");
page.push(0);
page.push(0x02);
page.extend_from_slice(&0u64.to_le_bytes());
page.extend_from_slice(&1u32.to_le_bytes());
page.extend_from_slice(&0u32.to_le_bytes());
page.extend_from_slice(&0u32.to_le_bytes());
page.push(1);
page.push(packet_len as u8);
page.extend_from_slice(&vorbis_id);
let (sf, _) = parse_vorbis_header(&page).unwrap();
assert_eq!(sf.rate(), 48000);
assert_eq!(sf.channels(), 6);
}
#[test]
fn not_ogg_fails() {
assert!(parse_vorbis_header(b"NOPE_not_ogg_data_at_all!!!!!").is_err());
}
#[test]
fn not_vorbis_fails() {
let mut page = Vec::new();
page.extend_from_slice(b"OggS");
page.push(0);
page.push(0);
page.extend_from_slice(&[0; 20]); page.push(1); page.push(16); page.extend_from_slice(&[0; 16]); assert!(parse_vorbis_header(&page).is_err());
}
}