use alloc::vec::Vec;
use serde::{Deserialize, Serialize};
use crate::error::{Result, ShravanError};
use crate::format::{AudioFormat, FormatInfo};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OpusHead {
pub version: u8,
pub channel_count: u8,
pub pre_skip: u16,
pub input_sample_rate: u32,
pub output_gain: i16,
pub channel_mapping_family: u8,
}
#[must_use = "parsed Opus header is returned and should not be discarded"]
pub fn parse_opus_head(packet: &[u8]) -> Result<OpusHead> {
if packet.len() < 19 {
return Err(ShravanError::InvalidHeader(
"OpusHead packet too short (need >= 19 bytes)".into(),
));
}
if &packet[0..8] != b"OpusHead" {
return Err(ShravanError::InvalidHeader("missing OpusHead magic".into()));
}
let version = packet[8];
let channel_count = packet[9];
let pre_skip = u16::from_le_bytes([packet[10], packet[11]]);
let input_sample_rate = u32::from_le_bytes([packet[12], packet[13], packet[14], packet[15]]);
let output_gain = i16::from_le_bytes([packet[16], packet[17]]);
let channel_mapping_family = packet[18];
if channel_count == 0 {
return Err(ShravanError::InvalidChannels(0));
}
Ok(OpusHead {
version,
channel_count,
pre_skip,
input_sample_rate,
output_gain,
channel_mapping_family,
})
}
#[must_use = "parsed Opus tags are returned and should not be discarded"]
#[cfg(feature = "tag")]
pub fn parse_opus_tags(packet: &[u8]) -> Result<crate::tag::AudioMetadata> {
if packet.len() < 8 {
return Err(ShravanError::InvalidHeader(
"OpusTags packet too short".into(),
));
}
if &packet[0..8] != b"OpusTags" {
return Err(ShravanError::InvalidHeader("missing OpusTags magic".into()));
}
crate::tag::read_vorbis_comment(&packet[8..])
}
#[must_use = "parsed Opus tags result should not be discarded"]
#[cfg(not(feature = "tag"))]
pub fn parse_opus_tags(packet: &[u8]) -> Result<()> {
if packet.len() < 8 {
return Err(ShravanError::InvalidHeader(
"OpusTags packet too short".into(),
));
}
if &packet[0..8] != b"OpusTags" {
return Err(ShravanError::InvalidHeader("missing OpusTags magic".into()));
}
Ok(())
}
fn find_last_granule(data: &[u8]) -> Option<i64> {
if data.len() < 27 {
return None;
}
let mut pos = data.len().saturating_sub(27);
loop {
if pos + 14 <= data.len() && &data[pos..pos + 4] == b"OggS" && data[pos + 4] == 0 {
if pos + 14 <= data.len() {
let granule = i64::from_le_bytes([
data[pos + 6],
data[pos + 7],
data[pos + 8],
data[pos + 9],
data[pos + 10],
data[pos + 11],
data[pos + 12],
data[pos + 13],
]);
return Some(granule);
}
}
if pos == 0 {
break;
}
pos -= 1;
}
None
}
pub(crate) fn decode_from_packets(
packets: &[Vec<u8>],
raw_data: &[u8],
) -> Result<(FormatInfo, Vec<f32>)> {
if packets.is_empty() {
return Err(ShravanError::EndOfStream);
}
let head = parse_opus_head(&packets[0])?;
if packets.len() >= 2 {
let _ = parse_opus_tags(&packets[1]);
}
let duration_secs = if let Some(granule) = find_last_granule(raw_data) {
let effective = granule.saturating_sub(i64::from(head.pre_skip));
if effective > 0 {
effective as f64 / 48000.0
} else {
0.0
}
} else {
0.0
};
let total_samples = if duration_secs > 0.0 {
(duration_secs * 48000.0) as u64
} else {
0
};
let info = FormatInfo {
format: AudioFormat::Opus,
sample_rate: 48000, channels: u16::from(head.channel_count),
bit_depth: 16, duration_secs,
total_samples,
};
Ok((info, Vec::new()))
}
#[must_use = "decoded audio data is returned and should not be discarded"]
pub fn decode(data: &[u8]) -> Result<(FormatInfo, Vec<f32>)> {
let packets = crate::ogg::extract_packets(data)?;
decode_from_packets(&packets, data)
}
const ENCODER_PRE_SKIP: u16 = 312;
const FRAME_SIZE: usize = 960;
fn serialize_opus_head(channels: u8, pre_skip: u16, input_sample_rate: u32) -> Vec<u8> {
let mut pkt = Vec::with_capacity(19);
pkt.extend_from_slice(b"OpusHead");
pkt.push(1); pkt.push(channels);
pkt.extend_from_slice(&pre_skip.to_le_bytes());
pkt.extend_from_slice(&input_sample_rate.to_le_bytes());
pkt.extend_from_slice(&0i16.to_le_bytes()); pkt.push(0); pkt
}
fn serialize_opus_tags() -> Vec<u8> {
let vendor = b"shravan";
let mut pkt = Vec::with_capacity(8 + 4 + vendor.len() + 4);
pkt.extend_from_slice(b"OpusTags");
pkt.extend_from_slice(&(vendor.len() as u32).to_le_bytes());
pkt.extend_from_slice(vendor);
pkt.extend_from_slice(&0u32.to_le_bytes());
pkt
}
struct RangeEncoder {
buf: Vec<u8>,
low: u32,
range: u32,
carry_count: u32,
cache: i32,
bits_used: u32,
}
impl RangeEncoder {
fn new() -> Self {
Self {
buf: Vec::new(),
low: 0,
range: 0x8000_0000,
carry_count: 0,
cache: -1,
bits_used: 0,
}
}
fn encode(&mut self, fl: u32, fh: u32, total: u32) {
debug_assert!(total > 0, "range encoder total must be > 0");
debug_assert!(fl < fh, "range encoder fl must be < fh");
debug_assert!(fh <= total, "range encoder fh must be <= total");
if total == 0 {
return;
}
let r = self.range / total;
let new_low = self.low.wrapping_add(r.wrapping_mul(fl));
if fh < total {
self.range = r.wrapping_mul(fh - fl);
} else {
self.range = self.range.wrapping_sub(r.wrapping_mul(fl));
}
self.low = new_low;
self.normalize();
}
fn encode_bit(&mut self, val: bool) {
self.encode(u32::from(val), u32::from(val) + 1, 2);
}
fn encode_uint(&mut self, val: u32, total: u32) {
if total <= 1 {
return;
}
self.encode(val, val + 1, total);
}
fn normalize(&mut self) {
while self.range <= 0x0080_0000 {
self.carry_out();
self.low <<= 8;
self.range <<= 8;
self.bits_used += 8;
}
}
fn carry_out(&mut self) {
let carry = (self.low >> 23) as i32;
if carry != 0xFF {
if self.cache >= 0 {
self.buf
.push((self.cache as u32).wrapping_add((carry >> 8) as u32) as u8);
}
for _ in 0..self.carry_count {
self.buf
.push(((carry >> 8) as u32).wrapping_add(0xFF) as u8);
}
self.carry_count = 0;
self.cache = carry & 0xFF;
} else {
self.carry_count += 1;
}
self.low &= 0x007F_FFFF;
}
fn finish(mut self) -> Vec<u8> {
if self.cache >= 0 {
let carry = self.low >> 23;
self.buf.push((self.cache as u32).wrapping_add(carry) as u8);
for _ in 0..self.carry_count {
self.buf
.push(carry.wrapping_add(0xFF).wrapping_sub(1) as u8);
}
}
let nbits = if self.range > 0 {
32u32.saturating_sub(self.range.leading_zeros()).max(1)
} else {
1
};
let nbytes = nbits.div_ceil(8);
let shift = nbytes * 8 - 8;
let mut val = self.low >> (23u32.saturating_sub(nbits));
for _ in 0..nbytes {
self.buf.push((val >> shift) as u8);
val <<= 8;
}
self.buf
}
fn bytes_used(&self) -> usize {
self.buf.len() + 1 + self.carry_count as usize
}
}
use crate::fft::mdct_forward;
#[inline]
fn sine_window(frame: &mut [f32]) {
let n = frame.len();
for (i, sample) in frame.iter_mut().enumerate() {
let w = libm::sinf(core::f32::consts::PI / (n as f32) * (i as f32 + 0.5));
*sample *= w;
}
}
const CELT_BAND_EDGES: [u16; 22] = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 17, 21, 26, 32, 40, 50, 62, 78, 100, 480,
];
const NUM_CELT_BANDS: usize = CELT_BAND_EDGES.len() - 1;
fn compute_band_energies(mdct: &[f32]) -> [f32; NUM_CELT_BANDS] {
let mut energies = [0.0f32; NUM_CELT_BANDS];
for band in 0..NUM_CELT_BANDS {
let start = CELT_BAND_EDGES[band] as usize;
let end = CELT_BAND_EDGES[band + 1] as usize;
let end = end.min(mdct.len());
let mut sum = 0.0f32;
for &c in &mdct[start..end] {
sum += c * c;
}
let band_size = (end - start).max(1) as f32;
energies[band] = libm::log2f((sum / band_size).max(1e-10));
}
energies
}
fn quantize_band_energies(energies: &[f32; NUM_CELT_BANDS]) -> [i16; NUM_CELT_BANDS] {
let mut quant = [0i16; NUM_CELT_BANDS];
for (i, &e) in energies.iter().enumerate() {
quant[i] = libm::roundf(e * 2.0) as i16;
}
quant
}
fn encode_band_energies(rc: &mut RangeEncoder, quant: &[i16; NUM_CELT_BANDS]) {
let mut prev = 0i32;
for &q in quant.iter() {
let diff = i32::from(q) - prev;
let bounded = diff.clamp(-64, 63);
let val = (bounded + 64) as u32;
rc.encode_uint(val, 128);
prev = i32::from(q);
}
}
fn normalize_bands(mdct: &[f32], norms: &mut [f32]) {
for band in 0..NUM_CELT_BANDS {
let start = CELT_BAND_EDGES[band] as usize;
let end = (CELT_BAND_EDGES[band + 1] as usize).min(mdct.len());
let mut sum_sq = 0.0f32;
for &c in &mdct[start..end] {
sum_sq += c * c;
}
let norm = libm::sqrtf(sum_sq).max(1e-10);
for (i, &c) in mdct[start..end].iter().enumerate() {
norms[start + i] = c / norm;
}
}
}
fn encode_spectral_shape(rc: &mut RangeEncoder, norms: &[f32], target_bytes: usize) {
for band in 0..NUM_CELT_BANDS {
let start = CELT_BAND_EDGES[band] as usize;
let end = (CELT_BAND_EDGES[band + 1] as usize).min(norms.len());
let band_size = end - start;
if band_size == 0 || rc.bytes_used() >= target_bytes {
break;
}
for &coeff in &norms[start..end] {
if rc.bytes_used() >= target_bytes {
break;
}
rc.encode_bit(coeff >= 0.0);
}
}
}
fn encode_celt_frame(samples: &[f32], channels: u16, target_bytes: usize) -> Vec<u8> {
let ch = channels as usize;
let frame_samples = FRAME_SIZE;
let mut mono = vec![0.0f32; frame_samples];
for (i, m) in mono.iter_mut().enumerate().take(frame_samples) {
let mut sum = 0.0f32;
for c in 0..ch {
let idx = i * ch + c;
if idx < samples.len() {
sum += samples[idx];
}
}
*m = sum / ch as f32;
}
sine_window(&mut mono);
let mdct_size = frame_samples / 2;
let mut mdct = vec![0.0f32; mdct_size];
mdct_forward(&mono, &mut mdct);
let energies = compute_band_energies(&mdct);
let quant_energies = quantize_band_energies(&energies);
let mut norms = vec![0.0f32; mdct_size];
normalize_bands(&mdct, &mut norms);
let mut rc = RangeEncoder::new();
let toc: u8 = 30 << 3;
let mut packet = Vec::with_capacity(target_bytes);
packet.push(toc);
encode_band_energies(&mut rc, &quant_energies);
encode_spectral_shape(&mut rc, &norms, target_bytes.saturating_sub(1));
let coded = rc.finish();
packet.extend_from_slice(&coded);
if packet.len() < target_bytes {
packet.resize(target_bytes, 0);
} else if packet.len() > target_bytes {
packet.truncate(target_bytes);
}
packet
}
#[must_use = "encoded Opus/Ogg bytes are returned and should not be discarded"]
pub fn encode(samples: &[f32], sample_rate: u32, channels: u16, bitrate: u32) -> Result<Vec<u8>> {
if sample_rate != 48000 {
return Err(ShravanError::InvalidSampleRate(sample_rate));
}
if channels == 0 || channels > 2 {
return Err(ShravanError::InvalidChannels(channels));
}
if !(32000..=256000).contains(&bitrate) {
return Err(ShravanError::EncodeError(alloc::format!(
"bitrate must be 32000..=256000, got {bitrate}"
)));
}
let ch = channels as usize;
let total_interleaved = samples.len();
let bytes_per_frame = (bitrate / 8 / 50).max(10) as usize;
let opus_head = serialize_opus_head(channels as u8, ENCODER_PRE_SKIP, sample_rate);
let opus_tags = serialize_opus_tags();
let mut audio_packets: Vec<Vec<u8>> = Vec::new();
let mut granule_positions: Vec<i64> = Vec::new();
let mut sample_pos: usize = 0;
let frame_interleaved = FRAME_SIZE * ch;
let mut granule: i64 = i64::from(ENCODER_PRE_SKIP);
while sample_pos < total_interleaved {
let end = (sample_pos + frame_interleaved).min(total_interleaved);
let frame_slice = &samples[sample_pos..end];
let frame_data = if frame_slice.len() < frame_interleaved {
let mut padded = vec![0.0f32; frame_interleaved];
padded[..frame_slice.len()].copy_from_slice(frame_slice);
encode_celt_frame(&padded, channels, bytes_per_frame)
} else {
encode_celt_frame(frame_slice, channels, bytes_per_frame)
};
audio_packets.push(frame_data);
let actual_samples = (end - sample_pos) / ch;
granule += actual_samples as i64;
granule_positions.push(granule);
sample_pos = end;
}
if audio_packets.is_empty() {
let silence = vec![0.0f32; frame_interleaved];
let frame_data = encode_celt_frame(&silence, channels, bytes_per_frame);
audio_packets.push(frame_data);
granule_positions.push(granule);
}
let serial: u32 = 0x5368_7261;
let mut ogg_data = Vec::new();
let page0 = crate::ogg::build_page(
crate::ogg::HEADER_FLAG_BOS,
0, serial,
0,
&opus_head,
);
ogg_data.extend_from_slice(&page0);
let page1 = crate::ogg::build_page(
0, 0, serial, 1, &opus_tags,
);
ogg_data.extend_from_slice(&page1);
let num_audio = audio_packets.len();
for (i, (packet, &granule_pos)) in audio_packets
.iter()
.zip(granule_positions.iter())
.enumerate()
{
let mut flags = 0u8;
if i == num_audio - 1 {
flags |= crate::ogg::HEADER_FLAG_EOS;
}
let page = crate::ogg::build_page(flags, granule_pos, serial, (i + 2) as u32, packet);
ogg_data.extend_from_slice(&page);
}
Ok(ogg_data)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_opus_head(channels: u8, pre_skip: u16, sample_rate: u32) -> Vec<u8> {
let mut pkt = Vec::new();
pkt.extend_from_slice(b"OpusHead");
pkt.push(1); pkt.push(channels);
pkt.extend_from_slice(&pre_skip.to_le_bytes());
pkt.extend_from_slice(&sample_rate.to_le_bytes());
pkt.extend_from_slice(&0i16.to_le_bytes()); pkt.push(0); pkt
}
#[test]
fn parse_valid_opus_head() {
let pkt = make_opus_head(2, 312, 48000);
let head = parse_opus_head(&pkt).unwrap();
assert_eq!(head.version, 1);
assert_eq!(head.channel_count, 2);
assert_eq!(head.pre_skip, 312);
assert_eq!(head.input_sample_rate, 48000);
assert_eq!(head.output_gain, 0);
assert_eq!(head.channel_mapping_family, 0);
}
#[test]
fn reject_wrong_magic() {
let mut pkt = make_opus_head(2, 312, 48000);
pkt[0..8].copy_from_slice(b"NotOpus!");
assert!(parse_opus_head(&pkt).is_err());
}
#[test]
fn reject_short_data() {
let pkt = b"OpusHead1234567"; assert!(parse_opus_head(pkt).is_err());
}
#[test]
fn reject_zero_channels() {
let pkt = make_opus_head(0, 312, 48000);
assert!(parse_opus_head(&pkt).is_err());
}
#[test]
fn opus_head_serde_roundtrip() {
let pkt = make_opus_head(2, 312, 44100);
let head = parse_opus_head(&pkt).unwrap();
let json = serde_json::to_string(&head).unwrap();
let head2: OpusHead = serde_json::from_str(&json).unwrap();
assert_eq!(head, head2);
}
#[test]
fn find_last_granule_none_on_empty() {
assert_eq!(find_last_granule(&[]), None);
}
#[test]
fn find_last_granule_finds_page() {
let mut data = Vec::new();
data.extend_from_slice(b"OggS");
data.push(0); data.push(0x04); let granule: i64 = 96000;
data.extend_from_slice(&granule.to_le_bytes());
data.resize(40, 0);
assert_eq!(find_last_granule(&data), Some(96000));
}
#[test]
fn opus_tags_reject_short() {
let pkt = b"Opus";
assert!(parse_opus_tags(pkt).is_err());
}
#[test]
fn opus_tags_reject_wrong_magic() {
let pkt = b"NotOpusTags_data";
assert!(parse_opus_tags(pkt).is_err());
}
#[cfg(not(feature = "tag"))]
#[test]
fn opus_tags_no_tag_feature() {
let mut pkt = Vec::new();
pkt.extend_from_slice(b"OpusTags");
pkt.extend_from_slice(&[0; 20]);
assert!(parse_opus_tags(&pkt).is_ok());
}
#[test]
fn serialize_opus_head_roundtrip() {
let serialized = serialize_opus_head(2, 312, 48000);
let parsed = parse_opus_head(&serialized).unwrap();
assert_eq!(parsed.version, 1);
assert_eq!(parsed.channel_count, 2);
assert_eq!(parsed.pre_skip, 312);
assert_eq!(parsed.input_sample_rate, 48000);
assert_eq!(parsed.output_gain, 0);
assert_eq!(parsed.channel_mapping_family, 0);
}
#[test]
fn serialize_opus_tags_valid() {
let tags = serialize_opus_tags();
assert!(tags.starts_with(b"OpusTags"));
assert!(parse_opus_tags(&tags).is_ok());
}
#[test]
fn encode_rejects_wrong_sample_rate() {
let samples = vec![0.0f32; 960];
assert!(encode(&samples, 44100, 1, 64000).is_err());
}
#[test]
fn encode_rejects_zero_channels() {
let samples = vec![0.0f32; 960];
assert!(encode(&samples, 48000, 0, 64000).is_err());
}
#[test]
fn encode_rejects_too_many_channels() {
let samples = vec![0.0f32; 960 * 3];
assert!(encode(&samples, 48000, 3, 64000).is_err());
}
#[test]
fn encode_rejects_invalid_bitrate() {
let samples = vec![0.0f32; 960];
assert!(encode(&samples, 48000, 1, 1000).is_err());
assert!(encode(&samples, 48000, 1, 500000).is_err());
}
#[test]
fn encode_silence_mono() {
let samples = vec![0.0f32; 48000]; let ogg_data = encode(&samples, 48000, 1, 64000).unwrap();
assert!(ogg_data.starts_with(b"OggS"));
let packets = crate::ogg::extract_packets(&ogg_data).unwrap();
assert!(packets.len() >= 3);
let head = parse_opus_head(&packets[0]).unwrap();
assert_eq!(head.channel_count, 1);
assert_eq!(head.pre_skip, ENCODER_PRE_SKIP);
assert_eq!(head.input_sample_rate, 48000);
assert!(packets[1].starts_with(b"OpusTags"));
}
#[test]
fn encode_silence_stereo() {
let samples = vec![0.0f32; 48000 * 2]; let ogg_data = encode(&samples, 48000, 2, 128000).unwrap();
let packets = crate::ogg::extract_packets(&ogg_data).unwrap();
let head = parse_opus_head(&packets[0]).unwrap();
assert_eq!(head.channel_count, 2);
}
#[test]
fn encode_sine_wave() {
let samples: Vec<f32> = (0..48000)
.map(|i| libm::sinf(2.0 * core::f32::consts::PI * 440.0 * i as f32 / 48000.0))
.collect();
let ogg_data = encode(&samples, 48000, 1, 96000).unwrap();
assert!(ogg_data.starts_with(b"OggS"));
assert!(ogg_data.len() > 100); }
#[test]
fn encode_empty_input() {
let samples: Vec<f32> = Vec::new();
let ogg_data = encode(&samples, 48000, 1, 64000).unwrap();
let packets = crate::ogg::extract_packets(&ogg_data).unwrap();
assert!(packets.len() >= 3);
}
#[test]
fn encode_short_input_padded() {
let samples = vec![0.5f32; 100];
let ogg_data = encode(&samples, 48000, 1, 64000).unwrap();
let packets = crate::ogg::extract_packets(&ogg_data).unwrap();
assert!(packets.len() >= 3);
}
#[test]
fn encode_decode_headers_match() {
let samples = vec![0.0f32; 9600]; let ogg_data = encode(&samples, 48000, 1, 64000).unwrap();
let (info, _) = decode(&ogg_data).unwrap();
assert_eq!(info.format, AudioFormat::Opus);
assert_eq!(info.sample_rate, 48000);
assert_eq!(info.channels, 1);
}
#[test]
fn encode_granule_positions_increase() {
let samples = vec![0.0f32; 48000 * 2]; let ogg_data = encode(&samples, 48000, 1, 64000).unwrap();
let granule = find_last_granule(&ogg_data);
assert!(granule.is_some());
let g = granule.unwrap();
assert!(g > 48000, "granule should be > 48000, got {g}");
}
#[test]
fn range_encoder_basic() {
let mut rc = RangeEncoder::new();
rc.encode_bit(true);
rc.encode_bit(false);
rc.encode_uint(3, 8);
let bytes = rc.finish();
assert!(!bytes.is_empty());
}
}