1use std::path::Path;
2
3use hound::{SampleFormat, WavSpec, WavWriter};
4use num::Float;
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8pub enum AudioWriteError {
9 #[error("could not decode audio")]
10 DecodingError(#[from] hound::Error),
11}
12
13#[derive(Debug, Clone, Copy, Default)]
15pub enum WriteSampleFormat {
16 #[default]
18 Int16,
19 Float32,
21}
22
23#[derive(Default)]
25pub struct AudioWriteConfig {
26 pub sample_format: WriteSampleFormat,
28}
29
30pub fn audio_write<F: Float>(
32 path: impl AsRef<Path>,
33 samples: &[F],
34 num_channels: u16,
35 sample_rate: u32,
36 config: AudioWriteConfig,
37) -> Result<(), AudioWriteError> {
38 let spec = WavSpec {
39 channels: num_channels,
40 sample_rate,
41 bits_per_sample: match config.sample_format {
42 WriteSampleFormat::Int16 => 16,
43 WriteSampleFormat::Float32 => 32,
44 },
45 sample_format: match config.sample_format {
46 WriteSampleFormat::Int16 => SampleFormat::Int,
47 WriteSampleFormat::Float32 => SampleFormat::Float,
48 },
49 };
50
51 let mut writer = WavWriter::create(path.as_ref(), spec)?;
52
53 match config.sample_format {
54 WriteSampleFormat::Int16 => {
55 for &sample in samples {
57 let sample_i16 = (sample.clamp(F::one().neg(), F::one())
58 * F::from(i16::MAX).unwrap_or(F::zero()))
59 .to_i16()
60 .unwrap_or(0);
61 writer.write_sample(sample_i16)?;
62 }
63 }
64 WriteSampleFormat::Float32 => {
65 for &sample in samples {
67 writer.write_sample(sample.to_f32().unwrap_or(0.0))?;
68 }
69 }
70 }
71
72 writer.finalize()?;
73
74 Ok(())
75}
76
77#[cfg(feature = "audio-blocks")]
79pub fn audio_write_block<P: AsRef<Path>, F: Float + 'static>(
80 path: P,
81 audio_block: impl audio_blocks::AudioBlock<F>,
82 sample_rate: u32,
83 config: AudioWriteConfig,
84) -> Result<(), AudioWriteError> {
85 let block = audio_blocks::AudioBlockInterleaved::from_block(&audio_block);
86 audio_write(
87 path,
88 block.raw_data(),
89 audio_block.num_channels(),
90 sample_rate,
91 config,
92 )
93}
94
95#[cfg(test)]
96mod tests {
97
98 #[test]
99 fn test_round_trip_i16() {
100 use super::*;
101 use crate::reader::{AudioReadConfig, audio_read};
102
103 let audio1 =
104 audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
105
106 audio_write(
107 "tmp1.wav",
108 &audio1.samples_interleaved,
109 audio1.num_channels,
110 audio1.sample_rate,
111 AudioWriteConfig {
112 sample_format: WriteSampleFormat::Int16,
113 },
114 )
115 .unwrap();
116
117 let audio2 = audio_read::<f32>("tmp1.wav", AudioReadConfig::default()).unwrap();
118 assert_eq!(audio1.sample_rate, audio2.sample_rate);
119 approx::assert_abs_diff_eq!(
120 audio1.samples_interleaved.as_slice(),
121 audio2.samples_interleaved.as_slice(),
122 epsilon = 1e-4
123 );
124
125 std::fs::remove_file("tmp1.wav").expect("Failed to remove temporary test file");
127 }
128
129 #[test]
130 fn test_round_trip_f32() {
131 use super::*;
132 use crate::reader::{AudioReadConfig, audio_read};
133
134 let audio1 =
135 audio_read::<f32>("test_data/test_1ch.wav", AudioReadConfig::default()).unwrap();
136
137 audio_write(
138 "tmp2.wav",
139 &audio1.samples_interleaved,
140 audio1.num_channels,
141 audio1.sample_rate,
142 AudioWriteConfig {
143 sample_format: WriteSampleFormat::Float32,
144 },
145 )
146 .unwrap();
147
148 let audio2 = audio_read::<f32>("tmp2.wav", AudioReadConfig::default()).unwrap();
149 assert_eq!(audio1.sample_rate, audio2.sample_rate);
150 approx::assert_abs_diff_eq!(
151 audio1.samples_interleaved.as_slice(),
152 audio2.samples_interleaved.as_slice(),
153 epsilon = 1e-6
154 );
155
156 std::fs::remove_file("tmp2.wav").expect("Failed to remove temporary test file");
158 }
159}