Skip to main content

media_format_isomp4/
demuxer.rs

1//! ISO Base Media File Format (MP4/MOV) demuxer
2
3use std::{io::SeekFrom, num::NonZeroU32};
4
5#[cfg(any(feature = "audio", feature = "video"))]
6use media_codec_types::decoder::ExtraData;
7#[cfg(feature = "audio")]
8use media_codec_types::AudioParameters;
9#[cfg(feature = "video")]
10use media_codec_types::VideoParameters;
11use media_codec_types::{
12    decoder::DecoderParameters,
13    packet::{Packet, PacketFlags},
14    CodecID, CodecParameters,
15};
16#[cfg(feature = "audio")]
17use media_core::audio::ChannelLayout;
18#[cfg(feature = "video")]
19use media_core::video::ColorRange;
20use media_core::{invalid_error, not_found_error, rational::Rational64, time::USEC_PER_SEC, variant::Variant, MediaType, Result};
21use media_format_types::{
22    demuxer::{Demuxer, DemuxerBuilder, DemuxerState, Reader, SeekFlags},
23    stream::Stream,
24    track::Track,
25    Format, FormatBuilder,
26};
27use mp4_atom::{Atom, Codec as Mp4Codec, Ftyp, Header, Mdat, Moov, ReadAtom, ReadFrom, Stbl, StszSamples};
28#[cfg(feature = "audio")]
29use mp4_atom::{Audio, Esds};
30#[cfg(feature = "video")]
31use mp4_atom::{Avcc, Colr, Hvcc, Visual};
32
33/// MP4 demuxer implementation
34pub struct Mp4Demuxer {
35    /// File type box
36    pub ftyp: Option<Ftyp>,
37    /// Movie box containing all metadata
38    pub moov: Option<Moov>,
39    /// Track current sample index for each track
40    track_sample_indices: Vec<usize>,
41}
42
43impl Default for Mp4Demuxer {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl Mp4Demuxer {
50    pub fn new() -> Self {
51        Self {
52            ftyp: None,
53            moov: None,
54            track_sample_indices: Vec::new(),
55        }
56    }
57
58    #[cfg(feature = "video")]
59    fn make_video_params(visual: &Visual, colr: Option<&Colr>) -> VideoParameters {
60        let mut video_params = VideoParameters {
61            width: NonZeroU32::new(visual.width as u32),
62            height: NonZeroU32::new(visual.height as u32),
63            ..Default::default()
64        };
65
66        let Some(colr) = colr else { return video_params };
67
68        let (primaries, transfer, matrix, range) = match colr {
69            Colr::Nclx {
70                colour_primaries,
71                transfer_characteristics,
72                matrix_coefficients,
73                full_range_flag,
74            } => (
75                *colour_primaries,
76                *transfer_characteristics,
77                *matrix_coefficients,
78                Some(if *full_range_flag {
79                    ColorRange::Full
80                } else {
81                    ColorRange::Video
82                }),
83            ),
84            _ => return video_params,
85        };
86
87        video_params.color_primaries = (primaries as usize).try_into().ok();
88        video_params.color_transfer_characteristics = (transfer as usize).try_into().ok();
89        video_params.color_matrix = (matrix as usize).try_into().ok();
90        video_params.color_range = range;
91
92        video_params
93    }
94
95    #[cfg(feature = "audio")]
96    fn make_audio_params(audio: &Audio) -> AudioParameters {
97        AudioParameters {
98            sample_rate: NonZeroU32::new(audio.sample_rate.integer() as u32),
99            channel_layout: ChannelLayout::default_from_channels(audio.channel_count as u8).ok(),
100            ..Default::default()
101        }
102    }
103
104    #[cfg(feature = "audio")]
105    fn make_asc_codec_params(esds: &Esds) -> DecoderParameters {
106        let asc = &esds.es_desc.dec_config.dec_specific;
107        DecoderParameters {
108            extra_data: Some(ExtraData::ASC {
109                object_type: asc.profile,
110                channel_config: asc.chan_conf,
111            }),
112            ..Default::default()
113        }
114    }
115
116    #[cfg(feature = "video")]
117    fn make_avc_codec_params(avc: &Avcc) -> DecoderParameters {
118        DecoderParameters {
119            extra_data: Some(ExtraData::AVC {
120                sps: avc.sequence_parameter_sets.clone(),
121                pps: avc.picture_parameter_sets.clone(),
122                nalu_length_size: avc.length_size,
123            }),
124            ..Default::default()
125        }
126    }
127
128    #[cfg(feature = "video")]
129    fn make_hevc_codec_params(hvcc: &Hvcc) -> DecoderParameters {
130        let mut decoder_params = DecoderParameters::default();
131
132        let mut vps: Option<Vec<Vec<u8>>> = None;
133        let mut sps = Vec::new();
134        let mut pps = Vec::new();
135
136        for array in &hvcc.arrays {
137            match array.nal_unit_type {
138                32 => vps.get_or_insert_with(Vec::new).extend(array.nalus.iter().cloned()),
139                33 => sps.extend(array.nalus.iter().cloned()),
140                34 => pps.extend(array.nalus.iter().cloned()),
141                _ => {}
142            }
143        }
144
145        decoder_params.extra_data = Some(ExtraData::HEVC {
146            vps,
147            sps,
148            pps,
149            nalu_length_size: hvcc.length_size_minus_one + 1,
150        });
151
152        decoder_params
153    }
154
155    fn codec_to_params(codec: &Mp4Codec) -> Option<(CodecID, CodecParameters)> {
156        match codec {
157            #[cfg(feature = "video")]
158            Mp4Codec::Avc1(avc1) => {
159                let video_params = Self::make_video_params(&avc1.visual, avc1.colr.as_ref());
160                let decoder_params = Self::make_avc_codec_params(&avc1.avcc);
161                Some((CodecID::H264, CodecParameters::new(video_params, decoder_params)))
162            }
163            #[cfg(feature = "video")]
164            Mp4Codec::Hev1(hev1) => {
165                let video_params = Self::make_video_params(&hev1.visual, hev1.colr.as_ref());
166                let decoder_params = Self::make_hevc_codec_params(&hev1.hvcc);
167                Some((CodecID::HEVC, CodecParameters::new(video_params, decoder_params)))
168            }
169            #[cfg(feature = "video")]
170            Mp4Codec::Hvc1(hvc1) => {
171                let video_params = Self::make_video_params(&hvc1.visual, hvc1.colr.as_ref());
172                let decoder_params = Self::make_hevc_codec_params(&hvc1.hvcc);
173                Some((CodecID::HEVC, CodecParameters::new(video_params, decoder_params)))
174            }
175            #[cfg(feature = "video")]
176            Mp4Codec::Vp08(vp08) => {
177                let video_params = Self::make_video_params(&vp08.visual, vp08.colr.as_ref());
178                Some((CodecID::VP8, CodecParameters::new(video_params, DecoderParameters::default())))
179            }
180            #[cfg(feature = "video")]
181            Mp4Codec::Vp09(vp09) => {
182                let video_params = Self::make_video_params(&vp09.visual, vp09.colr.as_ref());
183                Some((CodecID::VP9, CodecParameters::new(video_params, DecoderParameters::default())))
184            }
185            #[cfg(feature = "video")]
186            Mp4Codec::Av01(av01) => {
187                let video_params = Self::make_video_params(&av01.visual, av01.colr.as_ref());
188                Some((CodecID::AV1, CodecParameters::new(video_params, DecoderParameters::default())))
189            }
190            #[cfg(feature = "audio")]
191            Mp4Codec::Mp4a(mp4a) => {
192                let audio_params = Self::make_audio_params(&mp4a.audio);
193                let decoder_params = Self::make_asc_codec_params(&mp4a.esds);
194                Some((CodecID::AAC, CodecParameters::new(audio_params, decoder_params)))
195            }
196            #[cfg(feature = "audio")]
197            Mp4Codec::Opus(opus) => {
198                let audio_params = Self::make_audio_params(&opus.audio);
199                Some((CodecID::OPUS, CodecParameters::new(audio_params, DecoderParameters::default())))
200            }
201            #[cfg(feature = "audio")]
202            Mp4Codec::Flac(flac) => {
203                let audio_params = Self::make_audio_params(&flac.audio);
204                Some((CodecID::FLAC, CodecParameters::new(audio_params, DecoderParameters::default())))
205            }
206            #[cfg(feature = "audio")]
207            Mp4Codec::Ac3(ac3) => {
208                let audio_params = Self::make_audio_params(&ac3.audio);
209                Some((CodecID::AC3, CodecParameters::new(audio_params, DecoderParameters::default())))
210            }
211            #[cfg(feature = "audio")]
212            Mp4Codec::Eac3(eac3) => {
213                let audio_params = Self::make_audio_params(&eac3.audio);
214                Some((CodecID::EAC3, CodecParameters::new(audio_params, DecoderParameters::default())))
215            }
216            _ => None,
217        }
218    }
219
220    fn find_sample_index(stbl: &Stbl, target_dts: i64) -> usize {
221        let mut accumulated_dts = 0i64;
222        let mut sample_index = 0usize;
223
224        for entry in &stbl.stts.entries {
225            let samples_in_entry = entry.sample_count as usize;
226            let entry_duration = entry.sample_count as i64 * entry.sample_delta as i64;
227
228            if accumulated_dts + entry_duration > target_dts {
229                let offset = (target_dts - accumulated_dts) / entry.sample_delta as i64;
230                sample_index += offset as usize;
231                break;
232            }
233
234            accumulated_dts += entry_duration;
235            sample_index += samples_in_entry;
236        }
237
238        // Clamp to valid range
239        let total_samples = match &stbl.stsz.samples {
240            StszSamples::Identical {
241                count, ..
242            } => *count as usize,
243            StszSamples::Different {
244                sizes,
245            } => sizes.len(),
246        };
247        sample_index.min(total_samples.saturating_sub(1))
248    }
249}
250
251impl Format for Mp4Demuxer {
252    fn set_option(&mut self, _key: &str, _value: &Variant) -> Result<()> {
253        Ok(())
254    }
255}
256
257impl Demuxer for Mp4Demuxer {
258    fn read_header(&mut self, reader: &mut dyn Reader, state: &mut DemuxerState) -> Result<()> {
259        // Read atoms until find moov
260        loop {
261            let header = match Header::read_from(reader) {
262                Ok(h) => h,
263                Err(e) => {
264                    if self.moov.is_none() {
265                        return Err(not_found_error!("moov"));
266                    }
267                    return Err(invalid_error!(e.to_string()));
268                }
269            };
270
271            match header.kind {
272                Ftyp::KIND => {
273                    let ftyp = Ftyp::read_atom(&header, reader).map_err(|e| invalid_error!(e.to_string()))?;
274                    self.ftyp = Some(ftyp);
275                }
276                Moov::KIND => {
277                    let moov = Moov::read_atom(&header, reader).map_err(|e| invalid_error!(e.to_string()))?;
278
279                    // Initialize track_sample_indices with the number of tracks
280                    self.track_sample_indices = vec![0; moov.trak.len()];
281
282                    // Create a single stream
283                    let mut stream = Stream::new(0);
284
285                    // Process each track and add to stream
286                    for trak in &moov.trak {
287                        let track_id = trak.tkhd.track_id as isize;
288                        let timescale = trak.mdia.mdhd.timescale;
289                        let time_base = Rational64::new(1, timescale as i64);
290
291                        // Get codec info from stsd
292                        if let Some(codec) = trak.mdia.minf.stbl.stsd.codecs.first() {
293                            if let Some((codec_id, params)) = Self::codec_to_params(codec) {
294                                let mut track = Track::new(track_id, codec_id, params, time_base);
295                                track.duration = Some(trak.mdia.mdhd.duration as i64);
296                                stream.add_track(state.tracks.add_track(track));
297                            }
298                        }
299                    }
300
301                    state.streams.add_stream(stream);
302
303                    let timescale = moov.mvhd.timescale as i64;
304                    let duration = moov.mvhd.duration as i64;
305                    if timescale > 0 && duration > 0 {
306                        state.duration = Some(duration * USEC_PER_SEC / timescale);
307                    }
308
309                    self.moov = Some(moov);
310
311                    return Ok(());
312                }
313                Mdat::KIND => {
314                    // Skip mdat atom, read data later
315                    let skip_size = header.size.unwrap_or(0) as i64;
316                    reader.seek(SeekFrom::Current(skip_size))?;
317                }
318                _ => {
319                    // Skip unknown atoms
320                    if let Some(size) = header.size {
321                        reader.seek(SeekFrom::Current(size as i64))?;
322                    }
323                }
324            }
325        }
326    }
327
328    fn read_packet(&mut self, reader: &mut dyn Reader, state: &DemuxerState) -> Result<Packet<'static>> {
329        let moov = self.moov.as_ref().ok_or_else(|| not_found_error!("moov"))?;
330
331        // Find the track with the earliest next sample
332        let mut earliest_track_idx: Option<usize> = None;
333        let mut earliest_dts_us = i64::MAX;
334        let mut earliest_dts_raw = 0i64; // DTS in track's native timescale
335
336        for (track_idx, trak) in moov.trak.iter().enumerate() {
337            let sample_index = self.track_sample_indices[track_idx];
338
339            // Check if this track has more samples
340            let stts = &trak.mdia.minf.stbl.stts;
341            let mut total_samples = 0u32;
342            for entry in &stts.entries {
343                total_samples += entry.sample_count;
344            }
345
346            if sample_index >= total_samples as usize {
347                continue; // This track is exhausted
348            }
349
350            // Calculate DTS for this sample (in track's native timescale)
351            let mut dts = 0i64;
352            let mut accumulated_samples = 0usize;
353            for entry in &stts.entries {
354                if accumulated_samples + entry.sample_count as usize > sample_index {
355                    dts += (sample_index - accumulated_samples) as i64 * entry.sample_delta as i64;
356                    break;
357                }
358                dts += entry.sample_count as i64 * entry.sample_delta as i64;
359                accumulated_samples += entry.sample_count as usize;
360            }
361
362            // Convert DTS to microseconds for cross-track comparison
363            let timescale = trak.mdia.mdhd.timescale as i64;
364            let dts_us = dts * USEC_PER_SEC / timescale;
365
366            if dts_us < earliest_dts_us {
367                earliest_dts_us = dts_us;
368                earliest_dts_raw = dts;
369                earliest_track_idx = Some(track_idx);
370            }
371        }
372
373        let track_idx = earliest_track_idx.ok_or_else(|| not_found_error!("no more samples"))?;
374
375        // Find the corresponding trak
376        let trak = &moov.trak[track_idx];
377        let track_id = trak.tkhd.track_id;
378
379        let track = state.tracks.find_track(track_id as isize).ok_or_else(|| not_found_error!("track", track_id))?;
380
381        let sample_index = self.track_sample_indices[track_idx];
382        let stbl = &trak.mdia.minf.stbl;
383
384        // Calculate sample duration from stts
385        let mut duration = 0i64;
386        let mut accumulated_samples = 0usize;
387        for entry in &stbl.stts.entries {
388            if accumulated_samples + entry.sample_count as usize > sample_index {
389                duration = entry.sample_delta as i64;
390                break;
391            }
392            accumulated_samples += entry.sample_count as usize;
393        }
394
395        // Calculate PTS offset from ctts (Composition Time to Sample)
396        let pts_offset = if let Some(ref ctts) = stbl.ctts {
397            let mut accumulated_samples = 0usize;
398            let mut offset = 0i32;
399            for entry in &ctts.entries {
400                if accumulated_samples + entry.sample_count as usize > sample_index {
401                    offset = entry.sample_offset;
402                    break;
403                }
404                accumulated_samples += entry.sample_count as usize;
405            }
406            offset as i64
407        } else {
408            0i64
409        };
410
411        let sample_size = match &stbl.stsz.samples {
412            StszSamples::Identical {
413                size, ..
414            } => *size as usize,
415            StszSamples::Different {
416                sizes,
417            } => *sizes.get(sample_index).ok_or_else(|| not_found_error!("sample size"))? as usize,
418        };
419
420        // Get chunk and offset
421        let mut chunk_index = 0usize;
422        let mut sample_in_chunk = sample_index;
423
424        for (i, entry) in stbl.stsc.entries.iter().enumerate() {
425            let next_first_chunk = stbl.stsc.entries.get(i + 1).map(|e| e.first_chunk).unwrap_or(u32::MAX);
426
427            let chunks_in_this_group = next_first_chunk - entry.first_chunk;
428            let samples_per_chunk = entry.samples_per_chunk as usize;
429            let samples_in_this_group = chunks_in_this_group as usize * samples_per_chunk;
430
431            if sample_in_chunk < samples_in_this_group {
432                chunk_index = (entry.first_chunk - 1) as usize + sample_in_chunk / samples_per_chunk;
433                sample_in_chunk %= samples_per_chunk;
434                break;
435            }
436            sample_in_chunk -= samples_in_this_group;
437        }
438
439        let chunk_offset = if let Some(ref stco) = stbl.stco {
440            *stco.entries.get(chunk_index).ok_or_else(|| not_found_error!("chunk offset"))? as u64
441        } else if let Some(ref co64) = stbl.co64 {
442            *co64.entries.get(chunk_index).ok_or_else(|| not_found_error!("chunk offset"))?
443        } else {
444            return Err(not_found_error!("chunk offset"));
445        };
446
447        // Calculate sample offset within chunk
448        let mut sample_offset = chunk_offset;
449        for i in 0..sample_in_chunk {
450            let prev_sample_idx = sample_index - sample_in_chunk + i;
451            let prev_size = match &stbl.stsz.samples {
452                StszSamples::Identical {
453                    size, ..
454                } => *size as u64,
455                StszSamples::Different {
456                    sizes,
457                } => *sizes.get(prev_sample_idx).ok_or_else(|| not_found_error!("sample size"))? as u64,
458            };
459            sample_offset += prev_size;
460        }
461
462        let mut packet = Packet::from_buffer(track.pool.get_buffer_with_length(sample_size));
463        let buffer = packet.data_mut().ok_or_else(|| invalid_error!("packet buffer is not mutable"))?;
464
465        reader.seek(SeekFrom::Start(sample_offset))?;
466        reader.read_exact(buffer)?;
467
468        let timescale = trak.mdia.mdhd.timescale;
469        let time_base = Rational64::new(1, timescale as i64);
470
471        packet.track_index = Some(track.index());
472        packet.dts = Some(earliest_dts_raw);
473        packet.pts = Some(earliest_dts_raw + pts_offset);
474        packet.duration = Some(duration);
475        packet.time_base = Some(time_base);
476
477        // Check if this is a keyframe (sync sample)
478        packet.flags = if stbl.stss.is_some() {
479            let key = stbl.stss.as_ref().map(|stss| stss.entries.contains(&((sample_index + 1) as u32))).unwrap_or(false);
480
481            if key {
482                PacketFlags::Key
483            } else {
484                PacketFlags::empty()
485            }
486        } else {
487            PacketFlags::Key // If no stss, all samples are keyframes
488        };
489
490        // Update sample index
491        self.track_sample_indices[track_idx] = sample_index + 1;
492
493        Ok(packet)
494    }
495
496    fn seek(
497        &mut self,
498        _reader: &mut dyn Reader,
499        state: &DemuxerState,
500        track_index: Option<usize>,
501        timestamp_us: i64,
502        flags: SeekFlags,
503    ) -> Result<()> {
504        let moov = self.moov.as_ref().ok_or_else(|| not_found_error!("moov"))?;
505
506        // Determine the target track index
507        let track_index = track_index.unwrap_or_else(|| {
508            // Find the first video track, or fall back to the first track
509            state.tracks.into_iter().find(|t| t.media_type() == MediaType::Video).map(|t| t.index()).unwrap_or(0)
510        });
511
512        let target_trak = moov.trak.get(track_index).ok_or_else(|| not_found_error!("track at index {}", track_index))?;
513        let target_timescale = target_trak.mdia.mdhd.timescale;
514        let target_stbl = &target_trak.mdia.minf.stbl;
515
516        // Convert timestamp (in microseconds) to target track's timescale
517        let track_target_dts = timestamp_us * target_timescale as i64 / USEC_PER_SEC;
518
519        let mut target_sample_index = Self::find_sample_index(target_stbl, track_target_dts);
520
521        // Apply keyframe seeking (skip if ANY flag is set)
522        if !flags.contains(SeekFlags::ANY) {
523            if let Some(ref stss) = target_stbl.stss {
524                let target_sample_number = (target_sample_index + 1) as u32;
525
526                let keyframe_sample = if flags.contains(SeekFlags::BACKWARD) {
527                    // Find the largest sync sample that is <= target
528                    match stss.entries.partition_point(|s| *s <= target_sample_number) {
529                        0 => 1,
530                        i => stss.entries[i - 1],
531                    }
532                } else {
533                    // Find the nearest keyframe (before or after)
534                    let pos = stss.entries.partition_point(|s| *s < target_sample_number);
535                    let candidates = [pos.checked_sub(1).and_then(|i| stss.entries.get(i)), stss.entries.get(pos)];
536                    candidates.into_iter().flatten().min_by_key(|s| s.abs_diff(target_sample_number)).copied().unwrap_or(1)
537                };
538
539                target_sample_index = (keyframe_sample - 1) as usize;
540            }
541        }
542        // Keep the original target_sample_index (may be non-keyframe)
543
544        // Calculate the actual DTS of the selected keyframe
545        let mut actual_dts = 0i64;
546        let mut accumulated_samples = 0usize;
547        for entry in &target_stbl.stts.entries {
548            if accumulated_samples + entry.sample_count as usize > target_sample_index {
549                actual_dts += (target_sample_index - accumulated_samples) as i64 * entry.sample_delta as i64;
550                break;
551            }
552            actual_dts += entry.sample_count as i64 * entry.sample_delta as i64;
553            accumulated_samples += entry.sample_count as usize;
554        }
555
556        // Synchronize all tracks
557        for (trak_idx, trak) in moov.trak.iter().enumerate() {
558            let sample_index = if trak_idx == track_index {
559                // Target track: use keyframe-aligned position
560                target_sample_index
561            } else {
562                // Other tracks: find sample at the actual timestamp
563                let timescale = trak.mdia.mdhd.timescale;
564                let track_dts = actual_dts * timescale as i64 / target_timescale as i64;
565                Self::find_sample_index(&trak.mdia.minf.stbl, track_dts)
566            };
567
568            self.track_sample_indices[trak_idx] = sample_index;
569        }
570
571        Ok(())
572    }
573}
574
575/// Builder for MP4 demuxer
576pub struct Mp4DemuxerBuilder;
577
578impl FormatBuilder for Mp4DemuxerBuilder {
579    fn name(&self) -> &'static str {
580        "mp4"
581    }
582
583    fn extensions(&self) -> &[&'static str] {
584        &["mp4", "mov", "m4v", "m4a"]
585    }
586}
587
588impl DemuxerBuilder for Mp4DemuxerBuilder {
589    fn new_demuxer(&self) -> Result<Box<dyn Demuxer>> {
590        Ok(Box::new(Mp4Demuxer::new()))
591    }
592
593    fn probe(&self, reader: &mut dyn Reader) -> bool {
594        let mut buf = [0u8; 8];
595        reader.read_exact(&mut buf).ok();
596
597        matches!(&buf[4..8], b"ftyp" | b"moov" | b"mdat")
598    }
599}