1use std::path::Path;
2
3use audio_blocks::AudioBlock;
4use hound::{SampleFormat, WavSpec, WavWriter};
5use num::Float;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum AudioWriteError {
10 #[error("could not decode audio")]
11 DecodingError(#[from] hound::Error),
12}
13
14#[derive(Debug, Clone, Copy, Default)]
16pub enum WriteSampleFormat {
17 #[default]
19 Int16,
20 Float32,
22}
23
24#[derive(Default)]
26pub struct AudioWriteConfig {
27 pub sample_format: WriteSampleFormat,
29}
30
31pub fn audio_write<P: AsRef<Path>, F: Float + 'static>(
32 path: P,
33 audio_block: impl AudioBlock<F>,
34 sample_rate: u32,
35 config: AudioWriteConfig,
36) -> Result<(), AudioWriteError> {
37 let spec = WavSpec {
38 channels: audio_block.num_channels(),
39 sample_rate,
40 bits_per_sample: match config.sample_format {
41 WriteSampleFormat::Int16 => 16,
42 WriteSampleFormat::Float32 => 32,
43 },
44 sample_format: match config.sample_format {
45 WriteSampleFormat::Int16 => SampleFormat::Int,
46 WriteSampleFormat::Float32 => SampleFormat::Float,
47 },
48 };
49
50 let mut writer = WavWriter::create(path.as_ref(), spec)?;
51
52 match config.sample_format {
53 WriteSampleFormat::Int16 => {
54 for frame in audio_block.frame_iters() {
56 for sample in frame {
57 let sample_i16 = (sample.clamp(F::one().neg(), F::one())
58 * F::from(i16::MAX).unwrap_or(F::zero()))
59 .to_i16()
60 .unwrap_or(0);
61 writer.write_sample(sample_i16)?;
62 }
63 }
64 }
65 WriteSampleFormat::Float32 => {
66 for frame in audio_block.frame_iters() {
68 for sample in frame {
69 writer.write_sample(sample.to_f32().unwrap_or(0.0))?;
70 }
71 }
72 }
73 }
74
75 writer.finalize()?;
76
77 Ok(())
78}
79
80#[cfg(test)]
81mod tests {
82
83 #[test]
84 #[cfg(all(feature = "read", feature = "write"))]
85 fn test_round_trip_i16() {
86 use super::*;
87 use crate::reader::{AudioReadConfig, audio_read};
88
89 let data1 = audio_read::<_, f32>("test.wav", AudioReadConfig::default()).unwrap();
90
91 audio_write(
92 "tmp1.wav",
93 data1.audio_block(),
94 data1.sample_rate,
95 AudioWriteConfig {
96 sample_format: WriteSampleFormat::Int16,
97 },
98 )
99 .unwrap();
100
101 let data2 = audio_read::<_, f32>("tmp1.wav", AudioReadConfig::default()).unwrap();
102 assert_eq!(data1.sample_rate, data2.sample_rate);
103 approx::assert_abs_diff_eq!(
104 data1.audio_block().raw_data(),
105 data2.audio_block().raw_data(),
106 epsilon = 1e-4
107 );
108
109 let _ = std::fs::remove_file("tmp1.wav");
110 }
111
112 #[test]
113 #[cfg(all(feature = "read", feature = "write"))]
114 fn test_round_trip_f32() {
115 use super::*;
116 use crate::reader::{AudioReadConfig, audio_read};
117
118 let data1 = audio_read::<_, f32>("test.wav", AudioReadConfig::default()).unwrap();
119
120 audio_write(
121 "tmp2.wav",
122 data1.audio_block(),
123 data1.sample_rate,
124 AudioWriteConfig {
125 sample_format: WriteSampleFormat::Float32,
126 },
127 )
128 .unwrap();
129
130 let data2 = audio_read::<_, f32>("tmp2.wav", AudioReadConfig::default()).unwrap();
131 assert_eq!(data1.sample_rate, data2.sample_rate);
132 approx::assert_abs_diff_eq!(
133 data1.audio_block().raw_data(),
134 data2.audio_block().raw_data(),
135 epsilon = 1e-6
136 );
137
138 let _ = std::fs::remove_file("tmp2.wav");
139 }
140}