use std::{
net::SocketAddr,
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use assert_matches::assert_matches;
use bytes::Bytes;
use quinn::{Endpoint, TransportConfig, crypto::rustls::QuicClientConfig, rustls};
use rustls::ClientConfig;
use scion_sdk_observability::metrics::registry::MetricsRegistry;
use scion_sdk_reqwest_connect_rpc::token_source::{
mock::MockTokenSource, static_token::StaticTokenSource,
};
use scion_sdk_token_validator::validator::{Token, TokenValidator, TokenValidatorError};
use serde::{Deserialize, Serialize};
use snap_tun::{
client::{ClientBuilder, Control, Receiver, Sender},
metrics::Metrics,
server_deprecated::ControlError,
};
use tokio::task::JoinSet;
#[test_log::test(tokio::test)]
pub async fn assign_address_and_retrieve_echoed_packet() {
scion_sdk_utils::test::install_rustls_crypto_provider();
let (quic_client, quic_srv) = quic_endpoint_pair();
let srv_addr = quic_srv.local_addr().expect("no fail");
let srv = prepare_snaptun_server(MagicAuthorizer::default());
let mut js = JoinSet::<()>::new();
js.spawn(run_server(quic_srv, srv));
let (tx, rx, _ctrl) = prepare_snaptun_client(quic_client, srv_addr).await;
let n_packets = 64u16;
js.spawn(async move {
for i in 0..n_packets {
let p = gen_packet(i, n_packets);
tx.send_datagram(p).expect("no fail");
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
});
for i in 0..n_packets {
let p = rx.read_datagram().await.expect("no fail");
assert_eq!(gen_packet(i, n_packets), p);
}
}
#[test_log::test(tokio::test)]
pub async fn assign_sock_addr_and_retrieve_echoed_packet() {
scion_sdk_utils::test::install_rustls_crypto_provider();
let (quic_client, quic_srv) = quic_endpoint_pair();
let srv_addr = quic_srv.local_addr().expect("no fail");
let srv = prepare_snaptun_server(MagicAuthorizer::default());
let mut js = JoinSet::<()>::new();
js.spawn(run_server(quic_srv, srv));
let (tx, rx, _ctrl) = prepare_snaptun_client(quic_client, srv_addr).await;
let n_packets = 64u16;
js.spawn(async move {
for i in 0..n_packets {
let p = gen_packet(i, n_packets);
tx.send_datagram(p).expect("no fail");
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
});
for i in 0..n_packets {
let p = rx.read_datagram().await.expect("no fail");
assert_eq!(gen_packet(i, n_packets), p);
}
}
#[test_log::test(tokio::test)]
pub async fn token_enforcement() {
scion_sdk_utils::test::install_rustls_crypto_provider();
let (quic_client, quic_srv) = quic_endpoint_pair();
let srv_addr = quic_srv.local_addr().expect("no fail");
let srv = prepare_snaptun_server(MagicAuthorizer::new(1));
let join_handle = tokio::spawn(async move {
let incoming = quic_srv.accept().await.expect("no fail");
let conn = incoming.await.expect("no fail");
let (_tx, _rx, ctrl) = srv.accept_with_timeout(conn).await.expect("no fail");
let res = ctrl.await;
assert_matches!(res, Err(ControlError::TokenExpired));
});
let (_tx, _rx, _ctrl) = prepare_snaptun_client(quic_client, srv_addr).await;
join_handle.await.expect("no fail");
}
#[test_log::test(tokio::test)]
pub async fn manual_token_renewal() {
scion_sdk_utils::test::install_rustls_crypto_provider();
let (quic_client, quic_srv) = quic_endpoint_pair();
let srv_addr = quic_srv.local_addr().expect("no fail");
let srv = prepare_snaptun_server(MagicAuthorizer::default());
let mut js = JoinSet::<()>::new();
js.spawn(run_server(quic_srv, srv));
let (_tx, _rx, mut ctrl) = prepare_snaptun_client(quic_client, srv_addr).await;
let validity_before = ctrl.token_expiry();
tokio::time::sleep(Duration::from_secs(1)).await;
let res = ctrl.update_token().await;
assert!(res.is_ok(), "token update should succeed: {res:?}");
let validity_after = ctrl.token_expiry();
assert!(
validity_after > validity_before,
"token expiry must be extended {:?} > {:?}",
chrono::DateTime::<chrono::Utc>::from(validity_after),
chrono::DateTime::<chrono::Utc>::from(validity_before)
);
}
#[test_log::test(tokio::test)]
pub async fn auto_token_renewal() {
scion_sdk_utils::test::install_rustls_crypto_provider();
let (quic_client, quic_srv) = quic_endpoint_pair();
let srv_addr = quic_srv.local_addr().expect("no fail");
let srv = prepare_snaptun_server(MagicAuthorizer::new(3));
let mut js = JoinSet::<()>::new();
js.spawn(run_server(quic_srv, srv));
let token_src = Arc::new(MockTokenSource::new(MAGIC_TOKEN.to_string()));
let c = quic_client
.connect(srv_addr, "localhost")
.expect("no fail")
.await
.expect("no_fail");
let (_tx, _rx, ctrl) = ClientBuilder::new(token_src.clone())
.connect(c)
.await
.expect("no fail");
let validity_before = ctrl.token_expiry();
tokio::time::sleep(Duration::from_secs(1)).await;
token_src.update_token(MAGIC_TOKEN.to_string());
tokio::time::sleep(Duration::from_secs(1)).await;
let validity_after = ctrl.token_expiry();
assert!(
validity_after > validity_before,
"token expiry should have changed {:?} > {:?}",
chrono::DateTime::<chrono::Utc>::from(validity_after),
chrono::DateTime::<chrono::Utc>::from(validity_before)
);
}
fn prepare_snaptun_server(
validator: MagicAuthorizer,
) -> snap_tun::server_deprecated::Server<DummyToken> {
snap_tun::server_deprecated::Server::new(
Arc::new(validator),
Metrics::new(&MetricsRegistry::new()),
)
}
async fn prepare_snaptun_client(
quic_client: Endpoint,
srv_addr: SocketAddr,
) -> (Sender, Receiver, Control) {
let c = quic_client
.connect(srv_addr, "localhost")
.expect("no fail")
.await
.expect("no_fail");
let client_builder = ClientBuilder::new(Arc::new(StaticTokenSource::from(MAGIC_TOKEN)));
let (tx, rx, ctrl) = client_builder.connect(c).await.expect("no fail");
assert_eq!(
ctrl.assigned_sock_addr(),
Some(quic_client.local_addr().unwrap())
);
(tx, rx, ctrl)
}
async fn run_server(ep: Endpoint, srv: snap_tun::server_deprecated::Server<DummyToken>) {
let mut js = JoinSet::<()>::new();
while let Some(c) = ep.accept().await {
let c = c.await.expect("no fail");
let (tx, rx, ctrl) = srv.accept_with_timeout(c).await.expect("no fail");
js.spawn(async move {
match ctrl.await {
Ok(_) => {
tracing::info!("Session control stream closed gracefully");
}
Err(e) => {
tracing::warn!("Session control stream closed with error: {}", e);
}
}
});
js.spawn(async move {
loop {
let packet = rx.receive().await.expect("no fail");
tx.send_wait(packet).await.expect("no fail");
}
});
}
}
fn quic_endpoint_pair() -> (quinn::Endpoint, quinn::Endpoint) {
let (_cert, config) = scion_sdk_utils::test::generate_cert(
[42u8; 32],
vec!["localhost".into()],
vec![b"snaptun".to_vec()],
);
let sock_addr = "127.0.0.1:0".parse().expect("no fail");
let server_ep = quinn::Endpoint::server(config, sock_addr).expect("no fail");
let mut client_ep = quinn::Endpoint::client(sock_addr).expect("no fail");
client_ep.set_default_client_config(client_config());
(client_ep, server_ep)
}
fn client_config() -> quinn::ClientConfig {
let (cert_der, _config) = scion_sdk_utils::test::generate_cert(
[42u8; 32],
vec!["localhost".into()],
vec![b"snaptun".to_vec()],
);
let mut roots = rustls::RootCertStore::empty();
roots.add(cert_der).unwrap();
let mut client_crypto = ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
client_crypto.alpn_protocols = vec![b"snaptun".into()];
let mut transport_config = TransportConfig::default();
transport_config.max_idle_timeout(None);
let transport_config_arc = Arc::new(transport_config);
let mut client_config =
quinn::ClientConfig::new(Arc::new(QuicClientConfig::try_from(client_crypto).unwrap()));
client_config.transport_config(transport_config_arc);
client_config
}
fn gen_packet(idx: u16, total: u16) -> Bytes {
Bytes::from(format!("Packet {}/{}", idx + 1, total))
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DummyToken {
pub exp: u64,
}
impl Token for DummyToken {
fn id(&self) -> String {
"dummy_token".to_string()
}
fn exp_time(&self) -> SystemTime {
UNIX_EPOCH + Duration::from_secs(self.exp)
}
fn required_claims() -> Vec<&'static str> {
vec!["exp"]
}
}
const MAGIC_TOKEN: &str = "ANAPAYA";
struct MagicAuthorizer {
token_validity: u64,
}
impl Default for MagicAuthorizer {
fn default() -> Self {
Self { token_validity: 60 }
}
}
impl MagicAuthorizer {
pub fn new(token_validity: u64) -> Self {
Self { token_validity }
}
}
impl TokenValidator<DummyToken> for MagicAuthorizer {
fn validate(
&self,
now: std::time::SystemTime,
token: &str,
) -> Result<DummyToken, TokenValidatorError> {
match token {
MAGIC_TOKEN => {
Ok(DummyToken {
exp: now.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs()
+ self.token_validity,
})
}
_ => Err(TokenValidatorError::TokenExpired(std::time::UNIX_EPOCH)),
}
}
}