1use std::{
16 future::Future,
17 io,
18 net::SocketAddr,
19 pin::Pin,
20 sync::{Arc, Mutex},
21 time::{Duration, Instant},
22};
23
24use ana_gotatun::{
25 noise::{Tunn, TunnResult, errors::WireGuardError, rate_limiter::RateLimiter},
26 packet::{Packet, PacketBufPool, WgKind},
27 x25519::{self},
28};
29use bytes::{Bytes, BytesMut};
30use scion_sdk_utils::backoff::ExponentialBackoff;
31use tokio::{select, task::JoinHandle, time::Interval};
32use tracing::instrument;
33use zerocopy::IntoBytes as _;
34
35use super::{PACKET_BUF_POOL_SIZE, TunnelGuard};
36use crate::udp_batch::{QueuePacketError, RecvBatchError, UdpBatchReceiver, UdpBatchSender};
37
38const HANDSHAKE_RATE_LIMIT: u64 = 20;
39const RECEIVE_BATCH_SIZE: usize = 64;
40
41#[derive(Debug, thiserror::Error)]
43pub enum SnapTunnelDriverError {
44 #[error("send i/o error: {0}")]
46 SendIoError(#[from] std::io::Error),
47 #[error("receive i/o error: {0}")]
49 ReceiveIoError(std::io::Error),
50 #[error("receive queue closed")]
52 ReceiveQueueClosed,
53 #[error("connection expired")]
55 ConnectionExpired,
56 #[error("error receiving a Wireguard packet: {0:?}")]
59 WireguardError(WireGuardError),
60}
61
62struct SnapTunnelDriver {
63 pub tunn: Arc<Mutex<Tunn>>,
64 pub static_private: x25519::StaticSecret,
65 pub peer_public: x25519::PublicKey,
66 pub underlay_socket: Arc<tokio::net::UdpSocket>,
67 pub dataplane_address: SocketAddr,
68 pub persistent_keepalive_seconds: Option<u16>,
69 pub update_timers_interval: Interval,
70 pub packet_sender: async_channel::Sender<BytesMut>,
71 pub local_sockaddr: Option<SocketAddr>,
72 pub pool: PacketBufPool<PACKET_BUF_POOL_SIZE>,
73 pub receiver: UdpBatchReceiver<RECEIVE_BATCH_SIZE, PACKET_BUF_POOL_SIZE>,
74 pub sender: UdpBatchSender<RECEIVE_BATCH_SIZE, PACKET_BUF_POOL_SIZE>,
75}
76
77impl SnapTunnelDriver {
78 fn new(
79 static_private: x25519::StaticSecret,
80 peer_public: x25519::PublicKey,
81 underlay_socket: Arc<tokio::net::UdpSocket>,
82 dataplane_address: SocketAddr,
83 persistent_keepalive_seconds: Option<u16>,
84 packet_sender: async_channel::Sender<BytesMut>,
85 pool: PacketBufPool<PACKET_BUF_POOL_SIZE>,
86 ) -> io::Result<Self> {
87 let update_timers_interval = tokio::time::interval_at(
88 tokio::time::Instant::now() + Duration::from_millis(250),
89 Duration::from_millis(250),
90 );
91 let receiver = UdpBatchReceiver::<RECEIVE_BATCH_SIZE, PACKET_BUF_POOL_SIZE>::new(
92 underlay_socket.as_ref(),
93 &pool,
94 )?;
95 let sender = UdpBatchSender::<RECEIVE_BATCH_SIZE, PACKET_BUF_POOL_SIZE>::new(
96 underlay_socket.as_ref(),
97 )?;
98 Ok(Self {
99 tunn: Arc::new(Mutex::new(Self::create_tunn(
100 static_private.clone(),
101 peer_public,
102 dataplane_address,
103 persistent_keepalive_seconds,
104 ))),
105 static_private,
106 peer_public,
107 underlay_socket,
108 dataplane_address,
109 persistent_keepalive_seconds,
110 update_timers_interval,
111 packet_sender,
112 local_sockaddr: None,
113 receiver,
114 sender,
115 pool,
116 })
117 }
118
119 #[instrument(name = "st-client", skip(self), fields(socket_addr= ?self.local_sockaddr))]
120 async fn initiate_connection(&mut self) -> Result<SocketAddr, SnapTunnelDriverError> {
121 let handshake_init = self.tunn.lock().unwrap().format_handshake_initiation(false);
122 if let Some(wg_init) = handshake_init
123 && let Err(e) = self
124 .underlay_socket
125 .send_to(
126 to_bytes(WgKind::HandshakeInit(wg_init)).as_bytes(),
127 self.dataplane_address,
128 )
129 .await
130 {
131 return Err(SnapTunnelDriverError::SendIoError(e));
132 }
133 loop {
135 self.drive_once().await?;
136 if let Some(sockaddr) = self.tunn.lock().unwrap().get_initiator_remote_sockaddr() {
137 if self.local_sockaddr.is_none() {
138 self.local_sockaddr = Some(sockaddr);
139 }
140 tracing::debug!(local_addr=?sockaddr, "handshake completed, local address assigned");
141 return Ok(sockaddr);
142 }
143 }
144 }
145
146 #[instrument(name = "st-client", skip(self), fields(socket_addr= ?self.local_sockaddr))]
147 async fn main_loop(mut self) {
148 let local_sockaddr = self
149 .local_sockaddr
150 .expect("local address must be set before main_loop()");
151 loop {
152 match self.drive_once().await {
153 Err(SnapTunnelDriverError::ReceiveQueueClosed) => {
154 tracing::info!("receive queue closed, snap tunnel driver shutting down");
155 return;
156 }
157 Err(SnapTunnelDriverError::ConnectionExpired) => {
158 loop {
159 let mut backoff = BackoffState::new();
160 *self.tunn.lock().expect("poison") = Self::create_tunn(
162 self.static_private.clone(),
163 self.peer_public,
164 self.dataplane_address,
165 self.persistent_keepalive_seconds,
166 );
167 match self.initiate_connection().await {
168 Ok(addr) if addr == local_sockaddr => break,
169 Ok(addr) => {
170 tracing::error!(expected_addr=?local_sockaddr, new_addr=?addr, "local socket address changed");
171 }
172 Err(err) => {
173 tracing::error!(?err, "error driving tunnel");
174 }
175 }
176 backoff.backoff().await;
177 }
178 }
179 Err(ref e) => tracing::error!(err=?e, "error driving tunnel"),
180 _ => {}
181 }
182 }
183 }
184
185 async fn drive_once(&mut self) -> Result<(), SnapTunnelDriverError> {
189 select! {
190 biased;
192 _ = self.update_timers_interval.tick() => {
193 let p = match self.tunn.lock().unwrap().update_timers() {
194 Ok(Some(wg)) => { Some(wg) },
195 Ok(None) => None,
196 Err(WireGuardError::ConnectionExpired) => {
197 return Err(SnapTunnelDriverError::ConnectionExpired);
198 }
199 Err(e) => {
200 tracing::error!(err=?e, "unexpected error updating timers on tunnel");
203 None
204 }
205 };
206 if let Some(wg) = p && let Err(e) = self.underlay_socket.send_to(to_bytes(wg).as_bytes(), self.dataplane_address).await {
207 return Err(SnapTunnelDriverError::SendIoError(e));
208 }
209 },
210 recv = self.receiver.recv_batch(&self.underlay_socket, &self.pool, |buf, sender_addr| {
211 if sender_addr != self.dataplane_address {
212 return Ok(());
213 }
214 let Ok(wg) = buf.try_into_wg() else {
215 tracing::debug!("received packet that is not a valid WireGuard packet, ignoring");
216 return Ok(());
217 };
218 let result = self.tunn.lock().unwrap().handle_incoming_packet(wg);
219 match result {
220 TunnResult::Done => {}
221 TunnResult::Err(e) => {
222 return Err(SnapTunnelDriverError::WireguardError(e));
223 }
224 TunnResult::WriteToNetwork(p) => {
225 if let Err(error) = self
226 .sender
227 .try_queue_packet(to_bytes(p), self.dataplane_address)
228 {
229 match error {
230 QueuePacketError::Full { packet, target } => {
231 let err = self.sender.try_flush_best_effort(&self.underlay_socket);
232 if let Err(ref flush_err) = err
233 && flush_err.kind() != io::ErrorKind::WouldBlock
234 {
235 return Err(SnapTunnelDriverError::SendIoError(io::Error::new(
236 flush_err.kind(),
237 flush_err.to_string(),
238 )));
239 }
240 if self.sender.try_queue_packet(packet, target).is_err() {
241 tracing::debug!(?target, "dropping outbound packet because batched sender remains full");
242 }
243 }
244 QueuePacketError::PacketTooLarge {
245 packet_len,
246 max_packet_size,
247 ..
248 } => {
249 return Err(SnapTunnelDriverError::SendIoError(io::Error::new(
250 io::ErrorKind::InvalidInput,
251 format!(
252 "outbound packet length {packet_len} exceeds batched sender max of {max_packet_size}"
253 ),
254 )));
255 }
256 }
257 }
258 for queued in self.tunn.lock().unwrap().get_queued_packets() {
259 if let Err(error) = self
260 .sender
261 .try_queue_packet(to_bytes(queued), self.dataplane_address)
262 {
263 match error {
264 QueuePacketError::Full { packet, target } => {
265 let err = self.sender.try_flush_best_effort(&self.underlay_socket);
266 if let Err(ref flush_err) = err
267 && flush_err.kind() != io::ErrorKind::WouldBlock
268 {
269 return Err(SnapTunnelDriverError::SendIoError(io::Error::new(
270 flush_err.kind(),
271 flush_err.to_string(),
272 )));
273 }
274 if self.sender.try_queue_packet(packet, target).is_err() {
275 tracing::debug!(?target, "dropping queued outbound packet because batched sender remains full");
276 }
277 }
278 QueuePacketError::PacketTooLarge {
279 packet_len,
280 max_packet_size,
281 ..
282 } => {
283 return Err(SnapTunnelDriverError::SendIoError(io::Error::new(
284 io::ErrorKind::InvalidInput,
285 format!(
286 "queued outbound packet length {packet_len} exceeds batched sender max of {max_packet_size}"
287 ),
288 )));
289 }
290 }
291 }
292 }
293 }
294 TunnResult::WriteToTunnel(mut p) => {
295 let buf = p.buf_mut().to_owned();
296 if !buf.is_empty() {
297 match self.packet_sender.try_send(buf) {
298 Ok(()) => {}
299 Err(async_channel::TrySendError::Full(_)) => {
300 tracing::debug!("receive channel is full, dropping packet");
301 }
302 Err(_) => {
303 return Err(SnapTunnelDriverError::ReceiveQueueClosed);
304 }
305 }
306 }
307 }
308 }
309 Ok(())
310 }) => {
311 match recv {
312 Ok(()) => {
313 self.sender.flush(&self.underlay_socket).await?;
314 }
315 Err(RecvBatchError::Io(e)) => {
316 return Err(SnapTunnelDriverError::ReceiveIoError(e));
317 }
318 Err(RecvBatchError::Handler(e)) => {
319 return Err(e);
320 }
321 }
322 }
323 }
324 Ok(())
325 }
326
327 fn create_tunn(
328 static_private: x25519::StaticSecret,
329 peer_public: x25519::PublicKey,
330 dataplane_address: SocketAddr,
331 persistent_keepalive_seconds: Option<u16>,
332 ) -> Tunn {
333 let local_public = x25519::PublicKey::from(&static_private);
334 Tunn::new(
335 static_private,
336 peer_public,
337 None,
338 persistent_keepalive_seconds,
339 0,
340 Arc::new(RateLimiter::new(&local_public, HANDSHAKE_RATE_LIMIT)),
341 dataplane_address,
342 )
343 }
344}
345
346#[derive(Debug, thiserror::Error)]
348pub enum SnapTunnelReceiveError {
349 #[error("receive queue closed")]
351 ReceiveQueueClosed,
352}
353
354type RecvFuture = Pin<Box<dyn Future<Output = Result<BytesMut, async_channel::RecvError>> + Send>>;
355
356pub struct SnapTunnel {
358 _guard: TunnelGuard,
359 tunn: Arc<Mutex<Tunn>>,
360 underlay_socket: Arc<tokio::net::UdpSocket>,
361 dataplane_address: SocketAddr,
362 local_sockaddr: SocketAddr,
363 receive_queue: async_channel::Receiver<BytesMut>,
364 recv_future: Mutex<Option<RecvFuture>>,
366 driver_task: JoinHandle<()>,
369}
370
371impl Drop for SnapTunnel {
372 fn drop(&mut self) {
373 self.driver_task.abort();
374 }
375}
376
377impl SnapTunnel {
378 pub(super) async fn new(
389 guard: TunnelGuard,
390 static_private: x25519::StaticSecret,
391 peer_public: x25519::PublicKey,
392 underlay_socket: Arc<tokio::net::UdpSocket>,
393 dataplane_address: SocketAddr,
394 receive_queue_capacity: usize,
395 persistent_keepalive_seconds: Option<u16>,
396 pool: PacketBufPool<PACKET_BUF_POOL_SIZE>,
397 ) -> Result<Self, SnapTunnelDriverError> {
398 let (packet_sender, packet_receiver) = async_channel::bounded(receive_queue_capacity);
399 let mut driver = SnapTunnelDriver::new(
400 static_private,
401 peer_public,
402 underlay_socket.clone(),
403 dataplane_address,
404 persistent_keepalive_seconds,
405 packet_sender,
406 pool.clone(),
407 )?;
408 let socket_addr = driver.initiate_connection().await?;
409 Ok(Self {
410 _guard: guard,
411 tunn: driver.tunn.clone(),
412 underlay_socket,
413 dataplane_address,
414 local_sockaddr: socket_addr,
415 receive_queue: packet_receiver,
416 recv_future: Mutex::new(None),
417 driver_task: tokio::spawn(driver.main_loop()),
418 })
419 }
420
421 #[instrument(name = "st-client", skip_all, fields(socket_addr= ?self.local_sockaddr, payload_len= packet.len()))]
424 pub async fn send(&self, packet: Packet) -> io::Result<()> {
425 let encapsulated_packet = self.tunn.lock().unwrap().handle_outgoing_packet(packet);
426 match encapsulated_packet {
427 Some(wg) => {
428 let bytes = match wg {
429 WgKind::HandshakeInit(p) => p.into_bytes(),
430 WgKind::HandshakeResp(p) => p.into_bytes(),
431 WgKind::CookieReply(p) => p.into_bytes(),
432 WgKind::Data(p) => p.into_bytes(),
433 };
434 tracing::trace!(dataplane_address=?self.dataplane_address, "sending packet");
435 self.underlay_socket
436 .send_to(bytes.as_bytes(), self.dataplane_address)
437 .await?;
438 Ok(())
439 }
440 None => {
441 tracing::trace!("handshake ongoing, queueing packet");
445 Ok(())
446 }
447 }
448 }
449
450 #[instrument(name = "st-client", skip_all, fields(socket_addr= ?self.local_sockaddr, payload_len= packet.len()))]
452 pub fn try_send(&self, packet: Packet) -> io::Result<()> {
453 match self.tunn.lock().unwrap().handle_outgoing_packet(packet) {
454 Some(wg) => {
455 let bytes = match wg {
456 WgKind::HandshakeInit(p) => p.into_bytes(),
457 WgKind::HandshakeResp(p) => p.into_bytes(),
458 WgKind::CookieReply(p) => p.into_bytes(),
459 WgKind::Data(p) => p.into_bytes(),
460 };
461 tracing::trace!(dataplane_address=?self.dataplane_address, "trying to send packet");
462 self.underlay_socket
463 .try_send_to(bytes.as_bytes(), self.dataplane_address)?;
464 Ok(())
465 }
466 None => {
467 Ok(())
471 }
472 }
473 }
474
475 pub async fn recv(&self) -> Result<Bytes, SnapTunnelReceiveError> {
477 match self.receive_queue.recv().await {
478 Ok(packet) => Ok(packet.into()),
479 Err(_) => Err(SnapTunnelReceiveError::ReceiveQueueClosed),
480 }
481 }
482
483 pub fn poll_recv(
485 &self,
486 cx: &mut std::task::Context<'_>,
487 ) -> std::task::Poll<Result<Bytes, SnapTunnelReceiveError>> {
488 let mut fut_guard = self.recv_future.lock().expect("lock poisoned");
489
490 if fut_guard.is_none() {
492 let receiver = self.receive_queue.clone();
494 *fut_guard = Some(Box::pin(async move { receiver.recv().await }));
495 }
496
497 let fut = fut_guard.as_mut().expect("future cannot be none");
499 match fut.as_mut().poll(cx) {
500 std::task::Poll::Ready(Ok(packet)) => {
501 *fut_guard = None;
503 std::task::Poll::Ready(Ok(packet.into()))
504 }
505 std::task::Poll::Ready(Err(_)) => {
506 tracing::trace!("receive queue closed, returning error");
507 *fut_guard = None;
508 std::task::Poll::Ready(Err(SnapTunnelReceiveError::ReceiveQueueClosed))
509 }
510 std::task::Poll::Pending => std::task::Poll::Pending,
511 }
512 }
513
514 pub fn local_addr(&self) -> SocketAddr {
516 self.local_sockaddr
517 }
518
519 pub async fn writable(&self) -> io::Result<()> {
521 self.underlay_socket.writable().await
522 }
523
524 pub fn data_plane_address(&self) -> SocketAddr {
526 self.dataplane_address
527 }
528}
529
530struct BackoffState {
531 last: Instant,
532 exp_backoff: ExponentialBackoff,
533 attempt: usize,
534}
535
536impl BackoffState {
537 fn new() -> Self {
538 Self {
539 last: Instant::now(),
540 exp_backoff: ExponentialBackoff::new(
541 5.0, 180.0, 1.3, 0.5,
543 ),
544 attempt: 0,
545 }
546 }
547
548 fn backoff(&mut self) -> impl Future<Output = ()> {
549 let now = Instant::now();
550 let until_next = (self.last + self.exp_backoff.duration(self.attempt as u32))
551 .checked_duration_since(now);
552 self.attempt += 1;
553 self.last = now;
554
555 async move {
556 if let Some(d) = until_next {
557 tokio::time::sleep(d).await;
558 }
559 }
560 }
561}
562
563fn to_bytes(wg: WgKind) -> Packet<[u8]> {
564 match wg {
565 WgKind::HandshakeInit(p) => p.into_bytes(),
566 WgKind::HandshakeResp(p) => p.into_bytes(),
567 WgKind::CookieReply(p) => p.into_bytes(),
568 WgKind::Data(p) => p.into_bytes(),
569 }
570}