Skip to main content

scion_stack/scionstack/
socket.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SCION socket types.
15
16use std::{sync::Arc, time::Duration};
17
18use bytes::Bytes;
19use chrono::Utc;
20use futures::future::BoxFuture;
21use scion_proto::{
22    address::{ScionAddr, SocketAddr},
23    datagram::UdpMessage,
24    packet::{ByEndpoint, ScionPacketRaw, ScionPacketScmp, ScionPacketUdp},
25    path::Path,
26    scmp::{SCMP_PROTOCOL_NUMBER, ScmpMessage},
27};
28
29use super::UnderlaySocket;
30use crate::{
31    path::manager::{MultiPathManager, traits::PathManager},
32    scionstack::{
33        ScionSocketConnectError, ScionSocketReceiveError, ScionSocketSendError,
34        scmp_handler::ScmpHandler,
35    },
36    types::Subscribers,
37};
38
39/// A path unaware UDP SCION socket.
40pub struct PathUnawareUdpScionSocket {
41    inner: Box<dyn UnderlaySocket + Sync + Send>,
42    /// The SCMP handlers.
43    scmp_handlers: Vec<Box<dyn ScmpHandler>>,
44}
45
46impl std::fmt::Debug for PathUnawareUdpScionSocket {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("PathUnawareUdpScionSocket")
49            .field("local_addr", &self.inner.local_addr())
50            .finish()
51    }
52}
53
54impl PathUnawareUdpScionSocket {
55    pub(crate) fn new(
56        socket: Box<dyn UnderlaySocket + Sync + Send>,
57        scmp_handlers: Vec<Box<dyn ScmpHandler>>,
58    ) -> Self {
59        Self {
60            inner: socket,
61            scmp_handlers,
62        }
63    }
64
65    /// Send a SCION UDP datagram via the given path.
66    pub fn send_to_via<'a>(
67        &'a self,
68        payload: &[u8],
69        destination: SocketAddr,
70        path: &Path<&[u8]>,
71    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
72        let packet = match ScionPacketUdp::new(
73            ByEndpoint {
74                source: self.inner.local_addr(),
75                destination,
76            },
77            path.data_plane_path.to_bytes_path(),
78            Bytes::copy_from_slice(payload),
79        ) {
80            Ok(packet) => packet,
81            Err(e) => {
82                return Box::pin(async move {
83                    Err(ScionSocketSendError::InvalidPacket(
84                        format!("error encoding packet: {e}").into(),
85                    ))
86                });
87            }
88        }
89        .into();
90        self.inner.send(packet)
91    }
92
93    /// Receive a SCION packet with the sender and path.
94    #[allow(clippy::type_complexity)]
95    pub fn recv_from_with_path<'a>(
96        &'a self,
97        buffer: &'a mut [u8],
98        path_buffer: &'a mut [u8],
99    ) -> BoxFuture<'a, Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
100    {
101        Box::pin(async move {
102            loop {
103                let packet = self.inner.recv().await?;
104
105                let packet = match packet.headers.common.next_header {
106                    UdpMessage::PROTOCOL_NUMBER => packet,
107                    SCMP_PROTOCOL_NUMBER => {
108                        tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
109                        for handler in &self.scmp_handlers {
110                            if let Some(reply) = handler.handle(packet.clone())
111                                && let Err(e) = self.inner.try_send(reply)
112                            {
113                                tracing::warn!(error = %e, "failed to send SCMP reply");
114                            }
115                        }
116                        continue;
117                    }
118                    _ => {
119                        tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
120                        continue;
121                    }
122                };
123
124                let packet: ScionPacketUdp = match packet.try_into() {
125                    Ok(packet) => packet,
126                    Err(e) => {
127                        tracing::debug!(error = %e, "Received invalid UDP packet, skipping");
128                        continue;
129                    }
130                };
131                let src_addr = match packet.headers.address.source() {
132                    Some(source) => SocketAddr::new(source, packet.src_port()),
133                    None => {
134                        tracing::debug!("Received packet without source address header, skipping");
135                        continue;
136                    }
137                };
138                tracing::trace!(
139                    src = %src_addr,
140                    length = packet.datagram.payload.len(),
141                    "received packet",
142                );
143
144                let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
145                buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
146
147                if path_buffer.len() < packet.headers.path.raw().len() {
148                    return Err(ScionSocketReceiveError::PathBufTooSmall);
149                }
150
151                let dataplane_path = packet
152                    .headers
153                    .path
154                    .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
155
156                // Note, that we do not have the next hop address of the path.
157                // A socket that uses more than one tunnel will need to distinguish between
158                // packets received on different tunnels.
159                let path = Path::new(dataplane_path, packet.headers.address.ia, None);
160
161                return Ok((packet.datagram.payload.len(), src_addr, path));
162            }
163        })
164    }
165
166    /// Receive a SCION packet with the sender.
167    pub fn recv_from<'a>(
168        &'a self,
169        buffer: &'a mut [u8],
170    ) -> BoxFuture<'a, Result<(usize, SocketAddr), ScionSocketReceiveError>> {
171        Box::pin(async move {
172            loop {
173                let packet = self.inner.recv().await?;
174
175                let packet = match packet.headers.common.next_header {
176                    UdpMessage::PROTOCOL_NUMBER => packet,
177                    SCMP_PROTOCOL_NUMBER => {
178                        tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
179                        for handler in &self.scmp_handlers {
180                            if let Some(reply) = handler.handle(packet.clone())
181                                && let Err(e) = self.inner.try_send(reply)
182                            {
183                                tracing::warn!(error = %e, "failed to send SCMP reply");
184                            }
185                        }
186                        continue;
187                    }
188                    _ => {
189                        tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
190                        continue;
191                    }
192                };
193
194                let packet: ScionPacketUdp = match packet.try_into() {
195                    Ok(packet) => packet,
196                    Err(e) => {
197                        tracing::debug!(error = %e, "Received invalid UDP packet, dropping");
198                        continue;
199                    }
200                };
201                let src_addr = match packet.headers.address.source() {
202                    Some(source) => SocketAddr::new(source, packet.src_port()),
203                    None => {
204                        tracing::debug!("Received packet without source address header, dropping");
205                        continue;
206                    }
207                };
208
209                tracing::trace!(
210                    src = %src_addr,
211                    length = packet.datagram.payload.len(),
212                    buffer_size = buffer.len(),
213                    "received packet",
214                );
215
216                let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
217                buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
218
219                return Ok((packet.datagram.payload.len(), src_addr));
220            }
221        })
222    }
223
224    /// The local address the socket is bound to.
225    fn local_addr(&self) -> SocketAddr {
226        self.inner.local_addr()
227    }
228}
229
230/// A SCMP SCION socket.
231pub struct ScmpScionSocket {
232    inner: Box<dyn UnderlaySocket + Sync + Send>,
233}
234
235impl ScmpScionSocket {
236    pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
237        Self { inner: socket }
238    }
239}
240
241impl ScmpScionSocket {
242    /// Send a SCMP message to the destination via the given path.
243    pub fn send_to_via<'a>(
244        &'a self,
245        message: ScmpMessage,
246        destination: ScionAddr,
247        path: &Path<&[u8]>,
248    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
249        let packet = match ScionPacketScmp::new(
250            ByEndpoint {
251                source: self.inner.local_addr().scion_address(),
252                destination,
253            },
254            path.data_plane_path.to_bytes_path(),
255            message,
256        ) {
257            Ok(packet) => packet,
258            Err(e) => {
259                return Box::pin(async move {
260                    Err(ScionSocketSendError::InvalidPacket(
261                        format!("error encoding packet: {e}").into(),
262                    ))
263                });
264            }
265        };
266        let packet = packet.into();
267        Box::pin(async move { self.inner.send(packet).await })
268    }
269
270    /// Receive a SCMP message with the sender and path.
271    #[allow(clippy::type_complexity)]
272    pub fn recv_from_with_path<'a>(
273        &'a self,
274        path_buffer: &'a mut [u8],
275    ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
276    {
277        Box::pin(async move {
278            loop {
279                let packet = self.inner.recv().await?;
280                let packet: ScionPacketScmp = match packet.try_into() {
281                    Ok(packet) => packet,
282                    Err(e) => {
283                        tracing::debug!(error = %e, "Received invalid SCMP packet, dropping");
284                        continue;
285                    }
286                };
287                let src_addr = match packet.headers.address.source() {
288                    Some(source) => source,
289                    None => {
290                        tracing::debug!("Received packet without source address header, dropping");
291                        continue;
292                    }
293                };
294
295                if path_buffer.len() < packet.headers.path.raw().len() {
296                    return Err(ScionSocketReceiveError::PathBufTooSmall);
297                }
298                let dataplane_path = packet
299                    .headers
300                    .path
301                    .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
302                let path = Path::new(dataplane_path, packet.headers.address.ia, None);
303
304                return Ok((packet.message, src_addr, path));
305            }
306        })
307    }
308
309    /// Receive a SCMP message with the sender.
310    pub fn recv_from<'a>(
311        &'a self,
312    ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr), ScionSocketReceiveError>> {
313        Box::pin(async move {
314            loop {
315                let packet = self.inner.recv().await?;
316                let packet: ScionPacketScmp = match packet.try_into() {
317                    Ok(packet) => packet,
318                    Err(e) => {
319                        tracing::debug!(error = %e, "Received invalid SCMP packet, skipping");
320                        continue;
321                    }
322                };
323                let src_addr = match packet.headers.address.source() {
324                    Some(source) => source,
325                    None => {
326                        tracing::debug!("Received packet without source address header, skipping");
327                        continue;
328                    }
329                };
330                return Ok((packet.message, src_addr));
331            }
332        })
333    }
334
335    /// Return the local socket address.
336    pub fn local_addr(&self) -> SocketAddr {
337        self.inner.local_addr()
338    }
339}
340
341/// A raw SCION socket.
342pub struct RawScionSocket {
343    inner: Box<dyn UnderlaySocket>,
344}
345
346impl RawScionSocket {
347    pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
348        Self { inner: socket }
349    }
350}
351
352impl RawScionSocket {
353    /// Send a raw SCION packet.
354    pub fn send<'a>(
355        &'a self,
356        packet: ScionPacketRaw,
357    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
358        self.inner.send(packet)
359    }
360
361    /// Receive a raw SCION packet.
362    pub fn recv<'a>(&'a self) -> BoxFuture<'a, Result<ScionPacketRaw, ScionSocketReceiveError>> {
363        self.inner.recv()
364    }
365
366    /// Return the local socket address.
367    pub fn local_addr(&self) -> SocketAddr {
368        self.inner.local_addr()
369    }
370}
371
372/// A trait for receiving socket send errors.
373pub trait SendErrorReceiver: Send + Sync {
374    /// Reports an error when sending a packet.
375    /// This function must return immediately and not block.
376    fn report_send_error(&self, error: &ScionSocketSendError);
377}
378
379/// A path aware UDP socket generic over the underlay socket and path manager.
380pub struct UdpScionSocket<P: PathManager = MultiPathManager> {
381    socket: PathUnawareUdpScionSocket,
382    pather: Arc<P>,
383    connect_timeout: Duration,
384    remote_addr: Option<SocketAddr>,
385    send_error_receivers: Subscribers<dyn SendErrorReceiver>,
386}
387
388impl<P: PathManager> std::fmt::Debug for UdpScionSocket<P> {
389    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        f.debug_struct("UdpScionSocket")
391            .field("local_addr", &self.socket.local_addr())
392            .field("remote_addr", &self.remote_addr)
393            .finish()
394    }
395}
396
397impl<P: PathManager> UdpScionSocket<P> {
398    /// Creates a new path aware UDP SCION socket.
399    pub fn new(
400        socket: PathUnawareUdpScionSocket,
401        pather: Arc<P>,
402        connect_timeout: Duration,
403        send_error_receivers: Subscribers<dyn SendErrorReceiver>,
404    ) -> Self {
405        Self {
406            socket,
407            pather,
408            connect_timeout,
409            remote_addr: None,
410            send_error_receivers,
411        }
412    }
413
414    /// Connects the socket to a remote address.
415    ///
416    /// Ensures a Path to the Destination exists, returns an error if not.
417    ///
418    /// Timeouts after configured `connect_timeout`
419    pub async fn connect(self, remote_addr: SocketAddr) -> Result<Self, ScionSocketConnectError> {
420        // Check that a path exists to destination
421        let _path = self
422            .pather
423            .path_timeout(
424                self.socket.local_addr().isd_asn(),
425                remote_addr.isd_asn(),
426                Utc::now(),
427                self.connect_timeout,
428            )
429            .await?;
430
431        Ok(Self {
432            remote_addr: Some(remote_addr),
433            ..self
434        })
435    }
436
437    /// Send a datagram to the connected remote address.
438    pub async fn send(&self, payload: &[u8]) -> Result<(), ScionSocketSendError> {
439        if let Some(remote_addr) = self.remote_addr {
440            self.send_to(payload, remote_addr).await
441        } else {
442            Err(ScionSocketSendError::NotConnected)
443        }
444    }
445
446    /// Send a datagram to the specified destination.
447    pub async fn send_to(
448        &self,
449        payload: &[u8],
450        destination: SocketAddr,
451    ) -> Result<(), ScionSocketSendError> {
452        let path = &self
453            .pather
454            .path_wait(
455                self.socket.local_addr().isd_asn(),
456                destination.isd_asn(),
457                Utc::now(),
458            )
459            .await?;
460        self.socket
461            .send_to_via(payload, destination, &path.to_slice_path())
462            .await
463    }
464
465    /// Send a datagram to the specified destination via the specified path.
466    pub async fn send_to_via(
467        &self,
468        payload: &[u8],
469        destination: SocketAddr,
470        path: &Path<&[u8]>,
471    ) -> Result<(), ScionSocketSendError> {
472        self.socket
473            .send_to_via(payload, destination, path)
474            .await
475            .inspect_err(|e| {
476                self.send_error_receivers
477                    .for_each(|receiver| receiver.report_send_error(e));
478            })
479    }
480
481    /// Receive a datagram from any address, along with the sender address and path.
482    pub async fn recv_from_with_path<'a>(
483        &'a self,
484        buffer: &'a mut [u8],
485        path_buffer: &'a mut [u8],
486    ) -> Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError> {
487        let (len, sender_addr, path): (usize, SocketAddr, Path<&mut [u8]>) =
488            self.socket.recv_from_with_path(buffer, path_buffer).await?;
489
490        match path.to_reversed() {
491            Ok(reversed_path) => {
492                // Register the path for future use
493                self.pather.register_path(
494                    self.socket.local_addr().isd_asn(),
495                    sender_addr.isd_asn(),
496                    Utc::now(),
497                    reversed_path,
498                );
499            }
500            Err(e) => {
501                tracing::trace!(error = ?e, "Failed to reverse path for registration")
502            }
503        }
504
505        tracing::trace!(
506            src = %self.socket.local_addr(),
507            dst = %sender_addr,
508            "Registered reverse path",
509        );
510
511        Ok((len, sender_addr, path))
512    }
513
514    /// Receive a datagram from the connected remote address and write it into the provided buffer.
515    pub async fn recv_from(
516        &self,
517        buffer: &mut [u8],
518    ) -> Result<(usize, SocketAddr), ScionSocketReceiveError> {
519        // For this method, we need to get the path to register it, but we don't return it
520        let mut path_buffer = [0u8; 1024]; // Temporary buffer for path
521        let (len, sender_addr, _) = self.recv_from_with_path(buffer, &mut path_buffer).await?;
522        Ok((len, sender_addr))
523    }
524
525    /// Receive a datagram from the connected remote address.
526    ///
527    /// Datagrams from other addresses are silently discarded.
528    pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ScionSocketReceiveError> {
529        if self.remote_addr.is_none() {
530            return Err(ScionSocketReceiveError::NotConnected);
531        }
532        loop {
533            let (len, sender_addr) = self.recv_from(buffer).await?;
534            match self.remote_addr {
535                Some(remote_addr) => {
536                    if sender_addr == remote_addr {
537                        return Ok(len);
538                    }
539                }
540                None => return Err(ScionSocketReceiveError::NotConnected),
541            }
542        }
543    }
544
545    /// Returns the local socket address.
546    pub fn local_addr(&self) -> SocketAddr {
547        self.socket.local_addr()
548    }
549}