use std::num::NonZeroU32;
use audio_samples::{AudioSample, AudioSamples, ConvertTo, I24};
use ndarray::{Array1, Array2};
use crate::error::{AudioIOError, AudioIOResult};
#[derive(Debug, Clone)]
pub struct DecodedAudio {
channels: Vec<Vec<i32>>,
bits_per_sample: u8,
sample_rate: u32,
}
impl DecodedAudio {
pub fn new(channels: Vec<Vec<i32>>, bits_per_sample: u8, sample_rate: u32) -> Self {
DecodedAudio {
channels,
bits_per_sample,
sample_rate,
}
}
pub fn num_channels(&self) -> usize {
self.channels.len()
}
pub fn samples_per_channel(&self) -> usize {
self.channels.first().map(|c| c.len()).unwrap_or(0)
}
pub fn total_samples(&self) -> usize {
self.num_channels() * self.samples_per_channel()
}
pub fn bits_per_sample(&self) -> u8 {
self.bits_per_sample
}
pub fn sample_rate(&self) -> u32 {
self.sample_rate
}
pub fn read_samples<'a, T>(&self, sample_rate: NonZeroU32) -> AudioIOResult<AudioSamples<'a, T>>
where
T: AudioSample + 'static,
i16: ConvertTo<T>,
I24: ConvertTo<T>,
i32: ConvertTo<T>,
f32: ConvertTo<T>,
f64: ConvertTo<T>,
{
let num_channels = self.num_channels();
let samples_per_channel = self.samples_per_channel();
if num_channels == 0 || samples_per_channel == 0 {
return Err(AudioIOError::corrupted_data_simple(
"Empty audio data",
"No channels or samples",
));
}
let converted_channels: Vec<Vec<T>> = self
.channels
.iter()
.map(|ch| self.convert_channel_samples::<T>(ch))
.collect::<AudioIOResult<_>>()?;
if num_channels == 1 {
let data = Array1::from(converted_channels.into_iter().next().unwrap());
Ok(AudioSamples::new_mono(data, sample_rate))
} else {
let flat: Vec<T> = converted_channels.into_iter().flatten().collect();
let arr =
Array2::from_shape_vec((num_channels, samples_per_channel), flat).map_err(|e| {
AudioIOError::corrupted_data_simple("Array shape error", e.to_string())
})?;
Ok(AudioSamples::new_multi_channel(arr, sample_rate))
}
}
pub fn read_samples_planar<T>(&self) -> AudioIOResult<Vec<T>>
where
T: AudioSample + 'static,
i16: ConvertTo<T>,
I24: ConvertTo<T>,
i32: ConvertTo<T>,
f32: ConvertTo<T>,
f64: ConvertTo<T>,
{
let mut result = Vec::with_capacity(self.total_samples());
for channel in &self.channels {
let converted = self.convert_channel_samples::<T>(channel)?;
result.extend(converted);
}
Ok(result)
}
pub fn read_channel_samples<T>(&self, channel: usize) -> AudioIOResult<Vec<T>>
where
T: AudioSample + 'static,
i16: ConvertTo<T>,
I24: ConvertTo<T>,
i32: ConvertTo<T>,
f32: ConvertTo<T>,
f64: ConvertTo<T>,
{
let samples = self.channels.get(channel).ok_or_else(|| {
AudioIOError::corrupted_data_simple(
"Channel index out of bounds",
format!(
"Requested channel {}, have {}",
channel,
self.num_channels()
),
)
})?;
self.convert_channel_samples::<T>(samples)
}
pub fn read_samples_interleaved<T>(&self) -> AudioIOResult<Vec<T>>
where
T: AudioSample + 'static,
i16: ConvertTo<T>,
I24: ConvertTo<T>,
i32: ConvertTo<T>,
f32: ConvertTo<T>,
f64: ConvertTo<T>,
{
let num_channels = self.num_channels();
let samples_per_channel = self.samples_per_channel();
if num_channels == 0 || samples_per_channel == 0 {
return Ok(Vec::new());
}
let mut result = Vec::with_capacity(num_channels * samples_per_channel);
let converted_channels: Vec<Vec<T>> = self
.channels
.iter()
.map(|ch| self.convert_channel_samples::<T>(ch))
.collect::<AudioIOResult<_>>()?;
for i in 0..samples_per_channel {
for ch in &converted_channels {
result.push(ch[i]);
}
}
Ok(result)
}
fn convert_channel_samples<T>(&self, samples: &[i32]) -> AudioIOResult<Vec<T>>
where
T: AudioSample + 'static,
i16: ConvertTo<T>,
I24: ConvertTo<T>,
i32: ConvertTo<T>,
f32: ConvertTo<T>,
f64: ConvertTo<T>,
{
match self.bits_per_sample {
1..=8 => {
let shift = 16 - self.bits_per_sample;
Ok(samples
.iter()
.map(|&s| {
let scaled = (s << shift) as i16;
T::convert_from(scaled)
})
.collect())
}
9..=16 => {
Ok(samples.iter().map(|&s| T::convert_from(s as i16)).collect())
}
17..=24 => {
Ok(samples
.iter()
.map(|&s| T::convert_from(I24::wrapping_from_i32(s)))
.collect())
}
25..=32 => {
Ok(samples.iter().map(|&s| T::convert_from(s)).collect())
}
_ => Err(AudioIOError::corrupted_data_simple(
"Invalid bits per sample",
format!("{} bits", self.bits_per_sample),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoded_audio_basic() {
let channels = vec![vec![1000i32, 2000, 3000], vec![-1000i32, -2000, -3000]];
let audio = DecodedAudio::new(channels, 16, 44100);
assert_eq!(audio.num_channels(), 2);
assert_eq!(audio.samples_per_channel(), 3);
assert_eq!(audio.total_samples(), 6);
assert_eq!(audio.bits_per_sample(), 16);
assert_eq!(audio.sample_rate(), 44100);
}
#[test]
fn test_read_samples_planar_i16() {
let channels = vec![vec![1000i32, 2000], vec![-1000i32, -2000]];
let audio = DecodedAudio::new(channels, 16, 44100);
let samples: Vec<i16> = audio.read_samples_planar().unwrap();
assert_eq!(samples, vec![1000i16, 2000, -1000, -2000]);
}
#[test]
fn test_read_samples_interleaved() {
let channels = vec![vec![100i32, 200], vec![300i32, 400]];
let audio = DecodedAudio::new(channels, 16, 44100);
let samples: Vec<i16> = audio.read_samples_interleaved().unwrap();
assert_eq!(samples, vec![100i16, 300, 200, 400]);
}
#[test]
fn test_read_channel_samples() {
let channels = vec![vec![100i32, 200], vec![300i32, 400]];
let audio = DecodedAudio::new(channels, 16, 44100);
let ch0: Vec<i16> = audio.read_channel_samples(0).unwrap();
let ch1: Vec<i16> = audio.read_channel_samples(1).unwrap();
assert_eq!(ch0, vec![100i16, 200]);
assert_eq!(ch1, vec![300i16, 400]);
}
#[test]
fn test_24bit_conversion() {
let channels = vec![vec![0x7FFFFFi32, -0x800000i32]];
let audio = DecodedAudio::new(channels, 24, 48000);
let samples: Vec<I24> = audio.read_samples_planar().unwrap();
assert_eq!(samples.len(), 2);
}
#[test]
fn test_read_samples_to_audio_samples() {
let channels = vec![vec![100i32, 200], vec![300i32, 400]];
let audio = DecodedAudio::new(channels, 16, 44100);
let sample_rate = NonZeroU32::new(44100).unwrap();
let samples: AudioSamples<'static, i16> = audio.read_samples(sample_rate).unwrap();
assert_eq!(samples.num_channels(), 2);
assert_eq!(samples.samples_per_channel(), 2);
assert_eq!(samples.sample_rate(), sample_rate);
}
}