Skip to main content

oxideav_core/registry/
container.rs

1//! Container traits (demuxer + muxer) and a registry.
2//!
3//! This module defines the abstract [`Demuxer`] / [`Muxer`] traits that
4//! every container implementation (oxideav-mp4, oxideav-mkv,
5//! oxideav-flac, oxideav-ogg, …) fulfils, plus a
6//! [`ContainerRegistry`] that consumers of the framework use to pick a
7//! demuxer by probe bytes or filename hint.
8
9use std::collections::HashMap;
10use std::io::{Read, Seek, SeekFrom, Write};
11
12use crate::{CodecResolver, Error, Packet, Result, StreamInfo};
13
14// ───────────────────────── traits ─────────────────────────
15
16/// Reads a container and emits packets per stream.
17pub trait Demuxer: Send {
18    /// Name of the container format (e.g., `"wav"`).
19    fn format_name(&self) -> &str;
20
21    /// Streams in this container. Stable across the lifetime of the demuxer.
22    fn streams(&self) -> &[StreamInfo];
23
24    /// Read the next packet from any stream. Returns `Error::Eof` at end.
25    fn next_packet(&mut self) -> Result<Packet>;
26
27    /// Hint that only the listed stream indices will be consumed by the
28    /// pipeline. Demuxers that can efficiently skip inactive streams at
29    /// the container level (e.g., MKV cluster-aware, MP4 trak-aware)
30    /// should override this. The default is a no-op — the pipeline
31    /// drops unwanted packets on the floor.
32    fn set_active_streams(&mut self, _indices: &[u32]) {}
33
34    /// Seek to the nearest keyframe at or before `pts` (in the given
35    /// stream's time base). Returns the actual timestamp seeked to, or
36    /// `Error::Unsupported` if this demuxer can't seek.
37    fn seek_to(&mut self, _stream_index: u32, _pts: i64) -> Result<i64> {
38        Err(Error::unsupported("this demuxer does not support seeking"))
39    }
40
41    /// Container-level metadata as ordered (key, value) pairs.
42    /// Keys follow a loose convention borrowed from Vorbis comments:
43    /// `title`, `artist`, `album`, `comment`, `date`, `sample_name:<n>`,
44    /// `channels`, `n_patterns`, etc. Demuxers that carry no metadata
45    /// return an empty slice (the default).
46    fn metadata(&self) -> &[(String, String)] {
47        &[]
48    }
49    /// Container-level duration, if known. Default is `None` — callers
50    /// may fall back to the longest per-stream duration. Expressed as
51    /// microseconds for portability; convert to seconds at the edge.
52    fn duration_micros(&self) -> Option<i64> {
53        None
54    }
55
56    /// Attached pictures (cover art, artist photos, ...) embedded in
57    /// the container. Returns an empty slice (the default) when the
58    /// container carries none or doesn't support them. Containers that
59    /// do — ID3v2 on MP3, `METADATA_BLOCK_PICTURE` on FLAC, `covr`
60    /// atoms on MP4, etc. — override this to expose the images.
61    fn attached_pictures(&self) -> &[crate::AttachedPicture] {
62        &[]
63    }
64}
65
66/// Writes packets into a container.
67pub trait Muxer: Send {
68    fn format_name(&self) -> &str;
69
70    /// Write the container header. Must be called after stream configuration
71    /// and before the first `write_packet`.
72    fn write_header(&mut self) -> Result<()>;
73
74    fn write_packet(&mut self, packet: &Packet) -> Result<()>;
75
76    /// Finalize the file (write index, patch in total sizes, etc.).
77    fn write_trailer(&mut self) -> Result<()>;
78}
79
80/// Factory that tries to open a stream as a particular container format.
81///
82/// Implementations should read the minimum needed to confirm the format and
83/// return `Error::InvalidData` if the stream is not in this format.
84///
85/// The `codecs` parameter carries a resolver that converts container-
86/// level codec tags (FourCCs, WAVEFORMATEX wFormatTag, Matroska
87/// CodecIDs, …) into [`CodecId`](crate::CodecId) values.
88pub type OpenDemuxerFn =
89    fn(input: Box<dyn ReadSeek>, codecs: &dyn CodecResolver) -> Result<Box<dyn Demuxer>>;
90
91/// Factory that creates a muxer for a set of streams.
92pub type OpenMuxerFn =
93    fn(output: Box<dyn WriteSeek>, streams: &[StreamInfo]) -> Result<Box<dyn Muxer>>;
94
95/// Information passed to a content-based [`ProbeFn`].
96///
97/// `buf` holds the first few KB of the input — enough to recognise the
98/// magic bytes of any container we know about. `ext` carries the file
99/// extension as a hint (lowercase, no leading dot); some containers
100/// (raw MP3 with no ID3v2, headerless tracker formats) need it to break
101/// ties with otherwise weak signatures.
102pub struct ProbeData<'a> {
103    pub buf: &'a [u8],
104    pub ext: Option<&'a str>,
105}
106
107/// Confidence score returned by a [`ProbeFn`]. `0` means no match.
108/// Higher means more certain. Conventional values:
109///
110/// * `100` – unambiguous magic bytes at a known offset
111/// * `75`  – signature match corroborated by file extension
112/// * `50`  – signature match without extension corroboration
113/// * `25`  – extension match only (no content signature available)
114pub type ProbeScore = u8;
115
116/// Maximum probe score (alias for `100`).
117pub const MAX_PROBE_SCORE: ProbeScore = 100;
118/// Default score returned when only the file extension matches.
119pub const PROBE_SCORE_EXTENSION: ProbeScore = 25;
120
121/// Content-based format detection function.
122///
123/// Returns a [`ProbeScore`] in `0..=100`. Implementations should be
124/// pure (no I/O, no allocation beyond the stack) and fast — they may
125/// be invoked once per registered demuxer on every input file.
126pub type ContainerProbeFn = fn(probe: &ProbeData) -> ProbeScore;
127
128/// Convenience trait bundle for seekable readers.
129pub trait ReadSeek: Read + Seek + Send {}
130impl<T: Read + Seek + Send> ReadSeek for T {}
131
132/// Convenience trait bundle for seekable writers.
133pub trait WriteSeek: Write + Seek + Send {}
134impl<T: Write + Seek + Send> WriteSeek for T {}
135
136// ───────────────────────── ContainerRegistry ─────────────────────────
137
138#[derive(Default)]
139pub struct ContainerRegistry {
140    demuxers: HashMap<String, OpenDemuxerFn>,
141    muxers: HashMap<String, OpenMuxerFn>,
142    /// Lowercase file extension → container name (e.g. "wav" → "wav").
143    extensions: HashMap<String, String>,
144    /// Container name → content-probe function. Optional — containers
145    /// without a probe still work but require an extension hint or an
146    /// explicit format name.
147    probes: HashMap<String, ContainerProbeFn>,
148}
149
150impl ContainerRegistry {
151    pub fn new() -> Self {
152        Self::default()
153    }
154
155    pub fn register_demuxer(&mut self, name: &str, open: OpenDemuxerFn) {
156        self.demuxers.insert(name.to_owned(), open);
157    }
158
159    pub fn register_muxer(&mut self, name: &str, open: OpenMuxerFn) {
160        self.muxers.insert(name.to_owned(), open);
161    }
162
163    pub fn register_extension(&mut self, ext: &str, container_name: &str) {
164        self.extensions
165            .insert(ext.to_lowercase(), container_name.to_owned());
166    }
167
168    /// Attach a content-based probe to a registered demuxer. Called by
169    /// the registry's [`probe_input`](Self::probe_input) to detect the
170    /// container format from the first few KB of an input stream.
171    pub fn register_probe(&mut self, container_name: &str, probe: ContainerProbeFn) {
172        self.probes.insert(container_name.to_owned(), probe);
173    }
174
175    pub fn demuxer_names(&self) -> impl Iterator<Item = &str> {
176        self.demuxers.keys().map(|s| s.as_str())
177    }
178
179    pub fn muxer_names(&self) -> impl Iterator<Item = &str> {
180        self.muxers.keys().map(|s| s.as_str())
181    }
182
183    /// Open a demuxer explicitly by format name. The `codecs` resolver
184    /// is passed through to the demuxer so it can translate the
185    /// container's in-stream codec tags (FourCCs / wFormatTag /
186    /// Matroska CodecIDs) into [`CodecId`](crate::CodecId)
187    /// values. Demuxers that don't need tag resolution can ignore it.
188    pub fn open_demuxer(
189        &self,
190        name: &str,
191        input: Box<dyn ReadSeek>,
192        codecs: &dyn CodecResolver,
193    ) -> Result<Box<dyn Demuxer>> {
194        let open = self
195            .demuxers
196            .get(name)
197            .ok_or_else(|| Error::FormatNotFound(name.to_owned()))?;
198        open(input, codecs)
199    }
200
201    /// Open a muxer by format name.
202    pub fn open_muxer(
203        &self,
204        name: &str,
205        output: Box<dyn WriteSeek>,
206        streams: &[StreamInfo],
207    ) -> Result<Box<dyn Muxer>> {
208        let open = self
209            .muxers
210            .get(name)
211            .ok_or_else(|| Error::FormatNotFound(name.to_owned()))?;
212        open(output, streams)
213    }
214
215    /// Look up a container name from a file extension (no leading dot).
216    pub fn container_for_extension(&self, ext: &str) -> Option<&str> {
217        self.extensions.get(&ext.to_lowercase()).map(|s| s.as_str())
218    }
219
220    /// Detect the container format by reading the first ~256 KB of the
221    /// input, scoring each registered probe, and returning the highest-
222    /// scoring container's name. The extension is passed to probes as a
223    /// hint — they may use it to break ties when their signature is weak.
224    ///
225    /// Falls back to the extension table if no probe scores above zero.
226    /// The input cursor is restored to its starting position on success
227    /// and on the I/O failure paths that allow it.
228    pub fn probe_input(&self, input: &mut dyn ReadSeek, ext_hint: Option<&str>) -> Result<String> {
229        const PROBE_BUF_SIZE: usize = 256 * 1024;
230
231        let saved_pos = input.stream_position()?;
232        input.seek(SeekFrom::Start(0))?;
233        let mut buf = vec![0u8; PROBE_BUF_SIZE];
234        let mut got = 0;
235        while got < buf.len() {
236            match input.read(&mut buf[got..]) {
237                Ok(0) => break,
238                Ok(n) => got += n,
239                Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
240                Err(e) => {
241                    let _ = input.seek(SeekFrom::Start(saved_pos));
242                    return Err(e.into());
243                }
244            }
245        }
246        buf.truncate(got);
247        input.seek(SeekFrom::Start(saved_pos))?;
248
249        let ext_lower = ext_hint.map(|s| s.to_ascii_lowercase());
250        let probe_data = ProbeData {
251            buf: &buf,
252            ext: ext_lower.as_deref(),
253        };
254
255        let mut best: Option<(&str, ProbeScore)> = None;
256        for (name, probe) in &self.probes {
257            let score = probe(&probe_data);
258            if score == 0 {
259                continue;
260            }
261            match best {
262                Some((_, prev)) if score <= prev => {}
263                _ => best = Some((name.as_str(), score)),
264            }
265        }
266        if let Some((name, _)) = best {
267            return Ok(name.to_owned());
268        }
269
270        // Fall back to extension lookup with the conventional weak score.
271        if let Some(ext) = ext_hint {
272            if let Some(name) = self.container_for_extension(ext) {
273                let _ = PROBE_SCORE_EXTENSION; // export retained for symmetry
274                return Ok(name.to_owned());
275            }
276        }
277
278        Err(Error::FormatNotFound(
279            "no registered demuxer recognises this input".into(),
280        ))
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    struct DummyDemuxer;
289
290    impl Demuxer for DummyDemuxer {
291        fn format_name(&self) -> &str {
292            "dummy"
293        }
294        fn streams(&self) -> &[StreamInfo] {
295            &[]
296        }
297        fn next_packet(&mut self) -> Result<Packet> {
298            Err(Error::Eof)
299        }
300    }
301
302    #[test]
303    fn default_seek_to_is_unsupported() {
304        let mut d = DummyDemuxer;
305        match d.seek_to(0, 0) {
306            Err(Error::Unsupported(_)) => {}
307            other => panic!(
308                "expected default seek_to to return Unsupported, got {:?}",
309                other
310            ),
311        }
312    }
313}