pub mod common;
use std::fs;
use std::env;
use std::net::SocketAddr;
use std::time::Duration;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use axum::{routing::get, Router};
use axum::extract::{Path, State};
use axum::response::{Response, IntoResponse};
use axum::http::{header, StatusCode};
use axum::body::Body;
use axum_server::{Handle, bind};
use dash_mpd::{MPD, Period, AdaptationSet, Representation, SegmentList};
use dash_mpd::{SegmentTemplate, SegmentURL};
use dash_mpd::fetch::{DashDownloader, parse_resolving_xlinks};
use anyhow::{Context, Result};
use common::{generate_minimal_mp4, setup_logging};
#[derive(Debug, Default)]
struct AppState {
counter: AtomicUsize,
}
impl AppState {
fn new() -> AppState {
AppState { counter: AtomicUsize::new(0) }
}
}
fn add_xml_namespaces_recurse(element: &xmlem::Element, doc: &mut xmlem::Document) {
if element.attribute(doc, "href").is_some() {
element.set_attribute(doc, "xmlns:xlink", "http://www.w3.org/1999/xlink");
}
for child in element.children(doc).iter_mut() {
add_xml_namespaces_recurse(child, doc);
}
}
fn add_xml_namespaces(xml: &str) -> Result<String> {
let mut doc = xmlem::Document::from_str(xml).expect("xmlem parsing");
add_xml_namespaces_recurse(&doc.root(), &mut doc);
Ok(doc.to_string_pretty())
}
fn make_segment_list(urls: Vec<&str>) -> SegmentList {
let mut segment_urls = Vec::new();
for u in urls {
segment_urls.push(SegmentURL { media: Some(String::from(u)), ..Default::default() });
}
SegmentList { segment_urls, ..Default::default() }
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_xlink_retrieval() -> Result<()> {
setup_logging();
if env::var("CI").is_ok() {
return Ok(());
}
let segment_template1 = SegmentTemplate {
initialization: Some("/media/f1.mp4".to_string()),
..Default::default()
};
let rep1 = Representation {
id: Some("1".to_string()),
mimeType: Some("video/mp4".to_string()),
codecs: Some("avc1.640028".to_string()),
width: Some(1920),
height: Some(800),
bandwidth: Some(1980081),
SegmentTemplate: Some(segment_template1),
..Default::default()
};
let rep2 = Representation {
href: Some("http://localhost:6666/remote/representation.xml".to_string()),
actuate: Some("onLoad".to_string()),
..Default::default()
};
let remote_rep = Representation {
id: Some("rr1".to_string()),
width: Some(600),
height: Some(400),
SegmentList: Some(make_segment_list(vec!("/media/f2.mp4", "/media/f3.mp4"))),
..Default::default()
};
let adapt1 = AdaptationSet {
id: Some("1".to_string()),
contentType: Some("video".to_string()),
representations: vec!(rep1),
..Default::default()
};
let adapt2 = AdaptationSet {
id: Some("2".to_string()),
contentType: Some("video".to_string()),
representations: vec!(rep2),
..Default::default()
};
let period1 = Period {
id: Some("1".to_string()),
duration: Some(Duration::new(5, 0)),
adaptations: vec!(adapt1.clone()),
..Default::default()
};
let period2 = Period {
id: Some("2".to_string()),
href: Some("/remote/period2.xml".to_string()),
actuate: Some("onLoad".to_string()),
..Default::default()
};
let period3 = Period {
id: Some("3".to_string()),
href: Some("urn:mpeg:dash:resolve-to-zero:2013".to_string()),
..Default::default()
};
let remote_period1 = Period {
id: Some("r1".to_string()),
duration: Some(Duration::new(5, 0)),
adaptations: vec!(adapt1),
..Default::default()
};
let remote_period2 = Period {
id: Some("r2".to_string()),
duration: Some(Duration::new(5, 0)),
adaptations: vec!(adapt2),
..Default::default()
};
let mpd = MPD {
xmlns: Some("urn:mpeg:dash:schema:mpd:2011".to_string()),
mpdtype: Some("static".to_string()),
xlink: Some("http://www.w3.org/1999/xlink".to_string()),
periods: vec!(period1, period2, period3),
..Default::default()
};
let xml = mpd.to_string();
let xml = add_xml_namespaces(&xml)?;
let remote_period1_xml = quick_xml::se::to_string(&remote_period1)?;
let remote_period1_xml = add_xml_namespaces(&remote_period1_xml)?;
let remote_period2_xml = quick_xml::se::to_string(&remote_period2)?;
let remote_period2_xml = add_xml_namespaces(&remote_period2_xml)?;
let remote_period_xml = remote_period1_xml.clone() + &remote_period2_xml;
let remote_rep = quick_xml::se::to_string(&remote_rep)?;
let remote_representation_xml = add_xml_namespaces(&remote_rep)?;
println!("xlink3 XML> {}", remote_representation_xml);
let shared_state = Arc::new(AppState::new());
async fn send_mp4(Path(_): Path<String>, State(state): State<Arc<AppState>>) -> Response {
state.counter.fetch_add(1, Ordering::SeqCst);
let data = generate_minimal_mp4();
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "video/mp4")
.body(Body::from(data))
.unwrap()
}
async fn send_status(State(state): State<Arc<AppState>>) -> impl IntoResponse {
([(header::CONTENT_TYPE, "text/plain")], format!("{}", state.counter.load(Ordering::Relaxed)))
}
let app = Router::new()
.route("/mpd", get(
|| async { ([(header::CONTENT_TYPE, "application/dash+xml")], xml) }))
.route("/remote/period2.xml", get(
|| async { ([(header::CONTENT_TYPE, "application/dash+xml")], remote_period_xml) }))
.route("/remote/representation.xml", get(
|| async { ([(header::CONTENT_TYPE, "application/dash+xml")], remote_representation_xml) }))
.route("/media/{segment}", get(send_mp4))
.route("/status", get(send_status))
.with_state(shared_state);
let server_handle: Handle<SocketAddr> = Handle::new();
let backend_handle = server_handle.clone();
let backend = async move {
bind("127.0.0.1:6666".parse().unwrap())
.handle(backend_handle)
.serve(app.into_make_service()).await
.unwrap()
};
tokio::spawn(backend);
tokio::time::sleep(Duration::from_millis(1000)).await;
let client = reqwest::Client::builder()
.timeout(Duration::new(10, 0))
.build()
.context("creating HTTP client")?;
let txt = client.get("http://localhost:6666/status")
.send().await?
.error_for_status()?
.text().await
.context("fetching status")?;
assert!(txt.eq("0"), "Expecting 0 segment requests, got {txt}");
let mpd_url = "http://localhost:6666/mpd";
let dl = DashDownloader::new(mpd_url)
.with_http_client(client.clone());
let xml = client.get(mpd_url)
.send().await?
.error_for_status()?
.bytes().await
.context("fetching MPD")?;
let mpd: MPD = parse_resolving_xlinks(&dl, &xml).await
.context("parsing DASH XML resolving xlinks")?;
assert_eq!(mpd.periods.len(), 3);
assert!(mpd.periods.iter().any(|p| p.id.as_ref().is_some_and(|id| id.eq("r2"))));
let outpath = env::temp_dir().join("xlinked.mp4");
DashDownloader::new(mpd_url)
.verbosity(0)
.download_to(&outpath).await
.unwrap();
assert!(fs::metadata(outpath).is_ok());
let txt = client.get("http://localhost:6666/status")
.send().await?
.error_for_status()?
.text().await
.context("fetching status")?;
assert!(txt.eq("4"), "Expecting 4 segment requests, got {txt}");
server_handle.shutdown();
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_xlink_errors() -> Result<()> {
let period1 = Period {
id: Some("2".to_string()),
href: Some("/remote/period.xml".to_string()),
actuate: Some("onLoad".to_string()),
..Default::default()
};
let remote_period = Period {
id: Some("r1".to_string()),
href: Some("/remote/failure.xml".to_string()),
actuate: Some("onLoad".to_string()),
..Default::default()
};
let mpd = MPD {
xmlns: Some("urn:mpeg:dash:schema:mpd:2011".to_string()),
mpdtype: Some("static".to_string()),
xlink: Some("http://www.w3.org/1999/xlink".to_string()),
periods: vec!(period1),
..Default::default()
};
let xml = mpd.to_string();
let xml = add_xml_namespaces(&xml)?;
let remote_period_xml = quick_xml::se::to_string(&remote_period)?;
let remote_period_xml = add_xml_namespaces(&remote_period_xml)?;
setup_logging();
let app = Router::new()
.route("/mpd", get(
|| async { ([(header::CONTENT_TYPE, "application/dash+xml")], xml) }))
.route("/remote/period.xml", get(
|| async { ([(header::CONTENT_TYPE, "application/dash+xml")], remote_period_xml) }));
let server_handle: Handle<SocketAddr> = Handle::new();
let backend_handle = server_handle.clone();
let backend = async move {
bind("127.0.0.1:6669".parse().unwrap())
.handle(backend_handle)
.serve(app.into_make_service()).await
.unwrap()
};
tokio::spawn(backend);
tokio::time::sleep(Duration::from_millis(500)).await;
let outpath = env::temp_dir().join("nonexistent.mp4");
assert!(DashDownloader::new("http://localhost:6669/mpd")
.download_to(&outpath).await
.is_err());
server_handle.shutdown();
Ok(())
}