Skip to main content

scion_stack/scionstack/
quic.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SCION stack QUICK endpoint.
15
16use std::{
17    collections::HashMap,
18    fmt::{self, Debug},
19    hash::{BuildHasher, Hash as _, Hasher as _},
20    io::ErrorKind,
21    net::{IpAddr, Ipv6Addr},
22    pin::Pin,
23    sync::{Arc, Mutex},
24    task::{Poll, ready},
25    time::{Duration, Instant},
26};
27
28use bytes::BufMut as _;
29use chrono::Utc;
30use foldhash::fast::FixedState;
31use quinn::{AsyncUdpSocket, udp::RecvMeta};
32use scion_proto::{
33    address::SocketAddr,
34    packet::{ByEndpoint, ScionPacketUdp},
35};
36
37use super::{AsyncUdpUnderlaySocket, udp_polling::UdpPoller};
38use crate::{
39    path::manager::traits::{PathPrefetcher, SyncPathManager},
40    quic::ScionQuinnConn,
41};
42
43/// Log at most 1 IO error every 3 seconds.
44const IO_ERROR_LOG_INTERVAL: Duration = Duration::from_secs(3);
45
46/// A wrapper around a quinn::Endpoint that translates between SCION and ip:port addresses.
47///
48/// This is necessary because quinn expects a std::net::SocketAddr, but SCION uses
49/// scion_proto::address::SocketAddr.
50///
51/// Addresses are mapped by the provided ScionAsyncUdpSocket.
52pub struct Endpoint {
53    inner: quinn::Endpoint,
54    path_prefetcher: Arc<dyn PathPrefetcher + Send + Sync>,
55    address_translator: Arc<AddressTranslator>,
56    local_scion_addr: scion_proto::address::SocketAddr,
57}
58
59impl Endpoint {
60    /// Creates a new endpoint.
61    pub fn new_with_abstract_socket(
62        config: quinn::EndpointConfig,
63        server_config: Option<quinn::ServerConfig>,
64        socket: Arc<dyn quinn::AsyncUdpSocket>,
65        local_scion_addr: scion_proto::address::SocketAddr,
66        runtime: Arc<dyn quinn::Runtime>,
67        pather: Arc<dyn PathPrefetcher + Send + Sync>,
68        address_translator: Arc<AddressTranslator>,
69    ) -> std::io::Result<Self> {
70        Ok(Self {
71            inner: quinn::Endpoint::new_with_abstract_socket(
72                config,
73                server_config,
74                socket,
75                runtime,
76            )?,
77            path_prefetcher: pather,
78            address_translator,
79            local_scion_addr,
80        })
81    }
82
83    /// Connect to the address.
84    pub fn connect(
85        &self,
86        addr: scion_proto::address::SocketAddr,
87        server_name: &str,
88    ) -> Result<quinn::Connecting, quinn::ConnectError> {
89        let mapped_addr = self
90            .address_translator
91            .register_scion_address(addr.scion_address());
92        let local_addr = self
93            .address_translator
94            .lookup_scion_address(self.inner.local_addr().unwrap().ip())
95            .unwrap();
96        self.path_prefetcher
97            .prefetch_path(local_addr.isd_asn(), addr.isd_asn());
98        self.inner.connect(
99            std::net::SocketAddr::new(mapped_addr, addr.port()),
100            server_name,
101        )
102    }
103
104    /// Accepts a new incoming connection.
105    pub async fn accept(&self) -> Result<Option<ScionQuinnConn>, quinn::ConnectionError> {
106        let incoming = self.inner.accept().await;
107        if let Some(incoming) = incoming {
108            let remote_socket_addr = incoming.remote_address();
109            let local_scion_addr = incoming
110                .local_ip()
111                .and_then(|ip| self.address_translator.lookup_scion_address(ip));
112            let conn = ScionQuinnConn {
113                inner: incoming.await?,
114                // XXX(uniquefine): For now the ScionAsyncUdpSocket does not have access to a
115                // packets destination address, so we cannot lookup the local SCION
116                // address.
117                local_addr: local_scion_addr,
118                remote_addr: scion_proto::address::SocketAddr::new(
119                    self.address_translator
120                        .lookup_scion_address(remote_socket_addr.ip())
121                        .or_else(|| {
122                            panic!(
123                                "no scion address mapped for ip, this should never happen: {}",
124                                remote_socket_addr.ip(),
125                            );
126                        })
127                        .unwrap(),
128                    remote_socket_addr.port(),
129                ),
130            };
131            Ok(Some(conn))
132        } else {
133            Ok(None)
134        }
135    }
136
137    /// Set the default QUIC client configuration.
138    pub fn set_default_client_config(&mut self, config: quinn::ClientConfig) {
139        self.inner.set_default_client_config(config);
140    }
141
142    /// Wait until all connections on the endpoint cleanly shut down.
143    pub async fn wait_idle(&self) {
144        self.inner.wait_idle().await;
145    }
146
147    /// Returns the local socket address of the endpoint.
148    pub fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
149        self.inner.local_addr()
150    }
151
152    /// Returns the local SCION address of the endpoint.
153    pub fn local_scion_addr(&self) -> scion_proto::address::SocketAddr {
154        self.local_scion_addr
155    }
156}
157
158/// Type that can translate between SCION and IP addresses.
159// TODO(uniquefine): Expiration or cleanup of translated addresses
160pub struct AddressTranslator {
161    build_hasher: FixedState,
162    addr_map: Mutex<HashMap<std::net::Ipv6Addr, scion_proto::address::ScionAddr>>,
163}
164
165impl Debug for AddressTranslator {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        write!(
168            f,
169            "AddressTranslatorImpl {{ {} }}",
170            self.addr_map
171                .lock()
172                .unwrap()
173                .iter()
174                .map(|(ip, addr)| format!("{ip} -> {addr}"))
175                .collect::<Vec<_>>()
176                .join(", ")
177        )
178    }
179}
180
181impl AddressTranslator {
182    /// Creates a new address translator.
183    pub fn new(build_hasher: FixedState) -> Self {
184        Self {
185            build_hasher,
186            addr_map: Mutex::new(HashMap::new()),
187        }
188    }
189
190    fn hash_scion_address(&self, addr: scion_proto::address::ScionAddr) -> std::net::Ipv6Addr {
191        let mut hasher = self.build_hasher.build_hasher();
192        hasher.write_u64(addr.isd_asn().to_u64());
193        addr.local_address().hash(&mut hasher);
194        Ipv6Addr::from(hasher.finish() as u128)
195    }
196
197    /// Registers the SCION address and returns the corresponding IP address.
198    pub fn register_scion_address(
199        &self,
200        addr: scion_proto::address::ScionAddr,
201    ) -> std::net::IpAddr {
202        let ip = self.hash_scion_address(addr);
203        let mut addr_map = self.addr_map.lock().unwrap();
204        addr_map.entry(ip).or_insert(addr);
205        IpAddr::V6(ip)
206    }
207
208    /// Looks up the SCION address for the given IP address.
209    pub fn lookup_scion_address(
210        &self,
211        ip: std::net::IpAddr,
212    ) -> Option<scion_proto::address::ScionAddr> {
213        let ip = match ip {
214            IpAddr::V6(ip) => ip,
215            IpAddr::V4(_) => return None,
216        };
217        self.addr_map.lock().unwrap().get(&ip).cloned()
218    }
219}
220
221impl Default for AddressTranslator {
222    fn default() -> Self {
223        Self {
224            build_hasher: FixedState::with_seed(42),
225            addr_map: Mutex::new(HashMap::new()),
226        }
227    }
228}
229
230/// A path-aware UDP socket that implements the [quinn::AsyncUdpSocket] trait.
231///
232/// The socket translates the SCION addresses of incoming packets to IP addresses that
233/// are used by quinn.
234/// To connect to a SCION destination, the destination SCION address must first be registered
235/// with the [AddressTranslator].
236pub(crate) struct ScionAsyncUdpSocket {
237    socket: Arc<dyn AsyncUdpUnderlaySocket>,
238    path_manager: Arc<dyn SyncPathManager + Send + Sync>,
239    address_translator: Arc<AddressTranslator>,
240    /// The last time a poll_recv error was logged.
241    last_recv_error: Mutex<Instant>,
242    /// The last time a try_send error was logged.
243    last_send_error: Mutex<Instant>,
244}
245
246impl ScionAsyncUdpSocket {
247    pub fn new(
248        socket: Arc<dyn AsyncUdpUnderlaySocket>,
249        path_manager: Arc<dyn SyncPathManager + Send + Sync>,
250        address_translator: Arc<AddressTranslator>,
251    ) -> Self {
252        let now = Instant::now();
253        let instant = now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now);
254        Self {
255            socket,
256            path_manager,
257            address_translator,
258            last_recv_error: Mutex::new(instant),
259            last_send_error: Mutex::new(instant),
260        }
261    }
262}
263
264impl std::fmt::Debug for ScionAsyncUdpSocket {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        f.write_fmt(format_args!(
267            "ScionAsyncUdpSocket({})",
268            match self.local_addr() {
269                Ok(addr) => addr.to_string(),
270                Err(e) => e.to_string(),
271            }
272        ))
273    }
274}
275
276/// A wrapper that implements quinn::UdpPoller by delegating to scionstack::UdpPoller
277/// This allows scionstack to remain decoupled from the quinn crate
278struct QuinnUdpPollerWrapper(Pin<Box<dyn UdpPoller>>);
279
280impl std::fmt::Debug for QuinnUdpPollerWrapper {
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        self.0.fmt(f)
283    }
284}
285
286impl QuinnUdpPollerWrapper {
287    fn new(inner: Pin<Box<dyn UdpPoller>>) -> Self {
288        Self(inner)
289    }
290}
291
292impl quinn::UdpPoller for QuinnUdpPollerWrapper {
293    fn poll_writable(
294        mut self: Pin<&mut Self>,
295        cx: &mut std::task::Context,
296    ) -> Poll<std::io::Result<()>> {
297        self.0.as_mut().poll_writable(cx)
298    }
299}
300
301impl AsyncUdpSocket for ScionAsyncUdpSocket {
302    fn create_io_poller(self: Arc<Self>) -> std::pin::Pin<Box<dyn quinn::UdpPoller>> {
303        let socket = self.socket.clone();
304        let inner_poller = socket.create_io_poller();
305        let wrapper = QuinnUdpPollerWrapper::new(inner_poller);
306        Box::pin(wrapper)
307    }
308
309    fn try_send(&self, transmit: &quinn::udp::Transmit) -> std::io::Result<()> {
310        let buf = bytes::Bytes::copy_from_slice(transmit.contents);
311        let remote_scion_addr = SocketAddr::new(
312            self.address_translator
313                .lookup_scion_address(transmit.destination.ip())
314                .ok_or(std::io::Error::other(format!(
315                    "no scion address mapped for ip, this should never happen: {}",
316                    transmit.destination.ip(),
317                )))?,
318            transmit.destination.port(),
319        );
320        let path = self.path_manager.try_cached_path(
321            self.socket.local_addr().isd_asn(),
322            remote_scion_addr.isd_asn(),
323            Utc::now(),
324        )?;
325
326        let path = match path {
327            Some(path) => path,
328            None => return Ok(()),
329        };
330
331        let packet = ScionPacketUdp::new(
332            ByEndpoint {
333                source: self.socket.local_addr(),
334                destination: remote_scion_addr,
335            },
336            path.data_plane_path.to_bytes_path(),
337            buf,
338        )
339        .map_err(|_| std::io::Error::other("failed to encode packet"))?;
340
341        match self.socket.try_send(packet.into()) {
342            Ok(_) => Ok(()),
343            Err(e) if e.kind() == ErrorKind::WouldBlock => Err(e),
344            Err(e) => {
345                // XXX: We only log the error such that the quinn connection driver doesn't quit.
346                debounced_warn(
347                    &self.last_send_error,
348                    "Failed to send on the underlying socket",
349                    e,
350                );
351                Ok(())
352            }
353        }
354    }
355
356    fn poll_recv(
357        &self,
358        cx: &mut std::task::Context,
359        bufs: &mut [std::io::IoSliceMut<'_>],
360        meta: &mut [quinn::udp::RecvMeta],
361    ) -> std::task::Poll<std::io::Result<usize>> {
362        match ready!(self.socket.poll_recv_from_with_path(cx)) {
363            Ok((remote, bytes, path)) => {
364                match path.to_reversed() {
365                    Ok(path) => {
366                        // Register the path for later reuse
367                        self.path_manager.register_path(
368                            remote.isd_asn(),
369                            self.socket.local_addr().isd_asn(),
370                            Utc::now(),
371                            path,
372                        );
373                    }
374                    Err(e) => {
375                        tracing::trace!("Failed to reverse path for registration: {}", e)
376                    }
377                }
378
379                let remote_ip = self
380                    .address_translator
381                    .register_scion_address(remote.scion_address());
382
383                meta[0] = RecvMeta {
384                    addr: std::net::SocketAddr::new(remote_ip, remote.port()),
385                    len: bytes.len(),
386                    ecn: None,
387                    stride: bytes.len(),
388                    dst_ip: self.socket.local_addr().local_address().map(|s| s.ip()),
389                };
390                bufs[0].as_mut().put_slice(&bytes);
391
392                Poll::Ready(Ok(1))
393            }
394            Err(e) if e.kind() == ErrorKind::WouldBlock => Poll::Ready(Err(e)),
395            Err(e) => {
396                // XXX: We only log the error such that the endpoint driver doesn't quit.
397                debounced_warn(
398                    &self.last_recv_error,
399                    "Failed to receive on the underlying socket",
400                    e,
401                );
402
403                Poll::Pending
404            }
405        }
406    }
407
408    fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
409        Ok(std::net::SocketAddr::new(
410            self.address_translator
411                .register_scion_address(self.socket.local_addr().scion_address()),
412            self.socket.local_addr().port(),
413        ))
414    }
415}
416
417/// Logs a warning message when an error occurs.
418///
419/// Logging will only be performed if at least [`IO_ERROR_LOG_INTERVAL`]
420/// has elapsed since the last error was logged.
421// Inspired by quinn's `log_sendmsg_error`.
422fn debounced_warn(last_send_error: &Mutex<Instant>, msg: &str, err: impl core::fmt::Debug) {
423    let now = Instant::now();
424    let last_send_error = &mut *last_send_error.lock().expect("poisoned lock");
425    if now.saturating_duration_since(*last_send_error) > IO_ERROR_LOG_INTERVAL {
426        *last_send_error = now;
427        tracing::warn!(?err, "{msg}");
428    }
429}