Skip to main content

snap_tun/
client.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SNAP tunnel client.
15
16use std::{
17    borrow::Cow,
18    net::SocketAddr,
19    ops::Deref,
20    sync::{Arc, RwLock},
21    time::SystemTime,
22};
23
24use bytes::Bytes;
25use prost::Message;
26use quinn::{ConnectionError, RecvStream, SendStream};
27use scion_sdk_reqwest_connect_rpc::token_source::{self, TokenSource};
28use scion_sdk_utils::backoff::ExponentialBackoff;
29use tokio::{select, task::JoinHandle};
30
31use crate::requests::{
32    AddrError, SocketAddrAssignmentRequest, SocketAddrAssignmentResponse, TokenUpdateResponse,
33    system_time_from_unix_epoch_secs,
34};
35
36/// Maximum size of a control message, both request and response.
37pub const MAX_CTRL_MESSAGE_SIZE: usize = 4096;
38
39/// SNAP tunnel client builder.
40pub struct ClientBuilder {
41    token_source: Arc<dyn TokenSource>,
42}
43
44impl ClientBuilder {
45    /// Client builder with an initial SNAP token to be used to authenticate requests.
46    pub fn new(token_source: Arc<dyn TokenSource>) -> Self {
47        ClientBuilder { token_source }
48    }
49
50    /// Establish a SNAP tunnel using the provided QUIC connection using the builder's settings.
51    pub async fn connect(
52        self,
53        conn: quinn::Connection,
54    ) -> Result<(Sender, Receiver, Control), SnapTunError> {
55        let conn_state = SharedConnState::new(ConnState::new());
56        let mut ctrl = Control {
57            conn: conn.clone(),
58            state: conn_state.clone(),
59            token_renewal_task: None,
60        };
61
62        let mut token_watch = self.token_source.watch();
63
64        // Try to get the current token
65        let mut initial_token = match token_watch.borrow_and_update().as_ref() {
66            Some(Ok(token)) => Some(token.clone()),
67            Some(Err(e)) => return Err(SnapTunError::InitialTokenError(e.to_string())),
68            None => None,
69        };
70
71        // Wait for the initial token if not already available.
72        if initial_token.is_none() {
73            token_watch
74                .changed()
75                .await
76                .map_err(|e| SnapTunError::InitialTokenError(e.to_string()))?;
77
78            initial_token = match token_watch.borrow().as_ref() {
79                Some(Ok(token)) => Some(token.clone()),
80                Some(Err(e)) => return Err(SnapTunError::InitialTokenError(e.to_string())),
81                None => None,
82            };
83        }
84
85        let initial_token = initial_token.ok_or_else(|| {
86            SnapTunError::InitialTokenError("failed to obtain initial token".into())
87        })?;
88
89        ctrl.state.write().unwrap().snap_token = initial_token;
90        ctrl.update_token().await?;
91        ctrl.request_socket_addr().await?;
92
93        // If our token source supports notifications for new tokens, spawn a task to
94        // inform the server whenever the token is updated.
95        tracing::trace!("Starting token update task");
96        ctrl.session_token_update_task(token_watch);
97
98        Ok((Sender::new(conn.clone()), Receiver { conn }, ctrl))
99    }
100}
101
102/// Control can be used to send control messages to the server
103pub struct Control {
104    conn: quinn::Connection,
105    state: SharedConnState,
106    token_renewal_task: Option<JoinHandle<Result<(), RenewTaskError>>>,
107}
108
109impl Control {
110    /// Returns the socket address assigned by the server. This typically
111    /// corresponds to the client's _remote_ socket address; i.e. the possibly
112    /// NAT'ed address of the client visible to the server.
113    ///
114    /// It is up to the client to use the correct ISD-AS for this tunnel.
115    pub fn assigned_sock_addr(&self) -> Option<SocketAddr> {
116        self.state.read().expect("no fail").assigned_sock_addr
117    }
118
119    /// Returns the token expiry time.
120    pub fn token_expiry(&self) -> SystemTime {
121        self.state.read().expect("no fail").token_expiry
122    }
123
124    /// Returns the current SNAP token.
125    pub fn snap_token(&self) -> String {
126        self.state.read().expect("no fail").snap_token.clone()
127    }
128
129    /// Sends a socket address assign request to the snaptun server.
130    async fn request_socket_addr(&mut self) -> Result<(), ControlError> {
131        tracing::debug!("Requesting socket address assignment");
132        let (mut snd, mut rcv) = self.conn.open_bi().await?;
133
134        let request = SocketAddrAssignmentRequest {};
135
136        let body = request.encode_to_vec();
137        let token = self.state.read().expect("no fail").snap_token.clone();
138        send_control_request(
139            &mut snd,
140            crate::PATH_SOCK_ADDR_ASSIGNMENT,
141            body.as_ref(),
142            &token,
143        )
144        .await?;
145
146        // Parse address assignment response
147        let mut resp_buf = [0u8; MAX_CTRL_MESSAGE_SIZE];
148        let response =
149            recv_response::<SocketAddrAssignmentResponse>(&mut resp_buf[..], &mut rcv).await?;
150
151        let sock_addr = response
152            .socket_addr()
153            .map_err(|e| ControlError::AddressAssignmentFailed(AddrAssignError::InvalidAddr(e)))?;
154
155        let mut sstate = self.state.0.write().expect("no fail");
156        sstate.assigned_sock_addr = Some(sock_addr);
157
158        Ok(())
159    }
160
161    /// Sends a new SNAP token to keep the snaptun connection with the server established.
162    pub async fn update_token(&mut self) -> Result<(), ControlError> {
163        let token = self.state.read().unwrap().snap_token.clone();
164        self.set_token_expiry(update_token(&self.conn.clone(), &token).await?);
165        Ok(())
166    }
167
168    /// Spawns a task which informs the server whenever the client's token was updated.
169    fn session_token_update_task(&mut self, mut token_watch: token_source::TokenSourceWatch) {
170        let conn = self.conn.clone();
171        let conn_state = self.state.clone();
172
173        self.token_renewal_task = Some(tokio::spawn(async move {
174            loop {
175                let expiry = conn_state.read().expect("no fail").token_expiry;
176                let now = SystemTime::now();
177                let dur_until_expiry = expiry
178                    .duration_since(now)
179                    .unwrap_or_else(|_| std::time::Duration::from_secs(0));
180
181                let expiry_timeout = tokio::time::Instant::now() + dur_until_expiry;
182
183                select! {
184                    // A new token is available.
185                    _ = token_watch.changed() => {}
186                    // Our token has expired.
187                    _ = tokio::time::sleep_until(expiry_timeout) => {
188                        tracing::error!("SNAP token has expired but no new token was received from the token source");
189                        return Err(RenewTaskError::TokenExpired);
190                    },
191                }
192
193                // Try to get a new token from the token source. Can fail if the token source
194                // expired and failed fetching a new token in time
195                let new_token = token_watch
196                    .borrow_and_update()
197                    .as_ref()
198                    .ok_or_else(|| {
199                        RenewTaskError::TokenSourceError(
200                            "token source watch channel has no value".into(),
201                        )
202                    })?
203                    .as_ref()
204                    .map_err(|e| RenewTaskError::TokenSourceError(e.to_string().into()))?
205                    .clone();
206
207                // Try to update the token on the server.
208                let mut attempt = 0;
209                // Maximum number of retries for token renewal.
210                const MAX_RETRIES: u32 = 5;
211                // Update backoff
212                const BACKOFF: ExponentialBackoff = ExponentialBackoff::new(3.0, 30.0, 2.0, 1.0);
213
214                tracing::info!("Updating SNAP token on server");
215                // Note: Unlikely edgecase - If the token lifetime is very short, we might run into
216                // the situation where the token expires before we could successfully update it on
217                // the server.
218                loop {
219                    match update_token(&conn, &new_token).await {
220                        Ok(new_expiry) => {
221                            tracing::info!("Successfully updated SNAP token on server");
222                            // Update the token in the connection state.
223                            {
224                                let mut conn_state = conn_state.write().unwrap();
225                                conn_state.token_expiry = new_expiry;
226                                conn_state.snap_token = new_token.clone();
227                            }
228                            break;
229                        }
230                        Err(err) if attempt > MAX_RETRIES => {
231                            attempt += 1;
232                            tracing::error!(
233                                %attempt,
234                                %err,
235                                "Failed to update SNAP token on server, max retries reached",
236                            );
237
238                            return Err(RenewTaskError::MaxRetriesReached);
239                        }
240                        Err(err) => {
241                            attempt += 1;
242
243                            let delay = BACKOFF.duration(attempt);
244                            let next_try = delay.as_secs();
245                            tracing::warn!(
246                                %attempt,
247                                %err,
248                                %next_try,
249                                "Failed to update SNAP token on server",
250                            );
251
252                            if expiry_timeout <= tokio::time::Instant::now() + delay {
253                                tracing::error!(
254                                    "SNAP token has expired before it could be renewed"
255                                );
256                                return Err(RenewTaskError::TokenExpired);
257                            }
258
259                            tokio::time::sleep(delay).await;
260                        }
261                    }
262                }
263            }
264        }));
265    }
266
267    fn set_token_expiry(&mut self, expiry: SystemTime) {
268        self.state.write().expect("no fail").token_expiry = expiry;
269    }
270
271    /// An async function that returns when the underlying connection is closed.
272    pub async fn closed(&self) -> ConnectionError {
273        self.conn.closed().await
274    }
275
276    /// Returns the underlying QUIC connection.
277    pub fn inner_conn(&self) -> quinn::Connection {
278        self.conn.clone()
279    }
280
281    /// This is a helper function that returns a debug-printable object
282    /// containing metrics about the underlying QUIC-connection.
283    // XXX(dsd): We are overcautious here and do not want to commit to an
284    // implementation-specific type.
285    pub fn debug_path_stats(&self) -> impl std::fmt::Debug + 'static + use<> {
286        self.conn.stats().path
287    }
288}
289
290/// Token renew task error.
291#[derive(Debug, thiserror::Error)]
292pub enum RenewTaskError {
293    /// Token expired.
294    #[error("token expired")]
295    TokenExpired,
296    /// Maximum number of retries reached.
297    #[error("maximum number of retries reached")]
298    MaxRetriesReached,
299    /// Token source error.
300    #[error("token source failed: {0}")]
301    TokenSourceError(#[from] token_source::TokenSourceError),
302}
303
304/// Update SNAP token.
305///
306/// This opens a new bi-directional stream to the server, sends a update SNAP token request, and
307/// waits for the response. On success, it returns the new token expiry time.
308pub async fn update_token(
309    conn: &quinn::Connection,
310    token: &str,
311) -> Result<SystemTime, ControlError> {
312    let (mut snd, mut rcv) = conn.open_bi().await?;
313
314    let body = vec![];
315    send_control_request(&mut snd, crate::PATH_UPDATE_TOKEN, &body, token).await?;
316    let mut resp_buf = [0u8; MAX_CTRL_MESSAGE_SIZE];
317    let response: TokenUpdateResponse = recv_response(&mut resp_buf[..], &mut rcv).await?;
318
319    Ok(system_time_from_unix_epoch_secs(response.valid_until))
320}
321
322impl Drop for Control {
323    fn drop(&mut self) {
324        if let Some(task) = self.token_renewal_task.take() {
325            // Cancel the token renewal task
326            task.abort();
327        }
328    }
329}
330
331/// Connection state.
332#[derive(Debug, Clone)]
333struct ConnState {
334    snap_token: String,
335    token_expiry: SystemTime,
336    // The socket address that is assigned by the remote and should be used as
337    // the endhost socket address for this tunnel.
338    assigned_sock_addr: Option<SocketAddr>,
339}
340
341impl ConnState {
342    fn new() -> Self {
343        Self {
344            snap_token: String::new(),
345            token_expiry: SystemTime::UNIX_EPOCH,
346            assigned_sock_addr: None,
347        }
348    }
349}
350
351#[derive(Debug, Clone)]
352struct SharedConnState(Arc<RwLock<ConnState>>);
353
354impl SharedConnState {
355    fn new(conn_state: ConnState) -> Self {
356        Self(Arc::new(RwLock::new(conn_state)))
357    }
358}
359
360impl Deref for SharedConnState {
361    type Target = Arc<RwLock<ConnState>>;
362
363    fn deref(&self) -> &Self::Target {
364        &self.0
365    }
366}
367
368/// SNAP tunnel sender.
369#[derive(Debug, Clone)]
370pub struct Sender {
371    conn: quinn::Connection,
372}
373
374impl Sender {
375    /// Creates a new sender.
376    pub fn new(conn: quinn::Connection) -> Self {
377        Self { conn }
378    }
379
380    /// Sends a datagram to the connection.
381    pub fn send_datagram(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
382        self.conn.send_datagram(data)
383    }
384
385    /// Sends a datagram to the connection and waits for the datagram to be sent.
386    pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
387        self.conn.send_datagram_wait(data).await
388    }
389}
390
391/// SNAP tunnel receiver.
392#[derive(Debug, Clone)]
393pub struct Receiver {
394    conn: quinn::Connection,
395}
396
397impl Receiver {
398    /// Reads a datagram from the connection.
399    pub async fn read_datagram(&self) -> Result<Bytes, quinn::ConnectionError> {
400        self.conn.read_datagram().await
401    }
402}
403
404/// Parse response error.
405#[derive(Debug, thiserror::Error)]
406pub enum ParseResponseError {
407    /// Parsing HTTP envelope failed.
408    #[error("parsing HTTP envelope failed: {0}")]
409    HTTParseError(#[from] httparse::Error),
410    /// QUIC read error.
411    #[error("read error: {0}")]
412    ReadError(#[from] quinn::ReadError),
413    /// Protobuf decode error.
414    #[error("parsing control message failed: {0}")]
415    ParseError(#[from] prost::DecodeError),
416    /// Received a bad response.
417    #[error("received bad response: {0}")]
418    ResponseError(Cow<'static, str>),
419}
420
421async fn recv_response<M: prost::Message + Default>(
422    buf: &mut [u8],
423    rcv: &mut RecvStream,
424) -> Result<M, ParseResponseError> {
425    let mut cursor = 0;
426    let mut body_offset = 0;
427    let mut code = 0;
428
429    // Parse HTTP response headers.
430    while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
431        cursor += n;
432
433        let mut headers = [httparse::EMPTY_HEADER; 16];
434        let mut resp = httparse::Response::new(&mut headers);
435
436        match resp.parse(&buf[..cursor])? {
437            httparse::Status::Partial => {}
438            httparse::Status::Complete(n) => {
439                body_offset = n;
440                code = resp.code.unwrap_or(0);
441                break;
442            }
443        };
444
445        // Only keep reading if we have enough space in buffer
446        if cursor >= buf.len() {
447            return Err(ParseResponseError::ResponseError(
448                "response too large".into(),
449            ));
450        }
451    }
452
453    // We only have a single message on the stream, so the rest we expect to be the body.
454    while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
455        cursor += n;
456        if cursor >= buf.len() {
457            return Err(ParseResponseError::ResponseError(
458                "response too large".into(),
459            ));
460        }
461    }
462
463    // If the response code is not 200, return an error with the response body as message.
464    if code != 200 {
465        let msg = String::from_utf8_lossy(&buf[body_offset..cursor]).to_string();
466        return Err(ParseResponseError::ResponseError(msg.into()));
467    }
468
469    // Otherwise, parse the body as protobuf message.
470    let m = M::decode(&buf[body_offset..cursor])?;
471
472    Ok(m)
473}
474
475/// Send control request error.
476#[derive(Debug, thiserror::Error)]
477pub enum SendControlRequestError {
478    /// I/O error.
479    #[error("i/o error: {0}")]
480    IoError(#[from] std::io::Error),
481    /// QUIC closed stream error.
482    #[error("stream closed: {0}")]
483    ClosedStream(#[from] quinn::ClosedStream),
484}
485
486/// Send a control request to the server using `snd` as the request-stream.
487async fn send_control_request(
488    snd: &mut SendStream,
489    method: &str,
490    body: &[u8],
491    token: &str,
492) -> Result<(), SendControlRequestError> {
493    write_all(
494        snd,
495        format!(
496            "POST {method} HTTP/1.1\r\n\
497content-type: application/proto\r\n\
498connect-protocol-version: 1\r\n\
499content-encoding: identity\r\n\
500accept-encoding: identity\r\n\
501content-length: {}\r\n\
502Authorization: Bearer {token}\r\n\r\n",
503            body.len()
504        )
505        .as_bytes(),
506    )
507    .await?;
508    write_all(snd, body).await?;
509    snd.finish()?;
510    Ok(())
511}
512
513// SendStream::write_all is not cancel-safe, so we use loops instead.
514async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
515    let mut cursor = 0;
516    while cursor < data.len() {
517        cursor += stream.write(&data[cursor..]).await?;
518    }
519    Ok(())
520}
521
522/// SNAP tunnel errors.
523#[derive(Debug, thiserror::Error)]
524pub enum SnapTunError {
525    /// Initial token error.
526    #[error("initial token error: {0}")]
527    InitialTokenError(String),
528    /// Control error.
529    #[error("control error: {0}")]
530    ControlError(#[from] ControlError),
531}
532
533/// SNAP tunnel control errors.
534#[derive(Debug, thiserror::Error)]
535pub enum ControlError {
536    /// QUIC connection error.
537    #[error("quinn connection error: {0}")]
538    ConnectionError(#[from] quinn::ConnectionError),
539    /// Address assignment failed.
540    #[error("address assignment failed: {0}")]
541    AddressAssignmentFailed(#[from] AddrAssignError),
542    /// Parse control request response error.
543    #[error("parse control request response: {0}")]
544    ParseResponse(#[from] ParseResponseError),
545    /// Send control request error.
546    #[error("send control request error: {0}")]
547    SendRequestError(#[from] SendControlRequestError),
548}
549
550/// Address assignment error.
551#[derive(Debug, thiserror::Error)]
552pub enum AddrAssignError {
553    /// Invalid address.
554    #[error("invalid addr: {0}")]
555    InvalidAddr(#[from] AddrError),
556    /// No address assigned.
557    #[error("no address assigned")]
558    NoAddressAssigned,
559}