Skip to main content

media_format_matroska/
demuxer.rs

1//! Matroska/WebM demuxer
2
3use std::{io::SeekFrom, num::NonZeroU32};
4
5#[cfg(feature = "video")]
6use media_codec_h264::avcc::Avcc;
7#[cfg(feature = "video")]
8use media_codec_h265::hvcc::Hvcc;
9#[cfg(any(feature = "audio", feature = "video"))]
10use media_codec_types::decoder::ExtraData;
11#[cfg(feature = "audio")]
12use media_codec_types::AudioParameters;
13#[cfg(feature = "video")]
14use media_codec_types::VideoParameters;
15use media_codec_types::{
16    decoder::DecoderParameters,
17    packet::{Packet, PacketFlags},
18    CodecID, CodecParameters,
19};
20#[cfg(feature = "audio")]
21use media_core::audio::ChannelLayout;
22use media_core::{
23    not_found_error,
24    rational::Rational64,
25    read_failed_error,
26    time::{MSEC_PER_SEC, NSEC_PER_SEC, USEC_PER_SEC},
27    unsupported_error,
28    variant::Variant,
29    MediaType, Result,
30};
31use media_format_types::{
32    demuxer::{Demuxer, DemuxerBuilder, DemuxerState, Reader, SeekFlags},
33    stream::Stream,
34    track::Track,
35    Format,
36};
37use mkv_element::{
38    io::blocking_impl::{ReadElement, ReadFrom},
39    prelude::{
40        BlockGroup, Cluster, Cues, Ebml, Element, Header, Info, Position, PrevSize, SeekHead, Segment, SimpleBlock, Tags, Timestamp, TimestampScale,
41        TrackEntry, Tracks, VInt64, Void,
42    },
43    ClusterBlock, FrameData,
44};
45
46const MAX_ELEMENTS: u32 = 8192;
47const MAX_UNKNOWN_ELEMENTS: u32 = 8;
48const MAX_ELEMENT_SIZE: u64 = 256 * 1024 * 1024;
49const MAX_HEADER_ELEMENTS: u32 = 256;
50const MAX_SKIPS: u32 = 8;
51const SEEK_LOOKAHEAD_SEC: i64 = 10;
52
53/// Document type for Matroska container formats
54#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
55pub enum DocType {
56    #[default]
57    Matroska,
58    WebM,
59}
60
61impl DocType {
62    fn from_doc_type(doc_type: &str) -> Self {
63        match doc_type.to_lowercase().as_str() {
64            "webm" => DocType::WebM,
65            _ => DocType::Matroska,
66        }
67    }
68}
69
70/// Matroska/WebM demuxer implementation
71pub struct MkvDemuxer {
72    /// Document type (.mkv or .webm)
73    doc_type: DocType,
74    /// Time base in seconds
75    time_base: Rational64,
76    /// Duration in timestamp units
77    duration: Option<f64>,
78    /// Cues (index table) for fast seeking
79    cues: Option<Cues>,
80    /// Segment data start position (for calculating absolute positions from cue
81    /// positions)
82    segment_data_position: u64,
83    /// Current cluster being read
84    current_cluster: Option<Cluster>,
85    /// Current frame index within the current cluster
86    current_frame_index: usize,
87    /// Current lace index within the current frame (for laced frames with
88    /// multiple sub-frames)
89    current_lace_index: usize,
90}
91
92impl Default for MkvDemuxer {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl MkvDemuxer {
99    /// Creates a new Matroska demuxer
100    pub fn new() -> Self {
101        Self {
102            doc_type: DocType::default(),
103            time_base: Rational64::new(TimestampScale::default().0 as i64, NSEC_PER_SEC),
104            duration: None,
105            cues: None,
106            segment_data_position: 0,
107            current_cluster: None,
108            current_frame_index: 0,
109            current_lace_index: 0,
110        }
111    }
112
113    /// Codec ID mapping based on:
114    /// - Matroska Specs: https://www.matroska.org/technical/codec_specs.html
115    /// - IETF Draft: https://datatracker.ietf.org/doc/draft-ietf-cellar-codec/
116    ///
117    /// Note: `A_EAC3` and `V_MPEGI/ISO/VVC` are widely used in practice but
118    /// not yet in the official specification.
119    fn str_to_codec_id(codec_id_str: &str, media_type: MediaType, doc_type: DocType) -> Option<CodecID> {
120        match (media_type, doc_type) {
121            #[cfg(feature = "video")]
122            (MediaType::Video, DocType::WebM) => match codec_id_str {
123                "V_VP8" => Some(CodecID::VP8),
124                "V_VP9" => Some(CodecID::VP9),
125                "V_AV1" => Some(CodecID::AV1),
126                _ => None,
127            },
128            #[cfg(feature = "video")]
129            (MediaType::Video, DocType::Matroska) => match codec_id_str {
130                "V_MPEG4/ISO/AVC" => Some(CodecID::H264),
131                "V_MPEGH/ISO/HEVC" => Some(CodecID::HEVC),
132                "V_MPEGI/ISO/VVC" => Some(CodecID::VVC),
133                "V_VP8" => Some(CodecID::VP8),
134                "V_VP9" => Some(CodecID::VP9),
135                "V_AV1" => Some(CodecID::AV1),
136                "V_MPEG1" => Some(CodecID::MPEG1),
137                "V_MPEG2" => Some(CodecID::MPEG2),
138                "V_MPEG4/ISO/SP" | "V_MPEG4/ISO/ASP" | "V_MPEG4/ISO/AP" => Some(CodecID::MPEG4),
139                _ => None,
140            },
141            #[cfg(feature = "audio")]
142            (MediaType::Audio, DocType::WebM) => match codec_id_str {
143                "A_VORBIS" => Some(CodecID::VORBIS),
144                "A_OPUS" => Some(CodecID::OPUS),
145                _ => None,
146            },
147            #[cfg(feature = "audio")]
148            (MediaType::Audio, DocType::Matroska) => match codec_id_str {
149                "A_AAC" | "A_AAC/MPEG2/MAIN" | "A_AAC/MPEG2/LC" | "A_AAC/MPEG2/LC/SBR" | "A_AAC/MPEG2/SSR" | "A_AAC/MPEG4/MAIN" |
150                "A_AAC/MPEG4/LC" | "A_AAC/MPEG4/LC/SBR" | "A_AAC/MPEG4/SSR" | "A_AAC/MPEG4/LTP" => Some(CodecID::AAC),
151                "A_OPUS" => Some(CodecID::OPUS),
152                "A_VORBIS" => Some(CodecID::VORBIS),
153                "A_FLAC" => Some(CodecID::FLAC),
154                "A_AC3" | "A_AC3/BSID9" | "A_AC3/BSID10" => Some(CodecID::AC3),
155                "A_EAC3" => Some(CodecID::EAC3),
156                "A_DTS" | "A_DTS/EXPRESS" | "A_DTS/LOSSLESS" => Some(CodecID::DTS),
157                "A_MPEG/L1" => Some(CodecID::MP1),
158                "A_MPEG/L2" => Some(CodecID::MP2),
159                "A_MPEG/L3" => Some(CodecID::MP3),
160                _ => None,
161            },
162            #[allow(unreachable_patterns)]
163            _ => None,
164        }
165    }
166
167    /// Convert Matroska track type to MediaType
168    fn track_type_to_media_type(track_type: u64) -> Option<MediaType> {
169        match track_type {
170            #[cfg(feature = "video")]
171            1 => Some(MediaType::Video),
172            #[cfg(feature = "audio")]
173            2 => Some(MediaType::Audio),
174            _ => None,
175        }
176    }
177
178    /// Create video parameters from track entry
179    #[cfg(feature = "video")]
180    fn make_video_params(track: &TrackEntry) -> Option<VideoParameters> {
181        let video = track.video.as_ref()?;
182
183        Some(VideoParameters {
184            width: NonZeroU32::new(video.pixel_width.0 as u32),
185            height: NonZeroU32::new(video.pixel_height.0 as u32),
186            ..Default::default()
187        })
188    }
189
190    /// Create audio parameters from track entry
191    #[cfg(feature = "audio")]
192    fn make_audio_params(track: &TrackEntry) -> Option<AudioParameters> {
193        let audio = track.audio.as_ref()?;
194
195        Some(AudioParameters {
196            sample_rate: NonZeroU32::new(audio.sampling_frequency.0 as u32),
197            channel_layout: ChannelLayout::default_from_channels(audio.channels.0 as u8).ok(),
198            ..Default::default()
199        })
200    }
201
202    /// Create decoder parameters from track entry (codec private data)
203    #[cfg(feature = "video")]
204    fn make_video_decoder_params(track: &TrackEntry) -> DecoderParameters {
205        let codec_id_str: &str = &track.codec_id.0;
206        let codec_private = track.codec_private.as_ref().map(|cp| cp.0.as_ref());
207
208        match codec_id_str {
209            "V_MPEG4/ISO/AVC" => {
210                if let Some(data) = codec_private {
211                    if let Some(extra) = Self::parse_avcc(data) {
212                        return DecoderParameters {
213                            extra_data: Some(extra),
214                            ..Default::default()
215                        };
216                    }
217                }
218                DecoderParameters::default()
219            }
220            "V_MPEGH/ISO/HEVC" => {
221                if let Some(data) = codec_private {
222                    if let Some(extra) = Self::parse_hvcc(data) {
223                        return DecoderParameters {
224                            extra_data: Some(extra),
225                            ..Default::default()
226                        };
227                    }
228                }
229                DecoderParameters::default()
230            }
231            _ => {
232                if let Some(data) = codec_private {
233                    DecoderParameters {
234                        extra_data: Some(ExtraData::Raw(data.to_vec())),
235                        ..Default::default()
236                    }
237                } else {
238                    DecoderParameters::default()
239                }
240            }
241        }
242    }
243
244    #[cfg(feature = "video")]
245    fn parse_avcc(data: &[u8]) -> Option<ExtraData> {
246        let avcc = Avcc::parse(data).ok()?;
247
248        Some(ExtraData::AVC {
249            sps: avcc.sequence_parameter_sets,
250            pps: avcc.picture_parameter_sets,
251            nalu_length_size: avcc.length_size,
252        })
253    }
254
255    #[cfg(feature = "video")]
256    fn parse_hvcc(data: &[u8]) -> Option<ExtraData> {
257        let hvcc = Hvcc::parse(data).ok()?;
258
259        let vps: Vec<Vec<u8>> = hvcc.vps().map(|v| v.to_vec()).collect();
260        let sps: Vec<Vec<u8>> = hvcc.sps().map(|v| v.to_vec()).collect();
261        let pps: Vec<Vec<u8>> = hvcc.pps().map(|v| v.to_vec()).collect();
262
263        Some(ExtraData::HEVC {
264            vps: if vps.is_empty() {
265                None
266            } else {
267                Some(vps)
268            },
269            sps,
270            pps,
271            nalu_length_size: hvcc.length_size,
272        })
273    }
274
275    /// Convert track entry to codec parameters
276    fn track_to_params(track: &TrackEntry, doc_type: DocType) -> Option<(CodecID, CodecParameters)> {
277        let codec_id_str: &str = &track.codec_id.0;
278        let media_type = Self::track_type_to_media_type(track.track_type.0)?;
279        let codec_id = Self::str_to_codec_id(codec_id_str, media_type, doc_type)?;
280
281        match media_type {
282            #[cfg(feature = "video")]
283            MediaType::Video => {
284                let video_params = Self::make_video_params(track)?;
285                let decoder_params = Self::make_video_decoder_params(track);
286                Some((codec_id, CodecParameters::new(video_params, decoder_params)))
287            }
288            #[cfg(feature = "audio")]
289            MediaType::Audio => {
290                let audio_params = Self::make_audio_params(track)?;
291                let decoder_params = if let Some(ref cp) = track.codec_private {
292                    DecoderParameters {
293                        extra_data: Some(ExtraData::Raw(cp.0.to_vec())),
294                        ..Default::default()
295                    }
296                } else {
297                    DecoderParameters::default()
298                };
299                Some((codec_id, CodecParameters::new(audio_params, decoder_params)))
300            }
301            #[allow(unreachable_patterns)]
302            _ => None,
303        }
304    }
305
306    /// Find Cues position from SeekHead entries
307    fn find_cues_position(seek_heads: &[SeekHead], segment_data_position: u64) -> Option<u64> {
308        for seek_head in seek_heads {
309            for seek in &seek_head.seek {
310                // Parse the SeekID to get the element ID
311                let mut id_bytes = &seek.seek_id[..];
312                if let Ok(element_id) = VInt64::read_from(&mut id_bytes) {
313                    // Check if this is a Cues element
314                    if element_id == Cues::ID {
315                        // SeekPosition is relative to segment data start
316                        return Some(*seek.seek_position + segment_data_position);
317                    }
318                }
319            }
320        }
321        None
322    }
323
324    /// Read an element body into a Vec<u8> for decoding
325    fn read_element_body(reader: &mut dyn Reader, size: u64) -> Result<Vec<u8>> {
326        let mut buf = vec![0u8; size as usize];
327        reader.read_exact(&mut buf)?;
328        Ok(buf)
329    }
330
331    /// Read the next cluster from the reader
332    /// Handles both known-size and unknown-size clusters
333    fn read_next_cluster(&mut self, reader: &mut dyn Reader) -> Result<bool> {
334        // Safety counter to prevent infinite loops
335        let mut skip_count = 0;
336
337        loop {
338            // Read cluster header
339            let header = match Header::read_from(reader) {
340                Ok(h) => h,
341                Err(_) => return Ok(false), // End of file
342            };
343
344            // Handle non-cluster elements between clusters
345            if header.id == Cluster::ID {
346                // Check if this cluster has unknown size
347                if header.size.is_unknown {
348                    // Unknown size cluster - need to parse element by element
349                    return self.read_unknown_size_cluster(reader);
350                }
351
352                // Known size cluster - use normal read_element
353                let cluster = match Cluster::read_element(&header, reader) {
354                    Ok(c) => c,
355                    Err(e) => {
356                        // Try to skip failed cluster if size is known
357                        if header.size.value > 0 {
358                            let _ = reader.seek(SeekFrom::Current(header.size.value as i64));
359                            skip_count += 1;
360                            if skip_count >= MAX_SKIPS {
361                                return Ok(false);
362                            }
363                            continue;
364                        }
365                        return Err(read_failed_error!("cluster", e));
366                    }
367                };
368
369                self.current_frame_index = 0;
370                self.current_lace_index = 0;
371                self.current_cluster = Some(cluster);
372                return Ok(true);
373            }
374
375            // Skip non-cluster element
376            let size = header.size.value;
377
378            // Check for unknown size - can't skip these reliably
379            if header.size.is_unknown {
380                // Can't skip unknown size, assume end of readable data
381                return Ok(false);
382            }
383
384            // Sanity check: don't try to skip unreasonably large elements
385            // This prevents hangs from corrupted size values
386            if size > MAX_ELEMENT_SIZE {
387                return Ok(false);
388            }
389
390            // Skip the element
391            if size > 0 && reader.seek(SeekFrom::Current(size as i64)).is_err() {
392                return Ok(false);
393            }
394
395            skip_count += 1;
396            if skip_count >= MAX_SKIPS {
397                return Ok(false);
398            }
399        }
400    }
401
402    /// Read unknown-size cluster
403    fn read_unknown_size_cluster(&mut self, reader: &mut dyn Reader) -> Result<bool> {
404        let mut timestamp: Option<u64> = None;
405        let mut blocks: Vec<ClusterBlock> = Vec::new();
406
407        // Safety counter
408        let mut element_count = 0;
409        let mut unknown_element_count = 0;
410
411        while let (Ok(pos_before_header), Ok(sub_header)) = (reader.stream_position(), Header::read_from(reader)) {
412            // Check for next top-level element
413            if sub_header.id == Cluster::ID {
414                // Hit next cluster - seek back to cluster start and exit
415                let _ = reader.seek(SeekFrom::Start(pos_before_header));
416                break;
417            }
418
419            // Check for other top-level elements that indicate end of cluster
420            let id_val = sub_header.id.value;
421            if id_val == Cues::ID.value || id_val == Tags::ID.value || id_val == Info::ID.value || id_val == Tracks::ID.value {
422                // Hit another top-level element - seek back and exit
423                let _ = reader.seek(SeekFrom::Start(pos_before_header));
424                break;
425            }
426
427            // Parse cluster sub-elements
428            let element_size = sub_header.size.value;
429
430            // Safety check for element size
431            if sub_header.size.is_unknown || element_size > MAX_ELEMENT_SIZE {
432                // Can't process unknown size or very large elements
433                break;
434            }
435
436            match sub_header.id {
437                Timestamp::ID => {
438                    // Read timestamp
439                    if let Ok(body) = Self::read_element_body(reader, element_size) {
440                        if let Ok(ts) = Timestamp::decode_body(&mut body.as_slice()) {
441                            timestamp = Some(ts.0);
442                        }
443                    } else {
444                        // Read failed, skip by seeking
445                        if reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
446                            break;
447                        }
448                    }
449                    unknown_element_count = 0;
450                }
451                SimpleBlock::ID => {
452                    // Read SimpleBlock
453                    if let Ok(body) = Self::read_element_body(reader, element_size) {
454                        if let Ok(block) = SimpleBlock::decode_body(&mut body.as_slice()) {
455                            blocks.push(ClusterBlock::Simple(block));
456                        }
457                    } else {
458                        // Read failed, skip by seeking
459                        if reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
460                            break;
461                        }
462                    }
463                    unknown_element_count = 0;
464                }
465                BlockGroup::ID => {
466                    // Read BlockGroup
467                    if let Ok(body) = Self::read_element_body(reader, element_size) {
468                        if let Ok(block) = BlockGroup::decode_body(&mut body.as_slice()) {
469                            blocks.push(ClusterBlock::Group(block));
470                        }
471                    } else {
472                        // Read failed, skip by seeking
473                        if reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
474                            break;
475                        }
476                    }
477                    unknown_element_count = 0;
478                }
479                Position::ID | PrevSize::ID | Void::ID => {
480                    // Skip known but unimportant elements
481                    if element_size > 0 && reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
482                        break;
483                    }
484                    unknown_element_count = 0;
485                }
486                _ => {
487                    unknown_element_count += 1;
488                    if unknown_element_count >= MAX_UNKNOWN_ELEMENTS {
489                        // Too many unknown elements, may be corrupted
490                        break;
491                    }
492
493                    // Skip unknown element
494                    if element_size > 0 && reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
495                        break;
496                    }
497                }
498            }
499
500            element_count += 1;
501            if element_count >= MAX_ELEMENTS {
502                break;
503            }
504        }
505
506        // Must have at least timestamp to form a valid cluster
507        let timestamp = match timestamp {
508            Some(ts) => ts,
509            None => return Ok(false),
510        };
511
512        // Build the cluster
513        let cluster = Cluster {
514            crc32: None,
515            void: None,
516            timestamp: Timestamp(timestamp),
517            position: None,
518            prev_size: None,
519            blocks,
520        };
521
522        self.current_frame_index = 0;
523        self.current_lace_index = 0;
524        self.current_cluster = Some(cluster);
525        Ok(true)
526    }
527}
528
529impl Format for MkvDemuxer {
530    fn set_option(&mut self, _key: &str, _value: &Variant) -> Result<()> {
531        Ok(())
532    }
533}
534
535impl Demuxer for MkvDemuxer {
536    fn read_header(&mut self, reader: &mut dyn Reader, state: &mut DemuxerState) -> Result<()> {
537        // Read EBML header and extract doc_type
538        let ebml = Ebml::read_from(reader).map_err(|e| read_failed_error!("EBML header", e))?;
539
540        // Cache the document type (mkv or webm)
541        self.doc_type = ebml.doc_type.as_ref().map(|dt| DocType::from_doc_type(&dt.0)).unwrap_or(DocType::Matroska);
542
543        // Safety limit for header reading
544        let mut element_count = 0;
545
546        // Read top-level elements until Segment is found
547        loop {
548            let header = match Header::read_from(reader) {
549                Ok(h) => h,
550                Err(e) => {
551                    return Err(read_failed_error!("header", e));
552                }
553            };
554
555            if header.id == Segment::ID {
556                // Record segment data start position
557                self.segment_data_position = reader.stream_position()?;
558
559                // Read sub-elements until Cluster is found
560                let mut info: Option<Info> = None;
561                let mut tracks: Option<Tracks> = None;
562                let mut cues: Option<Cues> = None;
563                let mut seek_head: Vec<SeekHead> = Vec::new();
564
565                loop {
566                    let current_pos = reader.stream_position()?;
567
568                    let sub_header = match Header::read_from(reader) {
569                        Ok(h) => h,
570                        Err(_) => break, // End of file or segment
571                    };
572
573                    if sub_header.id == Info::ID {
574                        info = Some(Info::read_element(&sub_header, reader).map_err(|e| read_failed_error!(e.to_string()))?);
575                    } else if sub_header.id == Tracks::ID {
576                        tracks = Some(Tracks::read_element(&sub_header, reader).map_err(|e| read_failed_error!(e.to_string()))?);
577                    } else if sub_header.id == Cues::ID {
578                        cues = Some(Cues::read_element(&sub_header, reader).map_err(|e| read_failed_error!(e.to_string()))?);
579                    } else if sub_header.id == SeekHead::ID {
580                        // Parse SeekHead to get positions of other elements
581                        if let Ok(sh) = SeekHead::read_element(&sub_header, reader) {
582                            seek_head.push(sh);
583                        }
584                    } else if sub_header.id == Cluster::ID {
585                        // Try to locate Cues using SeekHead
586                        if cues.is_none() {
587                            if let Some(cues_pos) = Self::find_cues_position(&seek_head, self.segment_data_position) {
588                                if cues_pos > current_pos && reader.seek(SeekFrom::Start(cues_pos)).is_ok() {
589                                    if let Ok(cues_header) = Header::read_from(reader) {
590                                        if cues_header.id == Cues::ID {
591                                            if let Ok(c) = Cues::read_element(&cues_header, reader) {
592                                                cues = Some(c);
593                                            }
594                                        }
595                                    }
596                                }
597                            }
598                        }
599
600                        // Seek back to cluster start for later reading
601                        reader.seek(SeekFrom::Start(current_pos))?;
602                        break;
603                    } else {
604                        // Skip unknown elements
605                        let size = sub_header.size.value;
606
607                        // Check for unknown size or unreasonably large size
608                        if sub_header.size.is_unknown || size > MAX_ELEMENT_SIZE {
609                            // Can't reliably skip - stop header reading
610                            break;
611                        }
612
613                        if size > 0 && reader.seek(SeekFrom::Current(size as i64)).is_err() {
614                            break;
615                        }
616                    }
617
618                    element_count += 1;
619                    if element_count >= MAX_HEADER_ELEMENTS {
620                        break;
621                    }
622                }
623
624                // Validate required elements
625                let info = info.ok_or_else(|| not_found_error!("info element"))?;
626
627                let timestamp_scale = info.timestamp_scale.0;
628
629                // Store segment info
630                self.time_base = Rational64::new(info.timestamp_scale.0 as i64, NSEC_PER_SEC);
631                self.duration = info.duration.as_ref().map(|d| d.0);
632                self.cues = cues;
633
634                // Calculate duration in microseconds
635                if let Some(duration) = self.duration {
636                    state.duration = Some((duration * timestamp_scale as f64 / MSEC_PER_SEC as f64) as i64);
637                }
638
639                // Process tracks
640                if let Some(ref tracks) = tracks {
641                    let mut stream = Stream::new(0);
642
643                    for track_entry in &tracks.track_entry {
644                        let track_number = track_entry.track_number.0 as isize;
645
646                        if let Some((codec_id, params)) = Self::track_to_params(track_entry, self.doc_type) {
647                            let mut track = Track::new(track_number, codec_id, params, self.time_base);
648
649                            // Set track duration if available
650                            if let Some(duration) = self.duration {
651                                track.duration = Some(duration as i64);
652                            }
653
654                            stream.add_track(state.tracks.add_track(track));
655                        }
656                    }
657
658                    state.streams.add_stream(stream);
659                }
660
661                return Ok(());
662            } else {
663                // Skip unknown top-level elements
664                let size = header.size.value;
665                if size > 0 {
666                    reader.seek(SeekFrom::Current(size as i64))?;
667                }
668            }
669        }
670    }
671
672    fn read_packet(&mut self, reader: &mut dyn Reader, state: &DemuxerState) -> Result<Packet<'static>> {
673        loop {
674            if let Some(ref cluster) = self.current_cluster {
675                // Collect frames from current cluster starting at current_frame_index
676                let frames: Vec<_> = cluster.frames().filter_map(|r| r.ok()).collect();
677
678                while self.current_frame_index < frames.len() {
679                    let frame = &frames[self.current_frame_index];
680
681                    let track = state.tracks.find_track(frame.track_number as isize).ok_or_else(|| not_found_error!("track", frame.track_number))?;
682
683                    // Extract frame data from FrameData enum, handling laced frames
684                    let (frame_bytes, advance_frame): (&[u8], bool) = match &frame.data {
685                        FrameData::Single(data) => (data, true),
686                        FrameData::Multiple(data_vec) => {
687                            // For laced frames, return each sub-frame as a separate packet
688                            if self.current_lace_index < data_vec.len() {
689                                let data = data_vec[self.current_lace_index];
690                                self.current_lace_index += 1;
691                                // Only advance to next frame when all laced sub-frames are consumed
692                                let advance = self.current_lace_index >= data_vec.len();
693                                if advance {
694                                    self.current_lace_index = 0;
695                                }
696                                (data, advance)
697                            } else {
698                                // Should not happen, but reset and move to next frame
699                                self.current_lace_index = 0;
700                                self.current_frame_index += 1;
701                                continue;
702                            }
703                        }
704                    };
705
706                    if advance_frame {
707                        self.current_frame_index += 1;
708                    }
709
710                    let mut packet = Packet::from_buffer(track.pool.get_buffer_with_length(frame_bytes.len()));
711                    if let Some(buffer) = packet.data_mut() {
712                        buffer.copy_from_slice(frame_bytes);
713                    }
714
715                    packet.track_index = Some(track.index());
716                    packet.dts = Some(frame.timestamp);
717                    packet.pts = Some(frame.timestamp);
718                    packet.time_base = Some(self.time_base);
719
720                    if let Some(duration) = frame.duration {
721                        packet.duration = Some(duration.get() as i64);
722                    }
723
724                    // Set packet flags - only the first laced frame should be marked as keyframe
725                    packet.flags = if frame.is_keyframe &&
726                        (matches!(&frame.data, FrameData::Single(_)) ||
727                            self.current_lace_index == 1 ||
728                            (self.current_lace_index == 0 && advance_frame))
729                    {
730                        PacketFlags::Key
731                    } else {
732                        PacketFlags::empty()
733                    };
734
735                    return Ok(packet);
736                }
737            }
738
739            // Need to read the next cluster
740            if !self.read_next_cluster(reader)? {
741                return Err(not_found_error!("more packets"));
742            }
743        }
744    }
745
746    fn seek(&mut self, reader: &mut dyn Reader, state: &DemuxerState, track_index: Option<usize>, timestamp_us: i64, flags: SeekFlags) -> Result<()> {
747        // Convert timestamp to segment ticks
748        let target_ticks = Rational64::from_integer(timestamp_us) / (Rational64::from_integer(USEC_PER_SEC) * self.time_base);
749        let target_ticks = target_ticks.to_integer();
750
751        // Save current position for rollback
752        let saved_position = reader.stream_position()?;
753
754        // Determine target track
755        let target_track_index =
756            track_index.unwrap_or_else(|| state.tracks.into_iter().find(|t| t.media_type() == MediaType::Video).map(|t| t.index()).unwrap_or(0));
757
758        let target_track = state.tracks.get_track(target_track_index).ok_or_else(|| not_found_error!("track", target_track_index))?;
759
760        let target_track_number = target_track.id as u64;
761
762        // Use Cues for seeking
763        let Some(ref cues) = self.cues else {
764            return Err(unsupported_error!("seek without Cues index"));
765        };
766
767        // Save current cluster state for potential rollback on failure
768        let saved_cluster = self.current_cluster.take();
769        let saved_frame_index = self.current_frame_index;
770
771        // Collect all cue points for our target track
772        let mut track_cues: Vec<(i64, u64)> = Vec::new(); // (cue_time, cluster_pos)
773
774        for cue_point in &cues.cue_point {
775            let cue_time = cue_point.cue_time.0 as i64;
776            for track_pos in &cue_point.cue_track_positions {
777                if track_pos.cue_track.0 == target_track_number {
778                    let absolute_pos = self.segment_data_position + track_pos.cue_cluster_position.0;
779                    track_cues.push((cue_time, absolute_pos));
780                    break;
781                }
782            }
783        }
784
785        // Find the best cue point based on seek flags
786        let best_cue = if flags.contains(SeekFlags::BACKWARD) {
787            // BACKWARD: Find the largest cue_time <= target
788            track_cues.iter().filter(|(t, _)| *t <= target_ticks).max_by_key(|(t, _)| *t).copied()
789        } else {
790            // Default: Find the nearest cue point (before or after)
791            track_cues.iter().min_by_key(|(t, _)| (t - target_ticks).abs()).copied()
792        };
793
794        if let Some((_best_cue_time, cluster_pos)) = best_cue {
795            // Scan forward to find the frame closest to target
796            reader.seek(SeekFrom::Start(cluster_pos))?;
797
798            // Track the best frame before and after target
799            let mut best_before: Option<(u64, i64)> = None;
800            let mut best_after: Option<(u64, i64)> = None;
801
802            // For ANY mode, search any frame, otherwise search only key frame
803            let search_any = flags.contains(SeekFlags::ANY);
804
805            loop {
806                let cluster_start_pos = reader.stream_position()?;
807
808                // Try to read next cluster
809                if !self.read_next_cluster(reader)? {
810                    break; // End of file
811                }
812
813                let cluster_timestamp = self.current_cluster.as_ref().map(|c| c.timestamp.0 as i64).unwrap_or(0);
814
815                // Find frames for our target track in this cluster
816                if let Some(ref cluster) = self.current_cluster {
817                    for frame in cluster.frames().flatten() {
818                        // For ANY mode, consider all frames; otherwise only keyframes
819                        if frame.track_number == target_track_number && (search_any || frame.is_keyframe) {
820                            if frame.timestamp <= target_ticks {
821                                // Frame before or at target - update best_before
822                                if best_before.is_none() || frame.timestamp > best_before.as_ref().unwrap().1 {
823                                    best_before = Some((cluster_start_pos, frame.timestamp));
824                                }
825                            } else if best_after.is_none() {
826                                // First frame after target
827                                best_after = Some((cluster_start_pos, frame.timestamp));
828                            }
829                        }
830                    }
831                }
832
833                // Decide when to stop scanning
834                if flags.contains(SeekFlags::BACKWARD) {
835                    // For BACKWARD: stop when cluster timestamp is past target
836                    if cluster_timestamp > target_ticks {
837                        break;
838                    }
839                } else {
840                    // For default: stop after finding the first frame after target
841                    if best_after.is_some() {
842                        break;
843                    }
844                    // Safety: don't search too far past target
845                    let lookahead_ticks = (Rational64::from_integer(SEEK_LOOKAHEAD_SEC) / self.time_base).to_integer();
846                    if cluster_timestamp > target_ticks + lookahead_ticks {
847                        break;
848                    }
849                }
850            }
851
852            // Choose the best frame based on flags
853            let chosen = if flags.contains(SeekFlags::BACKWARD) {
854                best_before
855            } else {
856                // Find the nearest frame
857                match (best_before, best_after) {
858                    (Some(before), Some(after)) => {
859                        let dist_before = target_ticks - before.1;
860                        let dist_after = after.1 - target_ticks;
861                        if dist_before <= dist_after {
862                            Some(before)
863                        } else {
864                            Some(after)
865                        }
866                    }
867                    (Some(before), None) => Some(before),
868                    (None, Some(after)) => Some(after),
869                    (None, None) => None,
870                }
871            };
872
873            // Seek to the chosen cluster and set frame index
874            if let Some((chosen_cluster_pos, _chosen_ts)) = chosen {
875                reader.seek(SeekFrom::Start(chosen_cluster_pos))?;
876                self.current_frame_index = 0;
877                self.current_lace_index = 0;
878                self.current_cluster = None;
879
880                // Re-read the chosen cluster
881                self.read_next_cluster(reader)?;
882
883                return Ok(());
884            }
885        }
886
887        // Seek failed - restore previous state so read_packet can continue
888        let _ = reader.seek(SeekFrom::Start(saved_position));
889        self.current_cluster = saved_cluster;
890        self.current_frame_index = saved_frame_index;
891
892        Err(not_found_error!("seek position"))
893    }
894}
895
896/// Builder for Matroska/WebM demuxer
897pub struct MkvDemuxerBuilder;
898
899impl media_format_types::FormatBuilder for MkvDemuxerBuilder {
900    fn name(&self) -> &'static str {
901        "matroska"
902    }
903
904    fn extensions(&self) -> &[&'static str] {
905        &["mkv", "webm"]
906    }
907}
908
909impl DemuxerBuilder for MkvDemuxerBuilder {
910    fn new_demuxer(&self) -> Result<Box<dyn Demuxer>> {
911        Ok(Box::new(MkvDemuxer::new()))
912    }
913
914    fn probe(&self, reader: &mut dyn Reader) -> bool {
915        let mut buf = [0; 4];
916        reader.read_exact(&mut buf).ok();
917        buf == [0x1A, 0x45, 0xDF, 0xA3]
918    }
919}