1use std::{
17 borrow::Cow,
18 net::SocketAddr,
19 ops::Deref,
20 sync::{Arc, RwLock},
21 time::SystemTime,
22};
23
24use bytes::Bytes;
25use prost::Message;
26use quinn::{ConnectionError, RecvStream, SendStream};
27use scion_sdk_reqwest_connect_rpc::token_source::{self, TokenSource};
28use scion_sdk_utils::backoff::ExponentialBackoff;
29use tokio::{select, task::JoinHandle};
30
31use crate::requests::{
32 AddrError, SocketAddrAssignmentRequest, SocketAddrAssignmentResponse, TokenUpdateResponse,
33 system_time_from_unix_epoch_secs,
34};
35
36pub const MAX_CTRL_MESSAGE_SIZE: usize = 4096;
38
39pub struct ClientBuilder {
41 token_source: Arc<dyn TokenSource>,
42}
43
44impl ClientBuilder {
45 pub fn new(token_source: Arc<dyn TokenSource>) -> Self {
47 ClientBuilder { token_source }
48 }
49
50 pub async fn connect(
52 self,
53 conn: quinn::Connection,
54 ) -> Result<(Sender, Receiver, Control), SnapTunError> {
55 let conn_state = SharedConnState::new(ConnState::new());
56 let mut ctrl = Control {
57 conn: conn.clone(),
58 state: conn_state.clone(),
59 token_renewal_task: None,
60 };
61
62 let mut token_watch = self.token_source.watch();
63
64 let mut initial_token = match token_watch.borrow_and_update().as_ref() {
66 Some(Ok(token)) => Some(token.clone()),
67 Some(Err(e)) => return Err(SnapTunError::InitialTokenError(e.to_string())),
68 None => None,
69 };
70
71 if initial_token.is_none() {
73 token_watch
74 .changed()
75 .await
76 .map_err(|e| SnapTunError::InitialTokenError(e.to_string()))?;
77
78 initial_token = match token_watch.borrow().as_ref() {
79 Some(Ok(token)) => Some(token.clone()),
80 Some(Err(e)) => return Err(SnapTunError::InitialTokenError(e.to_string())),
81 None => None,
82 };
83 }
84
85 let initial_token = initial_token.ok_or_else(|| {
86 SnapTunError::InitialTokenError("failed to obtain initial token".into())
87 })?;
88
89 ctrl.state.write().unwrap().snap_token = initial_token;
90 ctrl.update_token().await?;
91 ctrl.request_socket_addr().await?;
92
93 tracing::trace!("Starting token update task");
96 ctrl.session_token_update_task(token_watch);
97
98 Ok((Sender::new(conn.clone()), Receiver { conn }, ctrl))
99 }
100}
101
102pub struct Control {
104 conn: quinn::Connection,
105 state: SharedConnState,
106 token_renewal_task: Option<JoinHandle<Result<(), RenewTaskError>>>,
107}
108
109impl Control {
110 pub fn assigned_sock_addr(&self) -> Option<SocketAddr> {
116 self.state.read().expect("no fail").assigned_sock_addr
117 }
118
119 pub fn token_expiry(&self) -> SystemTime {
121 self.state.read().expect("no fail").token_expiry
122 }
123
124 pub fn snap_token(&self) -> String {
126 self.state.read().expect("no fail").snap_token.clone()
127 }
128
129 async fn request_socket_addr(&mut self) -> Result<(), ControlError> {
131 tracing::debug!("Requesting socket address assignment");
132 let (mut snd, mut rcv) = self.conn.open_bi().await?;
133
134 let request = SocketAddrAssignmentRequest {};
135
136 let body = request.encode_to_vec();
137 let token = self.state.read().expect("no fail").snap_token.clone();
138 send_control_request(
139 &mut snd,
140 crate::PATH_SOCK_ADDR_ASSIGNMENT,
141 body.as_ref(),
142 &token,
143 )
144 .await?;
145
146 let mut resp_buf = [0u8; MAX_CTRL_MESSAGE_SIZE];
148 let response =
149 recv_response::<SocketAddrAssignmentResponse>(&mut resp_buf[..], &mut rcv).await?;
150
151 let sock_addr = response
152 .socket_addr()
153 .map_err(|e| ControlError::AddressAssignmentFailed(AddrAssignError::InvalidAddr(e)))?;
154
155 let mut sstate = self.state.0.write().expect("no fail");
156 sstate.assigned_sock_addr = Some(sock_addr);
157
158 Ok(())
159 }
160
161 pub async fn update_token(&mut self) -> Result<(), ControlError> {
163 let token = self.state.read().unwrap().snap_token.clone();
164 self.set_token_expiry(update_token(&self.conn.clone(), &token).await?);
165 Ok(())
166 }
167
168 fn session_token_update_task(&mut self, mut token_watch: token_source::TokenSourceWatch) {
170 let conn = self.conn.clone();
171 let conn_state = self.state.clone();
172
173 self.token_renewal_task = Some(tokio::spawn(async move {
174 loop {
175 let expiry = conn_state.read().expect("no fail").token_expiry;
176 let now = SystemTime::now();
177 let dur_until_expiry = expiry
178 .duration_since(now)
179 .unwrap_or_else(|_| std::time::Duration::from_secs(0));
180
181 let expiry_timeout = tokio::time::Instant::now() + dur_until_expiry;
182
183 select! {
184 _ = token_watch.changed() => {}
186 _ = tokio::time::sleep_until(expiry_timeout) => {
188 tracing::error!("SNAP token has expired but no new token was received from the token source");
189 return Err(RenewTaskError::TokenExpired);
190 },
191 }
192
193 let new_token = token_watch
196 .borrow_and_update()
197 .as_ref()
198 .ok_or_else(|| {
199 RenewTaskError::TokenSourceError(
200 "token source watch channel has no value".into(),
201 )
202 })?
203 .as_ref()
204 .map_err(|e| RenewTaskError::TokenSourceError(e.to_string().into()))?
205 .clone();
206
207 let mut attempt = 0;
209 const MAX_RETRIES: u32 = 5;
211 const BACKOFF: ExponentialBackoff = ExponentialBackoff::new(3.0, 30.0, 2.0, 1.0);
213
214 tracing::info!("Updating SNAP token on server");
215 loop {
219 match update_token(&conn, &new_token).await {
220 Ok(new_expiry) => {
221 tracing::info!("Successfully updated SNAP token on server");
222 {
224 let mut conn_state = conn_state.write().unwrap();
225 conn_state.token_expiry = new_expiry;
226 conn_state.snap_token = new_token.clone();
227 }
228 break;
229 }
230 Err(err) if attempt > MAX_RETRIES => {
231 attempt += 1;
232 tracing::error!(
233 %attempt,
234 %err,
235 "Failed to update SNAP token on server, max retries reached",
236 );
237
238 return Err(RenewTaskError::MaxRetriesReached);
239 }
240 Err(err) => {
241 attempt += 1;
242
243 let delay = BACKOFF.duration(attempt);
244 let next_try = delay.as_secs();
245 tracing::warn!(
246 %attempt,
247 %err,
248 %next_try,
249 "Failed to update SNAP token on server",
250 );
251
252 if expiry_timeout <= tokio::time::Instant::now() + delay {
253 tracing::error!(
254 "SNAP token has expired before it could be renewed"
255 );
256 return Err(RenewTaskError::TokenExpired);
257 }
258
259 tokio::time::sleep(delay).await;
260 }
261 }
262 }
263 }
264 }));
265 }
266
267 fn set_token_expiry(&mut self, expiry: SystemTime) {
268 self.state.write().expect("no fail").token_expiry = expiry;
269 }
270
271 pub async fn closed(&self) -> ConnectionError {
273 self.conn.closed().await
274 }
275
276 pub fn inner_conn(&self) -> quinn::Connection {
278 self.conn.clone()
279 }
280
281 pub fn debug_path_stats(&self) -> impl std::fmt::Debug + 'static + use<> {
286 self.conn.stats().path
287 }
288}
289
290#[derive(Debug, thiserror::Error)]
292pub enum RenewTaskError {
293 #[error("token expired")]
295 TokenExpired,
296 #[error("maximum number of retries reached")]
298 MaxRetriesReached,
299 #[error("token source failed: {0}")]
301 TokenSourceError(#[from] token_source::TokenSourceError),
302}
303
304pub async fn update_token(
309 conn: &quinn::Connection,
310 token: &str,
311) -> Result<SystemTime, ControlError> {
312 let (mut snd, mut rcv) = conn.open_bi().await?;
313
314 let body = vec![];
315 send_control_request(&mut snd, crate::PATH_UPDATE_TOKEN, &body, token).await?;
316 let mut resp_buf = [0u8; MAX_CTRL_MESSAGE_SIZE];
317 let response: TokenUpdateResponse = recv_response(&mut resp_buf[..], &mut rcv).await?;
318
319 Ok(system_time_from_unix_epoch_secs(response.valid_until))
320}
321
322impl Drop for Control {
323 fn drop(&mut self) {
324 if let Some(task) = self.token_renewal_task.take() {
325 task.abort();
327 }
328 }
329}
330
331#[derive(Debug, Clone)]
333struct ConnState {
334 snap_token: String,
335 token_expiry: SystemTime,
336 assigned_sock_addr: Option<SocketAddr>,
339}
340
341impl ConnState {
342 fn new() -> Self {
343 Self {
344 snap_token: String::new(),
345 token_expiry: SystemTime::UNIX_EPOCH,
346 assigned_sock_addr: None,
347 }
348 }
349}
350
351#[derive(Debug, Clone)]
352struct SharedConnState(Arc<RwLock<ConnState>>);
353
354impl SharedConnState {
355 fn new(conn_state: ConnState) -> Self {
356 Self(Arc::new(RwLock::new(conn_state)))
357 }
358}
359
360impl Deref for SharedConnState {
361 type Target = Arc<RwLock<ConnState>>;
362
363 fn deref(&self) -> &Self::Target {
364 &self.0
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct Sender {
371 conn: quinn::Connection,
372}
373
374impl Sender {
375 pub fn new(conn: quinn::Connection) -> Self {
377 Self { conn }
378 }
379
380 pub fn send_datagram(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
382 self.conn.send_datagram(data)
383 }
384
385 pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
387 self.conn.send_datagram_wait(data).await
388 }
389}
390
391#[derive(Debug, Clone)]
393pub struct Receiver {
394 conn: quinn::Connection,
395}
396
397impl Receiver {
398 pub async fn read_datagram(&self) -> Result<Bytes, quinn::ConnectionError> {
400 self.conn.read_datagram().await
401 }
402}
403
404#[derive(Debug, thiserror::Error)]
406pub enum ParseResponseError {
407 #[error("parsing HTTP envelope failed: {0}")]
409 HTTParseError(#[from] httparse::Error),
410 #[error("read error: {0}")]
412 ReadError(#[from] quinn::ReadError),
413 #[error("parsing control message failed: {0}")]
415 ParseError(#[from] prost::DecodeError),
416 #[error("received bad response: {0}")]
418 ResponseError(Cow<'static, str>),
419}
420
421async fn recv_response<M: prost::Message + Default>(
422 buf: &mut [u8],
423 rcv: &mut RecvStream,
424) -> Result<M, ParseResponseError> {
425 let mut cursor = 0;
426 let mut body_offset = 0;
427 let mut code = 0;
428
429 while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
431 cursor += n;
432
433 let mut headers = [httparse::EMPTY_HEADER; 16];
434 let mut resp = httparse::Response::new(&mut headers);
435
436 match resp.parse(&buf[..cursor])? {
437 httparse::Status::Partial => {}
438 httparse::Status::Complete(n) => {
439 body_offset = n;
440 code = resp.code.unwrap_or(0);
441 break;
442 }
443 };
444
445 if cursor >= buf.len() {
447 return Err(ParseResponseError::ResponseError(
448 "response too large".into(),
449 ));
450 }
451 }
452
453 while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
455 cursor += n;
456 if cursor >= buf.len() {
457 return Err(ParseResponseError::ResponseError(
458 "response too large".into(),
459 ));
460 }
461 }
462
463 if code != 200 {
465 let msg = String::from_utf8_lossy(&buf[body_offset..cursor]).to_string();
466 return Err(ParseResponseError::ResponseError(msg.into()));
467 }
468
469 let m = M::decode(&buf[body_offset..cursor])?;
471
472 Ok(m)
473}
474
475#[derive(Debug, thiserror::Error)]
477pub enum SendControlRequestError {
478 #[error("i/o error: {0}")]
480 IoError(#[from] std::io::Error),
481 #[error("stream closed: {0}")]
483 ClosedStream(#[from] quinn::ClosedStream),
484}
485
486async fn send_control_request(
488 snd: &mut SendStream,
489 method: &str,
490 body: &[u8],
491 token: &str,
492) -> Result<(), SendControlRequestError> {
493 write_all(
494 snd,
495 format!(
496 "POST {method} HTTP/1.1\r\n\
497content-type: application/proto\r\n\
498connect-protocol-version: 1\r\n\
499content-encoding: identity\r\n\
500accept-encoding: identity\r\n\
501content-length: {}\r\n\
502Authorization: Bearer {token}\r\n\r\n",
503 body.len()
504 )
505 .as_bytes(),
506 )
507 .await?;
508 write_all(snd, body).await?;
509 snd.finish()?;
510 Ok(())
511}
512
513async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
515 let mut cursor = 0;
516 while cursor < data.len() {
517 cursor += stream.write(&data[cursor..]).await?;
518 }
519 Ok(())
520}
521
522#[derive(Debug, thiserror::Error)]
524pub enum SnapTunError {
525 #[error("initial token error: {0}")]
527 InitialTokenError(String),
528 #[error("control error: {0}")]
530 ControlError(#[from] ControlError),
531}
532
533#[derive(Debug, thiserror::Error)]
535pub enum ControlError {
536 #[error("quinn connection error: {0}")]
538 ConnectionError(#[from] quinn::ConnectionError),
539 #[error("address assignment failed: {0}")]
541 AddressAssignmentFailed(#[from] AddrAssignError),
542 #[error("parse control request response: {0}")]
544 ParseResponse(#[from] ParseResponseError),
545 #[error("send control request error: {0}")]
547 SendRequestError(#[from] SendControlRequestError),
548}
549
550#[derive(Debug, thiserror::Error)]
552pub enum AddrAssignError {
553 #[error("invalid addr: {0}")]
555 InvalidAddr(#[from] AddrError),
556 #[error("no address assigned")]
558 NoAddressAssigned,
559}