Skip to main content

audio_file/
writer.rs

1use std::path::Path;
2
3use num::Float;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum WriteError {
8    #[error("could not encode audio")]
9    Encode(#[from] hound::Error),
10}
11
12/// Sample format for writing audio
13#[derive(Debug, Clone, Copy, Default)]
14pub enum SampleFormat {
15    /// 8-bit integer samples
16    Int8,
17    /// 16-bit integer samples
18    #[default]
19    Int16,
20    /// 32-bit integer samples
21    Int32,
22    /// 32-bit float samples
23    Float32,
24}
25
26/// Configuration for writing audio to WAV files
27#[derive(Default)]
28pub struct WriteConfig {
29    /// Sample format to use when writing
30    pub sample_format: SampleFormat,
31}
32
33/// Write interleaved audio samples to a WAV file
34pub fn write<F: Float>(
35    path: impl AsRef<Path>,
36    samples: &[F],
37    num_channels: u16,
38    sample_rate: u32,
39    config: WriteConfig,
40) -> Result<(), WriteError> {
41    let spec = hound::WavSpec {
42        channels: num_channels,
43        sample_rate,
44        bits_per_sample: match config.sample_format {
45            SampleFormat::Int8 => 8,
46            SampleFormat::Int16 => 16,
47            SampleFormat::Int32 => 32,
48            SampleFormat::Float32 => 32,
49        },
50        sample_format: match config.sample_format {
51            SampleFormat::Int8 | SampleFormat::Int16 | SampleFormat::Int32 => {
52                hound::SampleFormat::Int
53            }
54            SampleFormat::Float32 => hound::SampleFormat::Float,
55        },
56    };
57
58    let mut writer = hound::WavWriter::create(path.as_ref(), spec)?;
59
60    match config.sample_format {
61        SampleFormat::Int8 => {
62            for &sample in samples {
63                let sample_i8 = (sample.clamp(F::one().neg(), F::one())
64                    * F::from(i8::MAX).unwrap_or(F::zero()))
65                .to_i8()
66                .unwrap_or(0);
67                writer.write_sample(sample_i8)?;
68            }
69        }
70        SampleFormat::Int16 => {
71            for &sample in samples {
72                let sample_i16 = (sample.clamp(F::one().neg(), F::one())
73                    * F::from(i16::MAX).unwrap_or(F::zero()))
74                .to_i16()
75                .unwrap_or(0);
76                writer.write_sample(sample_i16)?;
77            }
78        }
79        SampleFormat::Int32 => {
80            for &sample in samples {
81                let sample_i32 = (sample.clamp(F::one().neg(), F::one())
82                    * F::from(i32::MAX).unwrap_or(F::zero()))
83                .to_i32()
84                .unwrap_or(0);
85                writer.write_sample(sample_i32)?;
86            }
87        }
88        SampleFormat::Float32 => {
89            for &sample in samples {
90                writer.write_sample(sample.to_f32().unwrap_or(0.0))?;
91            }
92        }
93    }
94
95    writer.finalize()?;
96
97    Ok(())
98}
99
100/// Write audio from an AudioBlock to a WAV file
101#[cfg(feature = "audio-blocks")]
102pub fn write_block<P: AsRef<Path>, F: Float + 'static>(
103    path: P,
104    audio_block: impl audio_blocks::AudioBlock<F>,
105    sample_rate: u32,
106    config: WriteConfig,
107) -> Result<(), WriteError> {
108    let block = audio_blocks::Interleaved::from_block(&audio_block);
109    write(
110        path,
111        block.raw_data(),
112        audio_block.num_channels(),
113        sample_rate,
114        config,
115    )
116}
117
118#[cfg(test)]
119mod tests {
120
121    #[test]
122    fn test_round_trip_i8() {
123        use super::*;
124        use crate::reader::{ReadConfig, read};
125
126        let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
127
128        write(
129            "tmp0.wav",
130            &audio1.samples_interleaved,
131            audio1.num_channels,
132            audio1.sample_rate,
133            WriteConfig {
134                sample_format: SampleFormat::Int8,
135            },
136        )
137        .unwrap();
138
139        let audio2 = read::<f32>("tmp0.wav", ReadConfig::default()).unwrap();
140        assert_eq!(audio1.sample_rate, audio2.sample_rate);
141        // 8-bit PCM has low precision. Symphonia normalizes by dividing by 128
142        // (not 127), so max representable value is 127/128 ≈ 0.992, giving a
143        // worst-case error of ~0.016 near full scale.
144        approx::assert_abs_diff_eq!(
145            audio1.samples_interleaved.as_slice(),
146            audio2.samples_interleaved.as_slice(),
147            epsilon = 2e-2
148        );
149
150        // Clean up temporary file
151        std::fs::remove_file("tmp0.wav").expect("Failed to remove temporary test file");
152    }
153
154    #[test]
155    fn test_round_trip_i16() {
156        use super::*;
157        use crate::reader::{ReadConfig, read};
158
159        let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
160
161        write(
162            "tmp1.wav",
163            &audio1.samples_interleaved,
164            audio1.num_channels,
165            audio1.sample_rate,
166            WriteConfig {
167                sample_format: SampleFormat::Int16,
168            },
169        )
170        .unwrap();
171
172        let audio2 = read::<f32>("tmp1.wav", ReadConfig::default()).unwrap();
173        assert_eq!(audio1.sample_rate, audio2.sample_rate);
174        approx::assert_abs_diff_eq!(
175            audio1.samples_interleaved.as_slice(),
176            audio2.samples_interleaved.as_slice(),
177            epsilon = 1e-4
178        );
179
180        // Clean up temporary file
181        std::fs::remove_file("tmp1.wav").expect("Failed to remove temporary test file");
182    }
183
184    #[test]
185    fn test_round_trip_i32() {
186        use super::*;
187        use crate::reader::{ReadConfig, read};
188
189        let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
190
191        write(
192            "tmp3.wav",
193            &audio1.samples_interleaved,
194            audio1.num_channels,
195            audio1.sample_rate,
196            WriteConfig {
197                sample_format: SampleFormat::Int32,
198            },
199        )
200        .unwrap();
201
202        let audio2 = read::<f32>("tmp3.wav", ReadConfig::default()).unwrap();
203        assert_eq!(audio1.sample_rate, audio2.sample_rate);
204        approx::assert_abs_diff_eq!(
205            audio1.samples_interleaved.as_slice(),
206            audio2.samples_interleaved.as_slice(),
207            epsilon = 1e-4
208        );
209
210        // Clean up temporary file
211        std::fs::remove_file("tmp3.wav").expect("Failed to remove temporary test file");
212    }
213
214    #[test]
215    fn test_round_trip_f32() {
216        use super::*;
217        use crate::reader::{ReadConfig, read};
218
219        let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
220
221        write(
222            "tmp2.wav",
223            &audio1.samples_interleaved,
224            audio1.num_channels,
225            audio1.sample_rate,
226            WriteConfig {
227                sample_format: SampleFormat::Float32,
228            },
229        )
230        .unwrap();
231
232        let audio2 = read::<f32>("tmp2.wav", ReadConfig::default()).unwrap();
233        assert_eq!(audio1.sample_rate, audio2.sample_rate);
234        approx::assert_abs_diff_eq!(
235            audio1.samples_interleaved.as_slice(),
236            audio2.samples_interleaved.as_slice(),
237            epsilon = 1e-6
238        );
239
240        // Clean up temporary file
241        std::fs::remove_file("tmp2.wav").expect("Failed to remove temporary test file");
242    }
243}