audio_io/
writer.rs

1use std::path::Path;
2
3use audio_blocks::AudioBlock;
4use hound::{SampleFormat, WavSpec, WavWriter};
5use num::Float;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum AudioWriteError {
10    #[error("could not decode audio")]
11    DecodingError(#[from] hound::Error),
12}
13
14/// Sample format for writing audio
15#[derive(Debug, Clone, Copy, Default)]
16pub enum WriteSampleFormat {
17    /// 16-bit integer samples
18    #[default]
19    Int16,
20    /// 32-bit float samples
21    Float32,
22}
23
24/// Configuration for writing audio to WAV files
25#[derive(Default)]
26pub struct AudioWriteConfig {
27    /// Sample format to use when writing
28    pub sample_format: WriteSampleFormat,
29}
30
31pub fn audio_write<P: AsRef<Path>, F: Float + 'static>(
32    path: P,
33    audio_block: impl AudioBlock<F>,
34    sample_rate: u32,
35    config: AudioWriteConfig,
36) -> Result<(), AudioWriteError> {
37    let spec = WavSpec {
38        channels: audio_block.num_channels(),
39        sample_rate,
40        bits_per_sample: match config.sample_format {
41            WriteSampleFormat::Int16 => 16,
42            WriteSampleFormat::Float32 => 32,
43        },
44        sample_format: match config.sample_format {
45            WriteSampleFormat::Int16 => SampleFormat::Int,
46            WriteSampleFormat::Float32 => SampleFormat::Float,
47        },
48    };
49
50    let mut writer = WavWriter::create(path.as_ref(), spec)?;
51
52    match config.sample_format {
53        WriteSampleFormat::Int16 => {
54            // Convert f32 samples to i16
55            for frame in audio_block.frame_iters() {
56                for sample in frame {
57                    let sample_i16 = (sample.clamp(F::one().neg(), F::one())
58                        * F::from(i16::MAX).unwrap_or(F::zero()))
59                    .to_i16()
60                    .unwrap_or(0);
61                    writer.write_sample(sample_i16)?;
62                }
63            }
64        }
65        WriteSampleFormat::Float32 => {
66            // Write f32 samples directly
67            for frame in audio_block.frame_iters() {
68                for sample in frame {
69                    writer.write_sample(sample.to_f32().unwrap_or(0.0))?;
70                }
71            }
72        }
73    }
74
75    writer.finalize()?;
76
77    Ok(())
78}
79
80#[cfg(test)]
81mod tests {
82
83    #[test]
84    #[cfg(all(feature = "read", feature = "write"))]
85    fn test_round_trip_i16() {
86        use super::*;
87        use crate::reader::{AudioReadConfig, audio_read};
88
89        let data1 =
90            audio_read::<_, f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
91
92        audio_write(
93            "tmp1.wav",
94            data1.audio_block(),
95            data1.sample_rate,
96            AudioWriteConfig {
97                sample_format: WriteSampleFormat::Int16,
98            },
99        )
100        .unwrap();
101
102        let data2 = audio_read::<_, f32>("tmp1.wav", AudioReadConfig::default()).unwrap();
103        assert_eq!(data1.sample_rate, data2.sample_rate);
104        approx::assert_abs_diff_eq!(
105            data1.audio_block().raw_data(),
106            data2.audio_block().raw_data(),
107            epsilon = 1e-4
108        );
109
110        let _ = std::fs::remove_file("tmp1.wav");
111    }
112
113    #[test]
114    #[cfg(all(feature = "read", feature = "write"))]
115    fn test_round_trip_f32() {
116        use super::*;
117        use crate::reader::{AudioReadConfig, audio_read};
118
119        let data1 =
120            audio_read::<_, f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
121
122        audio_write(
123            "tmp2.wav",
124            data1.audio_block(),
125            data1.sample_rate,
126            AudioWriteConfig {
127                sample_format: WriteSampleFormat::Float32,
128            },
129        )
130        .unwrap();
131
132        let data2 = audio_read::<_, f32>("tmp2.wav", AudioReadConfig::default()).unwrap();
133        assert_eq!(data1.sample_rate, data2.sample_rate);
134        approx::assert_abs_diff_eq!(
135            data1.audio_block().raw_data(),
136            data2.audio_block().raw_data(),
137            epsilon = 1e-6
138        );
139
140        let _ = std::fs::remove_file("tmp2.wav");
141    }
142}