wow_sharedmedia/converter/
audio.rs1use std::num::NonZeroU8;
4use std::num::NonZeroU32;
5use std::path::Path;
6
7use symphonia::core::audio::SampleBuffer;
8use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
9use symphonia::core::formats::FormatOptions;
10use symphonia::core::io::MediaSourceStream;
11use symphonia::core::meta::MetadataOptions;
12use symphonia::core::probe::Hint;
13
14use crate::Error;
15
16pub fn convert_to_ogg(input: &Path, output: &Path) -> Result<AudioConvertResult, Error> {
22 convert_to_ogg_with_quality(input, output, 0.4)
23}
24
25pub fn convert_to_ogg_with_quality(input: &Path, output: &Path, quality: f32) -> Result<AudioConvertResult, Error> {
30 let ext = input
31 .extension()
32 .and_then(|e| e.to_str())
33 .map(|e| e.to_lowercase())
34 .unwrap_or_default();
35
36 if ext == "ogg" {
38 std::fs::copy(input, output).map_err(|e| Error::Io {
39 source: e,
40 path: input.to_path_buf(),
41 })?;
42 return probe_audio(output);
43 }
44
45 let src = std::fs::File::open(input).map_err(|e| Error::Io {
47 source: e,
48 path: input.to_path_buf(),
49 })?;
50
51 let mss = MediaSourceStream::new(Box::new(src), Default::default());
53
54 let mut hint = Hint::new();
56 if !ext.is_empty() {
57 hint.with_extension(&ext);
58 }
59
60 let meta_opts: MetadataOptions = Default::default();
62 let fmt_opts: FormatOptions = Default::default();
63
64 let probed = symphonia::default::get_probe()
65 .format(&hint, mss, &fmt_opts, &meta_opts)
66 .map_err(|e| Error::InvalidAudio(format!("Cannot detect audio format: {e}")))?;
67
68 let mut format = probed.format;
69
70 let track = format
72 .tracks()
73 .iter()
74 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
75 .ok_or_else(|| Error::InvalidAudio("No supported audio track found".to_string()))?;
76
77 let track_id = track.id;
78 let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
79 let channels = track.codec_params.channels.map(|c| c.count() as u32).unwrap_or(2);
80
81 let dec_opts: DecoderOptions = Default::default();
83 let mut decoder = symphonia::default::get_codecs()
84 .make(&track.codec_params, &dec_opts)
85 .map_err(|e| Error::InvalidAudio(format!("Unsupported codec: {e}")))?;
86
87 let mut all_samples: Vec<f32> = Vec::new();
89 let mut total_duration_frames: u64 = 0;
90
91 loop {
92 let packet = match format.next_packet() {
93 Ok(packet) => packet,
94 Err(symphonia::core::errors::Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
95 break;
96 }
97 Err(e) => {
98 return Err(Error::InvalidAudio(format!("Format error: {e}")));
99 }
100 };
101
102 if packet.track_id() != track_id {
103 continue;
104 }
105
106 match decoder.decode(&packet) {
107 Ok(decoded) => {
108 let spec = *decoded.spec();
109 let frames = decoded.frames() as u64;
110 let mut sample_buf = SampleBuffer::<f32>::new(decoded.capacity() as u64, spec);
111 sample_buf.copy_interleaved_ref(decoded);
112 all_samples.extend_from_slice(sample_buf.samples());
113 total_duration_frames += frames;
114 }
115 Err(symphonia::core::errors::Error::DecodeError(_)) => continue,
116 Err(symphonia::core::errors::Error::IoError(_)) => continue,
117 Err(e) => {
118 return Err(Error::InvalidAudio(format!("Decode error: {e}")));
119 }
120 }
121 }
122
123 let duration_secs = if sample_rate > 0 && channels > 0 {
125 total_duration_frames as f64 / sample_rate as f64
126 } else {
127 0.0
128 };
129
130 if let Some(parent) = output.parent() {
132 std::fs::create_dir_all(parent).map_err(|e| Error::Io {
133 source: e,
134 path: parent.to_path_buf(),
135 })?;
136 }
137
138 let ogg_file = std::fs::File::create(output).map_err(|e| Error::Io {
139 source: e,
140 path: output.to_path_buf(),
141 })?;
142
143 let nz_sample_rate = NonZeroU32::new(sample_rate).unwrap_or(NonZeroU32::new(44100).unwrap());
145 let nz_channels = NonZeroU8::new(channels as u8).unwrap_or(NonZeroU8::new(2).unwrap());
146
147 let mut builder = vorbis_rs::VorbisEncoderBuilder::new_with_serial(nz_sample_rate, nz_channels, ogg_file, 1);
148 builder.bitrate_management_strategy(vorbis_rs::VorbisBitrateManagementStrategy::QualityVbr {
149 target_quality: quality.clamp(0.0, 1.0),
150 });
151
152 let mut encoder = builder
153 .build()
154 .map_err(|e| Error::AudioConversion(format!("Failed to build encoder: {e}")))?;
155
156 let num_channels = channels as usize;
159 if num_channels > 0 && !all_samples.is_empty() {
160 let samples_per_channel = all_samples.len() / num_channels;
161
162 let mut planar: Vec<Vec<f32>> = vec![Vec::with_capacity(samples_per_channel); num_channels];
164 for (i, sample) in all_samples.iter().enumerate() {
165 planar[i % num_channels].push(*sample);
166 }
167
168 let block_size = 1024;
170 let mut offset = 0;
171 while offset < samples_per_channel {
172 let end = (offset + block_size).min(samples_per_channel);
173 let block: Vec<&[f32]> = planar.iter().map(|ch| &ch[offset..end]).collect();
174 encoder
175 .encode_audio_block(&block)
176 .map_err(|e| Error::AudioConversion(format!("Encoding error: {e}")))?;
177 offset = end;
178 }
179 }
180
181 encoder
183 .finish()
184 .map_err(|e| Error::AudioConversion(format!("Failed to finalize: {e}")))?;
185
186 Ok(AudioConvertResult {
187 duration_secs,
188 sample_rate,
189 channels: channels as u32,
190 })
191}
192
193pub(crate) fn probe_audio(input: &Path) -> Result<AudioConvertResult, Error> {
198 let ext = input
199 .extension()
200 .and_then(|e| e.to_str())
201 .map(|e| e.to_lowercase())
202 .unwrap_or_default();
203
204 if ext != "ogg" {
205 let tmp_output = input.with_extension("tmp.ogg");
207 match convert_to_ogg(input, &tmp_output) {
208 Ok(result) => {
209 let _ = std::fs::remove_file(&tmp_output);
210 return Ok(result);
211 }
212 Err(error) => {
213 let _ = std::fs::remove_file(&tmp_output);
214 return Err(error);
215 }
216 }
217 }
218
219 let src = std::fs::File::open(input).map_err(|e| Error::Io {
221 source: e,
222 path: input.to_path_buf(),
223 })?;
224
225 let mss = MediaSourceStream::new(Box::new(src), Default::default());
226
227 let mut hint = Hint::new();
228 hint.with_extension("ogg");
229
230 let meta_opts: MetadataOptions = Default::default();
231 let fmt_opts: FormatOptions = Default::default();
232
233 let probed = match symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts) {
234 Ok(p) => p,
235 Err(error) => {
236 return Err(Error::InvalidAudio(format!("Cannot probe ogg metadata: {error}")));
237 }
238 };
239
240 let mut format = probed.format;
241
242 let track = match format.tracks().iter().find(|t| t.codec_params.codec != CODEC_TYPE_NULL) {
244 Some(t) => t,
245 None => {
246 return Err(Error::InvalidAudio(
247 "No supported audio track found in ogg file".to_string(),
248 ));
249 }
250 };
251
252 let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
253 let channels = track.codec_params.channels.map(|c| c.count() as u32).unwrap_or(2);
254
255 let dec_opts: DecoderOptions = Default::default();
257 let track_id = track.id;
258
259 let mut decoder = match symphonia::default::get_codecs().make(&track.codec_params, &dec_opts) {
260 Ok(d) => d,
261 Err(error) => {
262 return Err(Error::InvalidAudio(format!("Unsupported ogg codec: {error}")));
263 }
264 };
265
266 let mut total_frames: u64 = 0;
267
268 loop {
269 let packet = match format.next_packet() {
270 Ok(p) => p,
271 Err(symphonia::core::errors::Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
272 break;
273 }
274 Err(_) => break,
275 };
276
277 if packet.track_id() != track_id {
278 continue;
279 }
280
281 match decoder.decode(&packet) {
282 Ok(decoded) => {
283 total_frames += decoded.frames() as u64;
284 }
285 Err(_) => continue,
286 }
287 }
288
289 let duration_secs = if sample_rate > 0 {
290 total_frames as f64 / sample_rate as f64
291 } else {
292 0.0
293 };
294
295 Ok(AudioConvertResult {
296 duration_secs,
297 sample_rate,
298 channels,
299 })
300}
301
302#[derive(Debug, Clone, PartialEq)]
304pub struct AudioConvertResult {
305 pub duration_secs: f64,
307 pub sample_rate: u32,
309 pub channels: u32,
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use tempfile::TempDir;
317
318 fn write_test_wav(path: &Path, sample_rate: u32, channels: u16, samples: &[i16]) {
319 let bits_per_sample: u16 = 16;
320 let block_align: u16 = channels * (bits_per_sample / 8);
321 let byte_rate: u32 = sample_rate * block_align as u32;
322 let data_size: u32 = std::mem::size_of_val(samples) as u32;
323 let riff_size: u32 = 36 + data_size;
324
325 let mut bytes = Vec::with_capacity((44 + data_size) as usize);
326 bytes.extend_from_slice(b"RIFF");
327 bytes.extend_from_slice(&riff_size.to_le_bytes());
328 bytes.extend_from_slice(b"WAVE");
329 bytes.extend_from_slice(b"fmt ");
330 bytes.extend_from_slice(&16u32.to_le_bytes());
331 bytes.extend_from_slice(&1u16.to_le_bytes());
332 bytes.extend_from_slice(&channels.to_le_bytes());
333 bytes.extend_from_slice(&sample_rate.to_le_bytes());
334 bytes.extend_from_slice(&byte_rate.to_le_bytes());
335 bytes.extend_from_slice(&block_align.to_le_bytes());
336 bytes.extend_from_slice(&bits_per_sample.to_le_bytes());
337 bytes.extend_from_slice(b"data");
338 bytes.extend_from_slice(&data_size.to_le_bytes());
339 for sample in samples {
340 bytes.extend_from_slice(&sample.to_le_bytes());
341 }
342
343 std::fs::write(path, bytes).unwrap();
344 }
345
346 #[test]
347 fn test_convert_wav_to_ogg_and_probe() {
348 let dir = TempDir::new().unwrap();
349 let input = dir.path().join("input.wav");
350 let output = dir.path().join("output.ogg");
351
352 let samples = [0i16, 8192, -8192, 4096, -4096, 0, 2048, -2048];
353 write_test_wav(&input, 44_100, 1, &samples);
354
355 let result = convert_to_ogg(&input, &output).unwrap();
356 assert!(output.exists());
357 assert_eq!(result.sample_rate, 44_100);
358 assert_eq!(result.channels, 1);
359 assert!(result.duration_secs >= 0.0);
360
361 let probed = probe_audio(&output).unwrap();
362 assert!(probed.sample_rate > 0);
363 assert!(probed.channels > 0);
364 }
365
366 #[test]
367 fn test_convert_ogg_passthrough() {
368 let dir = TempDir::new().unwrap();
369 let wav = dir.path().join("input.wav");
370 let ogg = dir.path().join("input.ogg");
371 let copied = dir.path().join("copied.ogg");
372
373 let samples = [0i16, 4096, -4096, 0];
374 write_test_wav(&wav, 22_050, 1, &samples);
375 convert_to_ogg(&wav, &ogg).unwrap();
376
377 let original_bytes = std::fs::read(&ogg).unwrap();
378 let result = convert_to_ogg(&ogg, &copied).unwrap();
379 let copied_bytes = std::fs::read(&copied).unwrap();
380
381 assert_eq!(original_bytes, copied_bytes);
382 assert!(result.sample_rate > 0);
383 }
384
385 #[test]
386 fn test_invalid_audio_errors() {
387 let dir = TempDir::new().unwrap();
388 let input = dir.path().join("bad.wav");
389 let output = dir.path().join("bad.ogg");
390 std::fs::write(&input, b"not really a wav").unwrap();
391
392 let result = convert_to_ogg(&input, &output);
393 assert!(result.is_err());
394 match result.unwrap_err() {
395 Error::InvalidAudio(_) => {}
396 other => panic!("Expected InvalidAudio, got: {other}"),
397 }
398 }
399
400 #[test]
401 fn test_probe_invalid_ogg_errors() {
402 let dir = TempDir::new().unwrap();
403 let input = dir.path().join("bad.ogg");
404 std::fs::write(&input, b"not really an ogg").unwrap();
405
406 let result = probe_audio(&input);
407 assert!(result.is_err());
408 match result.unwrap_err() {
409 Error::InvalidAudio(_) => {}
410 other => panic!("Expected InvalidAudio, got: {other}"),
411 }
412 }
413}