Skip to main content

snap_tun/client/
tunnel.rs

1// Copyright 2026 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    future::Future,
17    io,
18    net::SocketAddr,
19    pin::Pin,
20    sync::{Arc, Mutex},
21    time::{Duration, Instant},
22};
23
24use ana_gotatun::{
25    noise::{Tunn, TunnResult, errors::WireGuardError, rate_limiter::RateLimiter},
26    packet::{Packet, PacketBufPool, WgKind},
27    x25519::{self},
28};
29use bytes::{Bytes, BytesMut};
30use scion_sdk_utils::backoff::ExponentialBackoff;
31use tokio::{select, task::JoinHandle, time::Interval};
32use tracing::instrument;
33use zerocopy::IntoBytes as _;
34
35use super::{PACKET_BUF_POOL_SIZE, TunnelGuard};
36use crate::udp_batch::{QueuePacketError, RecvBatchError, UdpBatchReceiver, UdpBatchSender};
37
38const HANDSHAKE_RATE_LIMIT: u64 = 20;
39const RECEIVE_BATCH_SIZE: usize = 64;
40
41/// Error when sending or receiving packets on the SNAP tunnel.
42#[derive(Debug, thiserror::Error)]
43pub enum SnapTunnelDriverError {
44    /// I/O error when sending packets on the underlay socket.
45    #[error("send i/o error: {0}")]
46    SendIoError(#[from] std::io::Error),
47    /// I/O error when receiving packets on the underlay socket.
48    #[error("receive i/o error: {0}")]
49    ReceiveIoError(std::io::Error),
50    /// Receive queue closed.
51    #[error("receive queue closed")]
52    ReceiveQueueClosed,
53    /// Connection expired.
54    #[error("connection expired")]
55    ConnectionExpired,
56    /// Error receiving a Wireguard packet.
57    /// This will never be WireGuardError::ConnectionExpired.
58    #[error("error receiving a Wireguard packet: {0:?}")]
59    WireguardError(WireGuardError),
60}
61
62struct SnapTunnelDriver {
63    pub tunn: Arc<Mutex<Tunn>>,
64    pub static_private: x25519::StaticSecret,
65    pub peer_public: x25519::PublicKey,
66    pub underlay_socket: Arc<tokio::net::UdpSocket>,
67    pub dataplane_address: SocketAddr,
68    pub persistent_keepalive_seconds: Option<u16>,
69    pub update_timers_interval: Interval,
70    pub packet_sender: async_channel::Sender<BytesMut>,
71    pub local_sockaddr: Option<SocketAddr>,
72    pub pool: PacketBufPool<PACKET_BUF_POOL_SIZE>,
73    pub receiver: UdpBatchReceiver<RECEIVE_BATCH_SIZE, PACKET_BUF_POOL_SIZE>,
74    pub sender: UdpBatchSender<RECEIVE_BATCH_SIZE, PACKET_BUF_POOL_SIZE>,
75}
76
77impl SnapTunnelDriver {
78    fn new(
79        static_private: x25519::StaticSecret,
80        peer_public: x25519::PublicKey,
81        underlay_socket: Arc<tokio::net::UdpSocket>,
82        dataplane_address: SocketAddr,
83        persistent_keepalive_seconds: Option<u16>,
84        packet_sender: async_channel::Sender<BytesMut>,
85        pool: PacketBufPool<PACKET_BUF_POOL_SIZE>,
86    ) -> io::Result<Self> {
87        let update_timers_interval = tokio::time::interval_at(
88            tokio::time::Instant::now() + Duration::from_millis(250),
89            Duration::from_millis(250),
90        );
91        let receiver = UdpBatchReceiver::<RECEIVE_BATCH_SIZE, PACKET_BUF_POOL_SIZE>::new(
92            underlay_socket.as_ref(),
93            &pool,
94        )?;
95        let sender = UdpBatchSender::<RECEIVE_BATCH_SIZE, PACKET_BUF_POOL_SIZE>::new(
96            underlay_socket.as_ref(),
97        )?;
98        Ok(Self {
99            tunn: Arc::new(Mutex::new(Self::create_tunn(
100                static_private.clone(),
101                peer_public,
102                dataplane_address,
103                persistent_keepalive_seconds,
104            ))),
105            static_private,
106            peer_public,
107            underlay_socket,
108            dataplane_address,
109            persistent_keepalive_seconds,
110            update_timers_interval,
111            packet_sender,
112            local_sockaddr: None,
113            receiver,
114            sender,
115            pool,
116        })
117    }
118
119    #[instrument(name = "st-client", skip(self), fields(socket_addr= ?self.local_sockaddr))]
120    async fn initiate_connection(&mut self) -> Result<SocketAddr, SnapTunnelDriverError> {
121        let handshake_init = self.tunn.lock().unwrap().format_handshake_initiation(false);
122        if let Some(wg_init) = handshake_init
123            && let Err(e) = self
124                .underlay_socket
125                .send_to(
126                    to_bytes(WgKind::HandshakeInit(wg_init)).as_bytes(),
127                    self.dataplane_address,
128                )
129                .await
130        {
131            return Err(SnapTunnelDriverError::SendIoError(e));
132        }
133        // Drive the tunnel until any error occurs or the handshake is completed.
134        loop {
135            self.drive_once().await?;
136            if let Some(sockaddr) = self.tunn.lock().unwrap().get_initiator_remote_sockaddr() {
137                if self.local_sockaddr.is_none() {
138                    self.local_sockaddr = Some(sockaddr);
139                }
140                tracing::debug!(local_addr=?sockaddr, "handshake completed, local address assigned");
141                return Ok(sockaddr);
142            }
143        }
144    }
145
146    #[instrument(name = "st-client", skip(self), fields(socket_addr= ?self.local_sockaddr))]
147    async fn main_loop(mut self) {
148        let local_sockaddr = self
149            .local_sockaddr
150            .expect("local address must be set before main_loop()");
151        loop {
152            match self.drive_once().await {
153                Err(SnapTunnelDriverError::ReceiveQueueClosed) => {
154                    tracing::info!("receive queue closed, snap tunnel driver shutting down");
155                    return;
156                }
157                Err(SnapTunnelDriverError::ConnectionExpired) => {
158                    loop {
159                        let mut backoff = BackoffState::new();
160                        // reset tunnel
161                        *self.tunn.lock().expect("poison") = Self::create_tunn(
162                            self.static_private.clone(),
163                            self.peer_public,
164                            self.dataplane_address,
165                            self.persistent_keepalive_seconds,
166                        );
167                        match self.initiate_connection().await {
168                            Ok(addr) if addr == local_sockaddr => break,
169                            Ok(addr) => {
170                                tracing::error!(expected_addr=?local_sockaddr, new_addr=?addr, "local socket address changed");
171                            }
172                            Err(err) => {
173                                tracing::error!(?err, "error driving tunnel");
174                            }
175                        }
176                        backoff.backoff().await;
177                    }
178                }
179                Err(ref e) => tracing::error!(err=?e, "error driving tunnel"),
180                _ => {}
181            }
182        }
183    }
184
185    /// Drives the tunnel once. Returns Ok(()) if no error occured in the drive, otherwise returns
186    /// the error. This method is called periodically by the main loop to update the timers and
187    /// receive packets.
188    async fn drive_once(&mut self) -> Result<(), SnapTunnelDriverError> {
189        select! {
190            // bias to ensure that high receive load cannot starve the timer
191            biased;
192            _ = self.update_timers_interval.tick() => {
193                let p = match self.tunn.lock().unwrap().update_timers() {
194                    Ok(Some(wg)) => { Some(wg) },
195                    Ok(None) => None,
196                    Err(WireGuardError::ConnectionExpired) => {
197                        return Err(SnapTunnelDriverError::ConnectionExpired);
198                    }
199                    Err(e) => {
200                        // At the time of writing, update_timers does not return any error
201                        // other than ConnectionExpired.
202                        tracing::error!(err=?e, "unexpected error updating timers on tunnel");
203                        None
204                    }
205                };
206                if let Some(wg) = p && let Err(e) = self.underlay_socket.send_to(to_bytes(wg).as_bytes(), self.dataplane_address).await {
207                    return Err(SnapTunnelDriverError::SendIoError(e));
208                }
209            },
210            recv = self.receiver.recv_batch(&self.underlay_socket, &self.pool, |buf, sender_addr| {
211                if sender_addr != self.dataplane_address {
212                    return Ok(());
213                }
214                let Ok(wg) = buf.try_into_wg() else {
215                    tracing::debug!("received packet that is not a valid WireGuard packet, ignoring");
216                    return Ok(());
217                };
218                let result = self.tunn.lock().unwrap().handle_incoming_packet(wg);
219                match result {
220                    TunnResult::Done => {}
221                    TunnResult::Err(e) => {
222                        return Err(SnapTunnelDriverError::WireguardError(e));
223                    }
224                    TunnResult::WriteToNetwork(p) => {
225                        if let Err(error) = self
226                            .sender
227                            .try_queue_packet(to_bytes(p), self.dataplane_address)
228                        {
229                            match error {
230                                QueuePacketError::Full { packet, target } => {
231                                    let err = self.sender.try_flush_best_effort(&self.underlay_socket);
232                                    if let Err(ref flush_err) = err
233                                        && flush_err.kind() != io::ErrorKind::WouldBlock
234                                    {
235                                        return Err(SnapTunnelDriverError::SendIoError(io::Error::new(
236                                            flush_err.kind(),
237                                            flush_err.to_string(),
238                                        )));
239                                    }
240                                    if self.sender.try_queue_packet(packet, target).is_err() {
241                                        tracing::debug!(?target, "dropping outbound packet because batched sender remains full");
242                                    }
243                                }
244                                QueuePacketError::PacketTooLarge {
245                                    packet_len,
246                                    max_packet_size,
247                                    ..
248                                } => {
249                                    return Err(SnapTunnelDriverError::SendIoError(io::Error::new(
250                                        io::ErrorKind::InvalidInput,
251                                        format!(
252                                            "outbound packet length {packet_len} exceeds batched sender max of {max_packet_size}"
253                                        ),
254                                    )));
255                                }
256                            }
257                        }
258                        for queued in self.tunn.lock().unwrap().get_queued_packets() {
259                            if let Err(error) = self
260                                .sender
261                                .try_queue_packet(to_bytes(queued), self.dataplane_address)
262                            {
263                                match error {
264                                    QueuePacketError::Full { packet, target } => {
265                                        let err = self.sender.try_flush_best_effort(&self.underlay_socket);
266                                        if let Err(ref flush_err) = err
267                                            && flush_err.kind() != io::ErrorKind::WouldBlock
268                                        {
269                                            return Err(SnapTunnelDriverError::SendIoError(io::Error::new(
270                                                flush_err.kind(),
271                                                flush_err.to_string(),
272                                            )));
273                                        }
274                                        if self.sender.try_queue_packet(packet, target).is_err() {
275                                            tracing::debug!(?target, "dropping queued outbound packet because batched sender remains full");
276                                        }
277                                    }
278                                    QueuePacketError::PacketTooLarge {
279                                        packet_len,
280                                        max_packet_size,
281                                        ..
282                                    } => {
283                                        return Err(SnapTunnelDriverError::SendIoError(io::Error::new(
284                                            io::ErrorKind::InvalidInput,
285                                            format!(
286                                                "queued outbound packet length {packet_len} exceeds batched sender max of {max_packet_size}"
287                                            ),
288                                        )));
289                                    }
290                                }
291                            }
292                        }
293                    }
294                    TunnResult::WriteToTunnel(mut p) => {
295                        let buf = p.buf_mut().to_owned();
296                        if !buf.is_empty() {
297                            match self.packet_sender.try_send(buf) {
298                                Ok(()) => {}
299                                Err(async_channel::TrySendError::Full(_)) => {
300                                    tracing::debug!("receive channel is full, dropping packet");
301                                }
302                                Err(_) => {
303                                    return Err(SnapTunnelDriverError::ReceiveQueueClosed);
304                                }
305                            }
306                        }
307                    }
308                }
309                Ok(())
310            }) => {
311                match recv {
312                    Ok(()) => {
313                        self.sender.flush(&self.underlay_socket).await?;
314                    }
315                    Err(RecvBatchError::Io(e)) => {
316                        return Err(SnapTunnelDriverError::ReceiveIoError(e));
317                    }
318                    Err(RecvBatchError::Handler(e)) => {
319                        return Err(e);
320                    }
321                }
322            }
323        }
324        Ok(())
325    }
326
327    fn create_tunn(
328        static_private: x25519::StaticSecret,
329        peer_public: x25519::PublicKey,
330        dataplane_address: SocketAddr,
331        persistent_keepalive_seconds: Option<u16>,
332    ) -> Tunn {
333        let local_public = x25519::PublicKey::from(&static_private);
334        Tunn::new(
335            static_private,
336            peer_public,
337            None,
338            persistent_keepalive_seconds,
339            0,
340            Arc::new(RateLimiter::new(&local_public, HANDSHAKE_RATE_LIMIT)),
341            dataplane_address,
342        )
343    }
344}
345
346/// Error when receiving a packet from the SNAP tunnel connection.
347#[derive(Debug, thiserror::Error)]
348pub enum SnapTunnelReceiveError {
349    /// The receive queue is closed.
350    #[error("receive queue closed")]
351    ReceiveQueueClosed,
352}
353
354type RecvFuture = Pin<Box<dyn Future<Output = Result<BytesMut, async_channel::RecvError>> + Send>>;
355
356/// A SNAP tunnel connection.
357pub struct SnapTunnel {
358    _guard: TunnelGuard,
359    tunn: Arc<Mutex<Tunn>>,
360    underlay_socket: Arc<tokio::net::UdpSocket>,
361    dataplane_address: SocketAddr,
362    local_sockaddr: SocketAddr,
363    receive_queue: async_channel::Receiver<BytesMut>,
364    /// Stored receive future for poll_recv. Protected by Mutex for interior mutability.
365    recv_future: Mutex<Option<RecvFuture>>,
366    /// Tasks that drives the SNAP tunnel.
367    /// Cancelled when the socket is dropped.
368    driver_task: JoinHandle<()>,
369}
370
371impl Drop for SnapTunnel {
372    fn drop(&mut self) {
373        self.driver_task.abort();
374    }
375}
376
377impl SnapTunnel {
378    /// Creates a new SNAP tunnel and waits for the handshake to complete.
379    ///
380    /// # Arguments
381    ///
382    /// * `static_private` - The client's static private key
383    /// * `peer_public` - The server's static public key (needed for handshake)
384    /// * `rate_limiter` - Rate limiter for the tunnel
385    /// * `underlay_socket` - UDP socket for sending/receiving packets
386    /// * `dataplane_address` - Address of the remote server
387    /// * `receive_queue_capacity` - Capacity of the receive queue
388    pub(super) async fn new(
389        guard: TunnelGuard,
390        static_private: x25519::StaticSecret,
391        peer_public: x25519::PublicKey,
392        underlay_socket: Arc<tokio::net::UdpSocket>,
393        dataplane_address: SocketAddr,
394        receive_queue_capacity: usize,
395        persistent_keepalive_seconds: Option<u16>,
396        pool: PacketBufPool<PACKET_BUF_POOL_SIZE>,
397    ) -> Result<Self, SnapTunnelDriverError> {
398        let (packet_sender, packet_receiver) = async_channel::bounded(receive_queue_capacity);
399        let mut driver = SnapTunnelDriver::new(
400            static_private,
401            peer_public,
402            underlay_socket.clone(),
403            dataplane_address,
404            persistent_keepalive_seconds,
405            packet_sender,
406            pool.clone(),
407        )?;
408        let socket_addr = driver.initiate_connection().await?;
409        Ok(Self {
410            _guard: guard,
411            tunn: driver.tunn.clone(),
412            underlay_socket,
413            dataplane_address,
414            local_sockaddr: socket_addr,
415            receive_queue: packet_receiver,
416            recv_future: Mutex::new(None),
417            driver_task: tokio::spawn(driver.main_loop()),
418        })
419    }
420
421    /// Send a packet to the remote server.
422    // xxx(dsd): during a connection reset, packets will be silently dropped.
423    #[instrument(name = "st-client", skip_all, fields(socket_addr= ?self.local_sockaddr, payload_len= packet.len()))]
424    pub async fn send(&self, packet: Packet) -> io::Result<()> {
425        let encapsulated_packet = self.tunn.lock().unwrap().handle_outgoing_packet(packet);
426        match encapsulated_packet {
427            Some(wg) => {
428                let bytes = match wg {
429                    WgKind::HandshakeInit(p) => p.into_bytes(),
430                    WgKind::HandshakeResp(p) => p.into_bytes(),
431                    WgKind::CookieReply(p) => p.into_bytes(),
432                    WgKind::Data(p) => p.into_bytes(),
433                };
434                tracing::trace!(dataplane_address=?self.dataplane_address, "sending packet");
435                self.underlay_socket
436                    .send_to(bytes.as_bytes(), self.dataplane_address)
437                    .await?;
438                Ok(())
439            }
440            None => {
441                // None is returned if a handshake is ongoing but not yet complete.
442                // In this case the packet is queued and will be sent when the handshake is
443                // complete.
444                tracing::trace!("handshake ongoing, queueing packet");
445                Ok(())
446            }
447        }
448    }
449
450    /// Try to send a packet to the remote server. Returns error of try_send_to.
451    #[instrument(name = "st-client", skip_all, fields(socket_addr= ?self.local_sockaddr, payload_len= packet.len()))]
452    pub fn try_send(&self, packet: Packet) -> io::Result<()> {
453        match self.tunn.lock().unwrap().handle_outgoing_packet(packet) {
454            Some(wg) => {
455                let bytes = match wg {
456                    WgKind::HandshakeInit(p) => p.into_bytes(),
457                    WgKind::HandshakeResp(p) => p.into_bytes(),
458                    WgKind::CookieReply(p) => p.into_bytes(),
459                    WgKind::Data(p) => p.into_bytes(),
460                };
461                tracing::trace!(dataplane_address=?self.dataplane_address, "trying to send packet");
462                self.underlay_socket
463                    .try_send_to(bytes.as_bytes(), self.dataplane_address)?;
464                Ok(())
465            }
466            None => {
467                // None is returned if a handshake is ongoing but not yet complete.
468                // In this case the packet is queued and will be sent when the handshake is
469                // complete.
470                Ok(())
471            }
472        }
473    }
474
475    /// Receive a packet from the remote server.
476    pub async fn recv(&self) -> Result<Bytes, SnapTunnelReceiveError> {
477        match self.receive_queue.recv().await {
478            Ok(packet) => Ok(packet.into()),
479            Err(_) => Err(SnapTunnelReceiveError::ReceiveQueueClosed),
480        }
481    }
482
483    /// Poll for a packet from the remote server.
484    pub fn poll_recv(
485        &self,
486        cx: &mut std::task::Context<'_>,
487    ) -> std::task::Poll<Result<Bytes, SnapTunnelReceiveError>> {
488        let mut fut_guard = self.recv_future.lock().expect("lock poisoned");
489
490        // Create future if it doesn't exist
491        if fut_guard.is_none() {
492            // Clone the receiver (cheap with async-channel) to avoid borrowing self
493            let receiver = self.receive_queue.clone();
494            *fut_guard = Some(Box::pin(async move { receiver.recv().await }));
495        }
496
497        // Poll the stored future
498        let fut = fut_guard.as_mut().expect("future cannot be none");
499        match fut.as_mut().poll(cx) {
500            std::task::Poll::Ready(Ok(packet)) => {
501                // Clear the future so a new one is created on next poll
502                *fut_guard = None;
503                std::task::Poll::Ready(Ok(packet.into()))
504            }
505            std::task::Poll::Ready(Err(_)) => {
506                tracing::trace!("receive queue closed, returning error");
507                *fut_guard = None;
508                std::task::Poll::Ready(Err(SnapTunnelReceiveError::ReceiveQueueClosed))
509            }
510            std::task::Poll::Pending => std::task::Poll::Pending,
511        }
512    }
513
514    /// Get the local socket address. Assigned by the remote server.
515    pub fn local_addr(&self) -> SocketAddr {
516        self.local_sockaddr
517    }
518
519    /// Check if the socket is writable.
520    pub async fn writable(&self) -> io::Result<()> {
521        self.underlay_socket.writable().await
522    }
523
524    /// The data plane the tunnel is connected to.
525    pub fn data_plane_address(&self) -> SocketAddr {
526        self.dataplane_address
527    }
528}
529
530struct BackoffState {
531    last: Instant,
532    exp_backoff: ExponentialBackoff,
533    attempt: usize,
534}
535
536impl BackoffState {
537    fn new() -> Self {
538        Self {
539            last: Instant::now(),
540            exp_backoff: ExponentialBackoff::new(
541                5.0, 180.0, // max 3 mins
542                1.3, 0.5,
543            ),
544            attempt: 0,
545        }
546    }
547
548    fn backoff(&mut self) -> impl Future<Output = ()> {
549        let now = Instant::now();
550        let until_next = (self.last + self.exp_backoff.duration(self.attempt as u32))
551            .checked_duration_since(now);
552        self.attempt += 1;
553        self.last = now;
554
555        async move {
556            if let Some(d) = until_next {
557                tokio::time::sleep(d).await;
558            }
559        }
560    }
561}
562
563fn to_bytes(wg: WgKind) -> Packet<[u8]> {
564    match wg {
565        WgKind::HandshakeInit(p) => p.into_bytes(),
566        WgKind::HandshakeResp(p) => p.into_bytes(),
567        WgKind::CookieReply(p) => p.into_bytes(),
568        WgKind::Data(p) => p.into_bytes(),
569    }
570}