use fs_err as fs;
use std::net::SocketAddr;
use std::time::Duration;
use std::io::BufReader;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use assert_cmd::cargo::cargo_bin_cmd;
use axum::{routing::get, Router};
use axum::extract::State;
use axum::response::{Response, IntoResponse};
use axum::http::{header, StatusCode};
use axum::body::Body;
use axum_server::tls_rustls::RustlsConfig;
use rustls::{RootCertStore, ServerConfig};
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, pem::PemObject};
use rustls::server::WebPkiClientVerifier;
use dash_mpd::{MPD, Period, AdaptationSet, Representation, BaseURL};
use anyhow::{Context, Result};
use test_log::test;
#[derive(Debug, Default)]
struct AppState {
counter: AtomicUsize,
}
impl AppState {
fn new() -> AppState {
AppState { counter: AtomicUsize::new(0) }
}
}
#[test(tokio::test(flavor = "multi_thread", worker_threads = 2))]
async fn test_add_client_identity() -> Result<(), anyhow::Error> {
let base = BaseURL {
base: "https://localhost:6666/init.mp4".to_string(),
..Default::default()
};
let rep = Representation {
id: Some("1".to_string()),
mimeType: Some("video/mp4".to_string()),
codecs: Some("avc1.640028".to_string()),
width: Some(1920),
height: Some(800),
bandwidth: Some(1980081),
BaseURL: vec!(base),
..Default::default()
};
let adapt = AdaptationSet {
id: Some("1".to_string()),
contentType: Some("video".to_string()),
representations: vec!(rep),
..Default::default()
};
let period = Period {
id: Some("1".to_string()),
duration: Some(Duration::new(5, 0)),
adaptations: vec!(adapt),
..Default::default()
};
let mpd = MPD {
mpdtype: Some("static".to_string()),
periods: vec!(period),
..Default::default()
};
let xml = quick_xml::se::to_string(&mpd)
.context("serializing MPD struct")?;
let shared_state = Arc::new(AppState::new());
async fn send_mp4(State(state): State<Arc<AppState>>) -> Response {
state.counter.fetch_add(1, Ordering::SeqCst);
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "video/mp4")
.body(Body::from(include_bytes!("fixtures/minimal-valid.mp4").as_slice()))
.unwrap()
}
async fn send_status(State(state): State<Arc<AppState>>) -> impl IntoResponse {
([(header::CONTENT_TYPE, "text/plain")], format!("{}", state.counter.load(Ordering::Relaxed)))
}
rustls::crypto::aws_lc_rs::default_provider().install_default().unwrap();
let app = Router::new()
.route("/mpd", get(|| async { ([(header::CONTENT_TYPE, "application/dash+xml")], xml) }))
.route("/init.mp4", get(send_mp4))
.route("/status", get(send_status))
.with_state(shared_state);
let addr = SocketAddr::from(([127, 0, 0, 1], 6666));
let mut client_auth_roots = RootCertStore::empty();
let root_cert = fs::File::open("tests/fixtures/root-CA.crt")?;
for maybe_cert in CertificateDer::pem_reader_iter(&mut BufReader::new(root_cert)) {
client_auth_roots.add(maybe_cert.unwrap()).unwrap();
}
let client_verifier = WebPkiClientVerifier::builder(client_auth_roots.into())
.build()
.unwrap();
let cert_file = fs::File::open("tests/fixtures/localhost-cert.crt")?;
let mut certificates: Vec<CertificateDer<'static>> = Vec::new();
for maybe_cert in CertificateDer::pem_reader_iter(&mut BufReader::new(cert_file)) {
certificates.push(maybe_cert.unwrap());
}
let key_file = fs::File::open("tests/fixtures/localhost-cert.key")?;
let mut keys: Vec<PrivateKeyDer> = Vec::new();
for maybe_key in PrivatePkcs8KeyDer::pem_reader_iter(&mut BufReader::new(key_file)) {
keys.push(PrivateKeyDer::Pkcs8(maybe_key.unwrap()));
}
keys.reverse();
let key = keys.pop().expect("no private key");
let mut config = ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(certificates, key)
.expect("bad server certificate/key");
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let backend = async move {
axum_server::bind_rustls(addr, RustlsConfig::from_config(config.into()))
.serve(app.into_make_service()).await
.unwrap()
};
tokio::spawn(backend);
tokio::time::sleep(Duration::from_millis(1000)).await;
#[cfg(feature = "rustls")]
let client_id = fs::read("tests/fixtures/client-id.pem")?;
#[cfg(feature = "rustls")]
let id = reqwest::Identity::from_pem(&client_id)
.context("reading PEM client identity from certificate")?;
#[cfg(feature = "native-tls")]
let client_cert = fs::read("tests/fixtures/client-cert.crt")?;
#[cfg(feature = "native-tls")]
let client_key = fs::read("tests/fixtures/client-cert.key")?;
#[cfg(feature = "native-tls")]
let id = reqwest::Identity::from_pkcs8_pem(&client_cert, &client_key)
.context("reading PKCS8 client identity from certificate")?;
let crt = fs::read("tests/fixtures/root-CA.crt")?;
let root_cert = reqwest::Certificate::from_pem(&crt)?;
let client = reqwest::Client::builder()
.timeout(Duration::new(30, 0))
.identity(id)
.add_root_certificate(root_cert)
.build()
.context("creating HTTP client")?;
let txt = client.get("https://localhost:6666/status")
.send().await?
.error_for_status()?
.text().await
.context("fetching status")?;
assert!(txt.eq("0"));
cargo_bin_cmd!()
.args(["--add-root-certificate", "tests/fixtures/root-CA.crt",
"https://localhost:6666/mpd"])
.assert()
.failure();
#[cfg(feature = "rustls")]
cargo_bin_cmd!()
.args(["-v", "-v", "-v",
"--add-root-certificate", "tests/fixtures/root-CA.crt",
"--client-identity-certificate", "tests/fixtures/client-id.pem",
"https://localhost:6666/mpd"])
.assert()
.success();
let txt = client.get("https://localhost:6666/status")
.send().await?
.error_for_status()?
.text().await
.context("fetching status")?;
assert!(txt.eq("1"));
Ok(())
}