1#![warn(rust_2018_idioms)]
2#![warn(rust_2021_compatibility)]
3#![warn(clippy::missing_panics_doc)]
4#![warn(clippy::clone_on_ref_ptr)]
5#![deny(trivial_numeric_casts)]
6#![forbid(unsafe_code)]
7
8use std::path::PathBuf;
9use std::{
10 fs::{File, OpenOptions},
11 io::{Seek, SeekFrom, Write},
12};
13
14use creek_core::{write, Encoder, FileInfo, WriteBlock, WriteStatus};
15
16pub mod error;
17mod header;
18
19#[cfg(test)]
20mod tests;
21
22pub mod wav_bit_depth;
23
24use error::{WavFatalError, WavOpenError};
25use header::Header;
26use wav_bit_depth::WavBitDepth;
27
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum FormatType {
30 Pcm,
31 Float,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum Format {
36 Uint8,
37 Int16,
38 Int24,
39 Float32,
40 Float64,
41}
42
43impl Format {
44 pub fn bits_per_sample(&self) -> u16 {
45 match self {
46 Format::Uint8 => 8,
47 Format::Int16 => 16,
48 Format::Int24 => 24,
49 Format::Float32 => 32,
50 Format::Float64 => 64,
51 }
52 }
53
54 pub fn bytes_per_sample(&self) -> u16 {
55 match self {
56 Format::Uint8 => 1,
57 Format::Int16 => 2,
58 Format::Int24 => 3,
59 Format::Float32 => 4,
60 Format::Float64 => 8,
61 }
62 }
63
64 pub fn format_type(&self) -> FormatType {
65 match self {
66 Format::Uint8 => FormatType::Pcm,
67 Format::Int16 => FormatType::Pcm,
68 Format::Int24 => FormatType::Pcm,
69 Format::Float32 => FormatType::Float,
70 Format::Float64 => FormatType::Float,
71 }
72 }
73}
74
75#[derive(Clone)]
76pub struct Params {
77 _format: Format,
78}
79
80pub struct WavEncoder<B: WavBitDepth + 'static> {
81 interleave_buf: Vec<B::T>,
82 file: Option<File>,
83 header: Header,
84 path: PathBuf,
85 bytes_per_frame: u64,
86 frames_written: u32,
87 max_file_bytes: u64,
88 max_block_bytes: u64,
89 num_channels: usize,
90 num_files: u32,
91 bit_depth: B,
92}
93
94impl<B: WavBitDepth + 'static> Encoder for WavEncoder<B> {
95 type T = B::T;
96 type AdditionalOpts = ();
97 type FileParams = Params;
98 type OpenError = WavOpenError;
99 type FatalError = WavFatalError;
100
101 const DEFAULT_BLOCK_SIZE: usize = 32768;
102 const DEFAULT_NUM_WRITE_BLOCKS: usize = 8;
103
104 fn new(
105 path: PathBuf,
106 num_channels: u16,
107 sample_rate: u32,
108 block_size: usize,
109 _num_write_blocks: usize,
110 _additional_opts: Self::AdditionalOpts,
111 ) -> Result<(Self, FileInfo<Self::FileParams>), Self::OpenError> {
112 let mut file = OpenOptions::new()
113 .write(true)
114 .truncate(true)
115 .create(true)
116 .open(path.clone())?;
117
118 let format = B::format();
119 let header = Header::new(num_channels, sample_rate, format);
120
121 file.write_all(header.buffer())?;
122 file.flush()?;
123
124 let interleave_buf: Vec<B::T> = Vec::with_capacity(block_size * usize::from(num_channels));
125
126 let max_file_bytes = u64::from(header.max_data_bytes());
127 let bytes_per_frame = u64::from(num_channels) * u64::from(format.bytes_per_sample());
128
129 Ok((
130 Self {
131 interleave_buf,
132 file: Some(file),
133 header,
134 path,
135 frames_written: 0,
136 bytes_per_frame,
137 max_file_bytes,
138 max_block_bytes: block_size as u64 * bytes_per_frame,
139 num_channels: usize::from(num_channels),
140 num_files: 1,
141 bit_depth: B::new(block_size, num_channels),
142 },
143 FileInfo {
144 num_frames: 0,
145 num_channels,
146 sample_rate: Some(sample_rate),
147 params: Params { _format: format },
148 },
149 ))
150 }
151
152 fn encode(
153 &mut self,
154 write_block: &WriteBlock<Self::T>,
155 ) -> Result<WriteStatus, Self::FatalError> {
156 let mut status = WriteStatus::Ok;
157
158 let written_frames = write_block.written_frames();
159 if written_frames == 0 {
160 return Ok(status);
161 }
162
163 if let Some(mut file) = self.file.take() {
164 if self.num_channels == 1 {
165 self.bit_depth
166 .write_to_disk(&write_block.block()[0][0..written_frames], &mut file)?;
167 } else {
168 if self.num_channels == 2 {
169 let ch1 = &write_block.block()[0][0..written_frames];
171 let ch2 = &write_block.block()[1][0..written_frames];
172
173 if self.interleave_buf.len() < written_frames * 2 {
174 self.interleave_buf
175 .resize(written_frames * 2, Default::default());
176 }
177
178 let interleave_buf_part = &mut self.interleave_buf[0..written_frames * 2];
179
180 for (i, frame) in interleave_buf_part.chunks_exact_mut(2).enumerate() {
181 frame[0] = ch1[i];
182 frame[1] = ch2[i];
183 }
184 } else {
185 if self.interleave_buf.len() < written_frames * self.num_channels {
186 self.interleave_buf
187 .resize(written_frames * self.num_channels, Default::default());
188 }
189
190 let interleave_buf_part =
191 &mut self.interleave_buf[0..written_frames * self.num_channels];
192
193 for (ch_i, ch) in write_block.block().iter().enumerate() {
194 let ch_slice = &ch[0..written_frames];
195
196 for (dst, src) in interleave_buf_part[ch_i..]
197 .iter_mut()
198 .step_by(self.num_channels)
199 .zip(ch_slice)
200 {
201 *dst = *src;
202 }
203 }
204 }
205
206 self.bit_depth.write_to_disk(
207 &self.interleave_buf[0..written_frames * self.num_channels],
208 &mut file,
209 )?;
210 }
211
212 self.frames_written += written_frames as u32;
213 let bytes_written = u64::from(self.frames_written) * self.bytes_per_frame;
214
215 self.header.set_num_frames(self.frames_written);
216
217 file.seek(SeekFrom::Start(0))?;
219 file.write_all(self.header.buffer())?;
220 file.seek(SeekFrom::Current(bytes_written as i64))?;
221 file.flush()?;
222
223 if bytes_written + self.max_block_bytes >= self.max_file_bytes {
225 let _ = file;
229
230 self.num_files += 1;
231
232 let mut file_name = self
233 .path
234 .file_name()
235 .ok_or_else(|| WavFatalError::CouldNotGetFileName)?
236 .to_os_string();
237 file_name.push(write::num_files_to_file_name_extension(self.num_files));
238 let mut new_file_path = self.path.clone();
239 new_file_path.set_file_name(file_name);
240
241 let mut file = OpenOptions::new()
243 .write(true)
244 .truncate(true)
245 .create(true)
246 .open(new_file_path)?;
247
248 self.frames_written = 0;
249 self.header.set_num_frames(0);
250
251 file.seek(SeekFrom::Start(0))?;
252 file.write_all(self.header.buffer())?;
253 file.flush()?;
254
255 status = WriteStatus::ReachedMaxSize {
256 num_files: self.num_files,
257 };
258 }
259
260 self.file = Some(file);
261 }
262
263 Ok(status)
264 }
265
266 fn finish_file(&mut self) -> Result<(), Self::FatalError> {
267 if let Some(mut file) = self.file.take() {
268 self.header.set_num_frames(self.frames_written);
269
270 file.seek(SeekFrom::Start(0))?;
271 file.write_all(self.header.buffer())?;
272 file.flush()?;
273
274 let _ = file;
276
277 self.num_files = 0;
278 }
279
280 Ok(())
281 }
282
283 fn discard_file(&mut self) -> Result<(), Self::FatalError> {
284 if let Some(file) = self.file.take() {
285 let _ = file;
287
288 std::fs::remove_file(self.path.clone())?;
289
290 if self.num_files > 1 {
292 for i in 2..(self.num_files + 1) {
293 let mut file_name = self
294 .path
295 .file_name()
296 .ok_or_else(|| WavFatalError::CouldNotGetFileName)?
297 .to_os_string();
298 file_name.push(write::num_files_to_file_name_extension(i));
299 let mut new_file_path = self.path.clone();
300 new_file_path.set_file_name(file_name);
301
302 std::fs::remove_file(new_file_path)?;
303 }
304 }
305
306 self.num_files = 0;
307 }
308
309 Ok(())
310 }
311
312 fn discard_and_restart(&mut self) -> Result<(), Self::FatalError> {
313 if let Some(mut file) = self.file.take() {
314 self.frames_written = 0;
315 self.header.set_num_frames(0);
316
317 if self.num_files > 1 {
318 let _ = file;
320
321 for i in 2..(self.num_files + 1) {
323 let mut file_name = self
324 .path
325 .file_name()
326 .ok_or_else(|| WavFatalError::CouldNotGetFileName)?
327 .to_os_string();
328 file_name.push(write::num_files_to_file_name_extension(i));
329 let mut new_file_path = self.path.clone();
330 new_file_path.set_file_name(file_name);
331
332 std::fs::remove_file(new_file_path)?;
333 }
334
335 let mut file = OpenOptions::new()
337 .write(true)
338 .truncate(true)
339 .create(true)
340 .open(self.path.clone())?;
341 file.seek(SeekFrom::Start(0))?;
342 file.write_all(self.header.buffer())?;
343 file.flush()?;
344
345 self.file = Some(file);
346 self.num_files = 1;
347 } else {
348 file.set_len(0)?;
349
350 file.seek(SeekFrom::Start(0))?;
351 file.write_all(self.header.buffer())?;
352 file.flush()?;
353
354 self.file = Some(file);
355 }
356 }
357
358 Ok(())
359 }
360}