use std::{error::Error, str::FromStr, time::Duration};
use duration_string::DurationString;
use serde::{Deserialize, Serialize};
use crate::audio::SampleFormat;
const DEFAULT_AUDIO_PLAYBACK_DELAY: Duration = Duration::ZERO;
const DEFAULT_BUFFER_SIZE: usize = 1024;
const DEFAULT_BUFFER_THREADS: usize = 2;
#[derive(Deserialize, Serialize, Clone, Copy, Debug, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ResamplerType {
#[default]
Sinc,
Fft,
}
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(untagged)]
pub enum StreamBufferSize {
#[serde(rename = "default")]
Default,
#[serde(rename = "min")]
Min,
Fixed(usize),
}
#[derive(Deserialize, Serialize, Clone)]
pub struct Audio {
device: String,
playback_delay: Option<String>,
sample_rate: Option<u32>,
sample_format: Option<String>,
bits_per_sample: Option<u16>,
buffer_size: Option<usize>,
stream_buffer_size: Option<StreamBufferSize>,
buffer_threads: Option<usize>,
resampler: Option<ResamplerType>,
}
impl Audio {
pub fn new(device: &str) -> Audio {
Audio {
device: device.to_string(),
playback_delay: None,
sample_rate: None,
sample_format: None,
bits_per_sample: None,
buffer_size: None,
stream_buffer_size: None,
buffer_threads: None,
resampler: None,
}
}
pub fn device(&self) -> &str {
&self.device
}
pub fn playback_delay(&self) -> Result<Duration, Box<dyn Error>> {
super::parse_playback_delay(&self.playback_delay, DEFAULT_AUDIO_PLAYBACK_DELAY)
}
pub fn sample_rate(&self) -> u32 {
self.sample_rate.unwrap_or(44100)
}
pub fn sample_format(&self) -> Result<SampleFormat, Box<dyn Error>> {
match self.sample_format.as_deref() {
Some(format) => SampleFormat::from_str(format),
None => Ok(SampleFormat::Int),
}
}
pub fn bits_per_sample(&self) -> u16 {
self.bits_per_sample.unwrap_or(32)
}
pub fn buffer_size(&self) -> usize {
self.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE)
}
pub fn buffer_threads(&self) -> usize {
self.buffer_threads.unwrap_or(DEFAULT_BUFFER_THREADS).max(1)
}
pub fn stream_buffer_size(&self) -> Option<StreamBufferSize> {
self.stream_buffer_size.clone()
}
pub fn resampler(&self) -> ResamplerType {
self.resampler.unwrap_or_default()
}
#[allow(dead_code)]
pub fn with_sample_rate(mut self, sample_rate: u32) -> Self {
self.sample_rate = Some(sample_rate);
self
}
#[allow(dead_code)]
pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
self.buffer_size = Some(buffer_size);
self
}
#[allow(dead_code)]
pub fn with_sample_format(mut self, format: &str) -> Self {
self.sample_format = Some(format.to_string());
self
}
#[allow(dead_code)]
pub fn with_bits_per_sample(mut self, bits: u16) -> Self {
self.bits_per_sample = Some(bits);
self
}
#[allow(dead_code)]
pub fn with_stream_buffer_size(mut self, sbs: StreamBufferSize) -> Self {
self.stream_buffer_size = Some(sbs);
self
}
#[allow(dead_code)]
pub fn with_resampler(mut self, resampler: ResamplerType) -> Self {
self.resampler = Some(resampler);
self
}
pub fn validate(&self) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
if self.device.trim().is_empty() {
errors.push("audio device must not be empty".to_string());
}
if let Some(rate) = self.sample_rate {
if rate == 0 {
errors.push("audio sample_rate must be greater than 0".to_string());
}
}
if let Some(bits) = self.bits_per_sample {
if bits == 0 {
errors.push("audio bits_per_sample must be greater than 0".to_string());
}
}
if let Some(ref fmt) = self.sample_format {
if SampleFormat::from_str(fmt).is_err() {
errors.push(format!(
"audio sample_format '{}' is invalid (expected 'int' or 'float')",
fmt
));
}
}
if let Some(ref delay) = self.playback_delay {
if DurationString::from_string(delay.clone()).is_err() {
errors.push(format!(
"audio playback_delay '{}' is not a valid duration",
delay
));
}
}
if let Some(size) = self.buffer_size {
if size == 0 {
errors.push("audio buffer_size must be greater than 0".to_string());
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[cfg(test)]
mod test {
use std::time::Duration;
use super::*;
#[test]
fn defaults() {
let audio = Audio::new("test-device");
assert_eq!(audio.device(), "test-device");
assert_eq!(audio.sample_rate(), 44100);
assert_eq!(audio.bits_per_sample(), 32);
assert_eq!(audio.sample_format().unwrap(), SampleFormat::Int);
assert_eq!(audio.buffer_size(), DEFAULT_BUFFER_SIZE);
assert_eq!(audio.buffer_threads(), DEFAULT_BUFFER_THREADS);
assert_eq!(audio.playback_delay().unwrap(), Duration::ZERO);
assert!(audio.stream_buffer_size().is_none());
assert_eq!(audio.resampler(), ResamplerType::Sinc);
}
#[test]
fn builder_sample_rate() {
let audio = Audio::new("dev").with_sample_rate(48000);
assert_eq!(audio.sample_rate(), 48000);
}
#[test]
fn builder_buffer_size() {
let audio = Audio::new("dev").with_buffer_size(2048);
assert_eq!(audio.buffer_size(), 2048);
}
#[test]
fn builder_bits_per_sample() {
let audio = Audio::new("dev").with_bits_per_sample(16);
assert_eq!(audio.bits_per_sample(), 16);
}
#[test]
fn builder_sample_format_float() {
let audio = Audio::new("dev").with_sample_format("float");
assert_eq!(audio.sample_format().unwrap(), SampleFormat::Float);
}
#[test]
fn builder_sample_format_int() {
let audio = Audio::new("dev").with_sample_format("int");
assert_eq!(audio.sample_format().unwrap(), SampleFormat::Int);
}
#[test]
fn sample_format_invalid() {
let audio = Audio::new("dev").with_sample_format("wav");
assert!(audio.sample_format().is_err());
}
#[test]
fn playback_delay_valid() {
let audio = Audio {
playback_delay: Some("500ms".to_string()),
..Audio::new("dev")
};
assert_eq!(audio.playback_delay().unwrap(), Duration::from_millis(500));
}
#[test]
fn playback_delay_invalid() {
let audio = Audio {
playback_delay: Some("not-a-duration".to_string()),
..Audio::new("dev")
};
assert!(audio.playback_delay().is_err());
}
#[test]
fn buffer_threads_clamped_to_one() {
let audio = Audio {
buffer_threads: Some(0),
..Audio::new("dev")
};
assert_eq!(audio.buffer_threads(), 1);
}
#[test]
fn buffer_threads_custom() {
let audio = Audio {
buffer_threads: Some(4),
..Audio::new("dev")
};
assert_eq!(audio.buffer_threads(), 4);
}
#[test]
fn builder_resampler_fft() {
let audio = Audio::new("dev").with_resampler(ResamplerType::Fft);
assert_eq!(audio.resampler(), ResamplerType::Fft);
}
#[test]
fn builder_stream_buffer_size() {
let audio = Audio::new("dev").with_stream_buffer_size(StreamBufferSize::Min);
assert!(matches!(
audio.stream_buffer_size(),
Some(StreamBufferSize::Min)
));
}
#[test]
fn builder_chaining() {
let audio = Audio::new("dev")
.with_sample_rate(96000)
.with_buffer_size(512)
.with_bits_per_sample(24)
.with_sample_format("float")
.with_resampler(ResamplerType::Fft);
assert_eq!(audio.sample_rate(), 96000);
assert_eq!(audio.buffer_size(), 512);
assert_eq!(audio.bits_per_sample(), 24);
assert_eq!(audio.sample_format().unwrap(), SampleFormat::Float);
assert_eq!(audio.resampler(), ResamplerType::Fft);
}
fn from_yaml(yaml: &str) -> Audio {
config::Config::builder()
.add_source(config::File::from_str(yaml, config::FileFormat::Yaml))
.build()
.expect("build config")
.try_deserialize::<Audio>()
.expect("deserialize")
}
#[test]
fn serde_defaults_from_minimal_yaml() {
let audio = from_yaml("device: minimal-device\n");
assert_eq!(audio.device(), "minimal-device");
assert_eq!(audio.sample_rate(), 44100);
assert_eq!(audio.bits_per_sample(), 32);
assert_eq!(audio.buffer_size(), DEFAULT_BUFFER_SIZE);
assert_eq!(audio.resampler(), ResamplerType::Sinc);
}
#[test]
fn serde_full_yaml() {
let audio = from_yaml(
r#"
device: my-device
sample_rate: 48000
buffer_size: 512
bits_per_sample: 24
sample_format: float
resampler: fft
buffer_threads: 4
playback_delay: 100ms
"#,
);
assert_eq!(audio.device(), "my-device");
assert_eq!(audio.sample_rate(), 48000);
assert_eq!(audio.buffer_size(), 512);
assert_eq!(audio.bits_per_sample(), 24);
assert_eq!(audio.sample_format().unwrap(), SampleFormat::Float);
assert_eq!(audio.resampler(), ResamplerType::Fft);
assert_eq!(audio.buffer_threads(), 4);
assert_eq!(audio.playback_delay().unwrap(), Duration::from_millis(100));
}
#[test]
fn serde_resampler_variants() {
let audio = from_yaml("device: dev\nresampler: sinc\n");
assert_eq!(audio.resampler(), ResamplerType::Sinc);
let audio = from_yaml("device: dev\nresampler: fft\n");
assert_eq!(audio.resampler(), ResamplerType::Fft);
}
}