Skip to main content

audio_io/
writer.rs

1use std::path::Path;
2
3use hound::{SampleFormat, WavSpec, WavWriter};
4use num::Float;
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8pub enum AudioWriteError {
9    #[error("could not decode audio")]
10    DecodingError(#[from] hound::Error),
11}
12
13/// Sample format for writing audio
14#[derive(Debug, Clone, Copy, Default)]
15pub enum WriteSampleFormat {
16    /// 16-bit integer samples
17    #[default]
18    Int16,
19    /// 32-bit float samples
20    Float32,
21}
22
23/// Configuration for writing audio to WAV files
24#[derive(Default)]
25pub struct AudioWriteConfig {
26    /// Sample format to use when writing
27    pub sample_format: WriteSampleFormat,
28}
29
30/// Write interleaved audio samples to a WAV file
31pub fn audio_write<F: Float>(
32    path: impl AsRef<Path>,
33    samples: &[F],
34    num_channels: u16,
35    sample_rate: u32,
36    config: AudioWriteConfig,
37) -> Result<(), AudioWriteError> {
38    let spec = WavSpec {
39        channels: num_channels,
40        sample_rate,
41        bits_per_sample: match config.sample_format {
42            WriteSampleFormat::Int16 => 16,
43            WriteSampleFormat::Float32 => 32,
44        },
45        sample_format: match config.sample_format {
46            WriteSampleFormat::Int16 => SampleFormat::Int,
47            WriteSampleFormat::Float32 => SampleFormat::Float,
48        },
49    };
50
51    let mut writer = WavWriter::create(path.as_ref(), spec)?;
52
53    match config.sample_format {
54        WriteSampleFormat::Int16 => {
55            // Convert samples to i16
56            for &sample in samples {
57                let sample_i16 = (sample.clamp(F::one().neg(), F::one())
58                    * F::from(i16::MAX).unwrap_or(F::zero()))
59                .to_i16()
60                .unwrap_or(0);
61                writer.write_sample(sample_i16)?;
62            }
63        }
64        WriteSampleFormat::Float32 => {
65            // Write f32 samples directly
66            for &sample in samples {
67                writer.write_sample(sample.to_f32().unwrap_or(0.0))?;
68            }
69        }
70    }
71
72    writer.finalize()?;
73
74    Ok(())
75}
76
77/// Write audio from an AudioBlock to a WAV file
78#[cfg(feature = "audio-blocks")]
79pub fn audio_write_block<P: AsRef<Path>, F: Float + 'static>(
80    path: P,
81    audio_block: impl audio_blocks::AudioBlock<F>,
82    sample_rate: u32,
83    config: AudioWriteConfig,
84) -> Result<(), AudioWriteError> {
85    let block = audio_blocks::AudioBlockInterleaved::from_block(&audio_block);
86    audio_write(
87        path,
88        block.raw_data(),
89        audio_block.num_channels(),
90        sample_rate,
91        config,
92    )
93}
94
95#[cfg(test)]
96mod tests {
97
98    #[test]
99    fn test_round_trip_i16() {
100        use super::*;
101        use crate::reader::{AudioReadConfig, audio_read};
102
103        let audio1 =
104            audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
105
106        audio_write(
107            "tmp1.wav",
108            &audio1.samples_interleaved,
109            audio1.num_channels,
110            audio1.sample_rate,
111            AudioWriteConfig {
112                sample_format: WriteSampleFormat::Int16,
113            },
114        )
115        .unwrap();
116
117        let audio2 = audio_read::<f32>("tmp1.wav", AudioReadConfig::default()).unwrap();
118        assert_eq!(audio1.sample_rate, audio2.sample_rate);
119        approx::assert_abs_diff_eq!(
120            audio1.samples_interleaved.as_slice(),
121            audio2.samples_interleaved.as_slice(),
122            epsilon = 1e-4
123        );
124
125        let _ = std::fs::remove_file("tmp1.wav");
126    }
127
128    #[test]
129    fn test_round_trip_f32() {
130        use super::*;
131        use crate::reader::{AudioReadConfig, audio_read};
132
133        let audio1 =
134            audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
135
136        audio_write(
137            "tmp2.wav",
138            &audio1.samples_interleaved,
139            audio1.num_channels,
140            audio1.sample_rate,
141            AudioWriteConfig {
142                sample_format: WriteSampleFormat::Float32,
143            },
144        )
145        .unwrap();
146
147        let audio2 = audio_read::<f32>("tmp2.wav", AudioReadConfig::default()).unwrap();
148        assert_eq!(audio1.sample_rate, audio2.sample_rate);
149        approx::assert_abs_diff_eq!(
150            audio1.samples_interleaved.as_slice(),
151            audio2.samples_interleaved.as_slice(),
152            epsilon = 1e-6
153        );
154
155        let _ = std::fs::remove_file("tmp2.wav");
156    }
157}