1use std::{io::SeekFrom, num::NonZeroU32};
4
5#[cfg(any(feature = "audio", feature = "video"))]
6use media_codec_types::decoder::ExtraData;
7#[cfg(feature = "audio")]
8use media_codec_types::AudioParameters;
9#[cfg(feature = "video")]
10use media_codec_types::VideoParameters;
11use media_codec_types::{
12 decoder::DecoderParameters,
13 packet::{Packet, PacketFlags},
14 CodecID, CodecParameters,
15};
16#[cfg(feature = "audio")]
17use media_core::audio::ChannelLayout;
18#[cfg(feature = "video")]
19use media_core::video::ColorRange;
20use media_core::{invalid_error, not_found_error, rational::Rational64, time::USEC_PER_SEC, variant::Variant, MediaType, Result};
21use media_format_types::{
22 demuxer::{Demuxer, DemuxerBuilder, DemuxerState, Reader, SeekFlags},
23 stream::Stream,
24 track::Track,
25 Format, FormatBuilder,
26};
27use mp4_atom::{Atom, Codec as Mp4Codec, Ftyp, Header, Mdat, Moov, ReadAtom, ReadFrom, Stbl, StszSamples};
28#[cfg(feature = "audio")]
29use mp4_atom::{Audio, Esds};
30#[cfg(feature = "video")]
31use mp4_atom::{Avcc, Colr, Hvcc, Visual};
32
33pub struct Mp4Demuxer {
35 pub ftyp: Option<Ftyp>,
37 pub moov: Option<Moov>,
39 track_sample_indices: Vec<usize>,
41}
42
43impl Default for Mp4Demuxer {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl Mp4Demuxer {
50 pub fn new() -> Self {
51 Self {
52 ftyp: None,
53 moov: None,
54 track_sample_indices: Vec::new(),
55 }
56 }
57
58 #[cfg(feature = "video")]
59 fn make_video_params(visual: &Visual, colr: Option<&Colr>) -> VideoParameters {
60 let mut video_params = VideoParameters {
61 width: NonZeroU32::new(visual.width as u32),
62 height: NonZeroU32::new(visual.height as u32),
63 ..Default::default()
64 };
65
66 let Some(colr) = colr else { return video_params };
67
68 let (primaries, transfer, matrix, range) = match colr {
69 Colr::Nclx {
70 colour_primaries,
71 transfer_characteristics,
72 matrix_coefficients,
73 full_range_flag,
74 } => (
75 *colour_primaries,
76 *transfer_characteristics,
77 *matrix_coefficients,
78 Some(if *full_range_flag {
79 ColorRange::Full
80 } else {
81 ColorRange::Video
82 }),
83 ),
84 _ => return video_params,
85 };
86
87 video_params.color_primaries = (primaries as usize).try_into().ok();
88 video_params.color_transfer_characteristics = (transfer as usize).try_into().ok();
89 video_params.color_matrix = (matrix as usize).try_into().ok();
90 video_params.color_range = range;
91
92 video_params
93 }
94
95 #[cfg(feature = "audio")]
96 fn make_audio_params(audio: &Audio) -> AudioParameters {
97 AudioParameters {
98 sample_rate: NonZeroU32::new(audio.sample_rate.integer() as u32),
99 channel_layout: ChannelLayout::default_from_channels(audio.channel_count as u8).ok(),
100 ..Default::default()
101 }
102 }
103
104 #[cfg(feature = "audio")]
105 fn make_asc_codec_params(esds: &Esds) -> DecoderParameters {
106 let asc = &esds.es_desc.dec_config.dec_specific;
107 DecoderParameters {
108 extra_data: Some(ExtraData::ASC {
109 object_type: asc.profile,
110 channel_config: asc.chan_conf,
111 }),
112 ..Default::default()
113 }
114 }
115
116 #[cfg(feature = "video")]
117 fn make_avc_codec_params(avc: &Avcc) -> DecoderParameters {
118 DecoderParameters {
119 extra_data: Some(ExtraData::AVC {
120 sps: avc.sequence_parameter_sets.clone(),
121 pps: avc.picture_parameter_sets.clone(),
122 nalu_length_size: avc.length_size,
123 }),
124 ..Default::default()
125 }
126 }
127
128 #[cfg(feature = "video")]
129 fn make_hevc_codec_params(hvcc: &Hvcc) -> DecoderParameters {
130 let mut decoder_params = DecoderParameters::default();
131
132 let mut vps: Option<Vec<Vec<u8>>> = None;
133 let mut sps = Vec::new();
134 let mut pps = Vec::new();
135
136 for array in &hvcc.arrays {
137 match array.nal_unit_type {
138 32 => vps.get_or_insert_with(Vec::new).extend(array.nalus.iter().cloned()),
139 33 => sps.extend(array.nalus.iter().cloned()),
140 34 => pps.extend(array.nalus.iter().cloned()),
141 _ => {}
142 }
143 }
144
145 decoder_params.extra_data = Some(ExtraData::HEVC {
146 vps,
147 sps,
148 pps,
149 nalu_length_size: hvcc.length_size_minus_one + 1,
150 });
151
152 decoder_params
153 }
154
155 fn codec_to_params(codec: &Mp4Codec) -> Option<(CodecID, CodecParameters)> {
156 match codec {
157 #[cfg(feature = "video")]
158 Mp4Codec::Avc1(avc1) => {
159 let video_params = Self::make_video_params(&avc1.visual, avc1.colr.as_ref());
160 let decoder_params = Self::make_avc_codec_params(&avc1.avcc);
161 Some((CodecID::H264, CodecParameters::new(video_params, decoder_params)))
162 }
163 #[cfg(feature = "video")]
164 Mp4Codec::Hev1(hev1) => {
165 let video_params = Self::make_video_params(&hev1.visual, hev1.colr.as_ref());
166 let decoder_params = Self::make_hevc_codec_params(&hev1.hvcc);
167 Some((CodecID::HEVC, CodecParameters::new(video_params, decoder_params)))
168 }
169 #[cfg(feature = "video")]
170 Mp4Codec::Hvc1(hvc1) => {
171 let video_params = Self::make_video_params(&hvc1.visual, hvc1.colr.as_ref());
172 let decoder_params = Self::make_hevc_codec_params(&hvc1.hvcc);
173 Some((CodecID::HEVC, CodecParameters::new(video_params, decoder_params)))
174 }
175 #[cfg(feature = "video")]
176 Mp4Codec::Vp08(vp08) => {
177 let video_params = Self::make_video_params(&vp08.visual, vp08.colr.as_ref());
178 Some((CodecID::VP8, CodecParameters::new(video_params, DecoderParameters::default())))
179 }
180 #[cfg(feature = "video")]
181 Mp4Codec::Vp09(vp09) => {
182 let video_params = Self::make_video_params(&vp09.visual, vp09.colr.as_ref());
183 Some((CodecID::VP9, CodecParameters::new(video_params, DecoderParameters::default())))
184 }
185 #[cfg(feature = "video")]
186 Mp4Codec::Av01(av01) => {
187 let video_params = Self::make_video_params(&av01.visual, av01.colr.as_ref());
188 Some((CodecID::AV1, CodecParameters::new(video_params, DecoderParameters::default())))
189 }
190 #[cfg(feature = "audio")]
191 Mp4Codec::Mp4a(mp4a) => {
192 let audio_params = Self::make_audio_params(&mp4a.audio);
193 let decoder_params = Self::make_asc_codec_params(&mp4a.esds);
194 Some((CodecID::AAC, CodecParameters::new(audio_params, decoder_params)))
195 }
196 #[cfg(feature = "audio")]
197 Mp4Codec::Opus(opus) => {
198 let audio_params = Self::make_audio_params(&opus.audio);
199 Some((CodecID::OPUS, CodecParameters::new(audio_params, DecoderParameters::default())))
200 }
201 #[cfg(feature = "audio")]
202 Mp4Codec::Flac(flac) => {
203 let audio_params = Self::make_audio_params(&flac.audio);
204 Some((CodecID::FLAC, CodecParameters::new(audio_params, DecoderParameters::default())))
205 }
206 #[cfg(feature = "audio")]
207 Mp4Codec::Ac3(ac3) => {
208 let audio_params = Self::make_audio_params(&ac3.audio);
209 Some((CodecID::AC3, CodecParameters::new(audio_params, DecoderParameters::default())))
210 }
211 #[cfg(feature = "audio")]
212 Mp4Codec::Eac3(eac3) => {
213 let audio_params = Self::make_audio_params(&eac3.audio);
214 Some((CodecID::EAC3, CodecParameters::new(audio_params, DecoderParameters::default())))
215 }
216 _ => None,
217 }
218 }
219
220 fn find_sample_index(stbl: &Stbl, target_dts: i64) -> usize {
221 let mut accumulated_dts = 0i64;
222 let mut sample_index = 0usize;
223
224 for entry in &stbl.stts.entries {
225 let samples_in_entry = entry.sample_count as usize;
226 let entry_duration = entry.sample_count as i64 * entry.sample_delta as i64;
227
228 if accumulated_dts + entry_duration > target_dts {
229 let offset = (target_dts - accumulated_dts) / entry.sample_delta as i64;
230 sample_index += offset as usize;
231 break;
232 }
233
234 accumulated_dts += entry_duration;
235 sample_index += samples_in_entry;
236 }
237
238 let total_samples = match &stbl.stsz.samples {
240 StszSamples::Identical {
241 count, ..
242 } => *count as usize,
243 StszSamples::Different {
244 sizes,
245 } => sizes.len(),
246 };
247 sample_index.min(total_samples.saturating_sub(1))
248 }
249}
250
251impl Format for Mp4Demuxer {
252 fn set_option(&mut self, _key: &str, _value: &Variant) -> Result<()> {
253 Ok(())
254 }
255}
256
257impl Demuxer for Mp4Demuxer {
258 fn read_header(&mut self, reader: &mut dyn Reader, state: &mut DemuxerState) -> Result<()> {
259 loop {
261 let header = match Header::read_from(reader) {
262 Ok(h) => h,
263 Err(e) => {
264 if self.moov.is_none() {
265 return Err(not_found_error!("moov"));
266 }
267 return Err(invalid_error!(e.to_string()));
268 }
269 };
270
271 match header.kind {
272 Ftyp::KIND => {
273 let ftyp = Ftyp::read_atom(&header, reader).map_err(|e| invalid_error!(e.to_string()))?;
274 self.ftyp = Some(ftyp);
275 }
276 Moov::KIND => {
277 let moov = Moov::read_atom(&header, reader).map_err(|e| invalid_error!(e.to_string()))?;
278
279 self.track_sample_indices = vec![0; moov.trak.len()];
281
282 let mut stream = Stream::new(0);
284
285 for trak in &moov.trak {
287 let track_id = trak.tkhd.track_id as isize;
288 let timescale = trak.mdia.mdhd.timescale;
289 let time_base = Rational64::new(1, timescale as i64);
290
291 if let Some(codec) = trak.mdia.minf.stbl.stsd.codecs.first() {
293 if let Some((codec_id, params)) = Self::codec_to_params(codec) {
294 let mut track = Track::new(track_id, codec_id, params, time_base);
295 track.duration = Some(trak.mdia.mdhd.duration as i64);
296 stream.add_track(state.tracks.add_track(track));
297 }
298 }
299 }
300
301 state.streams.add_stream(stream);
302
303 let timescale = moov.mvhd.timescale as i64;
304 let duration = moov.mvhd.duration as i64;
305 if timescale > 0 && duration > 0 {
306 state.duration = Some(duration * USEC_PER_SEC / timescale);
307 }
308
309 self.moov = Some(moov);
310
311 return Ok(());
312 }
313 Mdat::KIND => {
314 let skip_size = header.size.unwrap_or(0) as i64;
316 reader.seek(SeekFrom::Current(skip_size))?;
317 }
318 _ => {
319 if let Some(size) = header.size {
321 reader.seek(SeekFrom::Current(size as i64))?;
322 }
323 }
324 }
325 }
326 }
327
328 fn read_packet(&mut self, reader: &mut dyn Reader, state: &DemuxerState) -> Result<Packet<'static>> {
329 let moov = self.moov.as_ref().ok_or_else(|| not_found_error!("moov"))?;
330
331 let mut earliest_track_idx: Option<usize> = None;
333 let mut earliest_dts_us = i64::MAX;
334 let mut earliest_dts_raw = 0i64; for (track_idx, trak) in moov.trak.iter().enumerate() {
337 let sample_index = self.track_sample_indices[track_idx];
338
339 let stts = &trak.mdia.minf.stbl.stts;
341 let mut total_samples = 0u32;
342 for entry in &stts.entries {
343 total_samples += entry.sample_count;
344 }
345
346 if sample_index >= total_samples as usize {
347 continue; }
349
350 let mut dts = 0i64;
352 let mut accumulated_samples = 0usize;
353 for entry in &stts.entries {
354 if accumulated_samples + entry.sample_count as usize > sample_index {
355 dts += (sample_index - accumulated_samples) as i64 * entry.sample_delta as i64;
356 break;
357 }
358 dts += entry.sample_count as i64 * entry.sample_delta as i64;
359 accumulated_samples += entry.sample_count as usize;
360 }
361
362 let timescale = trak.mdia.mdhd.timescale as i64;
364 let dts_us = dts * USEC_PER_SEC / timescale;
365
366 if dts_us < earliest_dts_us {
367 earliest_dts_us = dts_us;
368 earliest_dts_raw = dts;
369 earliest_track_idx = Some(track_idx);
370 }
371 }
372
373 let track_idx = earliest_track_idx.ok_or_else(|| not_found_error!("no more samples"))?;
374
375 let trak = &moov.trak[track_idx];
377 let track_id = trak.tkhd.track_id;
378
379 let track = state.tracks.find_track(track_id as isize).ok_or_else(|| not_found_error!("track", track_id))?;
380
381 let sample_index = self.track_sample_indices[track_idx];
382 let stbl = &trak.mdia.minf.stbl;
383
384 let mut duration = 0i64;
386 let mut accumulated_samples = 0usize;
387 for entry in &stbl.stts.entries {
388 if accumulated_samples + entry.sample_count as usize > sample_index {
389 duration = entry.sample_delta as i64;
390 break;
391 }
392 accumulated_samples += entry.sample_count as usize;
393 }
394
395 let pts_offset = if let Some(ref ctts) = stbl.ctts {
397 let mut accumulated_samples = 0usize;
398 let mut offset = 0i32;
399 for entry in &ctts.entries {
400 if accumulated_samples + entry.sample_count as usize > sample_index {
401 offset = entry.sample_offset;
402 break;
403 }
404 accumulated_samples += entry.sample_count as usize;
405 }
406 offset as i64
407 } else {
408 0i64
409 };
410
411 let sample_size = match &stbl.stsz.samples {
412 StszSamples::Identical {
413 size, ..
414 } => *size as usize,
415 StszSamples::Different {
416 sizes,
417 } => *sizes.get(sample_index).ok_or_else(|| not_found_error!("sample size"))? as usize,
418 };
419
420 let mut chunk_index = 0usize;
422 let mut sample_in_chunk = sample_index;
423
424 for (i, entry) in stbl.stsc.entries.iter().enumerate() {
425 let next_first_chunk = stbl.stsc.entries.get(i + 1).map(|e| e.first_chunk).unwrap_or(u32::MAX);
426
427 let chunks_in_this_group = next_first_chunk - entry.first_chunk;
428 let samples_per_chunk = entry.samples_per_chunk as usize;
429 let samples_in_this_group = chunks_in_this_group as usize * samples_per_chunk;
430
431 if sample_in_chunk < samples_in_this_group {
432 chunk_index = (entry.first_chunk - 1) as usize + sample_in_chunk / samples_per_chunk;
433 sample_in_chunk %= samples_per_chunk;
434 break;
435 }
436 sample_in_chunk -= samples_in_this_group;
437 }
438
439 let chunk_offset = if let Some(ref stco) = stbl.stco {
440 *stco.entries.get(chunk_index).ok_or_else(|| not_found_error!("chunk offset"))? as u64
441 } else if let Some(ref co64) = stbl.co64 {
442 *co64.entries.get(chunk_index).ok_or_else(|| not_found_error!("chunk offset"))?
443 } else {
444 return Err(not_found_error!("chunk offset"));
445 };
446
447 let mut sample_offset = chunk_offset;
449 for i in 0..sample_in_chunk {
450 let prev_sample_idx = sample_index - sample_in_chunk + i;
451 let prev_size = match &stbl.stsz.samples {
452 StszSamples::Identical {
453 size, ..
454 } => *size as u64,
455 StszSamples::Different {
456 sizes,
457 } => *sizes.get(prev_sample_idx).ok_or_else(|| not_found_error!("sample size"))? as u64,
458 };
459 sample_offset += prev_size;
460 }
461
462 let mut packet = Packet::from_buffer(track.pool.get_buffer_with_length(sample_size));
463 let buffer = packet.data_mut().ok_or_else(|| invalid_error!("packet buffer is not mutable"))?;
464
465 reader.seek(SeekFrom::Start(sample_offset))?;
466 reader.read_exact(buffer)?;
467
468 let timescale = trak.mdia.mdhd.timescale;
469 let time_base = Rational64::new(1, timescale as i64);
470
471 packet.track_index = Some(track.index());
472 packet.dts = Some(earliest_dts_raw);
473 packet.pts = Some(earliest_dts_raw + pts_offset);
474 packet.duration = Some(duration);
475 packet.time_base = Some(time_base);
476
477 packet.flags = if stbl.stss.is_some() {
479 let key = stbl.stss.as_ref().map(|stss| stss.entries.contains(&((sample_index + 1) as u32))).unwrap_or(false);
480
481 if key {
482 PacketFlags::Key
483 } else {
484 PacketFlags::empty()
485 }
486 } else {
487 PacketFlags::Key };
489
490 self.track_sample_indices[track_idx] = sample_index + 1;
492
493 Ok(packet)
494 }
495
496 fn seek(
497 &mut self,
498 _reader: &mut dyn Reader,
499 state: &DemuxerState,
500 track_index: Option<usize>,
501 timestamp_us: i64,
502 flags: SeekFlags,
503 ) -> Result<()> {
504 let moov = self.moov.as_ref().ok_or_else(|| not_found_error!("moov"))?;
505
506 let track_index = track_index.unwrap_or_else(|| {
508 state.tracks.into_iter().find(|t| t.media_type() == MediaType::Video).map(|t| t.index()).unwrap_or(0)
510 });
511
512 let target_trak = moov.trak.get(track_index).ok_or_else(|| not_found_error!("track at index {}", track_index))?;
513 let target_timescale = target_trak.mdia.mdhd.timescale;
514 let target_stbl = &target_trak.mdia.minf.stbl;
515
516 let track_target_dts = timestamp_us * target_timescale as i64 / USEC_PER_SEC;
518
519 let mut target_sample_index = Self::find_sample_index(target_stbl, track_target_dts);
520
521 if !flags.contains(SeekFlags::ANY) {
523 if let Some(ref stss) = target_stbl.stss {
524 let target_sample_number = (target_sample_index + 1) as u32;
525
526 let keyframe_sample = if flags.contains(SeekFlags::BACKWARD) {
527 match stss.entries.partition_point(|s| *s <= target_sample_number) {
529 0 => 1,
530 i => stss.entries[i - 1],
531 }
532 } else {
533 let pos = stss.entries.partition_point(|s| *s < target_sample_number);
535 let candidates = [pos.checked_sub(1).and_then(|i| stss.entries.get(i)), stss.entries.get(pos)];
536 candidates.into_iter().flatten().min_by_key(|s| s.abs_diff(target_sample_number)).copied().unwrap_or(1)
537 };
538
539 target_sample_index = (keyframe_sample - 1) as usize;
540 }
541 }
542 let mut actual_dts = 0i64;
546 let mut accumulated_samples = 0usize;
547 for entry in &target_stbl.stts.entries {
548 if accumulated_samples + entry.sample_count as usize > target_sample_index {
549 actual_dts += (target_sample_index - accumulated_samples) as i64 * entry.sample_delta as i64;
550 break;
551 }
552 actual_dts += entry.sample_count as i64 * entry.sample_delta as i64;
553 accumulated_samples += entry.sample_count as usize;
554 }
555
556 for (trak_idx, trak) in moov.trak.iter().enumerate() {
558 let sample_index = if trak_idx == track_index {
559 target_sample_index
561 } else {
562 let timescale = trak.mdia.mdhd.timescale;
564 let track_dts = actual_dts * timescale as i64 / target_timescale as i64;
565 Self::find_sample_index(&trak.mdia.minf.stbl, track_dts)
566 };
567
568 self.track_sample_indices[trak_idx] = sample_index;
569 }
570
571 Ok(())
572 }
573}
574
575pub struct Mp4DemuxerBuilder;
577
578impl FormatBuilder for Mp4DemuxerBuilder {
579 fn name(&self) -> &'static str {
580 "mp4"
581 }
582
583 fn extensions(&self) -> &[&'static str] {
584 &["mp4", "mov", "m4v", "m4a"]
585 }
586}
587
588impl DemuxerBuilder for Mp4DemuxerBuilder {
589 fn new_demuxer(&self) -> Result<Box<dyn Demuxer>> {
590 Ok(Box::new(Mp4Demuxer::new()))
591 }
592
593 fn probe(&self, reader: &mut dyn Reader) -> bool {
594 let mut buf = [0u8; 8];
595 reader.read_exact(&mut buf).ok();
596
597 matches!(&buf[4..8], b"ftyp" | b"moov" | b"mdat")
598 }
599}