1use std::{
17 net::IpAddr,
18 ops::Deref,
19 pin::Pin,
20 sync::{Arc, RwLock},
21 time::{Duration, SystemTime},
22};
23
24use bytes::Bytes;
25use prost::Message;
26use quinn::{RecvStream, SendStream};
27use scion_proto::address::EndhostAddr;
28use tokio::{sync::watch, task::JoinHandle};
29use tracing::debug;
30
31use crate::requests::{
32 AddrError, AddressAssignRequest, AddressAssignResponse, AddressRange, SessionRenewalResponse,
33 system_time_from_unix_epoch_secs,
34};
35
36pub const CTRL_RESPONSE_BUF_SIZE: usize = 4096;
39
40pub const DEFAULT_RENEWAL_WAIT_THRESHOLD: Duration = Duration::from_secs(300); pub type TokenRenewError = Box<dyn std::error::Error + Sync + Send>;
46
47pub type TokenRenewFn = Box<
49 dyn Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenRenewError>> + Send>> + Send + Sync,
50>;
51
52pub struct AutoSessionRenewal {
54 token_renewer: TokenRenewFn,
55 renew_wait_threshold: Duration,
56}
57
58impl AutoSessionRenewal {
59 pub fn new(renew_wait_threshold: Duration, token_renewer: TokenRenewFn) -> Self {
65 AutoSessionRenewal {
66 token_renewer,
67 renew_wait_threshold,
68 }
69 }
70}
71
72pub struct ClientBuilder {
74 desired_addresses: Vec<EndhostAddr>,
75 initial_session_token: String,
76 auto_session_renewal: Option<AutoSessionRenewal>,
77}
78
79impl ClientBuilder {
80 pub fn new<S: AsRef<str>>(initial_session_token: S) -> Self {
82 ClientBuilder {
83 desired_addresses: Vec::new(),
84 initial_session_token: initial_session_token.as_ref().into(),
85 auto_session_renewal: None,
86 }
87 }
88
89 pub fn with_desired_addresses(mut self, desired_addresses: Vec<EndhostAddr>) -> Self {
92 self.desired_addresses = desired_addresses;
93 self
94 }
95
96 pub fn with_auto_session_renewal(mut self, session_renewal: AutoSessionRenewal) -> Self {
98 self.auto_session_renewal = Some(session_renewal);
99 self
100 }
101
102 pub async fn connect(
104 self,
105 conn: quinn::Connection,
106 ) -> Result<(Sender, Receiver, Control), SnapTunError> {
107 let (expiry_sender, expiry_receiver) = watch::channel(());
108 let conn_state = SharedConnState::new(ConnState::new(expiry_sender.clone()));
109 let mut ctrl = Control {
110 conn: conn.clone(),
111 state: conn_state.clone(),
112 session_renewal_task: None,
113 };
114
115 ctrl.state.write().expect("no fail").session_token = self.initial_session_token;
116 ctrl.renew_session().await?;
117 ctrl.request_address(self.desired_addresses).await?;
118
119 if let Some(auto_session_renewal) = self.auto_session_renewal {
120 ctrl.start_auto_session_renewal(auto_session_renewal, expiry_receiver);
121 }
122
123 Ok((Sender::new(conn.clone()), Receiver { conn }, ctrl))
124 }
125}
126
127pub struct Control {
129 conn: quinn::Connection,
130 state: SharedConnState,
131 session_renewal_task: Option<JoinHandle<Result<(), RenewTaskError>>>,
132}
133
134impl Control {
135 pub fn assigned_addresses(&self) -> Vec<EndhostAddr> {
137 self.state
138 .read()
139 .expect("no fail")
140 .assigned_addresses
141 .clone()
142 }
143
144 pub fn session_expiry(&self) -> SystemTime {
146 self.state.read().expect("no fail").session_expiry
147 }
148
149 async fn request_address(
156 &mut self,
157 desired_addresses: Vec<EndhostAddr>,
158 ) -> Result<(), ControlError> {
159 debug!(?desired_addresses, "Requesting address assignment");
160 let (mut snd, mut rcv) = self.conn.open_bi().await?;
161 let request = AddressAssignRequest {
162 requested_addresses: desired_addresses
163 .into_iter()
164 .map(|addr| {
165 let (version, prefix_length, octets) = match addr.local_address() {
166 IpAddr::V4(a) => (4, 32, a.octets().to_vec()),
167 IpAddr::V6(a) => (6, 128, a.octets().to_vec()),
168 };
169 AddressRange {
170 isd_as: addr.isd_asn().into(),
171 ip_version: version as u32,
172 prefix_length: prefix_length as u32,
173 address: octets,
174 }
175 })
176 .collect::<Vec<_>>(),
177 };
178 let body = request.encode_to_vec();
179 let token = self.state.read().expect("no fail").session_token.clone();
180 send_control_request(&mut snd, crate::PATH_ADDR_ASSIGNMENT, body.as_ref(), &token).await?;
181 let mut resp_buf = [0u8; CTRL_RESPONSE_BUF_SIZE];
182 let response: AddressAssignResponse =
183 parse_http_response(&mut resp_buf[..], &mut rcv).await?;
184
185 if response.assigned_addresses.is_empty() {
186 return Err(ControlError::AddressAssignmentFailed(
187 AddrAssignError::NoAddressAssigned,
188 ));
189 }
190 let assigned_addresses = response
191 .assigned_addresses
192 .iter()
193 .map(|address_range| {
194 TryInto::<EndhostAddr>::try_into(address_range).map_err(|e| {
195 ControlError::AddressAssignmentFailed(AddrAssignError::InvalidAddr(e))
196 })
197 })
198 .collect::<Result<Vec<_>, _>>()?;
199 debug!(?assigned_addresses, "Got address assignment");
200
201 self.state.write().expect("no fail").assigned_addresses = assigned_addresses;
202 Ok(())
203 }
204
205 pub async fn renew_session(&mut self) -> Result<(), ControlError> {
207 let token = self.state.read().expect("no fail").session_token.clone();
208 self.set_session_expiry(renew_session(&self.conn.clone(), &token).await?);
209 Ok(())
210 }
211
212 fn start_auto_session_renewal(
213 &mut self,
214 config: AutoSessionRenewal,
215 mut expiry_notifier: watch::Receiver<()>,
216 ) {
217 let conn = self.conn.clone();
218 let conn_state = self.state.clone();
219
220 self.session_renewal_task = Some(tokio::spawn(async move {
221 const MAX_RETRIES: u32 = 5;
223 const BASE_RETRY_DELAY_SECS: u64 = 3;
225 const SLEEP_FRACTION: f32 = 0.75; let mut retries: u32 = 0;
229 loop {
230 let secs_until_expiry = {
231 let expiry = conn_state.read().expect("no fail").session_expiry;
232 match expiry.duration_since(SystemTime::now()) {
234 Ok(duration) => duration.as_secs(),
235 Err(_) => {
236 tracing::error!("Session expiry already passed, stopping auto-renewal");
239 return Err(RenewTaskError::SessionExpired);
240 }
241 }
242 };
243
244 let sleep_secs = if secs_until_expiry < config.renew_wait_threshold.as_secs() {
246 0
247 } else {
248 (secs_until_expiry as f32 * SLEEP_FRACTION) as u64
249 };
250 debug!("Next session renewal in {sleep_secs} seconds");
251
252 tokio::select! {
253 _ = expiry_notifier.changed() => continue,
254 _ = tokio::time::sleep(Duration::from_secs(sleep_secs)) => {
255 debug!("Renewing token and snaptun session");
256
257 let token = match (config.token_renewer)().await {
259 Ok(token) => token,
260 Err(err) => {
261 debug!(%err, "Failed to renew token, retry");
262 retries += 1;
263 if retries >= MAX_RETRIES {
264 return Err(RenewTaskError::MaxRetriesReached);
265 }
266 tokio::time::sleep(Duration::from_secs(BASE_RETRY_DELAY_SECS.pow(retries))).await;
267 continue;
268 },
269 };
270
271 let new_expiry = match renew_session(&conn, &token).await {
273 Ok(exp) => exp,
274 Err(err) => {
275 debug!(%err, "Failed to renew session, retry");
276 retries += 1;
277 if retries >= MAX_RETRIES {
278 return Err(RenewTaskError::MaxRetriesReached);
279 }
280 tokio::time::sleep(Duration::from_secs(BASE_RETRY_DELAY_SECS.pow(retries))).await;
281 continue;
282 }
283 };
284
285 debug!(new_expiry=%chrono::DateTime::<chrono::Utc>::from(new_expiry).to_rfc3339(), "auto session renewal successful");
286 conn_state.write().expect("no fail").session_expiry = new_expiry;
287 retries = 0;
288 }
289 }
290 }
291 }));
292 }
293
294 fn set_session_expiry(&mut self, expiry: SystemTime) {
295 self.state.write().expect("no fail").session_expiry = expiry;
296 if self
297 .state
298 .read()
299 .expect("no fail")
300 .expiry_notifier
301 .send(())
302 .is_err()
303 {
304 debug!("Failed to notify session expiry update");
307 }
308 }
309}
310
311#[derive(Debug, thiserror::Error)]
313pub enum RenewTaskError {
314 #[error("session expired")]
316 SessionExpired,
317 #[error("maximum number of retries reached")]
319 MaxRetriesReached,
320}
321
322pub async fn renew_session(
327 conn: &quinn::Connection,
328 token: &str,
329) -> Result<SystemTime, ControlError> {
330 let (mut snd, mut rcv) = conn.open_bi().await?;
331
332 let body = vec![];
333 send_control_request(&mut snd, crate::PATH_SESSION_RENEWAL, &body, token).await?;
334 let mut resp_buf = [0u8; CTRL_RESPONSE_BUF_SIZE];
335 let response: SessionRenewalResponse = parse_http_response(&mut resp_buf[..], &mut rcv).await?;
336
337 Ok(system_time_from_unix_epoch_secs(response.valid_until))
338}
339
340impl Drop for Control {
341 fn drop(&mut self) {
342 if let Some(task) = self.session_renewal_task.take() {
343 task.abort();
345 }
346 }
347}
348
349#[derive(Debug, Clone)]
351struct ConnState {
352 session_token: String,
353 session_expiry: SystemTime,
354 assigned_addresses: Vec<EndhostAddr>,
355 expiry_notifier: watch::Sender<()>,
356}
357
358impl ConnState {
359 fn new(expiry_notifier: watch::Sender<()>) -> Self {
360 Self {
361 session_token: String::new(),
362 session_expiry: SystemTime::UNIX_EPOCH,
363 assigned_addresses: Vec::new(),
364 expiry_notifier,
365 }
366 }
367}
368
369#[derive(Debug, Clone)]
370struct SharedConnState(Arc<RwLock<ConnState>>);
371
372impl SharedConnState {
373 fn new(conn_state: ConnState) -> Self {
374 Self(Arc::new(RwLock::new(conn_state)))
375 }
376}
377
378impl Deref for SharedConnState {
379 type Target = Arc<RwLock<ConnState>>;
380
381 fn deref(&self) -> &Self::Target {
382 &self.0
383 }
384}
385
386pub struct Sender {
388 conn: quinn::Connection,
389}
390
391impl Sender {
392 pub fn new(conn: quinn::Connection) -> Self {
394 Self { conn }
395 }
396
397 pub fn send_datagram(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
399 self.conn.send_datagram(data)?;
400 Ok(())
401 }
402
403 pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
405 self.conn.send_datagram_wait(data).await?;
406 Ok(())
407 }
408}
409
410pub struct Receiver {
412 conn: quinn::Connection,
413}
414
415impl Receiver {
416 pub async fn read_datagram(&self) -> Result<Bytes, quinn::ConnectionError> {
418 let packet = self.conn.read_datagram().await?;
419 Ok(packet)
420 }
421}
422
423#[derive(Debug, thiserror::Error)]
425pub enum ParseResponseError {
426 #[error("parsing HTTP envelope failed: {0}")]
428 HTTParseError(#[from] httparse::Error),
429 #[error("read error: {0}")]
431 ReadError(#[from] quinn::ReadError),
432 #[error("parsing control message failed: {0}")]
434 ParseError(#[from] prost::DecodeError),
435}
436
437async fn parse_http_response<M: prost::Message + Default>(
438 buf: &mut [u8],
439 rcv: &mut RecvStream,
440) -> Result<M, ParseResponseError> {
441 let mut cursor = 0usize;
442 let mut body_offset = 0usize;
443 while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
444 cursor += n;
445 let mut headers = [httparse::EMPTY_HEADER; 16];
446 let mut resp = httparse::Response::new(&mut headers);
447 body_offset = match resp.parse(&buf[..cursor]) {
448 Ok(httparse::Status::Partial) => continue,
449 Ok(httparse::Status::Complete(n)) => n,
450 Err(e) => return Err(ParseResponseError::HTTParseError(e)),
451 };
452 }
453 while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
455 cursor += n;
456 }
457 let m = M::decode(&buf[body_offset..cursor])?;
458 Ok(m)
459}
460
461#[derive(Debug, thiserror::Error)]
463pub enum SendControlRequestError {
464 #[error("i/o error: {0}")]
466 IoError(#[from] std::io::Error),
467 #[error("stream closed: {0}")]
469 ClosedStream(#[from] quinn::ClosedStream),
470}
471
472async fn send_control_request(
474 snd: &mut SendStream,
475 method: &str,
476 body: &[u8],
477 token: &str,
478) -> Result<(), SendControlRequestError> {
479 write_all(
480 snd,
481 format!(
482 "POST {method} HTTP/1.1\r\n\
483content-type: application/proto\r\n\
484connect-protocol-version: 1\r\n\
485content-encoding: identity\r\n\
486accept-encoding: identity\r\n\
487content-length: {}\r\n\
488Authorization: Bearer {token}\r\n\r\n",
489 body.len()
490 )
491 .as_bytes(),
492 )
493 .await?;
494 write_all(snd, body).await?;
495 snd.finish()?;
496 Ok(())
497}
498
499async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
501 let mut cursor = 0;
502 while cursor < data.len() {
503 cursor += stream.write(&data[cursor..]).await?;
504 }
505 Ok(())
506}
507
508#[derive(Debug, thiserror::Error)]
510pub enum SnapTunError {
511 #[error("initial token error: {0}")]
513 InitialTokenError(#[from] TokenRenewError),
514 #[error("control error: {0}")]
516 ControlError(#[from] ControlError),
517}
518
519#[derive(Debug, thiserror::Error)]
521pub enum ControlError {
522 #[error("quinn connection error: {0}")]
524 ConnectionError(#[from] quinn::ConnectionError),
525 #[error("address assignment failed: {0}")]
527 AddressAssignmentFailed(#[from] AddrAssignError),
528 #[error("parse control request response: {0}")]
530 ParseResponse(#[from] ParseResponseError),
531 #[error("send control request error: {0}")]
533 SendRequestError(#[from] SendControlRequestError),
534}
535
536#[derive(Debug, thiserror::Error)]
538pub enum AddrAssignError {
539 #[error("invalid addr: {0}")]
541 InvalidAddr(#[from] AddrError),
542 #[error("no address assigned")]
544 NoAddressAssigned,
545}