use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AudioFormat {
Pcm16 { sample_rate: u32, channels: u32 },
Pcm32 { sample_rate: u32, channels: u32 },
Float32 { sample_rate: u32, channels: u32 },
Wav,
}
impl AudioFormat {
pub fn sample_rate(&self) -> Option<u32> {
match self {
AudioFormat::Pcm16 { sample_rate, .. } => Some(*sample_rate),
AudioFormat::Pcm32 { sample_rate, .. } => Some(*sample_rate),
AudioFormat::Float32 { sample_rate, .. } => Some(*sample_rate),
AudioFormat::Wav => None,
}
}
pub fn channels(&self) -> Option<u32> {
match self {
AudioFormat::Pcm16 { channels, .. } => Some(*channels),
AudioFormat::Pcm32 { channels, .. } => Some(*channels),
AudioFormat::Float32 { channels, .. } => Some(*channels),
AudioFormat::Wav => None,
}
}
pub fn bytes_per_sample(&self) -> Option<u32> {
match self {
AudioFormat::Pcm16 { .. } => Some(2),
AudioFormat::Pcm32 { .. } => Some(4),
AudioFormat::Float32 { .. } => Some(4),
AudioFormat::Wav => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
AudioFormat::Pcm16 { .. } => "pcm16",
AudioFormat::Pcm32 { .. } => "pcm32",
AudioFormat::Float32 { .. } => "float32",
AudioFormat::Wav => "wav",
}
}
pub fn pcm16(sample_rate: u32, channels: u32) -> Self {
AudioFormat::Pcm16 {
sample_rate,
channels,
}
}
pub fn float32(sample_rate: u32, channels: u32) -> Self {
AudioFormat::Float32 {
sample_rate,
channels,
}
}
pub fn asr_default() -> Self {
AudioFormat::Pcm16 {
sample_rate: 16000,
channels: 1,
}
}
}
impl Default for AudioFormat {
fn default() -> Self {
Self::asr_default()
}
}
pub fn detect_format(data: &[u8]) -> Result<AudioFormat, AudioFormatError> {
if data.len() >= 12 && &data[0..4] == b"RIFF" && &data[8..12] == b"WAVE" {
return Ok(AudioFormat::Wav);
}
Err(AudioFormatError::UnknownFormat(
"Could not detect audio format from header".to_string(),
))
}
#[derive(Error, Debug)]
pub enum AudioFormatError {
#[error("Unknown audio format: {0}")]
UnknownFormat(String),
#[error("Unsupported format: {0}")]
UnsupportedFormat(String),
#[error("Invalid format parameters: {0}")]
InvalidParameters(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audio_format_pcm16() {
let format = AudioFormat::pcm16(16000, 1);
assert_eq!(format.sample_rate(), Some(16000));
assert_eq!(format.channels(), Some(1));
assert_eq!(format.bytes_per_sample(), Some(2));
assert_eq!(format.as_str(), "pcm16");
}
#[test]
fn test_audio_format_float32() {
let format = AudioFormat::float32(44100, 2);
assert_eq!(format.sample_rate(), Some(44100));
assert_eq!(format.channels(), Some(2));
assert_eq!(format.bytes_per_sample(), Some(4));
assert_eq!(format.as_str(), "float32");
}
#[test]
fn test_audio_format_pcm32() {
let format = AudioFormat::Pcm32 {
sample_rate: 48000,
channels: 2,
};
assert_eq!(format.sample_rate(), Some(48000));
assert_eq!(format.channels(), Some(2));
assert_eq!(format.bytes_per_sample(), Some(4));
assert_eq!(format.as_str(), "pcm32");
}
#[test]
fn test_audio_format_wav() {
let format = AudioFormat::Wav;
assert_eq!(format.sample_rate(), None);
assert_eq!(format.channels(), None);
assert_eq!(format.bytes_per_sample(), None);
assert_eq!(format.as_str(), "wav");
}
#[test]
fn test_audio_format_default() {
let format = AudioFormat::default();
assert_eq!(format.sample_rate(), Some(16000));
assert_eq!(format.channels(), Some(1));
}
#[test]
fn test_detect_format_wav() {
let wav_header = b"RIFF\x00\x00\x00\x00WAVEfmt ";
let result = detect_format(wav_header);
assert!(result.is_ok());
assert_eq!(result.unwrap(), AudioFormat::Wav);
}
#[test]
fn test_detect_format_unknown() {
let unknown_data = b"unknown_format_data";
let result = detect_format(unknown_data);
assert!(result.is_err());
}
}