1use std::path::Path;
2
3use audio_blocks::AudioBlock;
4use hound::{SampleFormat, WavSpec, WavWriter};
5use num::Float;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum AudioWriteError {
10 #[error("could not decode audio")]
11 DecodingError(#[from] hound::Error),
12}
13
14#[derive(Debug, Clone, Copy, Default)]
16pub enum WriteSampleFormat {
17 #[default]
19 Int16,
20 Float32,
22}
23
24#[derive(Default)]
26pub struct AudioWriteConfig {
27 pub sample_format: WriteSampleFormat,
29}
30
31pub fn audio_write<P: AsRef<Path>, F: Float + 'static>(
32 path: P,
33 audio_block: impl AudioBlock<F>,
34 sample_rate: u32,
35 config: AudioWriteConfig,
36) -> Result<(), AudioWriteError> {
37 let spec = WavSpec {
38 channels: audio_block.num_channels(),
39 sample_rate,
40 bits_per_sample: match config.sample_format {
41 WriteSampleFormat::Int16 => 16,
42 WriteSampleFormat::Float32 => 32,
43 },
44 sample_format: match config.sample_format {
45 WriteSampleFormat::Int16 => SampleFormat::Int,
46 WriteSampleFormat::Float32 => SampleFormat::Float,
47 },
48 };
49
50 let mut writer = WavWriter::create(path.as_ref(), spec)?;
51
52 match config.sample_format {
53 WriteSampleFormat::Int16 => {
54 for frame in audio_block.frame_iters() {
56 for sample in frame {
57 let sample_i16 = (sample.clamp(F::one().neg(), F::one())
58 * F::from(i16::MAX).unwrap_or(F::zero()))
59 .to_i16()
60 .unwrap_or(0);
61 writer.write_sample(sample_i16)?;
62 }
63 }
64 }
65 WriteSampleFormat::Float32 => {
66 for frame in audio_block.frame_iters() {
68 for sample in frame {
69 writer.write_sample(sample.to_f32().unwrap_or(0.0))?;
70 }
71 }
72 }
73 }
74
75 writer.finalize()?;
76
77 Ok(())
78}
79
80#[cfg(test)]
81mod tests {
82
83 #[test]
84 #[cfg(all(feature = "read", feature = "write"))]
85 fn test_round_trip_i16() {
86 use super::*;
87 use crate::reader::{AudioReadConfig, audio_read};
88
89 let data1 =
90 audio_read::<_, f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
91
92 audio_write(
93 "tmp1.wav",
94 data1.audio_block(),
95 data1.sample_rate,
96 AudioWriteConfig {
97 sample_format: WriteSampleFormat::Int16,
98 },
99 )
100 .unwrap();
101
102 let data2 = audio_read::<_, f32>("tmp1.wav", AudioReadConfig::default()).unwrap();
103 assert_eq!(data1.sample_rate, data2.sample_rate);
104 approx::assert_abs_diff_eq!(
105 data1.audio_block().raw_data(),
106 data2.audio_block().raw_data(),
107 epsilon = 1e-4
108 );
109
110 let _ = std::fs::remove_file("tmp1.wav");
111 }
112
113 #[test]
114 #[cfg(all(feature = "read", feature = "write"))]
115 fn test_round_trip_f32() {
116 use super::*;
117 use crate::reader::{AudioReadConfig, audio_read};
118
119 let data1 =
120 audio_read::<_, f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
121
122 audio_write(
123 "tmp2.wav",
124 data1.audio_block(),
125 data1.sample_rate,
126 AudioWriteConfig {
127 sample_format: WriteSampleFormat::Float32,
128 },
129 )
130 .unwrap();
131
132 let data2 = audio_read::<_, f32>("tmp2.wav", AudioReadConfig::default()).unwrap();
133 assert_eq!(data1.sample_rate, data2.sample_rate);
134 approx::assert_abs_diff_eq!(
135 data1.audio_block().raw_data(),
136 data2.audio_block().raw_data(),
137 epsilon = 1e-6
138 );
139
140 let _ = std::fs::remove_file("tmp2.wav");
141 }
142}