use std::{
borrow::Cow,
net::SocketAddr,
ops::Deref,
sync::{Arc, RwLock},
time::SystemTime,
};
use bytes::Bytes;
use prost::Message;
use quinn::{ConnectionError, RecvStream, SendStream};
use scion_sdk_reqwest_connect_rpc::token_source::{self, TokenSource};
use scion_sdk_utils::backoff::ExponentialBackoff;
use tokio::{select, task::JoinHandle};
use crate::requests::{
AddrError, SocketAddrAssignmentRequest, SocketAddrAssignmentResponse, TokenUpdateResponse,
system_time_from_unix_epoch_secs,
};
pub const MAX_CTRL_MESSAGE_SIZE: usize = 4096;
pub struct ClientBuilder {
token_source: Arc<dyn TokenSource>,
}
impl ClientBuilder {
pub fn new(token_source: Arc<dyn TokenSource>) -> Self {
ClientBuilder { token_source }
}
pub async fn connect(
self,
conn: quinn::Connection,
) -> Result<(Sender, Receiver, Control), SnapTunError> {
let conn_state = SharedConnState::new(ConnState::new());
let mut ctrl = Control {
conn: conn.clone(),
state: conn_state.clone(),
token_renewal_task: None,
};
let mut token_watch = self.token_source.watch();
let mut initial_token = match token_watch.borrow_and_update().as_ref() {
Some(Ok(token)) => Some(token.clone()),
Some(Err(e)) => return Err(SnapTunError::InitialTokenError(e.to_string())),
None => None,
};
if initial_token.is_none() {
token_watch
.changed()
.await
.map_err(|e| SnapTunError::InitialTokenError(e.to_string()))?;
initial_token = match token_watch.borrow().as_ref() {
Some(Ok(token)) => Some(token.clone()),
Some(Err(e)) => return Err(SnapTunError::InitialTokenError(e.to_string())),
None => None,
};
}
let initial_token = initial_token.ok_or_else(|| {
SnapTunError::InitialTokenError("failed to obtain initial token".into())
})?;
ctrl.state.write().unwrap().snap_token = initial_token;
ctrl.update_token().await?;
ctrl.request_socket_addr().await?;
tracing::trace!("Starting token update task");
ctrl.session_token_update_task(token_watch);
Ok((Sender::new(conn.clone()), Receiver { conn }, ctrl))
}
}
pub struct Control {
conn: quinn::Connection,
state: SharedConnState,
token_renewal_task: Option<JoinHandle<Result<(), RenewTaskError>>>,
}
impl Control {
pub fn assigned_sock_addr(&self) -> Option<SocketAddr> {
self.state.read().expect("no fail").assigned_sock_addr
}
pub fn token_expiry(&self) -> SystemTime {
self.state.read().expect("no fail").token_expiry
}
pub fn snap_token(&self) -> String {
self.state.read().expect("no fail").snap_token.clone()
}
async fn request_socket_addr(&mut self) -> Result<(), ControlError> {
tracing::debug!("Requesting socket address assignment");
let (mut snd, mut rcv) = self.conn.open_bi().await?;
let request = SocketAddrAssignmentRequest {};
let body = request.encode_to_vec();
let token = self.state.read().expect("no fail").snap_token.clone();
send_control_request(
&mut snd,
crate::PATH_SOCK_ADDR_ASSIGNMENT,
body.as_ref(),
&token,
)
.await?;
let mut resp_buf = [0u8; MAX_CTRL_MESSAGE_SIZE];
let response =
recv_response::<SocketAddrAssignmentResponse>(&mut resp_buf[..], &mut rcv).await?;
let sock_addr = response
.socket_addr()
.map_err(|e| ControlError::AddressAssignmentFailed(AddrAssignError::InvalidAddr(e)))?;
let mut sstate = self.state.0.write().expect("no fail");
sstate.assigned_sock_addr = Some(sock_addr);
Ok(())
}
pub async fn update_token(&mut self) -> Result<(), ControlError> {
let token = self.state.read().unwrap().snap_token.clone();
self.set_token_expiry(update_token(&self.conn.clone(), &token).await?);
Ok(())
}
fn session_token_update_task(&mut self, mut token_watch: token_source::TokenSourceWatch) {
let conn = self.conn.clone();
let conn_state = self.state.clone();
self.token_renewal_task = Some(tokio::spawn(async move {
loop {
let expiry = conn_state.read().expect("no fail").token_expiry;
let now = SystemTime::now();
let dur_until_expiry = expiry
.duration_since(now)
.unwrap_or_else(|_| std::time::Duration::from_secs(0));
let expiry_timeout = tokio::time::Instant::now() + dur_until_expiry;
select! {
_ = token_watch.changed() => {}
_ = tokio::time::sleep_until(expiry_timeout) => {
tracing::error!("SNAP token has expired but no new token was received from the token source");
return Err(RenewTaskError::TokenExpired);
},
}
let new_token = token_watch
.borrow_and_update()
.as_ref()
.ok_or_else(|| {
RenewTaskError::TokenSourceError(
"token source watch channel has no value".into(),
)
})?
.as_ref()
.map_err(|e| RenewTaskError::TokenSourceError(e.to_string().into()))?
.clone();
let mut attempt = 0;
const MAX_RETRIES: u32 = 5;
const BACKOFF: ExponentialBackoff = ExponentialBackoff::new(3.0, 30.0, 2.0, 1.0);
tracing::info!("Updating SNAP token on server");
loop {
match update_token(&conn, &new_token).await {
Ok(new_expiry) => {
tracing::info!("Successfully updated SNAP token on server");
{
let mut conn_state = conn_state.write().unwrap();
conn_state.token_expiry = new_expiry;
conn_state.snap_token = new_token.clone();
}
break;
}
Err(err) if attempt > MAX_RETRIES => {
attempt += 1;
tracing::error!(
%attempt,
%err,
"Failed to update SNAP token on server, max retries reached",
);
return Err(RenewTaskError::MaxRetriesReached);
}
Err(err) => {
attempt += 1;
let delay = BACKOFF.duration(attempt);
let next_try = delay.as_secs();
tracing::warn!(
%attempt,
%err,
%next_try,
"Failed to update SNAP token on server",
);
if expiry_timeout <= tokio::time::Instant::now() + delay {
tracing::error!(
"SNAP token has expired before it could be renewed"
);
return Err(RenewTaskError::TokenExpired);
}
tokio::time::sleep(delay).await;
}
}
}
}
}));
}
fn set_token_expiry(&mut self, expiry: SystemTime) {
self.state.write().expect("no fail").token_expiry = expiry;
}
pub async fn closed(&self) -> ConnectionError {
self.conn.closed().await
}
pub fn inner_conn(&self) -> quinn::Connection {
self.conn.clone()
}
pub fn debug_path_stats(&self) -> impl std::fmt::Debug + 'static + use<> {
self.conn.stats().path
}
}
#[derive(Debug, thiserror::Error)]
pub enum RenewTaskError {
#[error("token expired")]
TokenExpired,
#[error("maximum number of retries reached")]
MaxRetriesReached,
#[error("token source failed: {0}")]
TokenSourceError(#[from] token_source::TokenSourceError),
}
pub async fn update_token(
conn: &quinn::Connection,
token: &str,
) -> Result<SystemTime, ControlError> {
let (mut snd, mut rcv) = conn.open_bi().await?;
let body = vec![];
send_control_request(&mut snd, crate::PATH_UPDATE_TOKEN, &body, token).await?;
let mut resp_buf = [0u8; MAX_CTRL_MESSAGE_SIZE];
let response: TokenUpdateResponse = recv_response(&mut resp_buf[..], &mut rcv).await?;
Ok(system_time_from_unix_epoch_secs(response.valid_until))
}
impl Drop for Control {
fn drop(&mut self) {
if let Some(task) = self.token_renewal_task.take() {
task.abort();
}
}
}
#[derive(Debug, Clone)]
struct ConnState {
snap_token: String,
token_expiry: SystemTime,
assigned_sock_addr: Option<SocketAddr>,
}
impl ConnState {
fn new() -> Self {
Self {
snap_token: String::new(),
token_expiry: SystemTime::UNIX_EPOCH,
assigned_sock_addr: None,
}
}
}
#[derive(Debug, Clone)]
struct SharedConnState(Arc<RwLock<ConnState>>);
impl SharedConnState {
fn new(conn_state: ConnState) -> Self {
Self(Arc::new(RwLock::new(conn_state)))
}
}
impl Deref for SharedConnState {
type Target = Arc<RwLock<ConnState>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct Sender {
conn: quinn::Connection,
}
impl Sender {
pub fn new(conn: quinn::Connection) -> Self {
Self { conn }
}
pub fn send_datagram(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
self.conn.send_datagram(data)
}
pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
self.conn.send_datagram_wait(data).await
}
}
#[derive(Debug, Clone)]
pub struct Receiver {
conn: quinn::Connection,
}
impl Receiver {
pub async fn read_datagram(&self) -> Result<Bytes, quinn::ConnectionError> {
self.conn.read_datagram().await
}
}
#[derive(Debug, thiserror::Error)]
pub enum ParseResponseError {
#[error("parsing HTTP envelope failed: {0}")]
HTTParseError(#[from] httparse::Error),
#[error("read error: {0}")]
ReadError(#[from] quinn::ReadError),
#[error("parsing control message failed: {0}")]
ParseError(#[from] prost::DecodeError),
#[error("received bad response: {0}")]
ResponseError(Cow<'static, str>),
}
async fn recv_response<M: prost::Message + Default>(
buf: &mut [u8],
rcv: &mut RecvStream,
) -> Result<M, ParseResponseError> {
let mut cursor = 0;
let mut body_offset = 0;
let mut code = 0;
while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
cursor += n;
let mut headers = [httparse::EMPTY_HEADER; 16];
let mut resp = httparse::Response::new(&mut headers);
match resp.parse(&buf[..cursor])? {
httparse::Status::Partial => {}
httparse::Status::Complete(n) => {
body_offset = n;
code = resp.code.unwrap_or(0);
break;
}
};
if cursor >= buf.len() {
return Err(ParseResponseError::ResponseError(
"response too large".into(),
));
}
}
while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
cursor += n;
if cursor >= buf.len() {
return Err(ParseResponseError::ResponseError(
"response too large".into(),
));
}
}
if code != 200 {
let msg = String::from_utf8_lossy(&buf[body_offset..cursor]).to_string();
return Err(ParseResponseError::ResponseError(msg.into()));
}
let m = M::decode(&buf[body_offset..cursor])?;
Ok(m)
}
#[derive(Debug, thiserror::Error)]
pub enum SendControlRequestError {
#[error("i/o error: {0}")]
IoError(#[from] std::io::Error),
#[error("stream closed: {0}")]
ClosedStream(#[from] quinn::ClosedStream),
}
async fn send_control_request(
snd: &mut SendStream,
method: &str,
body: &[u8],
token: &str,
) -> Result<(), SendControlRequestError> {
write_all(
snd,
format!(
"POST {method} HTTP/1.1\r\n\
content-type: application/proto\r\n\
connect-protocol-version: 1\r\n\
content-encoding: identity\r\n\
accept-encoding: identity\r\n\
content-length: {}\r\n\
Authorization: Bearer {token}\r\n\r\n",
body.len()
)
.as_bytes(),
)
.await?;
write_all(snd, body).await?;
snd.finish()?;
Ok(())
}
async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
let mut cursor = 0;
while cursor < data.len() {
cursor += stream.write(&data[cursor..]).await?;
}
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub enum SnapTunError {
#[error("initial token error: {0}")]
InitialTokenError(String),
#[error("control error: {0}")]
ControlError(#[from] ControlError),
}
#[derive(Debug, thiserror::Error)]
pub enum ControlError {
#[error("quinn connection error: {0}")]
ConnectionError(#[from] quinn::ConnectionError),
#[error("address assignment failed: {0}")]
AddressAssignmentFailed(#[from] AddrAssignError),
#[error("parse control request response: {0}")]
ParseResponse(#[from] ParseResponseError),
#[error("send control request error: {0}")]
SendRequestError(#[from] SendControlRequestError),
}
#[derive(Debug, thiserror::Error)]
pub enum AddrAssignError {
#[error("invalid addr: {0}")]
InvalidAddr(#[from] AddrError),
#[error("no address assigned")]
NoAddressAssigned,
}