audio_io/
writer.rs

1use std::path::Path;
2
3use audio_blocks::AudioBlock;
4use hound::{SampleFormat, WavSpec, WavWriter};
5use num::Float;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum AudioWriteError {
10    #[error("could not decode audio")]
11    DecodingError(#[from] hound::Error),
12}
13
14/// Sample format for writing audio
15#[derive(Debug, Clone, Copy, Default)]
16pub enum WriteSampleFormat {
17    /// 16-bit integer samples
18    #[default]
19    Int16,
20    /// 32-bit float samples
21    Float32,
22}
23
24/// Configuration for writing audio to WAV files
25#[derive(Default)]
26pub struct AudioWriteConfig {
27    /// Sample format to use when writing
28    pub sample_format: WriteSampleFormat,
29}
30
31pub fn audio_write<P: AsRef<Path>, F: Float + 'static>(
32    path: P,
33    audio_block: impl AudioBlock<F>,
34    sample_rate: u32,
35    config: AudioWriteConfig,
36) -> Result<(), AudioWriteError> {
37    let spec = WavSpec {
38        channels: audio_block.num_channels(),
39        sample_rate,
40        bits_per_sample: match config.sample_format {
41            WriteSampleFormat::Int16 => 16,
42            WriteSampleFormat::Float32 => 32,
43        },
44        sample_format: match config.sample_format {
45            WriteSampleFormat::Int16 => SampleFormat::Int,
46            WriteSampleFormat::Float32 => SampleFormat::Float,
47        },
48    };
49
50    let mut writer = WavWriter::create(path.as_ref(), spec)?;
51
52    match config.sample_format {
53        WriteSampleFormat::Int16 => {
54            // Convert f32 samples to i16
55            for frame in audio_block.frame_iters() {
56                for sample in frame {
57                    let sample_i16 = (sample.clamp(F::one().neg(), F::one())
58                        * F::from(i16::MAX).unwrap_or(F::zero()))
59                    .to_i16()
60                    .unwrap_or(0);
61                    writer.write_sample(sample_i16)?;
62                }
63            }
64        }
65        WriteSampleFormat::Float32 => {
66            // Write f32 samples directly
67            for frame in audio_block.frame_iters() {
68                for sample in frame {
69                    writer.write_sample(sample.to_f32().unwrap_or(0.0))?;
70                }
71            }
72        }
73    }
74
75    writer.finalize()?;
76
77    Ok(())
78}
79
80#[cfg(test)]
81mod tests {
82
83    #[test]
84    #[cfg(all(feature = "read", feature = "write"))]
85    fn test_round_trip_i16() {
86        use super::*;
87        use crate::reader::{AudioReadConfig, audio_read};
88
89        let data1 = audio_read::<_, f32>("test.wav", AudioReadConfig::default()).unwrap();
90
91        audio_write(
92            "tmp1.wav",
93            data1.audio_block(),
94            data1.sample_rate,
95            AudioWriteConfig {
96                sample_format: WriteSampleFormat::Int16,
97            },
98        )
99        .unwrap();
100
101        let data2 = audio_read::<_, f32>("tmp1.wav", AudioReadConfig::default()).unwrap();
102        assert_eq!(data1.sample_rate, data2.sample_rate);
103        approx::assert_abs_diff_eq!(
104            data1.audio_block().raw_data(),
105            data2.audio_block().raw_data(),
106            epsilon = 1e-4
107        );
108
109        let _ = std::fs::remove_file("tmp1.wav");
110    }
111
112    #[test]
113    #[cfg(all(feature = "read", feature = "write"))]
114    fn test_round_trip_f32() {
115        use super::*;
116        use crate::reader::{AudioReadConfig, audio_read};
117
118        let data1 = audio_read::<_, f32>("test.wav", AudioReadConfig::default()).unwrap();
119
120        audio_write(
121            "tmp2.wav",
122            data1.audio_block(),
123            data1.sample_rate,
124            AudioWriteConfig {
125                sample_format: WriteSampleFormat::Float32,
126            },
127        )
128        .unwrap();
129
130        let data2 = audio_read::<_, f32>("tmp2.wav", AudioReadConfig::default()).unwrap();
131        assert_eq!(data1.sample_rate, data2.sample_rate);
132        approx::assert_abs_diff_eq!(
133            data1.audio_block().raw_data(),
134            data2.audio_block().raw_data(),
135            epsilon = 1e-6
136        );
137
138        let _ = std::fs::remove_file("tmp2.wav");
139    }
140}