1use std::path::Path;
2
3use num::Float;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum WriteError {
8 #[error("could not encode audio")]
9 Encode(#[from] hound::Error),
10}
11
12#[derive(Debug, Clone, Copy, Default)]
14pub enum SampleFormat {
15 Int8,
17 #[default]
19 Int16,
20 Int32,
22 Float32,
24}
25
26#[derive(Default)]
28pub struct WriteConfig {
29 pub sample_format: SampleFormat,
31}
32
33pub fn write<F: Float>(
35 path: impl AsRef<Path>,
36 samples: &[F],
37 num_channels: u16,
38 sample_rate: u32,
39 config: WriteConfig,
40) -> Result<(), WriteError> {
41 let spec = hound::WavSpec {
42 channels: num_channels,
43 sample_rate,
44 bits_per_sample: match config.sample_format {
45 SampleFormat::Int8 => 8,
46 SampleFormat::Int16 => 16,
47 SampleFormat::Int32 => 32,
48 SampleFormat::Float32 => 32,
49 },
50 sample_format: match config.sample_format {
51 SampleFormat::Int8 | SampleFormat::Int16 | SampleFormat::Int32 => {
52 hound::SampleFormat::Int
53 }
54 SampleFormat::Float32 => hound::SampleFormat::Float,
55 },
56 };
57
58 let mut writer = hound::WavWriter::create(path.as_ref(), spec)?;
59
60 match config.sample_format {
61 SampleFormat::Int8 => {
62 for &sample in samples {
63 let sample_i8 = (sample.clamp(F::one().neg(), F::one())
64 * F::from(i8::MAX).unwrap_or(F::zero()))
65 .to_i8()
66 .unwrap_or(0);
67 writer.write_sample(sample_i8)?;
68 }
69 }
70 SampleFormat::Int16 => {
71 for &sample in samples {
72 let sample_i16 = (sample.clamp(F::one().neg(), F::one())
73 * F::from(i16::MAX).unwrap_or(F::zero()))
74 .to_i16()
75 .unwrap_or(0);
76 writer.write_sample(sample_i16)?;
77 }
78 }
79 SampleFormat::Int32 => {
80 for &sample in samples {
81 let sample_i32 = (sample.clamp(F::one().neg(), F::one())
82 * F::from(i32::MAX).unwrap_or(F::zero()))
83 .to_i32()
84 .unwrap_or(0);
85 writer.write_sample(sample_i32)?;
86 }
87 }
88 SampleFormat::Float32 => {
89 for &sample in samples {
90 writer.write_sample(sample.to_f32().unwrap_or(0.0))?;
91 }
92 }
93 }
94
95 writer.finalize()?;
96
97 Ok(())
98}
99
100#[cfg(feature = "audio-blocks")]
102pub fn write_block<P: AsRef<Path>, F: Float + 'static>(
103 path: P,
104 audio_block: impl audio_blocks::AudioBlock<F>,
105 sample_rate: u32,
106 config: WriteConfig,
107) -> Result<(), WriteError> {
108 let block = audio_blocks::Interleaved::from_block(&audio_block);
109 write(
110 path,
111 block.raw_data(),
112 audio_block.num_channels(),
113 sample_rate,
114 config,
115 )
116}
117
118#[cfg(test)]
119mod tests {
120
121 #[test]
122 fn test_round_trip_i8() {
123 use super::*;
124 use crate::reader::{ReadConfig, read};
125
126 let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
127
128 write(
129 "tmp0.wav",
130 &audio1.samples_interleaved,
131 audio1.num_channels,
132 audio1.sample_rate,
133 WriteConfig {
134 sample_format: SampleFormat::Int8,
135 },
136 )
137 .unwrap();
138
139 let audio2 = read::<f32>("tmp0.wav", ReadConfig::default()).unwrap();
140 assert_eq!(audio1.sample_rate, audio2.sample_rate);
141 approx::assert_abs_diff_eq!(
145 audio1.samples_interleaved.as_slice(),
146 audio2.samples_interleaved.as_slice(),
147 epsilon = 2e-2
148 );
149
150 std::fs::remove_file("tmp0.wav").expect("Failed to remove temporary test file");
152 }
153
154 #[test]
155 fn test_round_trip_i16() {
156 use super::*;
157 use crate::reader::{ReadConfig, read};
158
159 let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
160
161 write(
162 "tmp1.wav",
163 &audio1.samples_interleaved,
164 audio1.num_channels,
165 audio1.sample_rate,
166 WriteConfig {
167 sample_format: SampleFormat::Int16,
168 },
169 )
170 .unwrap();
171
172 let audio2 = read::<f32>("tmp1.wav", ReadConfig::default()).unwrap();
173 assert_eq!(audio1.sample_rate, audio2.sample_rate);
174 approx::assert_abs_diff_eq!(
175 audio1.samples_interleaved.as_slice(),
176 audio2.samples_interleaved.as_slice(),
177 epsilon = 1e-4
178 );
179
180 std::fs::remove_file("tmp1.wav").expect("Failed to remove temporary test file");
182 }
183
184 #[test]
185 fn test_round_trip_i32() {
186 use super::*;
187 use crate::reader::{ReadConfig, read};
188
189 let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
190
191 write(
192 "tmp3.wav",
193 &audio1.samples_interleaved,
194 audio1.num_channels,
195 audio1.sample_rate,
196 WriteConfig {
197 sample_format: SampleFormat::Int32,
198 },
199 )
200 .unwrap();
201
202 let audio2 = read::<f32>("tmp3.wav", ReadConfig::default()).unwrap();
203 assert_eq!(audio1.sample_rate, audio2.sample_rate);
204 approx::assert_abs_diff_eq!(
205 audio1.samples_interleaved.as_slice(),
206 audio2.samples_interleaved.as_slice(),
207 epsilon = 1e-4
208 );
209
210 std::fs::remove_file("tmp3.wav").expect("Failed to remove temporary test file");
212 }
213
214 #[test]
215 fn test_round_trip_f32() {
216 use super::*;
217 use crate::reader::{ReadConfig, read};
218
219 let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
220
221 write(
222 "tmp2.wav",
223 &audio1.samples_interleaved,
224 audio1.num_channels,
225 audio1.sample_rate,
226 WriteConfig {
227 sample_format: SampleFormat::Float32,
228 },
229 )
230 .unwrap();
231
232 let audio2 = read::<f32>("tmp2.wav", ReadConfig::default()).unwrap();
233 assert_eq!(audio1.sample_rate, audio2.sample_rate);
234 approx::assert_abs_diff_eq!(
235 audio1.samples_interleaved.as_slice(),
236 audio2.samples_interleaved.as_slice(),
237 epsilon = 1e-6
238 );
239
240 std::fs::remove_file("tmp2.wav").expect("Failed to remove temporary test file");
242 }
243}