1use std::{io::SeekFrom, num::NonZeroU32};
4
5#[cfg(feature = "audio")]
6use media_codec_types::AudioParameters;
7use media_codec_types::{
8 decoder::DecoderParameters,
9 packet::{Packet, PacketFlags},
10 CodecID, CodecParameters,
11};
12#[cfg(feature = "video")]
13use media_codec_types::{decoder::ExtraData, VideoParameters};
14#[cfg(feature = "audio")]
15use media_core::audio::ChannelLayout;
16#[cfg(feature = "video")]
17use media_core::video::ColorRange;
18use media_core::{invalid_error, not_found_error, rational::Rational64, time::USEC_PER_SEC, variant::Variant, MediaType, Result};
19use media_format_types::{
20 demuxer::{Demuxer, DemuxerBuilder, DemuxerState, Reader, SeekFlags},
21 stream::Stream,
22 track::Track,
23 Format, FormatBuilder,
24};
25use mp4_atom::{Atom, Codec as Mp4Codec, Ftyp, Header, Mdat, Moov, ReadAtom, ReadFrom, Stbl, StszSamples};
26#[cfg(feature = "audio")]
27use mp4_atom::{Audio, Esds};
28#[cfg(feature = "video")]
29use mp4_atom::{Avcc, Colr, Hvcc, Visual};
30
31pub struct Mp4Demuxer {
33 pub ftyp: Option<Ftyp>,
35 pub moov: Option<Moov>,
37 track_sample_indices: Vec<usize>,
39}
40
41impl Default for Mp4Demuxer {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl Mp4Demuxer {
48 pub fn new() -> Self {
49 Self {
50 ftyp: None,
51 moov: None,
52 track_sample_indices: Vec::new(),
53 }
54 }
55
56 #[cfg(feature = "video")]
57 fn make_video_params(visual: &Visual, colr: Option<&Colr>) -> VideoParameters {
58 let mut video_params = VideoParameters {
59 width: NonZeroU32::new(visual.width as u32),
60 height: NonZeroU32::new(visual.height as u32),
61 ..Default::default()
62 };
63
64 let Some(colr) = colr else { return video_params };
65
66 let (primaries, transfer, matrix, range) = match colr {
67 Colr::Nclx {
68 colour_primaries,
69 transfer_characteristics,
70 matrix_coefficients,
71 full_range_flag,
72 } => (
73 *colour_primaries,
74 *transfer_characteristics,
75 *matrix_coefficients,
76 Some(if *full_range_flag {
77 ColorRange::Full
78 } else {
79 ColorRange::Video
80 }),
81 ),
82 _ => return video_params,
83 };
84
85 video_params.color_primaries = (primaries as usize).try_into().ok();
86 video_params.color_transfer_characteristics = (transfer as usize).try_into().ok();
87 video_params.color_matrix = (matrix as usize).try_into().ok();
88 video_params.color_range = range;
89
90 video_params
91 }
92
93 #[cfg(feature = "audio")]
94 fn make_audio_params(audio: &Audio) -> AudioParameters {
95 AudioParameters {
96 sample_rate: NonZeroU32::new(audio.sample_rate.integer() as u32),
97 channel_layout: ChannelLayout::default_from_channels(audio.channel_count as u8).ok(),
98 ..Default::default()
99 }
100 }
101
102 #[cfg(feature = "audio")]
103 fn make_asc_codec_params(esds: &Esds) -> DecoderParameters {
104 let asc = &esds.es_desc.dec_config.dec_specific;
105 DecoderParameters {
106 extra_data: Some(ExtraData::ASC {
107 object_type: asc.profile,
108 channel_config: asc.chan_conf,
109 }),
110 ..Default::default()
111 }
112 }
113
114 #[cfg(feature = "video")]
115 fn make_avc_codec_params(avc: &Avcc) -> DecoderParameters {
116 DecoderParameters {
117 extra_data: Some(ExtraData::AVC {
118 sps: avc.sequence_parameter_sets.clone(),
119 pps: avc.picture_parameter_sets.clone(),
120 nalu_length_size: avc.length_size,
121 }),
122 ..Default::default()
123 }
124 }
125
126 #[cfg(feature = "video")]
127 fn make_hevc_codec_params(hvcc: &Hvcc) -> DecoderParameters {
128 let mut decoder_params = DecoderParameters::default();
129
130 let mut vps: Option<Vec<Vec<u8>>> = None;
131 let mut sps = Vec::new();
132 let mut pps = Vec::new();
133
134 for array in &hvcc.arrays {
135 match array.nal_unit_type {
136 32 => vps.get_or_insert_with(Vec::new).extend(array.nalus.iter().cloned()),
137 33 => sps.extend(array.nalus.iter().cloned()),
138 34 => pps.extend(array.nalus.iter().cloned()),
139 _ => {}
140 }
141 }
142
143 decoder_params.extra_data = Some(ExtraData::HEVC {
144 vps,
145 sps,
146 pps,
147 nalu_length_size: hvcc.length_size_minus_one + 1,
148 });
149
150 decoder_params
151 }
152
153 fn codec_to_params(codec: &Mp4Codec) -> Option<(CodecID, CodecParameters)> {
154 match codec {
155 #[cfg(feature = "video")]
156 Mp4Codec::Avc1(avc1) => {
157 let video_params = Self::make_video_params(&avc1.visual, avc1.colr.as_ref());
158 let decoder_params = Self::make_avc_codec_params(&avc1.avcc);
159 Some((CodecID::H264, CodecParameters::new(video_params, decoder_params)))
160 }
161 #[cfg(feature = "video")]
162 Mp4Codec::Hev1(hev1) => {
163 let video_params = Self::make_video_params(&hev1.visual, hev1.colr.as_ref());
164 let decoder_params = Self::make_hevc_codec_params(&hev1.hvcc);
165 Some((CodecID::HEVC, CodecParameters::new(video_params, decoder_params)))
166 }
167 #[cfg(feature = "video")]
168 Mp4Codec::Hvc1(hvc1) => {
169 let video_params = Self::make_video_params(&hvc1.visual, hvc1.colr.as_ref());
170 let decoder_params = Self::make_hevc_codec_params(&hvc1.hvcc);
171 Some((CodecID::HEVC, CodecParameters::new(video_params, decoder_params)))
172 }
173 #[cfg(feature = "video")]
174 Mp4Codec::Vp08(vp08) => {
175 let video_params = Self::make_video_params(&vp08.visual, vp08.colr.as_ref());
176 Some((CodecID::VP8, CodecParameters::new(video_params, DecoderParameters::default())))
177 }
178 #[cfg(feature = "video")]
179 Mp4Codec::Vp09(vp09) => {
180 let video_params = Self::make_video_params(&vp09.visual, vp09.colr.as_ref());
181 Some((CodecID::VP9, CodecParameters::new(video_params, DecoderParameters::default())))
182 }
183 #[cfg(feature = "video")]
184 Mp4Codec::Av01(av01) => {
185 let video_params = Self::make_video_params(&av01.visual, av01.colr.as_ref());
186 Some((CodecID::AV1, CodecParameters::new(video_params, DecoderParameters::default())))
187 }
188 #[cfg(feature = "audio")]
189 Mp4Codec::Mp4a(mp4a) => {
190 let audio_params = Self::make_audio_params(&mp4a.audio);
191 let decoder_params = Self::make_asc_codec_params(&mp4a.esds);
192 Some((CodecID::AAC, CodecParameters::new(audio_params, decoder_params)))
193 }
194 #[cfg(feature = "audio")]
195 Mp4Codec::Opus(opus) => {
196 let audio_params = Self::make_audio_params(&opus.audio);
197 Some((CodecID::OPUS, CodecParameters::new(audio_params, DecoderParameters::default())))
198 }
199 #[cfg(feature = "audio")]
200 Mp4Codec::Flac(flac) => {
201 let audio_params = Self::make_audio_params(&flac.audio);
202 Some((CodecID::FLAC, CodecParameters::new(audio_params, DecoderParameters::default())))
203 }
204 #[cfg(feature = "audio")]
205 Mp4Codec::Ac3(ac3) => {
206 let audio_params = Self::make_audio_params(&ac3.audio);
207 Some((CodecID::AC3, CodecParameters::new(audio_params, DecoderParameters::default())))
208 }
209 #[cfg(feature = "audio")]
210 Mp4Codec::Eac3(eac3) => {
211 let audio_params = Self::make_audio_params(&eac3.audio);
212 Some((CodecID::EAC3, CodecParameters::new(audio_params, DecoderParameters::default())))
213 }
214 _ => None,
215 }
216 }
217
218 fn find_sample_index(stbl: &Stbl, target_dts: i64) -> usize {
219 let mut accumulated_dts = 0i64;
220 let mut sample_index = 0usize;
221
222 for entry in &stbl.stts.entries {
223 let samples_in_entry = entry.sample_count as usize;
224 let entry_duration = entry.sample_count as i64 * entry.sample_delta as i64;
225
226 if accumulated_dts + entry_duration > target_dts {
227 let offset = (target_dts - accumulated_dts) / entry.sample_delta as i64;
228 sample_index += offset as usize;
229 break;
230 }
231
232 accumulated_dts += entry_duration;
233 sample_index += samples_in_entry;
234 }
235
236 let total_samples = match &stbl.stsz.samples {
238 StszSamples::Identical {
239 count, ..
240 } => *count as usize,
241 StszSamples::Different {
242 sizes,
243 } => sizes.len(),
244 };
245 sample_index.min(total_samples.saturating_sub(1))
246 }
247}
248
249impl Format for Mp4Demuxer {
250 fn set_option(&mut self, _key: &str, _value: &Variant) -> Result<()> {
251 Ok(())
252 }
253}
254
255impl Demuxer for Mp4Demuxer {
256 fn read_header(&mut self, reader: &mut dyn Reader, state: &mut DemuxerState) -> Result<()> {
257 loop {
259 let header = match Header::read_from(reader) {
260 Ok(h) => h,
261 Err(e) => {
262 if self.moov.is_none() {
263 return Err(not_found_error!("moov"));
264 }
265 return Err(invalid_error!(e.to_string()));
266 }
267 };
268
269 match header.kind {
270 Ftyp::KIND => {
271 let ftyp = Ftyp::read_atom(&header, reader).map_err(|e| invalid_error!(e.to_string()))?;
272 self.ftyp = Some(ftyp);
273 }
274 Moov::KIND => {
275 let moov = Moov::read_atom(&header, reader).map_err(|e| invalid_error!(e.to_string()))?;
276
277 self.track_sample_indices = vec![0; moov.trak.len()];
279
280 let mut stream = Stream::new(0);
282
283 for trak in &moov.trak {
285 let track_id = trak.tkhd.track_id as isize;
286 let timescale = trak.mdia.mdhd.timescale;
287 let time_base = Rational64::new(1, timescale as i64);
288
289 if let Some(codec) = trak.mdia.minf.stbl.stsd.codecs.first() {
291 if let Some((codec_id, params)) = Self::codec_to_params(codec) {
292 let mut track = Track::new(track_id, codec_id, params, time_base);
293 track.duration = Some(trak.mdia.mdhd.duration as i64);
294 stream.add_track(state.tracks.add_track(track));
295 }
296 }
297 }
298
299 state.streams.add_stream(stream);
300
301 let timescale = moov.mvhd.timescale as i64;
302 let duration = moov.mvhd.duration as i64;
303 if timescale > 0 && duration > 0 {
304 state.duration = Some(duration * USEC_PER_SEC / timescale);
305 }
306
307 self.moov = Some(moov);
308
309 return Ok(());
310 }
311 Mdat::KIND => {
312 let skip_size = header.size.unwrap_or(0) as i64;
314 reader.seek(SeekFrom::Current(skip_size))?;
315 }
316 _ => {
317 if let Some(size) = header.size {
319 reader.seek(SeekFrom::Current(size as i64))?;
320 }
321 }
322 }
323 }
324 }
325
326 fn read_packet(&mut self, reader: &mut dyn Reader, state: &DemuxerState) -> Result<Packet<'static>> {
327 let moov = self.moov.as_ref().ok_or_else(|| not_found_error!("moov"))?;
328
329 let mut earliest_track_idx: Option<usize> = None;
331 let mut earliest_dts_us = i64::MAX;
332 let mut earliest_dts_raw = 0i64; for (track_idx, trak) in moov.trak.iter().enumerate() {
335 let sample_index = self.track_sample_indices[track_idx];
336
337 let stts = &trak.mdia.minf.stbl.stts;
339 let mut total_samples = 0u32;
340 for entry in &stts.entries {
341 total_samples += entry.sample_count;
342 }
343
344 if sample_index >= total_samples as usize {
345 continue; }
347
348 let mut dts = 0i64;
350 let mut accumulated_samples = 0usize;
351 for entry in &stts.entries {
352 if accumulated_samples + entry.sample_count as usize > sample_index {
353 dts += (sample_index - accumulated_samples) as i64 * entry.sample_delta as i64;
354 break;
355 }
356 dts += entry.sample_count as i64 * entry.sample_delta as i64;
357 accumulated_samples += entry.sample_count as usize;
358 }
359
360 let timescale = trak.mdia.mdhd.timescale as i64;
362 let dts_us = dts * USEC_PER_SEC / timescale;
363
364 if dts_us < earliest_dts_us {
365 earliest_dts_us = dts_us;
366 earliest_dts_raw = dts;
367 earliest_track_idx = Some(track_idx);
368 }
369 }
370
371 let track_idx = earliest_track_idx.ok_or_else(|| not_found_error!("no more samples"))?;
372
373 let trak = &moov.trak[track_idx];
375 let track_id = trak.tkhd.track_id;
376
377 let track = state.tracks.find_track(track_id as isize).ok_or_else(|| not_found_error!("track", track_id))?;
378
379 let sample_index = self.track_sample_indices[track_idx];
380 let stbl = &trak.mdia.minf.stbl;
381
382 let mut duration = 0i64;
384 let mut accumulated_samples = 0usize;
385 for entry in &stbl.stts.entries {
386 if accumulated_samples + entry.sample_count as usize > sample_index {
387 duration = entry.sample_delta as i64;
388 break;
389 }
390 accumulated_samples += entry.sample_count as usize;
391 }
392
393 let pts_offset = if let Some(ref ctts) = stbl.ctts {
395 let mut accumulated_samples = 0usize;
396 let mut offset = 0i32;
397 for entry in &ctts.entries {
398 if accumulated_samples + entry.sample_count as usize > sample_index {
399 offset = entry.sample_offset;
400 break;
401 }
402 accumulated_samples += entry.sample_count as usize;
403 }
404 offset as i64
405 } else {
406 0i64
407 };
408
409 let sample_size = match &stbl.stsz.samples {
410 StszSamples::Identical {
411 size, ..
412 } => *size as usize,
413 StszSamples::Different {
414 sizes,
415 } => *sizes.get(sample_index).ok_or_else(|| not_found_error!("sample size"))? as usize,
416 };
417
418 let mut chunk_index = 0usize;
420 let mut sample_in_chunk = sample_index;
421
422 for (i, entry) in stbl.stsc.entries.iter().enumerate() {
423 let next_first_chunk = stbl.stsc.entries.get(i + 1).map(|e| e.first_chunk).unwrap_or(u32::MAX);
424
425 let chunks_in_this_group = next_first_chunk - entry.first_chunk;
426 let samples_per_chunk = entry.samples_per_chunk as usize;
427 let samples_in_this_group = chunks_in_this_group as usize * samples_per_chunk;
428
429 if sample_in_chunk < samples_in_this_group {
430 chunk_index = (entry.first_chunk - 1) as usize + sample_in_chunk / samples_per_chunk;
431 sample_in_chunk %= samples_per_chunk;
432 break;
433 }
434 sample_in_chunk -= samples_in_this_group;
435 }
436
437 let chunk_offset = if let Some(ref stco) = stbl.stco {
438 *stco.entries.get(chunk_index).ok_or_else(|| not_found_error!("chunk offset"))? as u64
439 } else if let Some(ref co64) = stbl.co64 {
440 *co64.entries.get(chunk_index).ok_or_else(|| not_found_error!("chunk offset"))?
441 } else {
442 return Err(not_found_error!("chunk offset"));
443 };
444
445 let mut sample_offset = chunk_offset;
447 for i in 0..sample_in_chunk {
448 let prev_sample_idx = sample_index - sample_in_chunk + i;
449 let prev_size = match &stbl.stsz.samples {
450 StszSamples::Identical {
451 size, ..
452 } => *size as u64,
453 StszSamples::Different {
454 sizes,
455 } => *sizes.get(prev_sample_idx).ok_or_else(|| not_found_error!("sample size"))? as u64,
456 };
457 sample_offset += prev_size;
458 }
459
460 let mut packet = Packet::from_buffer(track.pool.get_buffer_with_length(sample_size));
461 let buffer = packet.data_mut().ok_or_else(|| invalid_error!("packet buffer is not mutable"))?;
462
463 reader.seek(SeekFrom::Start(sample_offset))?;
464 reader.read_exact(buffer)?;
465
466 let timescale = trak.mdia.mdhd.timescale;
467 let time_base = Rational64::new(1, timescale as i64);
468
469 packet.track_index = Some(track.index());
470 packet.dts = Some(earliest_dts_raw);
471 packet.pts = Some(earliest_dts_raw + pts_offset);
472 packet.duration = Some(duration);
473 packet.time_base = Some(time_base);
474
475 packet.flags = if stbl.stss.is_some() {
477 let key = stbl.stss.as_ref().map(|stss| stss.entries.contains(&((sample_index + 1) as u32))).unwrap_or(false);
478
479 if key {
480 PacketFlags::Key
481 } else {
482 PacketFlags::empty()
483 }
484 } else {
485 PacketFlags::Key };
487
488 self.track_sample_indices[track_idx] = sample_index + 1;
490
491 Ok(packet)
492 }
493
494 fn seek(
495 &mut self,
496 _reader: &mut dyn Reader,
497 state: &DemuxerState,
498 track_index: Option<usize>,
499 timestamp_us: i64,
500 flags: SeekFlags,
501 ) -> Result<()> {
502 let moov = self.moov.as_ref().ok_or_else(|| not_found_error!("moov"))?;
503
504 let track_index = track_index.unwrap_or_else(|| {
506 state.tracks.into_iter().find(|t| t.media_type() == MediaType::Video).map(|t| t.index()).unwrap_or(0)
508 });
509
510 let target_trak = moov.trak.get(track_index).ok_or_else(|| not_found_error!("track at index {}", track_index))?;
511 let target_timescale = target_trak.mdia.mdhd.timescale;
512 let target_stbl = &target_trak.mdia.minf.stbl;
513
514 let track_target_dts = timestamp_us * target_timescale as i64 / USEC_PER_SEC;
516
517 let mut target_sample_index = Self::find_sample_index(target_stbl, track_target_dts);
518
519 if !flags.contains(SeekFlags::ANY) {
521 if let Some(ref stss) = target_stbl.stss {
522 let target_sample_number = (target_sample_index + 1) as u32;
523
524 let keyframe_sample = if flags.contains(SeekFlags::BACKWARD) {
525 match stss.entries.partition_point(|s| *s <= target_sample_number) {
527 0 => 1,
528 i => stss.entries[i - 1],
529 }
530 } else {
531 let pos = stss.entries.partition_point(|s| *s < target_sample_number);
533 let candidates = [pos.checked_sub(1).and_then(|i| stss.entries.get(i)), stss.entries.get(pos)];
534 candidates.into_iter().flatten().min_by_key(|s| s.abs_diff(target_sample_number)).copied().unwrap_or(1)
535 };
536
537 target_sample_index = (keyframe_sample - 1) as usize;
538 }
539 }
540 let mut actual_dts = 0i64;
544 let mut accumulated_samples = 0usize;
545 for entry in &target_stbl.stts.entries {
546 if accumulated_samples + entry.sample_count as usize > target_sample_index {
547 actual_dts += (target_sample_index - accumulated_samples) as i64 * entry.sample_delta as i64;
548 break;
549 }
550 actual_dts += entry.sample_count as i64 * entry.sample_delta as i64;
551 accumulated_samples += entry.sample_count as usize;
552 }
553
554 for (trak_idx, trak) in moov.trak.iter().enumerate() {
556 let sample_index = if trak_idx == track_index {
557 target_sample_index
559 } else {
560 let timescale = trak.mdia.mdhd.timescale;
562 let track_dts = actual_dts * timescale as i64 / target_timescale as i64;
563 Self::find_sample_index(&trak.mdia.minf.stbl, track_dts)
564 };
565
566 self.track_sample_indices[trak_idx] = sample_index;
567 }
568
569 Ok(())
570 }
571}
572
573pub struct Mp4DemuxerBuilder;
575
576impl FormatBuilder for Mp4DemuxerBuilder {
577 fn name(&self) -> &'static str {
578 "mp4"
579 }
580
581 fn extensions(&self) -> &[&'static str] {
582 &["mp4", "mov", "m4v", "m4a"]
583 }
584}
585
586impl DemuxerBuilder for Mp4DemuxerBuilder {
587 fn new_demuxer(&self) -> Result<Box<dyn Demuxer>> {
588 Ok(Box::new(Mp4Demuxer::new()))
589 }
590
591 fn probe(&self, reader: &mut dyn Reader) -> bool {
592 let mut buf = [0u8; 8];
593 reader.read_exact(&mut buf).ok();
594
595 matches!(&buf[4..8], b"ftyp" | b"moov" | b"mdat")
596 }
597}