1use std::io::Cursor;
25
26use crate::error::{Error, Result};
27use symphonia::core::io::MediaSourceStream;
28use symphonia::core::probe::Hint;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum AudioFormat {
33 WavPcm,
35 Flac,
37 Mp3,
39 Opus,
41 WebM,
43 Aac,
45}
46
47impl AudioFormat {
48 #[must_use]
50 pub fn as_str(self) -> &'static str {
51 match self {
52 Self::WavPcm => "wav",
53 Self::Flac => "flac",
54 Self::Mp3 => "mp3",
55 Self::Opus => "opus",
56 Self::WebM => "webm",
57 Self::Aac => "aac",
58 }
59 }
60
61 #[must_use]
63 pub const fn is_lossless(self) -> bool {
64 matches!(self, Self::WavPcm | Self::Flac)
65 }
66
67 #[must_use]
69 pub const fn is_container_format(self) -> bool {
70 matches!(self, Self::WavPcm | Self::Opus | Self::WebM | Self::Aac)
71 }
72}
73
74impl std::fmt::Display for AudioFormat {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.write_str(self.as_str())
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq)]
82pub struct AudioMetadata {
83 pub format: AudioFormat,
85 pub channels: Option<u16>,
87 pub sample_rate: Option<u32>,
89 pub bit_depth: Option<u16>,
91 pub duration_sec: Option<f64>,
93}
94
95impl AudioMetadata {
96 #[must_use]
98 pub const fn format_only(format: AudioFormat) -> Self {
99 Self {
100 format,
101 channels: None,
102 sample_rate: None,
103 bit_depth: None,
104 duration_sec: None,
105 }
106 }
107
108 #[must_use]
110 pub const fn with_properties(
111 format: AudioFormat,
112 channels: u16,
113 sample_rate: u32,
114 bit_depth: Option<u16>,
115 ) -> Self {
116 Self {
117 format,
118 channels: Some(channels),
119 sample_rate: Some(sample_rate),
120 bit_depth,
121 duration_sec: None,
122 }
123 }
124}
125
126#[derive(Debug, Default, Clone, Copy)]
128pub struct FormatDetector;
129
130impl FormatDetector {
131 #[must_use]
133 pub const fn new() -> Self {
134 Self
135 }
136
137 pub fn detect(data: &[u8]) -> Result<AudioMetadata> {
150 if data.len() < 4 {
151 return Err(Error::InvalidInput(
152 "audio payload too short (minimum 4 bytes required)".into(),
153 ));
154 }
155
156 if let Some(format) = Self::detect_magic_bytes(data) {
157 return Ok(AudioMetadata::format_only(format));
158 }
159
160 Self::detect_with_symphonia(data)
161 }
162
163 pub fn detect_with_metadata(data: &[u8]) -> Result<AudioMetadata> {
173 Self::detect_with_symphonia(data)
174 }
175
176 fn detect_magic_bytes(data: &[u8]) -> Option<AudioFormat> {
181 let len = data.len();
182
183 if len >= 12 {
185 if let (Some(riff), Some(wave)) = (data.get(0..4), data.get(8..12)) {
186 if riff == b"RIFF" && wave == b"WAVE" {
187 return Some(AudioFormat::WavPcm);
188 }
189 }
190 }
191
192 if len >= 4 {
194 if let Some(header) = data.get(0..4) {
195 if header == b"fLaC" {
196 return Some(AudioFormat::Flac);
197 }
198 }
199 }
200
201 if len >= 2 {
203 if let (Some(&first), Some(&second)) = (data.first(), data.get(1)) {
204 if first == 0xFF && (second & 0xE0) == 0xE0 {
205 let layer = (second >> 1) & 0x03;
206 if layer == 0x01 {
207 return Some(AudioFormat::Mp3);
208 }
209 }
210 }
211 }
212
213 if len >= 4 {
215 if let Some(header) = data.get(0..4) {
216 if header == b"OggS" {
217 return None;
218 }
219 }
220 }
221
222 if len >= 4 {
224 if let Some(header) = data.get(0..4) {
225 if header == [0x1A, 0x45, 0xDF, 0xA3] {
226 return Some(AudioFormat::WebM);
227 }
228 }
229 }
230
231 if len >= 12 {
233 if let (Some(ftyp), Some(brand)) = (data.get(4..8), data.get(8..12)) {
234 if ftyp == b"ftyp" && (brand == b"M4A " || brand == b"mp42" || brand == b"isom") {
235 return Some(AudioFormat::Aac);
236 }
237 }
238 }
239
240 None
241 }
242
243 fn detect_with_symphonia(data: &[u8]) -> Result<AudioMetadata> {
247 let data_vec = data.to_vec();
248 let cursor = Cursor::new(data_vec);
249 let mss = MediaSourceStream::new(
250 Box::new(cursor),
251 symphonia::core::io::MediaSourceStreamOptions::default(),
252 );
253
254 let hint = Hint::new();
255 let probe_result = symphonia::default::get_probe()
256 .format(
257 &hint,
258 mss,
259 &symphonia::core::formats::FormatOptions::default(),
260 &symphonia::core::meta::MetadataOptions::default(),
261 )
262 .map_err(|err| {
263 Error::InvalidInput(format!("unsupported or malformed audio format: {err}"))
264 })?;
265
266 let format_reader = probe_result.format;
267 let codec_params = &format_reader
268 .default_track()
269 .ok_or_else(|| Error::InvalidInput("no audio track found in container".into()))?
270 .codec_params;
271
272 let format = match codec_params.codec {
273 symphonia::core::codecs::CODEC_TYPE_PCM_S16LE
274 | symphonia::core::codecs::CODEC_TYPE_PCM_S24LE
275 | symphonia::core::codecs::CODEC_TYPE_PCM_S32LE
276 | symphonia::core::codecs::CODEC_TYPE_PCM_F32LE => AudioFormat::WavPcm,
277 symphonia::core::codecs::CODEC_TYPE_FLAC => AudioFormat::Flac,
278 symphonia::core::codecs::CODEC_TYPE_MP3 => AudioFormat::Mp3,
279 symphonia::core::codecs::CODEC_TYPE_OPUS => AudioFormat::Opus,
280 symphonia::core::codecs::CODEC_TYPE_VORBIS => {
281 return Err(Error::InvalidInput(
282 "Vorbis codec not supported (use Opus instead)".into(),
283 ));
284 }
285 symphonia::core::codecs::CODEC_TYPE_AAC => AudioFormat::Aac,
286 _ => {
287 return Err(Error::InvalidInput(format!(
288 "unsupported codec: {:?}",
289 codec_params.codec
290 )));
291 }
292 };
293
294 let channels = codec_params.channels.map(|ch| ch.count() as u16);
295 let sample_rate = codec_params.sample_rate;
296 let bit_depth = codec_params.bits_per_sample.map(|b| b as u16);
297 let duration_sec = codec_params
298 .n_frames
299 .and_then(|frames| sample_rate.map(|rate| frames as f64 / f64::from(rate)));
300
301 Ok(AudioMetadata {
302 format,
303 channels,
304 sample_rate,
305 bit_depth,
306 duration_sec,
307 })
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 type TestResult<T> = std::result::Result<T, String>;
316
317 fn create_detector() -> FormatDetector {
318 FormatDetector::new()
319 }
320
321 fn detect_format(_detector: FormatDetector, data: &[u8]) -> TestResult<AudioMetadata> {
322 FormatDetector::detect(data).map_err(|e| e.to_string())
323 }
324
325 fn wav_header() -> Vec<u8> {
327 let mut header = Vec::new();
329 header.extend_from_slice(b"RIFF");
330 header.extend_from_slice(&36u32.to_le_bytes()); header.extend_from_slice(b"WAVE");
332 header.extend_from_slice(b"fmt ");
333 header.extend_from_slice(&16u32.to_le_bytes()); header.extend_from_slice(&1u16.to_le_bytes()); header.extend_from_slice(&2u16.to_le_bytes()); header.extend_from_slice(&44100u32.to_le_bytes()); header.extend_from_slice(&(44100u32 * 2 * 2).to_le_bytes()); header.extend_from_slice(&4u16.to_le_bytes()); header.extend_from_slice(&16u16.to_le_bytes()); header
341 }
342
343 fn flac_header() -> Vec<u8> {
344 b"fLaC".to_vec()
346 }
347
348 fn mp3_header() -> Vec<u8> {
349 vec![0xFF, 0xFB, 0x90, 0x00] }
355
356 fn webm_header() -> Vec<u8> {
357 vec![0x1A, 0x45, 0xDF, 0xA3, 0x00, 0x00, 0x00, 0x20]
359 }
360
361 fn aac_header() -> Vec<u8> {
362 let mut header = Vec::new();
364 header.extend_from_slice(&20u32.to_be_bytes()); header.extend_from_slice(b"ftyp"); header.extend_from_slice(b"M4A "); header.extend_from_slice(&0u32.to_be_bytes()); header.extend_from_slice(b"mp42"); header
370 }
371
372 #[test]
374 fn test_detect_wav_format() -> TestResult<()> {
375 let detector = create_detector();
376 let metadata = detect_format(detector, &wav_header())?;
377 assert_eq!(metadata.format, AudioFormat::WavPcm);
378 assert_eq!(metadata.format.as_str(), "wav");
379 assert!(metadata.format.is_lossless());
380 Ok(())
381 }
382
383 #[test]
384 fn test_detect_flac_format() -> TestResult<()> {
385 let detector = create_detector();
386 let metadata = detect_format(detector, &flac_header())?;
387 assert_eq!(metadata.format, AudioFormat::Flac);
388 assert_eq!(metadata.format.as_str(), "flac");
389 assert!(metadata.format.is_lossless());
390 Ok(())
391 }
392
393 #[test]
394 fn test_detect_mp3_format() -> TestResult<()> {
395 let detector = create_detector();
396 let metadata = detect_format(detector, &mp3_header())?;
397 assert_eq!(metadata.format, AudioFormat::Mp3);
398 assert_eq!(metadata.format.as_str(), "mp3");
399 assert!(!metadata.format.is_lossless());
400 Ok(())
401 }
402
403 #[test]
404 fn test_detect_webm_format() -> TestResult<()> {
405 let detector = create_detector();
406 let metadata = detect_format(detector, &webm_header())?;
407 assert_eq!(metadata.format, AudioFormat::WebM);
408 assert_eq!(metadata.format.as_str(), "webm");
409 Ok(())
410 }
411
412 #[test]
413 fn test_detect_aac_format() -> TestResult<()> {
414 let detector = create_detector();
415 let metadata = detect_format(detector, &aac_header())?;
416 assert_eq!(metadata.format, AudioFormat::Aac);
417 assert_eq!(metadata.format.as_str(), "aac");
418 assert!(!metadata.format.is_lossless());
419 Ok(())
420 }
421
422 #[test]
424 fn test_reject_empty_payload() {
425 let result = FormatDetector::detect(&[]);
426 assert!(result.is_err());
427 }
428
429 #[test]
430 fn test_reject_too_short_payload() {
431 let result = FormatDetector::detect(&[0xFF, 0xFE]); assert!(result.is_err());
433 }
434
435 #[test]
436 fn test_reject_random_bytes() {
437 let random_data = vec![0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE];
438 let result = FormatDetector::detect(&random_data);
439 assert!(result.is_err());
440 }
441
442 #[test]
443 fn test_reject_truncated_wav_header() {
444 let truncated = b"RIFF".to_vec(); let result = FormatDetector::detect(&truncated);
446 assert!(result.is_err());
447 }
448
449 #[test]
450 fn test_reject_mismatched_riff_signature() {
451 let mut bad_wav = Vec::new();
452 bad_wav.extend_from_slice(b"RIFF");
453 bad_wav.extend_from_slice(&36u32.to_le_bytes());
454 bad_wav.extend_from_slice(b"AVI "); let result = FormatDetector::detect(&bad_wav);
456 assert!(result.is_err());
457 }
458
459 #[test]
461 fn test_handle_exact_minimum_length() -> TestResult<()> {
462 let detector = create_detector();
463 let flac_minimal = b"fLaC".to_vec(); let metadata = detect_format(detector, &flac_minimal)?;
465 assert_eq!(metadata.format, AudioFormat::Flac);
466 Ok(())
467 }
468
469 #[test]
470 fn test_handle_large_payload_prefix() -> TestResult<()> {
471 let detector = create_detector();
472 let mut large_payload = wav_header();
473 large_payload.extend(vec![0u8; 1024 * 1024]); let metadata = detect_format(detector, &large_payload)?;
475 assert_eq!(metadata.format, AudioFormat::WavPcm);
476 Ok(())
477 }
478
479 #[test]
481 fn test_format_display_matches_as_str() {
482 let formats = [
483 AudioFormat::WavPcm,
484 AudioFormat::Flac,
485 AudioFormat::Mp3,
486 AudioFormat::Opus,
487 AudioFormat::WebM,
488 AudioFormat::Aac,
489 ];
490 for format in &formats {
491 assert_eq!(format.to_string(), format.as_str());
492 }
493 }
494
495 #[test]
496 fn test_lossless_formats_identified() {
497 assert!(AudioFormat::WavPcm.is_lossless());
498 assert!(AudioFormat::Flac.is_lossless());
499 assert!(!AudioFormat::Mp3.is_lossless());
500 assert!(!AudioFormat::Opus.is_lossless());
501 assert!(!AudioFormat::Aac.is_lossless());
502 }
503
504 #[test]
505 fn test_container_formats_identified() {
506 assert!(AudioFormat::WavPcm.is_container_format());
507 assert!(AudioFormat::Opus.is_container_format());
508 assert!(AudioFormat::WebM.is_container_format());
509 assert!(AudioFormat::Aac.is_container_format());
510 assert!(!AudioFormat::Flac.is_container_format());
511 assert!(!AudioFormat::Mp3.is_container_format());
512 }
513
514 #[test]
515 fn test_metadata_format_only_constructor() {
516 let metadata = AudioMetadata::format_only(AudioFormat::Mp3);
517 assert_eq!(metadata.format, AudioFormat::Mp3);
518 assert_eq!(metadata.channels, None);
519 assert_eq!(metadata.sample_rate, None);
520 assert_eq!(metadata.bit_depth, None);
521 assert_eq!(metadata.duration_sec, None);
522 }
523
524 #[test]
525 fn test_metadata_with_properties_constructor() {
526 let metadata = AudioMetadata::with_properties(AudioFormat::WavPcm, 2, 44100, Some(16));
527 assert_eq!(metadata.format, AudioFormat::WavPcm);
528 assert_eq!(metadata.channels, Some(2));
529 assert_eq!(metadata.sample_rate, Some(44100));
530 assert_eq!(metadata.bit_depth, Some(16));
531 assert_eq!(metadata.duration_sec, None);
532 }
533}