use std::{io, time::Duration};
use rodio::source::SeekError;
use symphonia::{
core::{
audio::SampleBuffer,
codecs::{CodecParameters, CodecRegistry, DecoderOptions},
errors::Error as SymphoniaError,
formats::{FormatOptions, FormatReader, SeekMode, SeekTo},
io::{MediaSourceStream, MediaSourceStreamOptions},
meta::{MetadataOptions, StandardTagKey, Value},
probe::{Hint, Probe},
},
default::{
codecs::{AacDecoder, FlacDecoder, MpaDecoder, PcmDecoder},
formats::{AdtsReader, FlacReader, IsoMp4Reader, MpaReader, WavReader},
},
};
use crate::{
audio_file::{AudioFile, BUFFER_LEN},
error::{Error, Result},
normalize::{self, Normalize},
player::SampleFormat,
protocol::Codec,
track::{Track, DEFAULT_SAMPLE_RATE},
util::ToF32,
};
pub struct Decoder {
demuxer: Box<dyn FormatReader>,
decoder: Box<dyn symphonia::core::codecs::Decoder>,
buffer: Option<SampleBuffer<SampleFormat>>,
track_id: u32,
position: usize,
channels: u16,
sample_rate: u32,
total_duration: Option<Duration>,
total_samples: Option<usize>,
max_frame_length: Option<usize>,
}
const MAX_RETRIES: usize = 3;
impl Decoder {
pub fn new(track: &Track, file: AudioFile) -> Result<Self> {
let buffer_len = usize::max(64 * 1024, BUFFER_LEN * 2);
let stream =
MediaSourceStream::new(Box::new(file), MediaSourceStreamOptions { buffer_len });
let mut hint = Hint::new();
let mut codecs = CodecRegistry::default();
let mut probes = Probe::default();
let (codecs, probe) = if let Some(codec) = track.codec() {
match codec {
Codec::ADTS => {
codecs.register_all::<AacDecoder>();
probes.register_all::<AdtsReader>();
}
Codec::FLAC => {
codecs.register_all::<FlacDecoder>();
probes.register_all::<FlacReader>();
}
Codec::MP3 => {
codecs.register_all::<MpaDecoder>();
probes.register_all::<MpaReader>();
}
Codec::MP4 => {
codecs.register_all::<AacDecoder>();
probes.register_all::<IsoMp4Reader>();
}
Codec::WAV => {
codecs.register_all::<PcmDecoder>();
probes.register_all::<WavReader>();
}
}
hint.with_extension(codec.extension());
hint.mime_type(codec.mime_type());
(&codecs, &probes)
} else {
(
symphonia::default::get_codecs(),
symphonia::default::get_probe(),
)
};
let demuxer = probe
.format(
&hint,
stream,
&FormatOptions {
enable_gapless: true,
..Default::default()
},
&MetadataOptions::default(),
)?
.format;
let default_track = demuxer
.default_track()
.ok_or_else(|| Error::not_found("default track not found"))?;
let track_id = default_track.id;
let codec_params = &default_track.codec_params;
let decoder = codecs.make(codec_params, &DecoderOptions::default())?;
let codec_params = decoder.codec_params();
let total_duration = Self::calc_total_duration(codec_params);
let channels = Self::calc_channels(codec_params).unwrap_or(track.typ().default_channels());
let sample_rate = Self::calc_sample_rate(codec_params);
let max_frame_length = track
.codec()
.map(|codec| codec.max_frame_length(sample_rate, channels));
let total_samples = Self::calc_total_samples(codec_params, max_frame_length);
Ok(Self {
demuxer,
decoder,
buffer: None,
position: 0,
track_id,
channels,
sample_rate,
total_duration,
total_samples,
max_frame_length,
})
}
#[must_use]
pub fn normalize(
self,
ratio: f32,
threshold: f32,
knee_width: f32,
attack: Duration,
release: Duration,
) -> Normalize<Self>
where
Self: Sized,
{
normalize::normalize(self, ratio, threshold, knee_width, attack, release)
}
pub fn replay_gain(&mut self) -> Option<f32> {
self.demuxer
.metadata()
.skip_to_latest()
.and_then(|metadata| {
for tag in metadata.tags() {
if tag
.std_key
.is_some_and(|key| key == StandardTagKey::ReplayGainTrackGain)
{
if let Value::Float(gain) = tag.value {
return Some(gain.to_f32_lossy());
}
}
}
None
})
}
#[must_use]
pub fn bits_per_sample(&self) -> Option<u32> {
self.decoder.codec_params().bits_per_sample
}
#[must_use]
fn calc_channels(codec_params: &CodecParameters) -> Option<u16> {
codec_params
.channels
.map(|channels| channels.count().try_into().expect("channel count overflow"))
}
#[must_use]
fn calc_sample_rate(codec_params: &CodecParameters) -> u32 {
codec_params.sample_rate.unwrap_or(DEFAULT_SAMPLE_RATE)
}
#[must_use]
fn calc_total_samples(
codec_params: &CodecParameters,
max_frame_length: Option<usize>,
) -> Option<usize> {
if let (Some(n_frames), Some(max_frame_length)) = (codec_params.n_frames, max_frame_length)
{
usize::try_from(n_frames)
.ok()
.and_then(|frames| frames.checked_mul(max_frame_length))
} else {
None
}
}
#[must_use]
fn calc_total_duration(codec_params: &CodecParameters) -> Option<Duration> {
if let (Some(time_base), Some(frames)) = (codec_params.time_base, codec_params.n_frames) {
Some(time_base.calc_time(frames).into())
} else {
None
}
}
fn reload_spec(&mut self) {
let codec_params = self.decoder.codec_params();
self.sample_rate = Self::calc_sample_rate(codec_params);
self.total_samples = Self::calc_total_samples(codec_params, self.max_frame_length);
self.total_duration = Self::calc_total_duration(codec_params);
if let Some(channels) = Self::calc_channels(codec_params) {
self.channels = channels;
}
self.buffer = None;
debug!(
"decoder reloaded with sample rate: {} kHz; channels: {}",
self.sample_rate, self.channels,
);
}
fn get_next_packet(&mut self) -> Result<u64> {
let mut discarded = 0;
loop {
if discarded > MAX_RETRIES {
break Err(Error::cancelled("discarded too many packets, giving up"));
}
if discarded > 0 {
if let Some(buffer) = self.buffer.as_mut() {
buffer.clear();
}
}
discarded += 1;
match self.demuxer.next_packet() {
Ok(packet) => {
if packet.track_id() != self.track_id {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"track id mismatch",
)
.into());
}
let decoded = match self.decoder.decode(&packet) {
Ok(decoded) => decoded,
Err(SymphoniaError::DecodeError(e)) => {
error!("discarding malformed packet: {e}");
continue;
}
Err(SymphoniaError::IoError(e)) => {
error!("discarding unreadable packet: {e}");
continue;
}
Err(SymphoniaError::ResetRequired) => {
self.decoder.reset();
self.reload_spec();
continue;
}
Err(e) => {
break Err(e.into());
}
};
let buffer = match self.buffer.as_mut() {
Some(buffer) => buffer,
None => {
self.buffer.insert(SampleBuffer::new(
decoded.capacity() as u64,
*decoded.spec(),
))
}
};
buffer.copy_interleaved_ref(decoded);
self.position = 0;
break Ok(packet.dur());
}
Err(SymphoniaError::ResetRequired) => {
trace!("re-creating decoder");
let track = self
.demuxer
.default_track()
.ok_or_else(|| Error::not_found("default track not found"))?;
let codecs = symphonia::default::get_codecs();
self.decoder = codecs.make(&track.codec_params, &DecoderOptions::default())?;
self.reload_spec();
continue;
}
Err(e) => {
break Err(e.into());
}
}
}
}
#[expect(clippy::cast_possible_truncation)]
#[expect(clippy::cast_sign_loss)]
fn ts_to_samples(&self, ts: u64) -> Option<usize> {
if ts == 0 {
Some(0)
} else {
self.decoder.codec_params().time_base.map(|time_base| {
(Duration::from(time_base.calc_time(ts)).as_secs_f32()
* self.sample_rate.to_f32_lossy()
* f32::from(self.channels))
.ceil() as usize
})
}
}
}
impl rodio::Source for Decoder {
fn current_frame_len(&self) -> Option<usize> {
self.buffer.as_ref().map(SampleBuffer::len)
}
fn channels(&self) -> u16 {
self.channels
}
fn sample_rate(&self) -> u32 {
self.sample_rate
}
fn total_duration(&self) -> Option<Duration> {
self.total_duration
}
fn try_seek(&mut self, pos: Duration) -> std::result::Result<(), SeekError> {
let mut target = pos;
if let Some(total_duration) = self.total_duration {
if target > total_duration {
target = total_duration;
}
}
let active_channel = self.position % self.channels as usize;
let seek_res = self
.demuxer
.seek(
SeekMode::Accurate,
SeekTo::Time {
track_id: Some(self.track_id),
time: target.into(),
},
)
.map_err(|e| SeekError::Other(Box::new(e)))?;
self.decoder.reset();
self.position = usize::MAX;
let mut samples_to_skip = 0;
let time_gap = seek_res.required_ts.saturating_sub(seek_res.actual_ts);
if let Some(mut num_samples) = self.ts_to_samples(time_gap) {
num_samples -= num_samples % self.channels as usize;
samples_to_skip = num_samples;
}
for _ in 0..(samples_to_skip + active_channel) {
self.next();
}
Ok(())
}
}
impl Iterator for Decoder {
type Item = SampleFormat;
fn next(&mut self) -> Option<Self::Item> {
if let Some(buffer) = &self.buffer {
if self.position < buffer.len() {
let sample = buffer.samples()[self.position];
self.position += 1;
return Some(sample);
}
}
match self.get_next_packet() {
Ok(_) => {
if let Some(buffer) = &self.buffer {
if !buffer.is_empty() {
let sample = buffer.samples()[0];
self.position = 1;
return Some(sample);
}
}
self.next()
}
Err(e) => {
self.buffer = None;
if e.downcast::<io::Error>()
.is_none_or(|e| e.kind() != std::io::ErrorKind::UnexpectedEof)
{
error!("{e}");
}
None
}
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(0, self.total_samples)
}
}