use crate::ContainerFormat;
use oximedia_core::OxiError;
#[derive(Clone, Debug)]
pub struct ProbeResult {
pub format: ContainerFormat,
pub confidence: f32,
}
impl ProbeResult {
#[must_use]
pub const fn new(format: ContainerFormat, confidence: f32) -> Self {
Self { format, confidence }
}
}
const MATROSKA_MAGIC: &[u8] = &[0x1A, 0x45, 0xDF, 0xA3]; const OGG_MAGIC: &[u8] = b"OggS";
const FLAC_MAGIC: &[u8] = b"fLaC";
const RIFF_MAGIC: &[u8] = b"RIFF";
const WAVE_MAGIC: &[u8] = b"WAVE";
const ISOBMFF_FTYP: &[u8] = b"ftyp";
const WEBVTT_MAGIC: &[u8] = b"WEBVTT";
const Y4M_MAGIC: &[u8] = b"YUV4MPEG2";
const MPEG_TS_SYNC: u8 = 0x47; const TS_PACKET_SIZE: usize = 188;
pub fn probe_format(data: &[u8]) -> Result<ProbeResult, OxiError> {
if data.len() < 4 {
return Err(OxiError::UnknownFormat);
}
if data.starts_with(MATROSKA_MAGIC) {
return Ok(ProbeResult {
format: ContainerFormat::Matroska,
confidence: 0.95,
});
}
if data.starts_with(OGG_MAGIC) {
return Ok(ProbeResult {
format: ContainerFormat::Ogg,
confidence: 0.99,
});
}
if data.len() >= Y4M_MAGIC.len() && data.starts_with(Y4M_MAGIC) {
return Ok(ProbeResult {
format: ContainerFormat::Y4m,
confidence: 0.99,
});
}
if data.starts_with(FLAC_MAGIC) {
return Ok(ProbeResult {
format: ContainerFormat::Flac,
confidence: 0.99,
});
}
if data.len() >= 12 && data.starts_with(RIFF_MAGIC) && &data[8..12] == WAVE_MAGIC {
return Ok(ProbeResult {
format: ContainerFormat::Wav,
confidence: 0.99,
});
}
if data.len() >= 8 && &data[4..8] == ISOBMFF_FTYP {
return Ok(ProbeResult {
format: ContainerFormat::Mp4,
confidence: 0.90,
});
}
if data.len() >= TS_PACKET_SIZE * 2 {
let mut sync_count = 0;
let max_checks = (data.len() / TS_PACKET_SIZE).min(3);
for i in 0..max_checks {
if data[i * TS_PACKET_SIZE] == MPEG_TS_SYNC {
sync_count += 1;
} else {
break;
}
}
if sync_count >= 2 {
return Ok(ProbeResult {
format: ContainerFormat::MpegTs,
confidence: 0.95,
});
}
} else if data.len() >= TS_PACKET_SIZE && data[0] == MPEG_TS_SYNC {
return Ok(ProbeResult {
format: ContainerFormat::MpegTs,
confidence: 0.60,
});
}
if data.starts_with(WEBVTT_MAGIC) {
return Ok(ProbeResult {
format: ContainerFormat::WebVtt,
confidence: 0.99,
});
}
if data.len() >= 20 {
if let Ok(text) = std::str::from_utf8(&data[..data.len().min(100)]) {
let lines: Vec<&str> = text.lines().take(3).collect();
if lines.len() >= 2
&& lines[0].trim().chars().all(|c| c.is_ascii_digit())
&& lines[1].contains("-->")
&& lines[1].contains(',')
{
return Ok(ProbeResult {
format: ContainerFormat::Srt,
confidence: 0.85,
});
}
}
}
Err(OxiError::UnknownFormat)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_probe_matroska() {
let data = [0x1A, 0x45, 0xDF, 0xA3, 0x01, 0x00, 0x00, 0x00];
let result = probe_format(&data).expect("operation should succeed");
assert_eq!(result.format, ContainerFormat::Matroska);
assert!(result.confidence > 0.9);
}
#[test]
fn test_probe_ogg() {
let data = b"OggS\x00\x02\x00\x00\x00\x00\x00\x00";
let result = probe_format(data).expect("operation should succeed");
assert_eq!(result.format, ContainerFormat::Ogg);
}
#[test]
fn test_probe_flac() {
let data = b"fLaC\x00\x00\x00\x22";
let result = probe_format(data).expect("operation should succeed");
assert_eq!(result.format, ContainerFormat::Flac);
}
#[test]
fn test_probe_wav() {
let data = b"RIFF\x00\x00\x00\x00WAVEfmt ";
let result = probe_format(data).expect("operation should succeed");
assert_eq!(result.format, ContainerFormat::Wav);
}
#[test]
fn test_probe_unknown() {
let data = [0x00, 0x00, 0x00, 0x00];
assert!(probe_format(&data).is_err());
}
#[test]
fn test_probe_too_short() {
let data = [0x1A, 0x45];
assert!(probe_format(&data).is_err());
}
}