use std::{
net::IpAddr,
ops::Deref,
pin::Pin,
sync::{Arc, RwLock},
time::{Duration, SystemTime},
};
use bytes::Bytes;
use prost::Message;
use quinn::{RecvStream, SendStream};
use scion_proto::address::EndhostAddr;
use tokio::{sync::watch, task::JoinHandle};
use tracing::debug;
use crate::requests::{
AddrError, AddressAssignRequest, AddressAssignResponse, AddressRange, SessionRenewalResponse,
system_time_from_unix_epoch_secs,
};
pub const CTRL_RESPONSE_BUF_SIZE: usize = 4096;
pub const DEFAULT_RENEWAL_WAIT_THRESHOLD: Duration = Duration::from_secs(300);
pub type TokenRenewError = Box<dyn std::error::Error + Sync + Send>;
pub type TokenRenewFn = Box<
dyn Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenRenewError>> + Send>> + Send + Sync,
>;
pub struct AutoSessionRenewal {
token_renewer: TokenRenewFn,
renew_wait_threshold: Duration,
}
impl AutoSessionRenewal {
pub fn new(renew_wait_threshold: Duration, token_renewer: TokenRenewFn) -> Self {
AutoSessionRenewal {
token_renewer,
renew_wait_threshold,
}
}
}
pub struct ClientBuilder {
desired_addresses: Vec<EndhostAddr>,
initial_session_token: String,
auto_session_renewal: Option<AutoSessionRenewal>,
}
impl ClientBuilder {
pub fn new<S: AsRef<str>>(initial_session_token: S) -> Self {
ClientBuilder {
desired_addresses: Vec::new(),
initial_session_token: initial_session_token.as_ref().into(),
auto_session_renewal: None,
}
}
pub fn with_desired_addresses(mut self, desired_addresses: Vec<EndhostAddr>) -> Self {
self.desired_addresses = desired_addresses;
self
}
pub fn with_auto_session_renewal(mut self, session_renewal: AutoSessionRenewal) -> Self {
self.auto_session_renewal = Some(session_renewal);
self
}
pub async fn connect(
self,
conn: quinn::Connection,
) -> Result<(Sender, Receiver, Control), SnapTunError> {
let (expiry_sender, expiry_receiver) = watch::channel(());
let conn_state = SharedConnState::new(ConnState::new(expiry_sender.clone()));
let mut ctrl = Control {
conn: conn.clone(),
state: conn_state.clone(),
session_renewal_task: None,
};
ctrl.state.write().expect("no fail").session_token = self.initial_session_token;
ctrl.renew_session().await?;
ctrl.request_address(self.desired_addresses).await?;
if let Some(auto_session_renewal) = self.auto_session_renewal {
ctrl.start_auto_session_renewal(auto_session_renewal, expiry_receiver);
}
Ok((Sender::new(conn.clone()), Receiver { conn }, ctrl))
}
}
pub struct Control {
conn: quinn::Connection,
state: SharedConnState,
session_renewal_task: Option<JoinHandle<Result<(), RenewTaskError>>>,
}
impl Control {
pub fn assigned_addresses(&self) -> Vec<EndhostAddr> {
self.state
.read()
.expect("no fail")
.assigned_addresses
.clone()
}
pub fn session_expiry(&self) -> SystemTime {
self.state.read().expect("no fail").session_expiry
}
async fn request_address(
&mut self,
desired_addresses: Vec<EndhostAddr>,
) -> Result<(), ControlError> {
debug!(?desired_addresses, "Requesting address assignment");
let (mut snd, mut rcv) = self.conn.open_bi().await?;
let request = AddressAssignRequest {
requested_addresses: desired_addresses
.into_iter()
.map(|addr| {
let (version, prefix_length, octets) = match addr.local_address() {
IpAddr::V4(a) => (4, 32, a.octets().to_vec()),
IpAddr::V6(a) => (6, 128, a.octets().to_vec()),
};
AddressRange {
isd_as: addr.isd_asn().into(),
ip_version: version as u32,
prefix_length: prefix_length as u32,
address: octets,
}
})
.collect::<Vec<_>>(),
};
let body = request.encode_to_vec();
let token = self.state.read().expect("no fail").session_token.clone();
send_control_request(&mut snd, crate::PATH_ADDR_ASSIGNMENT, body.as_ref(), &token).await?;
let mut resp_buf = [0u8; CTRL_RESPONSE_BUF_SIZE];
let response: AddressAssignResponse =
parse_http_response(&mut resp_buf[..], &mut rcv).await?;
if response.assigned_addresses.is_empty() {
return Err(ControlError::AddressAssignmentFailed(
AddrAssignError::NoAddressAssigned,
));
}
let assigned_addresses = response
.assigned_addresses
.iter()
.map(|address_range| {
TryInto::<EndhostAddr>::try_into(address_range).map_err(|e| {
ControlError::AddressAssignmentFailed(AddrAssignError::InvalidAddr(e))
})
})
.collect::<Result<Vec<_>, _>>()?;
debug!(?assigned_addresses, "Got address assignment");
self.state.write().expect("no fail").assigned_addresses = assigned_addresses;
Ok(())
}
pub async fn renew_session(&mut self) -> Result<(), ControlError> {
let token = self.state.read().expect("no fail").session_token.clone();
self.set_session_expiry(renew_session(&self.conn.clone(), &token).await?);
Ok(())
}
fn start_auto_session_renewal(
&mut self,
config: AutoSessionRenewal,
mut expiry_notifier: watch::Receiver<()>,
) {
let conn = self.conn.clone();
let conn_state = self.state.clone();
self.session_renewal_task = Some(tokio::spawn(async move {
const MAX_RETRIES: u32 = 5;
const BASE_RETRY_DELAY_SECS: u64 = 3;
const SLEEP_FRACTION: f32 = 0.75;
let mut retries: u32 = 0;
loop {
let secs_until_expiry = {
let expiry = conn_state.read().expect("no fail").session_expiry;
match expiry.duration_since(SystemTime::now()) {
Ok(duration) => duration.as_secs(),
Err(_) => {
tracing::error!("Session expiry already passed, stopping auto-renewal");
return Err(RenewTaskError::SessionExpired);
}
}
};
let sleep_secs = if secs_until_expiry < config.renew_wait_threshold.as_secs() {
0
} else {
(secs_until_expiry as f32 * SLEEP_FRACTION) as u64
};
debug!("Next session renewal in {sleep_secs} seconds");
tokio::select! {
_ = expiry_notifier.changed() => continue,
_ = tokio::time::sleep(Duration::from_secs(sleep_secs)) => {
debug!("Renewing token and snaptun session");
let token = match (config.token_renewer)().await {
Ok(token) => token,
Err(err) => {
debug!(%err, "Failed to renew token, retry");
retries += 1;
if retries >= MAX_RETRIES {
return Err(RenewTaskError::MaxRetriesReached);
}
tokio::time::sleep(Duration::from_secs(BASE_RETRY_DELAY_SECS.pow(retries))).await;
continue;
},
};
let new_expiry = match renew_session(&conn, &token).await {
Ok(exp) => exp,
Err(err) => {
debug!(%err, "Failed to renew session, retry");
retries += 1;
if retries >= MAX_RETRIES {
return Err(RenewTaskError::MaxRetriesReached);
}
tokio::time::sleep(Duration::from_secs(BASE_RETRY_DELAY_SECS.pow(retries))).await;
continue;
}
};
debug!(new_expiry=%chrono::DateTime::<chrono::Utc>::from(new_expiry).to_rfc3339(), "auto session renewal successful");
conn_state.write().expect("no fail").session_expiry = new_expiry;
retries = 0;
}
}
}
}));
}
fn set_session_expiry(&mut self, expiry: SystemTime) {
self.state.write().expect("no fail").session_expiry = expiry;
if self
.state
.read()
.expect("no fail")
.expiry_notifier
.send(())
.is_err()
{
debug!("Failed to notify session expiry update");
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum RenewTaskError {
#[error("session expired")]
SessionExpired,
#[error("maximum number of retries reached")]
MaxRetriesReached,
}
pub async fn renew_session(
conn: &quinn::Connection,
token: &str,
) -> Result<SystemTime, ControlError> {
let (mut snd, mut rcv) = conn.open_bi().await?;
let body = vec![];
send_control_request(&mut snd, crate::PATH_SESSION_RENEWAL, &body, token).await?;
let mut resp_buf = [0u8; CTRL_RESPONSE_BUF_SIZE];
let response: SessionRenewalResponse = parse_http_response(&mut resp_buf[..], &mut rcv).await?;
Ok(system_time_from_unix_epoch_secs(response.valid_until))
}
impl Drop for Control {
fn drop(&mut self) {
if let Some(task) = self.session_renewal_task.take() {
task.abort();
}
}
}
#[derive(Debug, Clone)]
struct ConnState {
session_token: String,
session_expiry: SystemTime,
assigned_addresses: Vec<EndhostAddr>,
expiry_notifier: watch::Sender<()>,
}
impl ConnState {
fn new(expiry_notifier: watch::Sender<()>) -> Self {
Self {
session_token: String::new(),
session_expiry: SystemTime::UNIX_EPOCH,
assigned_addresses: Vec::new(),
expiry_notifier,
}
}
}
#[derive(Debug, Clone)]
struct SharedConnState(Arc<RwLock<ConnState>>);
impl SharedConnState {
fn new(conn_state: ConnState) -> Self {
Self(Arc::new(RwLock::new(conn_state)))
}
}
impl Deref for SharedConnState {
type Target = Arc<RwLock<ConnState>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct Sender {
conn: quinn::Connection,
}
impl Sender {
pub fn new(conn: quinn::Connection) -> Self {
Self { conn }
}
pub fn send_datagram(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
self.conn.send_datagram(data)?;
Ok(())
}
pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
self.conn.send_datagram_wait(data).await?;
Ok(())
}
}
pub struct Receiver {
conn: quinn::Connection,
}
impl Receiver {
pub async fn read_datagram(&self) -> Result<Bytes, quinn::ConnectionError> {
let packet = self.conn.read_datagram().await?;
Ok(packet)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ParseResponseError {
#[error("parsing HTTP envelope failed: {0}")]
HTTParseError(#[from] httparse::Error),
#[error("read error: {0}")]
ReadError(#[from] quinn::ReadError),
#[error("parsing control message failed: {0}")]
ParseError(#[from] prost::DecodeError),
}
async fn parse_http_response<M: prost::Message + Default>(
buf: &mut [u8],
rcv: &mut RecvStream,
) -> Result<M, ParseResponseError> {
let mut cursor = 0usize;
let mut body_offset = 0usize;
while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
cursor += n;
let mut headers = [httparse::EMPTY_HEADER; 16];
let mut resp = httparse::Response::new(&mut headers);
body_offset = match resp.parse(&buf[..cursor]) {
Ok(httparse::Status::Partial) => continue,
Ok(httparse::Status::Complete(n)) => n,
Err(e) => return Err(ParseResponseError::HTTParseError(e)),
};
}
while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
cursor += n;
}
let m = M::decode(&buf[body_offset..cursor])?;
Ok(m)
}
#[derive(Debug, thiserror::Error)]
pub enum SendControlRequestError {
#[error("i/o error: {0}")]
IoError(#[from] std::io::Error),
#[error("stream closed: {0}")]
ClosedStream(#[from] quinn::ClosedStream),
}
async fn send_control_request(
snd: &mut SendStream,
method: &str,
body: &[u8],
token: &str,
) -> Result<(), SendControlRequestError> {
write_all(
snd,
format!(
"POST {method} HTTP/1.1\r\n\
content-type: application/proto\r\n\
connect-protocol-version: 1\r\n\
content-encoding: identity\r\n\
accept-encoding: identity\r\n\
content-length: {}\r\n\
Authorization: Bearer {token}\r\n\r\n",
body.len()
)
.as_bytes(),
)
.await?;
write_all(snd, body).await?;
snd.finish()?;
Ok(())
}
async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
let mut cursor = 0;
while cursor < data.len() {
cursor += stream.write(&data[cursor..]).await?;
}
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub enum SnapTunError {
#[error("initial token error: {0}")]
InitialTokenError(#[from] TokenRenewError),
#[error("control error: {0}")]
ControlError(#[from] ControlError),
}
#[derive(Debug, thiserror::Error)]
pub enum ControlError {
#[error("quinn connection error: {0}")]
ConnectionError(#[from] quinn::ConnectionError),
#[error("address assignment failed: {0}")]
AddressAssignmentFailed(#[from] AddrAssignError),
#[error("parse control request response: {0}")]
ParseResponse(#[from] ParseResponseError),
#[error("send control request error: {0}")]
SendRequestError(#[from] SendControlRequestError),
}
#[derive(Debug, thiserror::Error)]
pub enum AddrAssignError {
#[error("invalid addr: {0}")]
InvalidAddr(#[from] AddrError),
#[error("no address assigned")]
NoAddressAssigned,
}