use std::cmp::{min, max};
use std::fs::File;
use ac_ffmpeg::codec::CodecParameters;
use ac_ffmpeg::packet::Packet;
use ac_ffmpeg::time::Timestamp;
use ac_ffmpeg::format::io::IO;
use ac_ffmpeg::format::demuxer::Demuxer;
use ac_ffmpeg::format::demuxer::DemuxerWithStreamInfo;
use ac_ffmpeg::format::muxer::Muxer;
use ac_ffmpeg::format::muxer::OutputFormat;
use crate::DashMpdError;
use crate::fetch::DashDownloader;
fn libav_open_input(path: &str) -> Result<DemuxerWithStreamInfo<File>, DashMpdError> {
let input = File::open(path)
.map_err(|_| DashMpdError::Muxing(String::from("opening libav input path")))?;
let io = IO::from_seekable_read_stream(input);
Demuxer::builder()
.build(io)
.map_err(|_| DashMpdError::Muxing(String::from("building libav demuxer")))?
.find_stream_info(Some(std::time::Duration::new(2, 0)))
.map_err(|(_, _e)| DashMpdError::Muxing(String::from("building libav demuxer")))
}
fn libav_open_output(path: &str, elementary_streams: &[CodecParameters]) -> Result<Muxer<File>, DashMpdError> {
let output_format = OutputFormat::guess_from_file_name(path)
.or_else(|| OutputFormat::find_by_name("mp4"))
.ok_or_else(|| DashMpdError::Muxing(String::from("guessing libav output format")))?;
let output = File::create(path)
.map_err(|e| DashMpdError::Io(e, String::from("creating output file")))?;
let io = IO::from_seekable_write_stream(output);
let mut muxer_builder = Muxer::builder();
for codec_parameters in elementary_streams {
muxer_builder.add_stream(codec_parameters)
.map_err(|_| DashMpdError::Muxing(String::from("adding libav stream to muxer")))?;
}
muxer_builder
.build(io, output_format)
.map_err(|e| DashMpdError::Muxing(
format!("building libav muxer: {:?}", e)))
}
fn has_invalid_timestamps(p: &Packet, last_dts: Timestamp) -> bool {
!last_dts.is_null() && (p.dts().is_null() || p.dts() <= last_dts)
}
pub fn mux_audio_video(
downloader: &DashDownloader,
audio_path: &str,
video_path: &str) -> Result<(), DashMpdError> {
ac_ffmpeg::set_log_callback(|_count, msg: &str| log::info!("ffmpeg: {}", msg));
let mut video_demuxer = libav_open_input(video_path)
.map_err(|_| DashMpdError::Muxing(String::from("opening input video stream")))?;
let (video_pos, video_codec) = video_demuxer
.streams()
.iter()
.enumerate()
.find_map(|(pos, stream)| {
let params = stream.codec_parameters();
if params.is_video_codec() {
return Some((pos, params));
}
None
})
.ok_or_else(|| DashMpdError::Muxing(String::from("finding libav video codec")))?;
let mut audio_demuxer = libav_open_input(audio_path)?;
let (audio_pos, audio_codec) = audio_demuxer
.streams()
.iter()
.enumerate()
.find_map(|(pos, stream)| {
let params = stream.codec_parameters();
if params.is_audio_codec() {
return Some((pos, params));
}
None
})
.ok_or_else(|| DashMpdError::Muxing(String::from("finding libav audio codec")))?;
let out = &downloader.output_path.as_ref().unwrap().to_str()
.ok_or_else(|| DashMpdError::Muxing(String::from("converting output path")))?;
let mut muxer = libav_open_output(out, &[video_codec, audio_codec])?;
let mut last_dts: Timestamp = Timestamp::null();
while let Some(mut pkt) = video_demuxer.take()
.map_err(|_| DashMpdError::Muxing(String::from("fetching video packet from libav demuxer")))? {
if pkt.stream_index() == video_pos {
if has_invalid_timestamps(&pkt, last_dts) {
let next_dts = Timestamp::new(last_dts.timestamp() + 1, last_dts.time_base());
if !pkt.pts().is_null() && pkt.pts() > pkt.dts() {
let mut max = next_dts;
if pkt.pts() > max {
max = pkt.pts();
}
pkt = pkt.with_pts(max);
}
if pkt.pts().is_null() {
pkt = pkt.with_pts(next_dts);
}
pkt = pkt.with_dts(next_dts);
}
if !pkt.pts().is_null() && !pkt.dts().is_null() && pkt.dts() > pkt.pts() {
log::info!("Fixing invalid DTS (dts > pts) in DASH video stream");
let pts_ts = pkt.pts().timestamp();
let dts_ts = pkt.dts().timestamp();
let next_ts = last_dts.timestamp() + 1;
let fixed_dts_ts = pts_ts + dts_ts + next_ts
- min(pts_ts, min(dts_ts, next_ts))
- max(pts_ts, max(dts_ts, next_ts));
let fixed_dts = Timestamp::new(fixed_dts_ts, last_dts.time_base());
pkt = pkt.with_dts(fixed_dts).with_pts(fixed_dts);
}
last_dts = pkt.dts();
muxer.push(pkt.with_stream_index(0))
.map_err(|_| DashMpdError::Muxing(String::from("pushing video packet to libav muxer")))?;
}
}
muxer.flush()
.map_err(|_| DashMpdError::Muxing(String::from("flushing libav muxer")))?;
last_dts = Timestamp::null();
while let Some(mut pkt) = audio_demuxer.take()
.map_err(|_| DashMpdError::Muxing(String::from("fetching audio packet from libav demuxer")))? {
if pkt.stream_index() == audio_pos {
if has_invalid_timestamps(&pkt, last_dts) {
let next_dts = Timestamp::new(last_dts.timestamp() + 1, last_dts.time_base());
if !pkt.pts().is_null() && (pkt.pts() > pkt.dts()) {
let mut max = next_dts;
if pkt.pts() > max {
max = pkt.pts();
}
pkt = pkt.with_pts(max);
}
if pkt.pts().is_null() {
pkt = pkt.with_pts(next_dts);
}
pkt = pkt.with_dts(next_dts);
}
if !pkt.pts().is_null() && !pkt.dts().is_null() && pkt.dts() > pkt.pts() {
log::info!("Fixing invalid DTS (dts > pts) in DASH audio stream");
let pts_ts = pkt.pts().timestamp();
let dts_ts = pkt.dts().timestamp();
let next_ts = last_dts.timestamp() + 1;
let fixed_dts_ts = pts_ts + dts_ts + next_ts
- min(pts_ts, min(dts_ts, next_ts))
- max(pts_ts, max(dts_ts, next_ts));
let fixed_dts = Timestamp::new(fixed_dts_ts, last_dts.time_base());
pkt = pkt.with_dts(fixed_dts).with_pts(fixed_dts);
}
last_dts = pkt.dts();
muxer.push(pkt.with_stream_index(1))
.map_err(|_| DashMpdError::Muxing(String::from("pushing audio packet to libav muxer")))?;
}
}
muxer.flush()
.map_err(|_| DashMpdError::Muxing(String::from("flushing libav muxer")))?;
muxer.close()
.map_err(|_| DashMpdError::Muxing(String::from("closing libav muxer")))?;
Ok(())
}