Skip to main content

media_format_isomp4/
demuxer.rs

1//! ISO Base Media File Format (MP4/MOV) demuxer
2
3use std::{io::SeekFrom, num::NonZeroU32};
4
5#[cfg(feature = "audio")]
6use media_codec_types::AudioParameters;
7use media_codec_types::{
8    decoder::DecoderParameters,
9    packet::{Packet, PacketFlags},
10    CodecID, CodecParameters,
11};
12#[cfg(feature = "video")]
13use media_codec_types::{decoder::ExtraData, VideoParameters};
14#[cfg(feature = "audio")]
15use media_core::audio::ChannelLayout;
16#[cfg(feature = "video")]
17use media_core::video::ColorRange;
18use media_core::{invalid_error, not_found_error, rational::Rational64, time::USEC_PER_SEC, variant::Variant, MediaType, Result};
19use media_format_types::{
20    demuxer::{Demuxer, DemuxerBuilder, DemuxerState, Reader, SeekFlags},
21    stream::Stream,
22    track::Track,
23    Format, FormatBuilder,
24};
25use mp4_atom::{Atom, Codec as Mp4Codec, Ftyp, Header, Mdat, Moov, ReadAtom, ReadFrom, Stbl, StszSamples};
26#[cfg(feature = "audio")]
27use mp4_atom::{Audio, Esds};
28#[cfg(feature = "video")]
29use mp4_atom::{Avcc, Colr, Hvcc, Visual};
30
31/// MP4 demuxer implementation
32pub struct Mp4Demuxer {
33    /// File type box
34    pub ftyp: Option<Ftyp>,
35    /// Movie box containing all metadata
36    pub moov: Option<Moov>,
37    /// Track current sample index for each track
38    track_sample_indices: Vec<usize>,
39}
40
41impl Default for Mp4Demuxer {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl Mp4Demuxer {
48    pub fn new() -> Self {
49        Self {
50            ftyp: None,
51            moov: None,
52            track_sample_indices: Vec::new(),
53        }
54    }
55
56    #[cfg(feature = "video")]
57    fn make_video_params(visual: &Visual, colr: Option<&Colr>) -> VideoParameters {
58        let mut video_params = VideoParameters {
59            width: NonZeroU32::new(visual.width as u32),
60            height: NonZeroU32::new(visual.height as u32),
61            ..Default::default()
62        };
63
64        let Some(colr) = colr else { return video_params };
65
66        let (primaries, transfer, matrix, range) = match colr {
67            Colr::Nclx {
68                colour_primaries,
69                transfer_characteristics,
70                matrix_coefficients,
71                full_range_flag,
72            } => (
73                *colour_primaries,
74                *transfer_characteristics,
75                *matrix_coefficients,
76                Some(if *full_range_flag {
77                    ColorRange::Full
78                } else {
79                    ColorRange::Video
80                }),
81            ),
82            _ => return video_params,
83        };
84
85        video_params.color_primaries = (primaries as usize).try_into().ok();
86        video_params.color_transfer_characteristics = (transfer as usize).try_into().ok();
87        video_params.color_matrix = (matrix as usize).try_into().ok();
88        video_params.color_range = range;
89
90        video_params
91    }
92
93    #[cfg(feature = "audio")]
94    fn make_audio_params(audio: &Audio) -> AudioParameters {
95        AudioParameters {
96            sample_rate: NonZeroU32::new(audio.sample_rate.integer() as u32),
97            channel_layout: ChannelLayout::default_from_channels(audio.channel_count as u8).ok(),
98            ..Default::default()
99        }
100    }
101
102    #[cfg(feature = "audio")]
103    fn make_asc_codec_params(esds: &Esds) -> DecoderParameters {
104        let asc = &esds.es_desc.dec_config.dec_specific;
105        DecoderParameters {
106            extra_data: Some(ExtraData::ASC {
107                object_type: asc.profile,
108                channel_config: asc.chan_conf,
109            }),
110            ..Default::default()
111        }
112    }
113
114    #[cfg(feature = "video")]
115    fn make_avc_codec_params(avc: &Avcc) -> DecoderParameters {
116        DecoderParameters {
117            extra_data: Some(ExtraData::AVC {
118                sps: avc.sequence_parameter_sets.clone(),
119                pps: avc.picture_parameter_sets.clone(),
120                nalu_length_size: avc.length_size,
121            }),
122            ..Default::default()
123        }
124    }
125
126    #[cfg(feature = "video")]
127    fn make_hevc_codec_params(hvcc: &Hvcc) -> DecoderParameters {
128        let mut decoder_params = DecoderParameters::default();
129
130        let mut vps: Option<Vec<Vec<u8>>> = None;
131        let mut sps = Vec::new();
132        let mut pps = Vec::new();
133
134        for array in &hvcc.arrays {
135            match array.nal_unit_type {
136                32 => vps.get_or_insert_with(Vec::new).extend(array.nalus.iter().cloned()),
137                33 => sps.extend(array.nalus.iter().cloned()),
138                34 => pps.extend(array.nalus.iter().cloned()),
139                _ => {}
140            }
141        }
142
143        decoder_params.extra_data = Some(ExtraData::HEVC {
144            vps,
145            sps,
146            pps,
147            nalu_length_size: hvcc.length_size_minus_one + 1,
148        });
149
150        decoder_params
151    }
152
153    fn codec_to_params(codec: &Mp4Codec) -> Option<(CodecID, CodecParameters)> {
154        match codec {
155            #[cfg(feature = "video")]
156            Mp4Codec::Avc1(avc1) => {
157                let video_params = Self::make_video_params(&avc1.visual, avc1.colr.as_ref());
158                let decoder_params = Self::make_avc_codec_params(&avc1.avcc);
159                Some((CodecID::H264, CodecParameters::new(video_params, decoder_params)))
160            }
161            #[cfg(feature = "video")]
162            Mp4Codec::Hev1(hev1) => {
163                let video_params = Self::make_video_params(&hev1.visual, hev1.colr.as_ref());
164                let decoder_params = Self::make_hevc_codec_params(&hev1.hvcc);
165                Some((CodecID::HEVC, CodecParameters::new(video_params, decoder_params)))
166            }
167            #[cfg(feature = "video")]
168            Mp4Codec::Hvc1(hvc1) => {
169                let video_params = Self::make_video_params(&hvc1.visual, hvc1.colr.as_ref());
170                let decoder_params = Self::make_hevc_codec_params(&hvc1.hvcc);
171                Some((CodecID::HEVC, CodecParameters::new(video_params, decoder_params)))
172            }
173            #[cfg(feature = "video")]
174            Mp4Codec::Vp08(vp08) => {
175                let video_params = Self::make_video_params(&vp08.visual, vp08.colr.as_ref());
176                Some((CodecID::VP8, CodecParameters::new(video_params, DecoderParameters::default())))
177            }
178            #[cfg(feature = "video")]
179            Mp4Codec::Vp09(vp09) => {
180                let video_params = Self::make_video_params(&vp09.visual, vp09.colr.as_ref());
181                Some((CodecID::VP9, CodecParameters::new(video_params, DecoderParameters::default())))
182            }
183            #[cfg(feature = "video")]
184            Mp4Codec::Av01(av01) => {
185                let video_params = Self::make_video_params(&av01.visual, av01.colr.as_ref());
186                Some((CodecID::AV1, CodecParameters::new(video_params, DecoderParameters::default())))
187            }
188            #[cfg(feature = "audio")]
189            Mp4Codec::Mp4a(mp4a) => {
190                let audio_params = Self::make_audio_params(&mp4a.audio);
191                let decoder_params = Self::make_asc_codec_params(&mp4a.esds);
192                Some((CodecID::AAC, CodecParameters::new(audio_params, decoder_params)))
193            }
194            #[cfg(feature = "audio")]
195            Mp4Codec::Opus(opus) => {
196                let audio_params = Self::make_audio_params(&opus.audio);
197                Some((CodecID::OPUS, CodecParameters::new(audio_params, DecoderParameters::default())))
198            }
199            #[cfg(feature = "audio")]
200            Mp4Codec::Flac(flac) => {
201                let audio_params = Self::make_audio_params(&flac.audio);
202                Some((CodecID::FLAC, CodecParameters::new(audio_params, DecoderParameters::default())))
203            }
204            #[cfg(feature = "audio")]
205            Mp4Codec::Ac3(ac3) => {
206                let audio_params = Self::make_audio_params(&ac3.audio);
207                Some((CodecID::AC3, CodecParameters::new(audio_params, DecoderParameters::default())))
208            }
209            #[cfg(feature = "audio")]
210            Mp4Codec::Eac3(eac3) => {
211                let audio_params = Self::make_audio_params(&eac3.audio);
212                Some((CodecID::EAC3, CodecParameters::new(audio_params, DecoderParameters::default())))
213            }
214            _ => None,
215        }
216    }
217
218    fn find_sample_index(stbl: &Stbl, target_dts: i64) -> usize {
219        let mut accumulated_dts = 0i64;
220        let mut sample_index = 0usize;
221
222        for entry in &stbl.stts.entries {
223            let samples_in_entry = entry.sample_count as usize;
224            let entry_duration = entry.sample_count as i64 * entry.sample_delta as i64;
225
226            if accumulated_dts + entry_duration > target_dts {
227                let offset = (target_dts - accumulated_dts) / entry.sample_delta as i64;
228                sample_index += offset as usize;
229                break;
230            }
231
232            accumulated_dts += entry_duration;
233            sample_index += samples_in_entry;
234        }
235
236        // Clamp to valid range
237        let total_samples = match &stbl.stsz.samples {
238            StszSamples::Identical {
239                count, ..
240            } => *count as usize,
241            StszSamples::Different {
242                sizes,
243            } => sizes.len(),
244        };
245        sample_index.min(total_samples.saturating_sub(1))
246    }
247}
248
249impl Format for Mp4Demuxer {
250    fn set_option(&mut self, _key: &str, _value: &Variant) -> Result<()> {
251        Ok(())
252    }
253}
254
255impl Demuxer for Mp4Demuxer {
256    fn read_header(&mut self, reader: &mut dyn Reader, state: &mut DemuxerState) -> Result<()> {
257        // Read atoms until find moov
258        loop {
259            let header = match Header::read_from(reader) {
260                Ok(h) => h,
261                Err(e) => {
262                    if self.moov.is_none() {
263                        return Err(not_found_error!("moov"));
264                    }
265                    return Err(invalid_error!(e.to_string()));
266                }
267            };
268
269            match header.kind {
270                Ftyp::KIND => {
271                    let ftyp = Ftyp::read_atom(&header, reader).map_err(|e| invalid_error!(e.to_string()))?;
272                    self.ftyp = Some(ftyp);
273                }
274                Moov::KIND => {
275                    let moov = Moov::read_atom(&header, reader).map_err(|e| invalid_error!(e.to_string()))?;
276
277                    // Initialize track_sample_indices with the number of tracks
278                    self.track_sample_indices = vec![0; moov.trak.len()];
279
280                    // Create a single stream
281                    let mut stream = Stream::new(0);
282
283                    // Process each track and add to stream
284                    for trak in &moov.trak {
285                        let track_id = trak.tkhd.track_id as isize;
286                        let timescale = trak.mdia.mdhd.timescale;
287                        let time_base = Rational64::new(1, timescale as i64);
288
289                        // Get codec info from stsd
290                        if let Some(codec) = trak.mdia.minf.stbl.stsd.codecs.first() {
291                            if let Some((codec_id, params)) = Self::codec_to_params(codec) {
292                                let mut track = Track::new(track_id, codec_id, params, time_base);
293                                track.duration = Some(trak.mdia.mdhd.duration as i64);
294                                stream.add_track(state.tracks.add_track(track));
295                            }
296                        }
297                    }
298
299                    state.streams.add_stream(stream);
300
301                    let timescale = moov.mvhd.timescale as i64;
302                    let duration = moov.mvhd.duration as i64;
303                    if timescale > 0 && duration > 0 {
304                        state.duration = Some(duration * USEC_PER_SEC / timescale);
305                    }
306
307                    self.moov = Some(moov);
308
309                    return Ok(());
310                }
311                Mdat::KIND => {
312                    // Skip mdat atom, read data later
313                    let skip_size = header.size.unwrap_or(0) as i64;
314                    reader.seek(SeekFrom::Current(skip_size))?;
315                }
316                _ => {
317                    // Skip unknown atoms
318                    if let Some(size) = header.size {
319                        reader.seek(SeekFrom::Current(size as i64))?;
320                    }
321                }
322            }
323        }
324    }
325
326    fn read_packet(&mut self, reader: &mut dyn Reader, state: &DemuxerState) -> Result<Packet<'static>> {
327        let moov = self.moov.as_ref().ok_or_else(|| not_found_error!("moov"))?;
328
329        // Find the track with the earliest next sample
330        let mut earliest_track_idx: Option<usize> = None;
331        let mut earliest_dts_us = i64::MAX;
332        let mut earliest_dts_raw = 0i64; // DTS in track's native timescale
333
334        for (track_idx, trak) in moov.trak.iter().enumerate() {
335            let sample_index = self.track_sample_indices[track_idx];
336
337            // Check if this track has more samples
338            let stts = &trak.mdia.minf.stbl.stts;
339            let mut total_samples = 0u32;
340            for entry in &stts.entries {
341                total_samples += entry.sample_count;
342            }
343
344            if sample_index >= total_samples as usize {
345                continue; // This track is exhausted
346            }
347
348            // Calculate DTS for this sample (in track's native timescale)
349            let mut dts = 0i64;
350            let mut accumulated_samples = 0usize;
351            for entry in &stts.entries {
352                if accumulated_samples + entry.sample_count as usize > sample_index {
353                    dts += (sample_index - accumulated_samples) as i64 * entry.sample_delta as i64;
354                    break;
355                }
356                dts += entry.sample_count as i64 * entry.sample_delta as i64;
357                accumulated_samples += entry.sample_count as usize;
358            }
359
360            // Convert DTS to microseconds for cross-track comparison
361            let timescale = trak.mdia.mdhd.timescale as i64;
362            let dts_us = dts * USEC_PER_SEC / timescale;
363
364            if dts_us < earliest_dts_us {
365                earliest_dts_us = dts_us;
366                earliest_dts_raw = dts;
367                earliest_track_idx = Some(track_idx);
368            }
369        }
370
371        let track_idx = earliest_track_idx.ok_or_else(|| not_found_error!("no more samples"))?;
372
373        // Find the corresponding trak
374        let trak = &moov.trak[track_idx];
375        let track_id = trak.tkhd.track_id;
376
377        let track = state.tracks.find_track(track_id as isize).ok_or_else(|| not_found_error!("track", track_id))?;
378
379        let sample_index = self.track_sample_indices[track_idx];
380        let stbl = &trak.mdia.minf.stbl;
381
382        // Calculate sample duration from stts
383        let mut duration = 0i64;
384        let mut accumulated_samples = 0usize;
385        for entry in &stbl.stts.entries {
386            if accumulated_samples + entry.sample_count as usize > sample_index {
387                duration = entry.sample_delta as i64;
388                break;
389            }
390            accumulated_samples += entry.sample_count as usize;
391        }
392
393        // Calculate PTS offset from ctts (Composition Time to Sample)
394        let pts_offset = if let Some(ref ctts) = stbl.ctts {
395            let mut accumulated_samples = 0usize;
396            let mut offset = 0i32;
397            for entry in &ctts.entries {
398                if accumulated_samples + entry.sample_count as usize > sample_index {
399                    offset = entry.sample_offset;
400                    break;
401                }
402                accumulated_samples += entry.sample_count as usize;
403            }
404            offset as i64
405        } else {
406            0i64
407        };
408
409        let sample_size = match &stbl.stsz.samples {
410            StszSamples::Identical {
411                size, ..
412            } => *size as usize,
413            StszSamples::Different {
414                sizes,
415            } => *sizes.get(sample_index).ok_or_else(|| not_found_error!("sample size"))? as usize,
416        };
417
418        // Get chunk and offset
419        let mut chunk_index = 0usize;
420        let mut sample_in_chunk = sample_index;
421
422        for (i, entry) in stbl.stsc.entries.iter().enumerate() {
423            let next_first_chunk = stbl.stsc.entries.get(i + 1).map(|e| e.first_chunk).unwrap_or(u32::MAX);
424
425            let chunks_in_this_group = next_first_chunk - entry.first_chunk;
426            let samples_per_chunk = entry.samples_per_chunk as usize;
427            let samples_in_this_group = chunks_in_this_group as usize * samples_per_chunk;
428
429            if sample_in_chunk < samples_in_this_group {
430                chunk_index = (entry.first_chunk - 1) as usize + sample_in_chunk / samples_per_chunk;
431                sample_in_chunk %= samples_per_chunk;
432                break;
433            }
434            sample_in_chunk -= samples_in_this_group;
435        }
436
437        let chunk_offset = if let Some(ref stco) = stbl.stco {
438            *stco.entries.get(chunk_index).ok_or_else(|| not_found_error!("chunk offset"))? as u64
439        } else if let Some(ref co64) = stbl.co64 {
440            *co64.entries.get(chunk_index).ok_or_else(|| not_found_error!("chunk offset"))?
441        } else {
442            return Err(not_found_error!("chunk offset"));
443        };
444
445        // Calculate sample offset within chunk
446        let mut sample_offset = chunk_offset;
447        for i in 0..sample_in_chunk {
448            let prev_sample_idx = sample_index - sample_in_chunk + i;
449            let prev_size = match &stbl.stsz.samples {
450                StszSamples::Identical {
451                    size, ..
452                } => *size as u64,
453                StszSamples::Different {
454                    sizes,
455                } => *sizes.get(prev_sample_idx).ok_or_else(|| not_found_error!("sample size"))? as u64,
456            };
457            sample_offset += prev_size;
458        }
459
460        let mut packet = Packet::from_buffer(track.pool.get_buffer_with_length(sample_size));
461        let buffer = packet.data_mut().ok_or_else(|| invalid_error!("packet buffer is not mutable"))?;
462
463        reader.seek(SeekFrom::Start(sample_offset))?;
464        reader.read_exact(buffer)?;
465
466        let timescale = trak.mdia.mdhd.timescale;
467        let time_base = Rational64::new(1, timescale as i64);
468
469        packet.track_index = Some(track.index());
470        packet.dts = Some(earliest_dts_raw);
471        packet.pts = Some(earliest_dts_raw + pts_offset);
472        packet.duration = Some(duration);
473        packet.time_base = Some(time_base);
474
475        // Check if this is a keyframe (sync sample)
476        packet.flags = if stbl.stss.is_some() {
477            let key = stbl.stss.as_ref().map(|stss| stss.entries.contains(&((sample_index + 1) as u32))).unwrap_or(false);
478
479            if key {
480                PacketFlags::Key
481            } else {
482                PacketFlags::empty()
483            }
484        } else {
485            PacketFlags::Key // If no stss, all samples are keyframes
486        };
487
488        // Update sample index
489        self.track_sample_indices[track_idx] = sample_index + 1;
490
491        Ok(packet)
492    }
493
494    fn seek(
495        &mut self,
496        _reader: &mut dyn Reader,
497        state: &DemuxerState,
498        track_index: Option<usize>,
499        timestamp_us: i64,
500        flags: SeekFlags,
501    ) -> Result<()> {
502        let moov = self.moov.as_ref().ok_or_else(|| not_found_error!("moov"))?;
503
504        // Determine the target track index
505        let track_index = track_index.unwrap_or_else(|| {
506            // Find the first video track, or fall back to the first track
507            state.tracks.into_iter().find(|t| t.media_type() == MediaType::Video).map(|t| t.index()).unwrap_or(0)
508        });
509
510        let target_trak = moov.trak.get(track_index).ok_or_else(|| not_found_error!("track at index {}", track_index))?;
511        let target_timescale = target_trak.mdia.mdhd.timescale;
512        let target_stbl = &target_trak.mdia.minf.stbl;
513
514        // Convert timestamp (in microseconds) to target track's timescale
515        let track_target_dts = timestamp_us * target_timescale as i64 / USEC_PER_SEC;
516
517        let mut target_sample_index = Self::find_sample_index(target_stbl, track_target_dts);
518
519        // Apply keyframe seeking (skip if ANY flag is set)
520        if !flags.contains(SeekFlags::ANY) {
521            if let Some(ref stss) = target_stbl.stss {
522                let target_sample_number = (target_sample_index + 1) as u32;
523
524                let keyframe_sample = if flags.contains(SeekFlags::BACKWARD) {
525                    // Find the largest sync sample that is <= target
526                    match stss.entries.partition_point(|s| *s <= target_sample_number) {
527                        0 => 1,
528                        i => stss.entries[i - 1],
529                    }
530                } else {
531                    // Find the nearest keyframe (before or after)
532                    let pos = stss.entries.partition_point(|s| *s < target_sample_number);
533                    let candidates = [pos.checked_sub(1).and_then(|i| stss.entries.get(i)), stss.entries.get(pos)];
534                    candidates.into_iter().flatten().min_by_key(|s| s.abs_diff(target_sample_number)).copied().unwrap_or(1)
535                };
536
537                target_sample_index = (keyframe_sample - 1) as usize;
538            }
539        }
540        // Keep the original target_sample_index (may be non-keyframe)
541
542        // Calculate the actual DTS of the selected keyframe
543        let mut actual_dts = 0i64;
544        let mut accumulated_samples = 0usize;
545        for entry in &target_stbl.stts.entries {
546            if accumulated_samples + entry.sample_count as usize > target_sample_index {
547                actual_dts += (target_sample_index - accumulated_samples) as i64 * entry.sample_delta as i64;
548                break;
549            }
550            actual_dts += entry.sample_count as i64 * entry.sample_delta as i64;
551            accumulated_samples += entry.sample_count as usize;
552        }
553
554        // Synchronize all tracks
555        for (trak_idx, trak) in moov.trak.iter().enumerate() {
556            let sample_index = if trak_idx == track_index {
557                // Target track: use keyframe-aligned position
558                target_sample_index
559            } else {
560                // Other tracks: find sample at the actual timestamp
561                let timescale = trak.mdia.mdhd.timescale;
562                let track_dts = actual_dts * timescale as i64 / target_timescale as i64;
563                Self::find_sample_index(&trak.mdia.minf.stbl, track_dts)
564            };
565
566            self.track_sample_indices[trak_idx] = sample_index;
567        }
568
569        Ok(())
570    }
571}
572
573/// Builder for MP4 demuxer
574pub struct Mp4DemuxerBuilder;
575
576impl FormatBuilder for Mp4DemuxerBuilder {
577    fn name(&self) -> &'static str {
578        "mp4"
579    }
580
581    fn extensions(&self) -> &[&'static str] {
582        &["mp4", "mov", "m4v", "m4a"]
583    }
584}
585
586impl DemuxerBuilder for Mp4DemuxerBuilder {
587    fn new_demuxer(&self) -> Result<Box<dyn Demuxer>> {
588        Ok(Box::new(Mp4Demuxer::new()))
589    }
590
591    fn probe(&self, reader: &mut dyn Reader) -> bool {
592        let mut buf = [0u8; 8];
593        reader.read_exact(&mut buf).ok();
594
595        matches!(&buf[4..8], b"ftyp" | b"moov" | b"mdat")
596    }
597}