use std::collections::VecDeque;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::time::Duration;
use anyhow::{Context, Result, bail};
use reqwest::Client;
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::{CODEC_TYPE_NULL, CodecParameters, DecoderOptions};
use symphonia::core::errors::Error as SymphoniaError;
use symphonia::core::formats::{FormatOptions, FormatReader};
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
use tempfile::{Builder, NamedTempFile};
use tokio::fs;
use url::Url;
use crate::video::is_video_file;
pub struct PreparedAudio {
pub display_name: String,
pub metadata: AudioMetadata,
pub samples: Vec<f32>,
_temp_file: Option<NamedTempFile>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RawAudioSpec {
pub sample_rate: u32,
pub channels: u16,
pub encoding: RawAudioEncoding,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RawAudioEncoding {
F32Le,
}
#[derive(Debug)]
pub struct RawAudioStreamDecoder {
spec: RawAudioSpec,
target_sample_rate: u32,
pending_bytes: Vec<u8>,
source_buffer: VecDeque<f32>,
next_source_position: f64,
}
#[derive(Debug, Clone)]
pub struct AudioMetadata {
pub source_sample_rate: Option<u32>,
pub target_sample_rate: u32,
pub channels: Option<u16>,
pub duration: Option<Duration>,
pub codec: String,
}
pub async fn prepare_audio_source(input: &str) -> Result<PreparedAudio> {
prepare_audio_source_for_rate(input, 16_000).await
}
pub async fn prepare_audio_source_for_rate(
input: &str,
target_sample_rate: u32,
) -> Result<PreparedAudio> {
if let Ok(url) = Url::parse(input) {
if matches!(url.scheme(), "http" | "https") {
return download_remote_audio(url, target_sample_rate).await;
}
}
let path = resolve_local_audio_path(input).await?;
let (metadata, samples) = inspect_audio_file(&path, target_sample_rate)?;
Ok(PreparedAudio {
display_name: path.display().to_string(),
metadata,
samples,
_temp_file: None,
})
}
async fn resolve_local_audio_path(input: &str) -> Result<std::path::PathBuf> {
let requested_path = Path::new(input);
let resolved_path = if requested_path.is_absolute() {
requested_path.to_path_buf()
} else {
std::env::current_dir()
.context("failed to resolve current working directory")?
.join(requested_path)
};
fs::canonicalize(&resolved_path).await.with_context(|| {
format!(
"failed to resolve audio path `{}` from current directory",
input
)
})
}
async fn download_remote_audio(url: Url, target_sample_rate: u32) -> Result<PreparedAudio> {
let suffix = url
.path_segments()
.and_then(|segments| segments.last())
.and_then(|name| Path::new(name).extension())
.and_then(|ext| ext.to_str())
.map(|ext| format!(".{ext}"))
.unwrap_or_else(|| ".audio".to_string());
let mut temp_file = Builder::new()
.prefix("transcribe-cli-")
.suffix(&suffix)
.tempfile()
.context("failed to create temporary audio file")?;
let client = Client::builder()
.user_agent("transcribe-cli/0.1.0")
.build()
.context("failed to build HTTP client")?;
let response = client
.get(url.clone())
.send()
.await
.with_context(|| format!("failed to download audio from `{url}`"))?
.error_for_status()
.with_context(|| format!("audio download returned an error for `{url}`"))?;
let bytes = response
.bytes()
.await
.with_context(|| format!("failed to read audio body from `{url}`"))?;
temp_file
.write_all(bytes.as_ref())
.context("failed to save downloaded audio")?;
let local_path = temp_file.path().to_path_buf();
let (metadata, samples) = inspect_audio_file(&local_path, target_sample_rate)?;
Ok(PreparedAudio {
display_name: url.to_string(),
metadata,
samples,
_temp_file: Some(temp_file),
})
}
pub(crate) fn inspect_audio_file(
path: &Path,
target_sample_rate: u32,
) -> Result<(AudioMetadata, Vec<f32>)> {
let file = File::open(path)
.with_context(|| format!("failed to open audio file `{}`", path.display()))?;
let source = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
hint.with_extension(extension);
}
let probe = symphonia::default::get_probe()
.format(
&hint,
source,
&FormatOptions::default(),
&MetadataOptions::default(),
)
.map_err(|error| {
anyhow::anyhow!(
"failed to parse audio input `{}`: {error}. This build uses the Rust Symphonia core with support for mp3, wav, flac, ogg, mkv/webm audio, m4a/mp4, aac, caf, aiff, alac, adpcm, and pcm",
path.display()
)
})?;
let mut format = probe.format;
let (mut codec_params, track_id) = select_audio_track(format.as_ref()).with_context(|| {
if is_video_file(path) {
format!(
"failed to find a supported audio track in video input `{}`",
path.display()
)
} else {
format!(
"failed to find a supported audio track in `{}`",
path.display()
)
}
})?;
let mut decoder = symphonia::default::get_codecs()
.make(&codec_params, &DecoderOptions::default())
.context("failed to create audio decoder")?;
let mut source_sample_rate = codec_params.sample_rate;
let mut source_channels = codec_params.channels.map(|channels| channels.count());
let mut mono_samples = Vec::new();
loop {
let packet = match format.next_packet() {
Ok(packet) => packet,
Err(SymphoniaError::IoError(error))
if error.kind() == std::io::ErrorKind::UnexpectedEof =>
{
break;
}
Err(SymphoniaError::ResetRequired) => {
bail!("audio stream reset is required and is not supported")
}
Err(error) => {
return Err(error)
.with_context(|| format!("failed to read packet from `{}`", path.display()));
}
};
if packet.track_id() != track_id {
continue;
}
let decoded = match decoder.decode(&packet) {
Ok(decoded) => decoded,
Err(SymphoniaError::DecodeError(_)) => continue,
Err(error) => {
return Err(error)
.with_context(|| format!("failed to decode `{}`", path.display()));
}
};
let spec = *decoded.spec();
if source_sample_rate.is_none() {
source_sample_rate = Some(spec.rate);
codec_params.with_sample_rate(spec.rate);
}
if source_channels.is_none() {
source_channels = Some(spec.channels.count());
codec_params.with_channels(spec.channels);
}
let source_channels =
source_channels.context("audio stream does not expose channel information")?;
let mut sample_buffer = SampleBuffer::<f32>::new(decoded.capacity() as u64, spec);
sample_buffer.copy_interleaved_ref(decoded);
for frame in sample_buffer.samples().chunks(source_channels) {
let mono = frame.iter().copied().sum::<f32>() / source_channels as f32;
mono_samples.push(mono.clamp(-1.0, 1.0));
}
}
let source_sample_rate =
source_sample_rate.context("audio stream does not expose a sample rate")?;
let duration = codec_params
.n_frames
.zip(codec_params.sample_rate)
.map(|(frames, sample_rate)| Duration::from_secs_f64(frames as f64 / sample_rate as f64));
let resampled_samples = if source_sample_rate == target_sample_rate {
mono_samples
} else {
linear_resample(&mono_samples, source_sample_rate, target_sample_rate)
};
Ok((
extract_metadata(
&codec_params,
duration,
source_sample_rate,
target_sample_rate,
),
resampled_samples,
))
}
pub(crate) fn inspect_raw_audio_file(
path: &Path,
target_sample_rate: u32,
spec: &RawAudioSpec,
) -> Result<(AudioMetadata, Vec<f32>)> {
if spec.sample_rate == 0 {
bail!("raw audio sample rate must be greater than zero");
}
if spec.channels == 0 {
bail!("raw audio channel count must be greater than zero");
}
let bytes = std::fs::read(path)
.with_context(|| format!("failed to read raw audio input `{}`", path.display()))?;
let channels = spec.channels as usize;
let mono_samples = match spec.encoding {
RawAudioEncoding::F32Le => decode_raw_f32le_to_mono(&bytes, channels)?,
};
let duration = Some(Duration::from_secs_f64(
mono_samples.len() as f64 / spec.sample_rate as f64,
));
let resampled_samples = if spec.sample_rate == target_sample_rate {
mono_samples
} else {
linear_resample(&mono_samples, spec.sample_rate, target_sample_rate)
};
Ok((
AudioMetadata {
source_sample_rate: Some(spec.sample_rate),
target_sample_rate,
channels: Some(spec.channels),
duration,
codec: format!("raw {:?}", spec.encoding),
},
resampled_samples,
))
}
impl RawAudioStreamDecoder {
pub fn new(spec: RawAudioSpec, target_sample_rate: u32) -> Result<Self> {
if spec.sample_rate == 0 {
bail!("raw audio sample rate must be greater than zero");
}
if spec.channels == 0 {
bail!("raw audio channel count must be greater than zero");
}
Ok(Self {
spec,
target_sample_rate,
pending_bytes: Vec::new(),
source_buffer: VecDeque::new(),
next_source_position: 0.0,
})
}
pub fn push_bytes(&mut self, bytes: &[u8]) -> Result<Vec<f32>> {
self.pending_bytes.extend_from_slice(bytes);
let frame_bytes = self.spec.channels as usize * 4;
let usable_len = self.pending_bytes.len() - (self.pending_bytes.len() % frame_bytes);
if usable_len == 0 {
return Ok(Vec::new());
}
let chunk = self.pending_bytes.drain(..usable_len).collect::<Vec<_>>();
let mono_samples = match self.spec.encoding {
RawAudioEncoding::F32Le => {
decode_raw_f32le_to_mono(&chunk, self.spec.channels as usize)?
}
};
self.source_buffer.extend(mono_samples);
Ok(self.pull_resampled_samples(false))
}
pub fn finish(&mut self) -> Result<Vec<f32>> {
if !self.pending_bytes.is_empty() {
let frame_bytes = self.spec.channels as usize * 4;
let usable_len = self.pending_bytes.len() - (self.pending_bytes.len() % frame_bytes);
if usable_len > 0 {
let chunk = self.pending_bytes.drain(..usable_len).collect::<Vec<_>>();
let mono_samples = match self.spec.encoding {
RawAudioEncoding::F32Le => {
decode_raw_f32le_to_mono(&chunk, self.spec.channels as usize)?
}
};
self.source_buffer.extend(mono_samples);
}
self.pending_bytes.clear();
}
Ok(self.pull_resampled_samples(true))
}
fn pull_resampled_samples(&mut self, finalize: bool) -> Vec<f32> {
if self.source_buffer.is_empty() {
return Vec::new();
}
if self.spec.sample_rate == self.target_sample_rate {
let samples = self.source_buffer.drain(..).collect::<Vec<_>>();
self.next_source_position = 0.0;
return samples;
}
let step = self.spec.sample_rate as f64 / self.target_sample_rate as f64;
let mut output = Vec::new();
let available_len = self.source_buffer.len() as f64;
loop {
let required_right = self.next_source_position.floor() as usize + 1;
if required_right >= self.source_buffer.len() {
if finalize
&& self.next_source_position < available_len
&& !self.source_buffer.is_empty()
{
output.push(*self.source_buffer.back().unwrap_or(&0.0));
self.next_source_position += step;
continue;
}
break;
}
let left_index = self.next_source_position.floor() as usize;
let right_index = left_index + 1;
let fraction = (self.next_source_position - left_index as f64) as f32;
let left = self.source_buffer[left_index];
let right = self.source_buffer[right_index];
output.push(left + (right - left) * fraction);
self.next_source_position += step;
}
let drop_count = self.next_source_position.floor() as usize;
for _ in 0..drop_count.min(self.source_buffer.len()) {
self.source_buffer.pop_front();
}
self.next_source_position -= drop_count as f64;
if finalize {
self.source_buffer.clear();
self.next_source_position = 0.0;
}
output
}
}
fn select_audio_track(format: &dyn FormatReader) -> Result<(CodecParameters, u32)> {
let decoder_options = DecoderOptions::default();
select_audio_track_with(format, |codec_params| {
codec_params.codec != CODEC_TYPE_NULL
&& symphonia::default::get_codecs()
.make(codec_params, &decoder_options)
.is_ok()
})
}
fn select_audio_track_with<F>(
format: &dyn FormatReader,
is_supported: F,
) -> Result<(CodecParameters, u32)>
where
F: Fn(&CodecParameters) -> bool,
{
if let Some(track) = format
.tracks()
.iter()
.find(|track| is_supported(&track.codec_params))
{
return Ok((track.codec_params.clone(), track.id));
}
bail!("media input does not contain a decodable audio track")
}
fn extract_metadata(
codec: &CodecParameters,
duration: Option<Duration>,
source_sample_rate: u32,
target_sample_rate: u32,
) -> AudioMetadata {
AudioMetadata {
source_sample_rate: Some(source_sample_rate),
target_sample_rate,
channels: codec.channels.map(|channels| channels.count() as u16),
duration,
codec: format!("{:?}", codec.codec),
}
}
fn linear_resample(samples: &[f32], source_rate: u32, target_rate: u32) -> Vec<f32> {
if samples.is_empty() || source_rate == target_rate {
return samples.to_vec();
}
let ratio = target_rate as f64 / source_rate as f64;
let output_len = ((samples.len() as f64) * ratio).round() as usize;
let mut resampled = Vec::with_capacity(output_len);
for index in 0..output_len {
let source_position = index as f64 / ratio;
let left_index = source_position.floor() as usize;
let right_index = (left_index + 1).min(samples.len().saturating_sub(1));
let fraction = (source_position - left_index as f64) as f32;
let left = samples[left_index];
let right = samples[right_index];
resampled.push(left + (right - left) * fraction);
}
resampled
}
fn decode_raw_f32le_to_mono(bytes: &[u8], channels: usize) -> Result<Vec<f32>> {
let aligned_len = bytes.len() - (bytes.len() % 4);
if aligned_len == 0 {
bail!("not enough buffered raw audio data yet");
}
let sample_count = aligned_len / 4;
let frame_count = sample_count / channels;
if frame_count == 0 {
bail!("not enough buffered raw audio frames yet");
}
let usable_samples = frame_count * channels;
let mut mono_samples = Vec::with_capacity(frame_count);
for frame_index in 0..frame_count {
let mut mixed = 0.0f32;
for channel_index in 0..channels {
let sample_index = frame_index * channels + channel_index;
if sample_index >= usable_samples {
break;
}
let start = sample_index * 4;
let value = f32::from_le_bytes(
bytes[start..start + 4]
.try_into()
.context("failed to read raw float32 sample")?,
);
mixed += value;
}
mono_samples.push((mixed / channels as f32).clamp(-1.0, 1.0));
}
Ok(mono_samples)
}
#[cfg(test)]
mod tests {
use std::io::Write;
use symphonia::core::codecs::{CODEC_TYPE_NULL, CodecParameters, decl_codec_type};
use symphonia::core::formats::{
Cue, FormatOptions, FormatReader, Packet, SeekMode, SeekTo, Track,
};
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::{Metadata, MetadataLog};
use tempfile::NamedTempFile;
use super::{RawAudioEncoding, RawAudioSpec, inspect_raw_audio_file, select_audio_track_with};
struct TestFormatReader {
tracks: Vec<Track>,
metadata: MetadataLog,
}
impl FormatReader for TestFormatReader {
fn try_new(
_source: MediaSourceStream,
_options: &FormatOptions,
) -> symphonia::core::errors::Result<Self>
where
Self: Sized,
{
unreachable!()
}
fn cues(&self) -> &[Cue] {
&[]
}
fn metadata(&mut self) -> Metadata<'_> {
self.metadata.metadata()
}
fn seek(
&mut self,
_mode: SeekMode,
_to: SeekTo,
) -> symphonia::core::errors::Result<symphonia::core::formats::SeekedTo> {
unreachable!()
}
fn next_packet(&mut self) -> symphonia::core::errors::Result<Packet> {
unreachable!()
}
fn tracks(&self) -> &[Track] {
&self.tracks
}
fn into_inner(self: Box<Self>) -> MediaSourceStream {
unreachable!()
}
}
#[test]
fn selects_first_decodable_audio_track_instead_of_default_track() {
let fake_audio_codec = decl_codec_type(b"aud");
let video_track = Track::new(
1,
CodecParameters::new()
.for_codec(decl_codec_type(b"vid"))
.clone(),
);
let audio_track = Track::new(
2,
CodecParameters::new().for_codec(fake_audio_codec).clone(),
);
let format = TestFormatReader {
tracks: vec![video_track, audio_track],
metadata: MetadataLog::default(),
};
let (codec_params, track_id) = select_audio_track_with(&format, |codec_params| {
codec_params.codec == fake_audio_codec
})
.expect("audio track");
assert_eq!(track_id, 2);
assert_eq!(codec_params.codec, fake_audio_codec);
}
#[test]
fn rejects_inputs_without_decodable_audio_tracks() {
let unknown_track =
Track::new(1, CodecParameters::new().for_codec(CODEC_TYPE_NULL).clone());
let format = TestFormatReader {
tracks: vec![unknown_track],
metadata: MetadataLog::default(),
};
assert!(select_audio_track_with(&format, |_| false).is_err());
}
#[test]
fn decodes_raw_f32le_mono_input() {
let mut temp = NamedTempFile::new().expect("temp file");
for sample in [0.25f32, -0.5, 1.5] {
temp.write_all(&sample.to_le_bytes()).expect("write sample");
}
let (metadata, samples) = inspect_raw_audio_file(
temp.path(),
16_000,
&RawAudioSpec {
sample_rate: 16_000,
channels: 1,
encoding: RawAudioEncoding::F32Le,
},
)
.expect("raw audio");
assert_eq!(metadata.source_sample_rate, Some(16_000));
assert_eq!(metadata.channels, Some(1));
assert_eq!(samples, vec![0.25, -0.5, 1.0]);
}
}