Skip to main content

scion_stack/scionstack/
socket.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SCION socket types.
15
16use std::{sync::Arc, time::Duration};
17
18use bytes::Bytes;
19use chrono::Utc;
20use futures::future::BoxFuture;
21use scion_proto::{
22    address::{ScionAddr, SocketAddr},
23    datagram::UdpMessage,
24    packet::{ByEndpoint, ScionPacketRaw, ScionPacketScmp, ScionPacketUdp},
25    path::Path,
26    scmp::{SCMP_PROTOCOL_NUMBER, ScmpMessage},
27};
28
29use super::{NetworkError, UnderlaySocket};
30use crate::{
31    path::manager::{
32        MultiPathManager,
33        traits::{PathManager, PathWaitError},
34    },
35    scionstack::{
36        ScionSocketConnectError, ScionSocketReceiveError, ScionSocketSendError,
37        scmp_handler::ScmpHandler,
38    },
39    types::Subscribers,
40};
41
42/// A path unaware UDP SCION socket.
43pub struct PathUnawareUdpScionSocket {
44    inner: Box<dyn UnderlaySocket + Sync + Send>,
45    /// The SCMP handlers.
46    scmp_handlers: Vec<Box<dyn ScmpHandler>>,
47}
48
49impl std::fmt::Debug for PathUnawareUdpScionSocket {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("PathUnawareUdpScionSocket")
52            .field("local_addr", &self.inner.local_addr())
53            .finish()
54    }
55}
56
57impl PathUnawareUdpScionSocket {
58    pub(crate) fn new(
59        socket: Box<dyn UnderlaySocket + Sync + Send>,
60        scmp_handlers: Vec<Box<dyn ScmpHandler>>,
61    ) -> Self {
62        Self {
63            inner: socket,
64            scmp_handlers,
65        }
66    }
67
68    /// Send a SCION UDP datagram via the given path.
69    pub fn send_to_via<'a>(
70        &'a self,
71        payload: &[u8],
72        destination: SocketAddr,
73        path: &Path<&[u8]>,
74    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
75        let packet = match ScionPacketUdp::new(
76            ByEndpoint {
77                source: self.inner.local_addr(),
78                destination,
79            },
80            path.data_plane_path.to_bytes_path(),
81            Bytes::copy_from_slice(payload),
82        ) {
83            Ok(packet) => packet,
84            Err(e) => {
85                return Box::pin(async move {
86                    Err(ScionSocketSendError::InvalidPacket(
87                        format!("error encoding packet: {e}").into(),
88                    ))
89                });
90            }
91        }
92        .into();
93        self.inner.send(packet)
94    }
95
96    /// Receive a SCION packet with the sender and path.
97    #[allow(clippy::type_complexity)]
98    pub fn recv_from_with_path<'a>(
99        &'a self,
100        buffer: &'a mut [u8],
101        path_buffer: &'a mut [u8],
102    ) -> BoxFuture<'a, Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
103    {
104        Box::pin(async move {
105            loop {
106                let packet = self.inner.recv().await?;
107
108                let packet = match packet.headers.common.next_header {
109                    UdpMessage::PROTOCOL_NUMBER => packet,
110                    SCMP_PROTOCOL_NUMBER => {
111                        tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
112                        for handler in &self.scmp_handlers {
113                            if let Some(reply) = handler.handle(packet.clone())
114                                && let Err(e) = self.inner.try_send(reply)
115                            {
116                                tracing::warn!(error = %e, "failed to send SCMP reply");
117                            }
118                        }
119                        continue;
120                    }
121                    _ => {
122                        tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
123                        continue;
124                    }
125                };
126
127                let packet: ScionPacketUdp = match packet.try_into() {
128                    Ok(packet) => packet,
129                    Err(e) => {
130                        tracing::debug!(error = %e, "Received invalid UDP packet, skipping");
131                        continue;
132                    }
133                };
134                let src_addr = match packet.headers.address.source() {
135                    Some(source) => SocketAddr::new(source, packet.src_port()),
136                    None => {
137                        tracing::debug!("Received packet without source address header, skipping");
138                        continue;
139                    }
140                };
141                tracing::trace!(
142                    src = %src_addr,
143                    length = packet.datagram.payload.len(),
144                    "received packet",
145                );
146
147                let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
148                buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
149
150                if path_buffer.len() < packet.headers.path.raw().len() {
151                    return Err(ScionSocketReceiveError::PathBufTooSmall);
152                }
153
154                let dataplane_path = packet
155                    .headers
156                    .path
157                    .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
158
159                // Note, that we do not have the next hop address of the path.
160                // A socket that uses more than one tunnel will need to distinguish between
161                // packets received on different tunnels.
162                let path = Path::new(dataplane_path, packet.headers.address.ia, None);
163
164                return Ok((packet.datagram.payload.len(), src_addr, path));
165            }
166        })
167    }
168
169    /// Receive a SCION packet with the sender.
170    pub fn recv_from<'a>(
171        &'a self,
172        buffer: &'a mut [u8],
173    ) -> BoxFuture<'a, Result<(usize, SocketAddr), ScionSocketReceiveError>> {
174        Box::pin(async move {
175            loop {
176                let packet = self.inner.recv().await?;
177
178                let packet = match packet.headers.common.next_header {
179                    UdpMessage::PROTOCOL_NUMBER => packet,
180                    SCMP_PROTOCOL_NUMBER => {
181                        tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
182                        for handler in &self.scmp_handlers {
183                            if let Some(reply) = handler.handle(packet.clone())
184                                && let Err(e) = self.inner.try_send(reply)
185                            {
186                                tracing::warn!(error = %e, "failed to send SCMP reply");
187                            }
188                        }
189                        continue;
190                    }
191                    _ => {
192                        tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
193                        continue;
194                    }
195                };
196
197                let packet: ScionPacketUdp = match packet.try_into() {
198                    Ok(packet) => packet,
199                    Err(e) => {
200                        tracing::debug!(error = %e, "Received invalid UDP packet, dropping");
201                        continue;
202                    }
203                };
204                let src_addr = match packet.headers.address.source() {
205                    Some(source) => SocketAddr::new(source, packet.src_port()),
206                    None => {
207                        tracing::debug!("Received packet without source address header, dropping");
208                        continue;
209                    }
210                };
211
212                tracing::trace!(
213                    src = %src_addr,
214                    length = packet.datagram.payload.len(),
215                    buffer_size = buffer.len(),
216                    "received packet",
217                );
218
219                let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
220                buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
221
222                return Ok((packet.datagram.payload.len(), src_addr));
223            }
224        })
225    }
226
227    /// The local address the socket is bound to.
228    fn local_addr(&self) -> SocketAddr {
229        self.inner.local_addr()
230    }
231}
232
233/// A SCMP SCION socket.
234pub struct ScmpScionSocket {
235    inner: Box<dyn UnderlaySocket + Sync + Send>,
236}
237
238impl ScmpScionSocket {
239    pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
240        Self { inner: socket }
241    }
242}
243
244impl ScmpScionSocket {
245    /// Send a SCMP message to the destination via the given path.
246    pub fn send_to_via<'a>(
247        &'a self,
248        message: ScmpMessage,
249        destination: ScionAddr,
250        path: &Path<&[u8]>,
251    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
252        let packet = match ScionPacketScmp::new(
253            ByEndpoint {
254                source: self.inner.local_addr().scion_address(),
255                destination,
256            },
257            path.data_plane_path.to_bytes_path(),
258            message,
259        ) {
260            Ok(packet) => packet,
261            Err(e) => {
262                return Box::pin(async move {
263                    Err(ScionSocketSendError::InvalidPacket(
264                        format!("error encoding packet: {e}").into(),
265                    ))
266                });
267            }
268        };
269        let packet = packet.into();
270        Box::pin(async move { self.inner.send(packet).await })
271    }
272
273    /// Receive a SCMP message with the sender and path.
274    #[allow(clippy::type_complexity)]
275    pub fn recv_from_with_path<'a>(
276        &'a self,
277        path_buffer: &'a mut [u8],
278    ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
279    {
280        Box::pin(async move {
281            loop {
282                let packet = self.inner.recv().await?;
283                let packet: ScionPacketScmp = match packet.try_into() {
284                    Ok(packet) => packet,
285                    Err(e) => {
286                        tracing::debug!(error = %e, "Received invalid SCMP packet, dropping");
287                        continue;
288                    }
289                };
290                let src_addr = match packet.headers.address.source() {
291                    Some(source) => source,
292                    None => {
293                        tracing::debug!("Received packet without source address header, dropping");
294                        continue;
295                    }
296                };
297
298                if path_buffer.len() < packet.headers.path.raw().len() {
299                    return Err(ScionSocketReceiveError::PathBufTooSmall);
300                }
301                let dataplane_path = packet
302                    .headers
303                    .path
304                    .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
305                let path = Path::new(dataplane_path, packet.headers.address.ia, None);
306
307                return Ok((packet.message, src_addr, path));
308            }
309        })
310    }
311
312    /// Receive a SCMP message with the sender.
313    pub fn recv_from<'a>(
314        &'a self,
315    ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr), ScionSocketReceiveError>> {
316        Box::pin(async move {
317            loop {
318                let packet = self.inner.recv().await?;
319                let packet: ScionPacketScmp = match packet.try_into() {
320                    Ok(packet) => packet,
321                    Err(e) => {
322                        tracing::debug!(error = %e, "Received invalid SCMP packet, skipping");
323                        continue;
324                    }
325                };
326                let src_addr = match packet.headers.address.source() {
327                    Some(source) => source,
328                    None => {
329                        tracing::debug!("Received packet without source address header, skipping");
330                        continue;
331                    }
332                };
333                return Ok((packet.message, src_addr));
334            }
335        })
336    }
337
338    /// Return the local socket address.
339    pub fn local_addr(&self) -> SocketAddr {
340        self.inner.local_addr()
341    }
342}
343
344/// A raw SCION socket.
345pub struct RawScionSocket {
346    inner: Box<dyn UnderlaySocket>,
347}
348
349impl RawScionSocket {
350    pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
351        Self { inner: socket }
352    }
353}
354
355impl RawScionSocket {
356    /// Send a raw SCION packet.
357    pub fn send<'a>(
358        &'a self,
359        packet: ScionPacketRaw,
360    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
361        self.inner.send(packet)
362    }
363
364    /// Receive a raw SCION packet.
365    pub fn recv<'a>(&'a self) -> BoxFuture<'a, Result<ScionPacketRaw, ScionSocketReceiveError>> {
366        self.inner.recv()
367    }
368
369    /// Return the local socket address.
370    pub fn local_addr(&self) -> SocketAddr {
371        self.inner.local_addr()
372    }
373}
374
375/// A trait for receiving socket send errors.
376pub trait SendErrorReceiver: Send + Sync {
377    /// Reports an error when sending a packet.
378    /// This function must return immediately and not block.
379    fn report_send_error(&self, error: &ScionSocketSendError);
380}
381
382/// A path aware UDP socket generic over the underlay socket and path manager.
383pub struct UdpScionSocket<P: PathManager = MultiPathManager> {
384    socket: PathUnawareUdpScionSocket,
385    pather: Arc<P>,
386    connect_timeout: Duration,
387    remote_addr: Option<SocketAddr>,
388    send_error_receivers: Subscribers<dyn SendErrorReceiver>,
389}
390
391impl<P: PathManager> std::fmt::Debug for UdpScionSocket<P> {
392    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393        f.debug_struct("UdpScionSocket")
394            .field("local_addr", &self.socket.local_addr())
395            .field("remote_addr", &self.remote_addr)
396            .finish()
397    }
398}
399
400impl<P: PathManager> UdpScionSocket<P> {
401    /// Creates a new path aware UDP SCION socket.
402    pub fn new(
403        socket: PathUnawareUdpScionSocket,
404        pather: Arc<P>,
405        connect_timeout: Duration,
406        send_error_receivers: Subscribers<dyn SendErrorReceiver>,
407    ) -> Self {
408        Self {
409            socket,
410            pather,
411            connect_timeout,
412            remote_addr: None,
413            send_error_receivers,
414        }
415    }
416
417    /// Connects the socket to a remote address.
418    ///
419    /// Ensures a Path to the Destination exists, returns an error if not.
420    ///
421    /// Timeouts after configured `connect_timeout`
422    pub async fn connect(self, remote_addr: SocketAddr) -> Result<Self, ScionSocketConnectError> {
423        // Check that a path exists to destination
424        let _path = self
425            .pather
426            .path_timeout(
427                self.socket.local_addr().isd_asn(),
428                remote_addr.isd_asn(),
429                Utc::now(),
430                self.connect_timeout,
431            )
432            .await
433            .map_err(|e| e.to_string());
434
435        Ok(Self {
436            remote_addr: Some(remote_addr),
437            ..self
438        })
439    }
440
441    /// Send a datagram to the connected remote address.
442    pub async fn send(&self, payload: &[u8]) -> Result<(), ScionSocketSendError> {
443        if let Some(remote_addr) = self.remote_addr {
444            self.send_to(payload, remote_addr).await
445        } else {
446            Err(ScionSocketSendError::NotConnected)
447        }
448    }
449
450    /// Send a datagram to the specified destination.
451    pub async fn send_to(
452        &self,
453        payload: &[u8],
454        destination: SocketAddr,
455    ) -> Result<(), ScionSocketSendError> {
456        let path = &self
457            .pather
458            .path_wait(
459                self.socket.local_addr().isd_asn(),
460                destination.isd_asn(),
461                Utc::now(),
462            )
463            .await
464            .map_err(|e| {
465                match e {
466                    PathWaitError::FetchFailed(e) => {
467                        ScionSocketSendError::PathLookupError(e.into())
468                    }
469                    PathWaitError::NoPathFound => {
470                        ScionSocketSendError::NetworkUnreachable(
471                            NetworkError::DestinationUnreachable("No path found".into()),
472                        )
473                    }
474                }
475            })?;
476        self.socket
477            .send_to_via(payload, destination, &path.to_slice_path())
478            .await
479    }
480
481    /// Send a datagram to the specified destination via the specified path.
482    pub async fn send_to_via(
483        &self,
484        payload: &[u8],
485        destination: SocketAddr,
486        path: &Path<&[u8]>,
487    ) -> Result<(), ScionSocketSendError> {
488        self.socket
489            .send_to_via(payload, destination, path)
490            .await
491            .inspect_err(|e| {
492                self.send_error_receivers
493                    .for_each(|receiver| receiver.report_send_error(e));
494            })
495    }
496
497    /// Receive a datagram from any address, along with the sender address and path.
498    pub async fn recv_from_with_path<'a>(
499        &'a self,
500        buffer: &'a mut [u8],
501        path_buffer: &'a mut [u8],
502    ) -> Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError> {
503        let (len, sender_addr, path): (usize, SocketAddr, Path<&mut [u8]>) =
504            self.socket.recv_from_with_path(buffer, path_buffer).await?;
505
506        match path.to_reversed() {
507            Ok(reversed_path) => {
508                // Register the path for future use
509                self.pather.register_path(
510                    self.socket.local_addr().isd_asn(),
511                    sender_addr.isd_asn(),
512                    Utc::now(),
513                    reversed_path,
514                );
515            }
516            Err(e) => {
517                tracing::trace!(error = ?e, "Failed to reverse path for registration")
518            }
519        }
520
521        tracing::trace!(
522            src = %self.socket.local_addr(),
523            dst = %sender_addr,
524            "Registered reverse path",
525        );
526
527        Ok((len, sender_addr, path))
528    }
529
530    /// Receive a datagram from the connected remote address and write it into the provided buffer.
531    pub async fn recv_from(
532        &self,
533        buffer: &mut [u8],
534    ) -> Result<(usize, SocketAddr), ScionSocketReceiveError> {
535        // For this method, we need to get the path to register it, but we don't return it
536        let mut path_buffer = [0u8; 1024]; // Temporary buffer for path
537        let (len, sender_addr, _) = self.recv_from_with_path(buffer, &mut path_buffer).await?;
538        Ok((len, sender_addr))
539    }
540
541    /// Receive a datagram from the connected remote address.
542    ///
543    /// Datagrams from other addresses are silently discarded.
544    pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ScionSocketReceiveError> {
545        if self.remote_addr.is_none() {
546            return Err(ScionSocketReceiveError::NotConnected);
547        }
548        loop {
549            let (len, sender_addr) = self.recv_from(buffer).await?;
550            match self.remote_addr {
551                Some(remote_addr) => {
552                    if sender_addr == remote_addr {
553                        return Ok(len);
554                    }
555                }
556                None => return Err(ScionSocketReceiveError::NotConnected),
557            }
558        }
559    }
560
561    /// Returns the local socket address.
562    pub fn local_addr(&self) -> SocketAddr {
563        self.socket.local_addr()
564    }
565}