Skip to main content

media_format_matroska/
demuxer.rs

1//! Matroska/WebM demuxer
2
3use std::{io::SeekFrom, num::NonZeroU32};
4
5#[cfg(feature = "video")]
6use media_codec_h264::avcc::Avcc;
7#[cfg(feature = "video")]
8use media_codec_h265::hvcc::Hvcc;
9#[cfg(feature = "video")]
10use media_codec_types::decoder::ExtraData;
11#[cfg(feature = "audio")]
12use media_codec_types::AudioParameters;
13#[cfg(feature = "video")]
14use media_codec_types::VideoParameters;
15use media_codec_types::{
16    decoder::DecoderParameters,
17    packet::{Packet, PacketFlags},
18    CodecID, CodecParameters,
19};
20#[cfg(feature = "audio")]
21use media_core::audio::ChannelLayout;
22use media_core::{
23    not_found_error,
24    rational::Rational64,
25    read_failed_error,
26    time::{MSEC_PER_SEC, NSEC_PER_SEC, USEC_PER_SEC},
27    unsupported_error,
28    variant::Variant,
29    MediaType, Result,
30};
31use media_format_types::{
32    demuxer::{Demuxer, DemuxerBuilder, DemuxerState, Reader, SeekFlags},
33    stream::Stream,
34    track::Track,
35    Format,
36};
37use mkv_element::{
38    io::blocking_impl::{ReadElement, ReadFrom},
39    prelude::{
40        BlockGroup, Cluster, Cues, Ebml, Element, Header, Info, Position, PrevSize, SeekHead, Segment, SimpleBlock, Tags, Timestamp, TimestampScale,
41        TrackEntry, Tracks, VInt64, Void,
42    },
43    ClusterBlock, FrameData,
44};
45
46const MAX_ELEMENTS: u32 = 8192;
47const MAX_UNKNOWN_ELEMENTS: u32 = 8;
48const MAX_ELEMENT_SIZE: u64 = 256 * 1024 * 1024;
49const MAX_HEADER_ELEMENTS: u32 = 256;
50const MAX_SKIPS: u32 = 8;
51const SEEK_LOOKAHEAD_SEC: i64 = 10;
52
53/// Document type for Matroska container formats
54#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
55pub enum DocType {
56    #[default]
57    Matroska,
58    WebM,
59}
60
61impl DocType {
62    fn from_doc_type(doc_type: &str) -> Self {
63        match doc_type.to_lowercase().as_str() {
64            "webm" => DocType::WebM,
65            _ => DocType::Matroska,
66        }
67    }
68}
69
70/// Matroska/WebM demuxer implementation
71pub struct MkvDemuxer {
72    /// Document type (.mkv or .webm)
73    doc_type: DocType,
74    /// Time base in seconds
75    time_base: Rational64,
76    /// Duration in timestamp units
77    duration: Option<f64>,
78    /// Cues (index table) for fast seeking
79    cues: Option<Cues>,
80    /// Segment data start position (for calculating absolute positions from cue
81    /// positions)
82    segment_data_position: u64,
83    /// Current cluster being read
84    current_cluster: Option<Cluster>,
85    /// Current frame index within the current cluster
86    current_frame_index: usize,
87    /// Current lace index within the current frame (for laced frames with
88    /// multiple sub-frames)
89    current_lace_index: usize,
90}
91
92impl Default for MkvDemuxer {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl MkvDemuxer {
99    /// Creates a new Matroska demuxer
100    pub fn new() -> Self {
101        Self {
102            doc_type: DocType::default(),
103            time_base: Rational64::new(TimestampScale::default().0 as i64, NSEC_PER_SEC),
104            duration: None,
105            cues: None,
106            segment_data_position: 0,
107            current_cluster: None,
108            current_frame_index: 0,
109            current_lace_index: 0,
110        }
111    }
112
113    /// Codec ID mapping based on:
114    /// - Matroska Specs: https://www.matroska.org/technical/codec_specs.html
115    /// - IETF Draft: https://datatracker.ietf.org/doc/draft-ietf-cellar-codec/
116    ///
117    /// Note: `A_EAC3` and `V_MPEGI/ISO/VVC` are widely used in practice but
118    /// not yet in the official specification.
119    fn str_to_codec_id(codec_id_str: &str, media_type: MediaType, doc_type: DocType) -> Option<CodecID> {
120        match (media_type, doc_type) {
121            #[cfg(feature = "video")]
122            (MediaType::Video, DocType::WebM) => match codec_id_str {
123                "V_VP8" => Some(CodecID::VP8),
124                "V_VP9" => Some(CodecID::VP9),
125                "V_AV1" => Some(CodecID::AV1),
126                _ => None,
127            },
128            #[cfg(feature = "video")]
129            (MediaType::Video, DocType::Matroska) => match codec_id_str {
130                "V_MPEG4/ISO/AVC" => Some(CodecID::H264),
131                "V_MPEGH/ISO/HEVC" => Some(CodecID::HEVC),
132                "V_MPEGI/ISO/VVC" => Some(CodecID::VVC),
133                "V_VP8" => Some(CodecID::VP8),
134                "V_VP9" => Some(CodecID::VP9),
135                "V_AV1" => Some(CodecID::AV1),
136                "V_MPEG1" => Some(CodecID::MPEG1),
137                "V_MPEG2" => Some(CodecID::MPEG2),
138                "V_MPEG4/ISO/SP" | "V_MPEG4/ISO/ASP" | "V_MPEG4/ISO/AP" => Some(CodecID::MPEG4),
139                _ => None,
140            },
141            #[cfg(feature = "audio")]
142            (MediaType::Audio, DocType::WebM) => match codec_id_str {
143                "A_VORBIS" => Some(CodecID::VORBIS),
144                "A_OPUS" => Some(CodecID::OPUS),
145                _ => None,
146            },
147            #[cfg(feature = "audio")]
148            (MediaType::Audio, DocType::Matroska) => match codec_id_str {
149                "A_AAC" | "A_AAC/MPEG2/MAIN" | "A_AAC/MPEG2/LC" | "A_AAC/MPEG2/LC/SBR" | "A_AAC/MPEG2/SSR" | "A_AAC/MPEG4/MAIN" |
150                "A_AAC/MPEG4/LC" | "A_AAC/MPEG4/LC/SBR" | "A_AAC/MPEG4/SSR" | "A_AAC/MPEG4/LTP" => Some(CodecID::AAC),
151                "A_OPUS" => Some(CodecID::OPUS),
152                "A_VORBIS" => Some(CodecID::VORBIS),
153                "A_FLAC" => Some(CodecID::FLAC),
154                "A_AC3" | "A_AC3/BSID9" | "A_AC3/BSID10" => Some(CodecID::AC3),
155                "A_EAC3" => Some(CodecID::EAC3),
156                "A_DTS" | "A_DTS/EXPRESS" | "A_DTS/LOSSLESS" => Some(CodecID::DTS),
157                "A_MPEG/L1" => Some(CodecID::MP1),
158                "A_MPEG/L2" => Some(CodecID::MP2),
159                "A_MPEG/L3" => Some(CodecID::MP3),
160                _ => None,
161            },
162            #[allow(unreachable_patterns)]
163            _ => None,
164        }
165    }
166
167    /// Convert Matroska track type to MediaType
168    fn track_type_to_media_type(track_type: u64) -> Option<MediaType> {
169        match track_type {
170            #[cfg(feature = "video")]
171            1 => Some(MediaType::Video),
172            #[cfg(feature = "audio")]
173            2 => Some(MediaType::Audio),
174            _ => None,
175        }
176    }
177
178    /// Create video parameters from track entry
179    #[cfg(feature = "video")]
180    fn make_video_params(track: &TrackEntry) -> Option<VideoParameters> {
181        let video = track.video.as_ref()?;
182
183        Some(VideoParameters {
184            width: NonZeroU32::new(video.pixel_width.0 as u32),
185            height: NonZeroU32::new(video.pixel_height.0 as u32),
186            ..Default::default()
187        })
188    }
189
190    /// Create audio parameters from track entry
191    #[cfg(feature = "audio")]
192    fn make_audio_params(track: &TrackEntry) -> Option<AudioParameters> {
193        let audio = track.audio.as_ref()?;
194
195        Some(AudioParameters {
196            sample_rate: NonZeroU32::new(audio.sampling_frequency.0 as u32),
197            channel_layout: ChannelLayout::default_from_channels(audio.channels.0 as u8).ok(),
198            ..Default::default()
199        })
200    }
201
202    /// Create decoder parameters from track entry (codec private data)
203    #[cfg(feature = "video")]
204    fn make_video_decoder_params(track: &TrackEntry) -> DecoderParameters {
205        let codec_id_str: &str = &track.codec_id.0;
206        let codec_private = track.codec_private.as_ref().map(|cp| cp.0.as_ref());
207
208        match codec_id_str {
209            "V_MPEG4/ISO/AVC" => {
210                if let Some(data) = codec_private {
211                    if let Some(extra) = Self::parse_avcc(data) {
212                        return DecoderParameters {
213                            extra_data: Some(extra),
214                            ..Default::default()
215                        };
216                    }
217                }
218                DecoderParameters::default()
219            }
220            "V_MPEGH/ISO/HEVC" => {
221                if let Some(data) = codec_private {
222                    if let Some(extra) = Self::parse_hvcc(data) {
223                        return DecoderParameters {
224                            extra_data: Some(extra),
225                            ..Default::default()
226                        };
227                    }
228                }
229                DecoderParameters::default()
230            }
231            _ => {
232                if let Some(data) = codec_private {
233                    DecoderParameters {
234                        extra_data: Some(ExtraData::Raw(data.to_vec())),
235                        ..Default::default()
236                    }
237                } else {
238                    DecoderParameters::default()
239                }
240            }
241        }
242    }
243
244    #[cfg(feature = "video")]
245    fn parse_avcc(data: &[u8]) -> Option<ExtraData> {
246        let avcc = Avcc::parse(data).ok()?;
247
248        Some(ExtraData::AVC {
249            sps: avcc.sequence_parameter_sets,
250            pps: avcc.picture_parameter_sets,
251            nalu_length_size: avcc.length_size,
252        })
253    }
254
255    #[cfg(feature = "video")]
256    fn parse_hvcc(data: &[u8]) -> Option<ExtraData> {
257        let hvcc = Hvcc::parse(data).ok()?;
258
259        let vps: Vec<Vec<u8>> = hvcc.vps().map(|v| v.to_vec()).collect();
260        let sps: Vec<Vec<u8>> = hvcc.sps().map(|v| v.to_vec()).collect();
261        let pps: Vec<Vec<u8>> = hvcc.pps().map(|v| v.to_vec()).collect();
262
263        Some(ExtraData::HEVC {
264            vps: if vps.is_empty() {
265                None
266            } else {
267                Some(vps)
268            },
269            sps,
270            pps,
271            nalu_length_size: hvcc.length_size,
272        })
273    }
274
275    /// Convert track entry to codec parameters
276    fn track_to_params(track: &TrackEntry, doc_type: DocType) -> Option<(CodecID, CodecParameters)> {
277        let codec_id_str: &str = &track.codec_id.0;
278        let media_type = Self::track_type_to_media_type(track.track_type.0)?;
279        let codec_id = Self::str_to_codec_id(codec_id_str, media_type, doc_type)?;
280
281        match media_type {
282            #[cfg(feature = "video")]
283            MediaType::Video => {
284                let video_params = Self::make_video_params(track)?;
285                let decoder_params = Self::make_video_decoder_params(track);
286                Some((codec_id, CodecParameters::new(video_params, decoder_params)))
287            }
288            #[cfg(feature = "audio")]
289            MediaType::Audio => {
290                let audio_params = Self::make_audio_params(track)?;
291                let decoder_params = if let Some(ref cp) = track.codec_private {
292                    DecoderParameters {
293                        extra_data: Some(ExtraData::Raw(cp.0.to_vec())),
294                        ..Default::default()
295                    }
296                } else {
297                    DecoderParameters::default()
298                };
299                Some((codec_id, CodecParameters::new(audio_params, decoder_params)))
300            }
301            #[allow(unreachable_patterns)]
302            _ => None,
303        }
304    }
305
306    /// Find Cues position from SeekHead entries
307    fn find_cues_position(seek_heads: &[SeekHead], segment_data_position: u64) -> Option<u64> {
308        for seek_head in seek_heads {
309            for seek in &seek_head.seek {
310                // Parse the SeekID to get the element ID
311                let mut id_bytes = &seek.seek_id[..];
312                if let Ok(element_id) = VInt64::read_from(&mut id_bytes) {
313                    // Check if this is a Cues element
314                    if element_id == Cues::ID {
315                        // SeekPosition is relative to segment data start
316                        return Some(*seek.seek_position + segment_data_position);
317                    }
318                }
319            }
320        }
321        None
322    }
323
324    /// Read an element body into a Vec<u8> for decoding
325    fn read_element_body(reader: &mut dyn Reader, size: u64) -> Result<Vec<u8>> {
326        let mut buf = vec![0u8; size as usize];
327        reader.read_exact(&mut buf)?;
328        Ok(buf)
329    }
330
331    /// Read the next cluster from the reader
332    /// Handles both known-size and unknown-size clusters
333    fn read_next_cluster(&mut self, reader: &mut dyn Reader) -> Result<bool> {
334        // Safety counter to prevent infinite loops
335        let mut skip_count = 0;
336
337        loop {
338            // Read cluster header
339            let header = match Header::read_from(reader) {
340                Ok(h) => h,
341                Err(_) => return Ok(false), // End of file
342            };
343
344            // Handle non-cluster elements between clusters
345            if header.id == Cluster::ID {
346                // Check if this cluster has unknown size
347                if header.size.is_unknown {
348                    // Unknown size cluster - need to parse element by element
349                    return self.read_unknown_size_cluster(reader);
350                }
351
352                // Known size cluster - use normal read_element
353                let cluster = match Cluster::read_element(&header, reader) {
354                    Ok(c) => c,
355                    Err(e) => {
356                        // Try to skip failed cluster if size is known
357                        if header.size.value > 0 {
358                            let _ = reader.seek(SeekFrom::Current(header.size.value as i64));
359                            skip_count += 1;
360                            if skip_count >= MAX_SKIPS {
361                                return Ok(false);
362                            }
363                            continue;
364                        }
365                        return Err(read_failed_error!("cluster", e));
366                    }
367                };
368
369                self.current_frame_index = 0;
370                self.current_lace_index = 0;
371                self.current_cluster = Some(cluster);
372                return Ok(true);
373            }
374
375            // Skip non-cluster element
376            let size = header.size.value;
377
378            // Check for unknown size - can't skip these reliably
379            if header.size.is_unknown {
380                // Can't skip unknown size, assume end of readable data
381                return Ok(false);
382            }
383
384            // Sanity check: don't try to skip unreasonably large elements
385            // This prevents hangs from corrupted size values
386            if size > MAX_ELEMENT_SIZE {
387                return Ok(false);
388            }
389
390            // Skip the element
391            if size > 0 {
392                if reader.seek(SeekFrom::Current(size as i64)).is_err() {
393                    return Ok(false);
394                }
395            }
396
397            skip_count += 1;
398            if skip_count >= MAX_SKIPS {
399                return Ok(false);
400            }
401        }
402    }
403
404    /// Read unknown-size cluster
405    fn read_unknown_size_cluster(&mut self, reader: &mut dyn Reader) -> Result<bool> {
406        let mut timestamp: Option<u64> = None;
407        let mut blocks: Vec<ClusterBlock> = Vec::new();
408
409        // Safety counter
410        let mut element_count = 0;
411        let mut unknown_element_count = 0;
412
413        loop {
414            // Save position for potential seek back
415            let pos_before_header = match reader.stream_position() {
416                Ok(p) => p,
417                Err(_) => break,
418            };
419
420            let sub_header = match Header::read_from(reader) {
421                Ok(h) => h,
422                Err(_) => break, // End of file or read error
423            };
424
425            // Check for next top-level element
426            if sub_header.id == Cluster::ID {
427                // Hit next cluster - seek back to cluster start and exit
428                let _ = reader.seek(SeekFrom::Start(pos_before_header));
429                break;
430            }
431
432            // Check for other top-level elements that indicate end of cluster
433            let id_val = sub_header.id.value;
434            if id_val == Cues::ID.value || id_val == Tags::ID.value || id_val == Info::ID.value || id_val == Tracks::ID.value {
435                // Hit another top-level element - seek back and exit
436                let _ = reader.seek(SeekFrom::Start(pos_before_header));
437                break;
438            }
439
440            // Parse cluster sub-elements
441            let element_size = sub_header.size.value;
442
443            // Safety check for element size
444            if sub_header.size.is_unknown || element_size > MAX_ELEMENT_SIZE {
445                // Can't process unknown size or very large elements
446                break;
447            }
448
449            match sub_header.id {
450                Timestamp::ID => {
451                    // Read timestamp
452                    if let Ok(body) = Self::read_element_body(reader, element_size) {
453                        if let Ok(ts) = Timestamp::decode_body(&mut body.as_slice()) {
454                            timestamp = Some(ts.0);
455                        }
456                    } else {
457                        // Read failed, skip by seeking
458                        if reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
459                            break;
460                        }
461                    }
462                    unknown_element_count = 0;
463                }
464                SimpleBlock::ID => {
465                    // Read SimpleBlock
466                    if let Ok(body) = Self::read_element_body(reader, element_size) {
467                        if let Ok(block) = SimpleBlock::decode_body(&mut body.as_slice()) {
468                            blocks.push(ClusterBlock::Simple(block));
469                        }
470                    } else {
471                        // Read failed, skip by seeking
472                        if reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
473                            break;
474                        }
475                    }
476                    unknown_element_count = 0;
477                }
478                BlockGroup::ID => {
479                    // Read BlockGroup
480                    if let Ok(body) = Self::read_element_body(reader, element_size) {
481                        if let Ok(block) = BlockGroup::decode_body(&mut body.as_slice()) {
482                            blocks.push(ClusterBlock::Group(block));
483                        }
484                    } else {
485                        // Read failed, skip by seeking
486                        if reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
487                            break;
488                        }
489                    }
490                    unknown_element_count = 0;
491                }
492                Position::ID | PrevSize::ID | Void::ID => {
493                    // Skip known but unimportant elements
494                    if element_size > 0 {
495                        if reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
496                            break;
497                        }
498                    }
499                    unknown_element_count = 0;
500                }
501                _ => {
502                    unknown_element_count += 1;
503                    if unknown_element_count >= MAX_UNKNOWN_ELEMENTS {
504                        // Too many unknown elements, may be corrupted
505                        break;
506                    }
507
508                    // Skip unknown element
509                    if element_size > 0 {
510                        if reader.seek(SeekFrom::Current(element_size as i64)).is_err() {
511                            break;
512                        }
513                    }
514                }
515            }
516
517            element_count += 1;
518            if element_count >= MAX_ELEMENTS {
519                break;
520            }
521        }
522
523        // Must have at least timestamp to form a valid cluster
524        let timestamp = match timestamp {
525            Some(ts) => ts,
526            None => return Ok(false),
527        };
528
529        // Build the cluster
530        let cluster = Cluster {
531            crc32: None,
532            void: None,
533            timestamp: Timestamp(timestamp),
534            position: None,
535            prev_size: None,
536            blocks,
537        };
538
539        self.current_frame_index = 0;
540        self.current_lace_index = 0;
541        self.current_cluster = Some(cluster);
542        Ok(true)
543    }
544}
545
546impl Format for MkvDemuxer {
547    fn set_option(&mut self, _key: &str, _value: &Variant) -> Result<()> {
548        Ok(())
549    }
550}
551
552impl Demuxer for MkvDemuxer {
553    fn read_header(&mut self, reader: &mut dyn Reader, state: &mut DemuxerState) -> Result<()> {
554        // Read EBML header and extract doc_type
555        let ebml = Ebml::read_from(reader).map_err(|e| read_failed_error!("EBML header", e))?;
556
557        // Cache the document type (mkv or webm)
558        self.doc_type = ebml.doc_type.as_ref().map(|dt| DocType::from_doc_type(&dt.0)).unwrap_or(DocType::Matroska);
559
560        // Safety limit for header reading
561        let mut element_count = 0;
562
563        // Read top-level elements until Segment is found
564        loop {
565            let header = match Header::read_from(reader) {
566                Ok(h) => h,
567                Err(e) => {
568                    return Err(read_failed_error!("header", e));
569                }
570            };
571
572            if header.id == Segment::ID {
573                // Record segment data start position
574                self.segment_data_position = reader.stream_position()?;
575
576                // Read sub-elements until Cluster is found
577                let mut info: Option<Info> = None;
578                let mut tracks: Option<Tracks> = None;
579                let mut cues: Option<Cues> = None;
580                let mut seek_head: Vec<SeekHead> = Vec::new();
581
582                loop {
583                    let current_pos = reader.stream_position()?;
584
585                    let sub_header = match Header::read_from(reader) {
586                        Ok(h) => h,
587                        Err(_) => break, // End of file or segment
588                    };
589
590                    if sub_header.id == Info::ID {
591                        info = Some(Info::read_element(&sub_header, reader).map_err(|e| read_failed_error!(e.to_string()))?);
592                    } else if sub_header.id == Tracks::ID {
593                        tracks = Some(Tracks::read_element(&sub_header, reader).map_err(|e| read_failed_error!(e.to_string()))?);
594                    } else if sub_header.id == Cues::ID {
595                        cues = Some(Cues::read_element(&sub_header, reader).map_err(|e| read_failed_error!(e.to_string()))?);
596                    } else if sub_header.id == SeekHead::ID {
597                        // Parse SeekHead to get positions of other elements
598                        if let Ok(sh) = SeekHead::read_element(&sub_header, reader) {
599                            seek_head.push(sh);
600                        }
601                    } else if sub_header.id == Cluster::ID {
602                        // Try to locate Cues using SeekHead
603                        if cues.is_none() {
604                            if let Some(cues_pos) = Self::find_cues_position(&seek_head, self.segment_data_position) {
605                                if cues_pos > current_pos {
606                                    if reader.seek(SeekFrom::Start(cues_pos)).is_ok() {
607                                        if let Ok(cues_header) = Header::read_from(reader) {
608                                            if cues_header.id == Cues::ID {
609                                                if let Ok(c) = Cues::read_element(&cues_header, reader) {
610                                                    cues = Some(c);
611                                                }
612                                            }
613                                        }
614                                    }
615                                }
616                            }
617                        }
618
619                        // Seek back to cluster start for later reading
620                        reader.seek(SeekFrom::Start(current_pos))?;
621                        break;
622                    } else {
623                        // Skip unknown elements
624                        let size = sub_header.size.value;
625
626                        // Check for unknown size or unreasonably large size
627                        if sub_header.size.is_unknown || size > MAX_ELEMENT_SIZE {
628                            // Can't reliably skip - stop header reading
629                            break;
630                        }
631
632                        if size > 0 {
633                            if reader.seek(SeekFrom::Current(size as i64)).is_err() {
634                                break;
635                            }
636                        }
637                    }
638
639                    element_count += 1;
640                    if element_count >= MAX_HEADER_ELEMENTS {
641                        break;
642                    }
643                }
644
645                // Validate required elements
646                let info = info.ok_or_else(|| not_found_error!("info element"))?;
647
648                let timestamp_scale = info.timestamp_scale.0;
649
650                // Store segment info
651                self.time_base = Rational64::new(info.timestamp_scale.0 as i64, NSEC_PER_SEC);
652                self.duration = info.duration.as_ref().map(|d| d.0);
653                self.cues = cues;
654
655                // Calculate duration in microseconds
656                if let Some(duration) = self.duration {
657                    state.duration = Some((duration * timestamp_scale as f64 / MSEC_PER_SEC as f64) as i64);
658                }
659
660                // Process tracks
661                if let Some(ref tracks) = tracks {
662                    let mut stream = Stream::new(0);
663
664                    for track_entry in &tracks.track_entry {
665                        let track_number = track_entry.track_number.0 as isize;
666
667                        if let Some((codec_id, params)) = Self::track_to_params(track_entry, self.doc_type) {
668                            let mut track = Track::new(track_number, codec_id, params, self.time_base);
669
670                            // Set track duration if available
671                            if let Some(duration) = self.duration {
672                                track.duration = Some(duration as i64);
673                            }
674
675                            stream.add_track(state.tracks.add_track(track));
676                        }
677                    }
678
679                    state.streams.add_stream(stream);
680                }
681
682                return Ok(());
683            } else {
684                // Skip unknown top-level elements
685                let size = header.size.value;
686                if size > 0 {
687                    reader.seek(SeekFrom::Current(size as i64))?;
688                }
689            }
690        }
691    }
692
693    fn read_packet(&mut self, reader: &mut dyn Reader, state: &DemuxerState) -> Result<Packet<'static>> {
694        loop {
695            if let Some(ref cluster) = self.current_cluster {
696                // Collect frames from current cluster starting at current_frame_index
697                let frames: Vec<_> = cluster.frames().filter_map(|r| r.ok()).collect();
698
699                while self.current_frame_index < frames.len() {
700                    let frame = &frames[self.current_frame_index];
701
702                    let track = state.tracks.find_track(frame.track_number as isize).ok_or_else(|| not_found_error!("track", frame.track_number))?;
703
704                    // Extract frame data from FrameData enum, handling laced frames
705                    let (frame_bytes, advance_frame): (&[u8], bool) = match &frame.data {
706                        FrameData::Single(data) => (data, true),
707                        FrameData::Multiple(data_vec) => {
708                            // For laced frames, return each sub-frame as a separate packet
709                            if self.current_lace_index < data_vec.len() {
710                                let data = data_vec[self.current_lace_index];
711                                self.current_lace_index += 1;
712                                // Only advance to next frame when all laced sub-frames are consumed
713                                let advance = self.current_lace_index >= data_vec.len();
714                                if advance {
715                                    self.current_lace_index = 0;
716                                }
717                                (data, advance)
718                            } else {
719                                // Should not happen, but reset and move to next frame
720                                self.current_lace_index = 0;
721                                self.current_frame_index += 1;
722                                continue;
723                            }
724                        }
725                    };
726
727                    if advance_frame {
728                        self.current_frame_index += 1;
729                    }
730
731                    let mut packet = Packet::from_buffer(track.pool.get_buffer_with_length(frame_bytes.len()));
732                    if let Some(buffer) = packet.data_mut() {
733                        buffer.copy_from_slice(frame_bytes);
734                    }
735
736                    packet.track_index = Some(track.index());
737                    packet.dts = Some(frame.timestamp);
738                    packet.pts = Some(frame.timestamp);
739                    packet.time_base = Some(self.time_base);
740
741                    if let Some(duration) = frame.duration {
742                        packet.duration = Some(duration.get() as i64);
743                    }
744
745                    // Set packet flags - only the first laced frame should be marked as keyframe
746                    packet.flags = if frame.is_keyframe &&
747                        (matches!(&frame.data, FrameData::Single(_)) ||
748                            self.current_lace_index == 1 ||
749                            (self.current_lace_index == 0 && advance_frame))
750                    {
751                        PacketFlags::Key
752                    } else {
753                        PacketFlags::empty()
754                    };
755
756                    return Ok(packet);
757                }
758            }
759
760            // Need to read the next cluster
761            if !self.read_next_cluster(reader)? {
762                return Err(not_found_error!("more packets"));
763            }
764        }
765    }
766
767    fn seek(&mut self, reader: &mut dyn Reader, state: &DemuxerState, track_index: Option<usize>, timestamp_us: i64, flags: SeekFlags) -> Result<()> {
768        // Convert timestamp to segment ticks
769        let target_ticks = Rational64::from_integer(timestamp_us) / (Rational64::from_integer(USEC_PER_SEC) * self.time_base);
770        let target_ticks = target_ticks.to_integer();
771
772        // Save current position for rollback
773        let saved_position = reader.stream_position()?;
774
775        // Determine target track
776        let target_track_index =
777            track_index.unwrap_or_else(|| state.tracks.into_iter().find(|t| t.media_type() == MediaType::Video).map(|t| t.index()).unwrap_or(0));
778
779        let target_track = state.tracks.get_track(target_track_index).ok_or_else(|| not_found_error!("track", target_track_index))?;
780
781        let target_track_number = target_track.id as u64;
782
783        // Use Cues for seeking
784        let Some(ref cues) = self.cues else {
785            return Err(unsupported_error!("seek without Cues index"));
786        };
787
788        // Save current cluster state for potential rollback on failure
789        let saved_cluster = self.current_cluster.take();
790        let saved_frame_index = self.current_frame_index;
791
792        // Collect all cue points for our target track
793        let mut track_cues: Vec<(i64, u64)> = Vec::new(); // (cue_time, cluster_pos)
794
795        for cue_point in &cues.cue_point {
796            let cue_time = cue_point.cue_time.0 as i64;
797            for track_pos in &cue_point.cue_track_positions {
798                if track_pos.cue_track.0 == target_track_number {
799                    let absolute_pos = self.segment_data_position + track_pos.cue_cluster_position.0;
800                    track_cues.push((cue_time, absolute_pos));
801                    break;
802                }
803            }
804        }
805
806        // Find the best cue point based on seek flags
807        let best_cue = if flags.contains(SeekFlags::BACKWARD) {
808            // BACKWARD: Find the largest cue_time <= target
809            track_cues.iter().filter(|(t, _)| *t <= target_ticks).max_by_key(|(t, _)| *t).copied()
810        } else {
811            // Default: Find the nearest cue point (before or after)
812            track_cues.iter().min_by_key(|(t, _)| (t - target_ticks).abs()).copied()
813        };
814
815        if let Some((_best_cue_time, cluster_pos)) = best_cue {
816            // Scan forward to find the frame closest to target
817            reader.seek(SeekFrom::Start(cluster_pos))?;
818
819            // Track the best frame before and after target
820            let mut best_before: Option<(u64, i64)> = None;
821            let mut best_after: Option<(u64, i64)> = None;
822
823            // For ANY mode, search any frame, otherwise search only key frame
824            let search_any = flags.contains(SeekFlags::ANY);
825
826            loop {
827                let cluster_start_pos = reader.stream_position()?;
828
829                // Try to read next cluster
830                if !self.read_next_cluster(reader)? {
831                    break; // End of file
832                }
833
834                let cluster_timestamp = self.current_cluster.as_ref().map(|c| c.timestamp.0 as i64).unwrap_or(0);
835
836                // Find frames for our target track in this cluster
837                if let Some(ref cluster) = self.current_cluster {
838                    for frame_result in cluster.frames() {
839                        if let Ok(frame) = frame_result {
840                            // For ANY mode, consider all frames; otherwise only keyframes
841                            if frame.track_number == target_track_number && (search_any || frame.is_keyframe) {
842                                if frame.timestamp <= target_ticks {
843                                    // Frame before or at target - update best_before
844                                    if best_before.is_none() || frame.timestamp > best_before.as_ref().unwrap().1 {
845                                        best_before = Some((cluster_start_pos, frame.timestamp));
846                                    }
847                                } else if best_after.is_none() {
848                                    // First frame after target
849                                    best_after = Some((cluster_start_pos, frame.timestamp));
850                                }
851                            }
852                        }
853                    }
854                }
855
856                // Decide when to stop scanning
857                if flags.contains(SeekFlags::BACKWARD) {
858                    // For BACKWARD: stop when cluster timestamp is past target
859                    if cluster_timestamp > target_ticks {
860                        break;
861                    }
862                } else {
863                    // For default: stop after finding the first frame after target
864                    if best_after.is_some() {
865                        break;
866                    }
867                    // Safety: don't search too far past target
868                    let lookahead_ticks = (Rational64::from_integer(SEEK_LOOKAHEAD_SEC) / self.time_base).to_integer();
869                    if cluster_timestamp > target_ticks + lookahead_ticks {
870                        break;
871                    }
872                }
873            }
874
875            // Choose the best frame based on flags
876            let chosen = if flags.contains(SeekFlags::BACKWARD) {
877                best_before
878            } else {
879                // Find the nearest frame
880                match (best_before, best_after) {
881                    (Some(before), Some(after)) => {
882                        let dist_before = target_ticks - before.1;
883                        let dist_after = after.1 - target_ticks;
884                        if dist_before <= dist_after {
885                            Some(before)
886                        } else {
887                            Some(after)
888                        }
889                    }
890                    (Some(before), None) => Some(before),
891                    (None, Some(after)) => Some(after),
892                    (None, None) => None,
893                }
894            };
895
896            // Seek to the chosen cluster and set frame index
897            if let Some((chosen_cluster_pos, _chosen_ts)) = chosen {
898                reader.seek(SeekFrom::Start(chosen_cluster_pos))?;
899                self.current_frame_index = 0;
900                self.current_lace_index = 0;
901                self.current_cluster = None;
902
903                // Re-read the chosen cluster
904                self.read_next_cluster(reader)?;
905
906                return Ok(());
907            }
908        }
909
910        // Seek failed - restore previous state so read_packet can continue
911        let _ = reader.seek(SeekFrom::Start(saved_position));
912        self.current_cluster = saved_cluster;
913        self.current_frame_index = saved_frame_index;
914
915        Err(not_found_error!("seek position"))
916    }
917}
918
919/// Builder for Matroska/WebM demuxer
920pub struct MkvDemuxerBuilder;
921
922impl media_format_types::FormatBuilder for MkvDemuxerBuilder {
923    fn name(&self) -> &'static str {
924        "matroska"
925    }
926
927    fn extensions(&self) -> &[&'static str] {
928        &["mkv", "webm"]
929    }
930}
931
932impl DemuxerBuilder for MkvDemuxerBuilder {
933    fn new_demuxer(&self) -> Result<Box<dyn Demuxer>> {
934        Ok(Box::new(MkvDemuxer::new()))
935    }
936
937    fn probe(&self, reader: &mut dyn Reader) -> bool {
938        let mut buf = [0; 4];
939        reader.read_exact(&mut buf).ok();
940        buf == [0x1A, 0x45, 0xDF, 0xA3]
941    }
942}