1use std::{sync::Arc, time::Duration};
17
18use bytes::Bytes;
19use chrono::Utc;
20use futures::future::BoxFuture;
21use scion_proto::{
22 address::{ScionAddr, SocketAddr},
23 datagram::UdpMessage,
24 packet::{ByEndpoint, ScionPacketRaw, ScionPacketScmp, ScionPacketUdp},
25 path::Path,
26 scmp::{SCMP_PROTOCOL_NUMBER, ScmpMessage},
27};
28
29use super::{NetworkError, UnderlaySocket};
30use crate::{
31 path::manager::{
32 MultiPathManager,
33 traits::{PathManager, PathWaitError},
34 },
35 scionstack::{
36 ScionSocketConnectError, ScionSocketReceiveError, ScionSocketSendError,
37 scmp_handler::ScmpHandler,
38 },
39 types::Subscribers,
40};
41
42pub struct PathUnawareUdpScionSocket {
44 inner: Box<dyn UnderlaySocket + Sync + Send>,
45 scmp_handlers: Vec<Box<dyn ScmpHandler>>,
47}
48
49impl std::fmt::Debug for PathUnawareUdpScionSocket {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("PathUnawareUdpScionSocket")
52 .field("local_addr", &self.inner.local_addr())
53 .finish()
54 }
55}
56
57impl PathUnawareUdpScionSocket {
58 pub(crate) fn new(
59 socket: Box<dyn UnderlaySocket + Sync + Send>,
60 scmp_handlers: Vec<Box<dyn ScmpHandler>>,
61 ) -> Self {
62 Self {
63 inner: socket,
64 scmp_handlers,
65 }
66 }
67
68 pub fn send_to_via<'a>(
70 &'a self,
71 payload: &[u8],
72 destination: SocketAddr,
73 path: &Path<&[u8]>,
74 ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
75 let packet = match ScionPacketUdp::new(
76 ByEndpoint {
77 source: self.inner.local_addr(),
78 destination,
79 },
80 path.data_plane_path.to_bytes_path(),
81 Bytes::copy_from_slice(payload),
82 ) {
83 Ok(packet) => packet,
84 Err(e) => {
85 return Box::pin(async move {
86 Err(ScionSocketSendError::InvalidPacket(
87 format!("error encoding packet: {e}").into(),
88 ))
89 });
90 }
91 }
92 .into();
93 self.inner.send(packet)
94 }
95
96 #[allow(clippy::type_complexity)]
98 pub fn recv_from_with_path<'a>(
99 &'a self,
100 buffer: &'a mut [u8],
101 path_buffer: &'a mut [u8],
102 ) -> BoxFuture<'a, Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
103 {
104 Box::pin(async move {
105 loop {
106 let packet = self.inner.recv().await?;
107
108 let packet = match packet.headers.common.next_header {
109 UdpMessage::PROTOCOL_NUMBER => packet,
110 SCMP_PROTOCOL_NUMBER => {
111 tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
112 for handler in &self.scmp_handlers {
113 if let Some(reply) = handler.handle(packet.clone())
114 && let Err(e) = self.inner.try_send(reply)
115 {
116 tracing::warn!(error = %e, "failed to send SCMP reply");
117 }
118 }
119 continue;
120 }
121 _ => {
122 tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
123 continue;
124 }
125 };
126
127 let packet: ScionPacketUdp = match packet.try_into() {
128 Ok(packet) => packet,
129 Err(e) => {
130 tracing::debug!(error = %e, "Received invalid UDP packet, skipping");
131 continue;
132 }
133 };
134 let src_addr = match packet.headers.address.source() {
135 Some(source) => SocketAddr::new(source, packet.src_port()),
136 None => {
137 tracing::debug!("Received packet without source address header, skipping");
138 continue;
139 }
140 };
141 tracing::trace!(
142 src = %src_addr,
143 length = packet.datagram.payload.len(),
144 "received packet",
145 );
146
147 let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
148 buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
149
150 if path_buffer.len() < packet.headers.path.raw().len() {
151 return Err(ScionSocketReceiveError::PathBufTooSmall);
152 }
153
154 let dataplane_path = packet
155 .headers
156 .path
157 .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
158
159 let path = Path::new(dataplane_path, packet.headers.address.ia, None);
163
164 return Ok((packet.datagram.payload.len(), src_addr, path));
165 }
166 })
167 }
168
169 pub fn recv_from<'a>(
171 &'a self,
172 buffer: &'a mut [u8],
173 ) -> BoxFuture<'a, Result<(usize, SocketAddr), ScionSocketReceiveError>> {
174 Box::pin(async move {
175 loop {
176 let packet = self.inner.recv().await?;
177
178 let packet = match packet.headers.common.next_header {
179 UdpMessage::PROTOCOL_NUMBER => packet,
180 SCMP_PROTOCOL_NUMBER => {
181 tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
182 for handler in &self.scmp_handlers {
183 if let Some(reply) = handler.handle(packet.clone())
184 && let Err(e) = self.inner.try_send(reply)
185 {
186 tracing::warn!(error = %e, "failed to send SCMP reply");
187 }
188 }
189 continue;
190 }
191 _ => {
192 tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
193 continue;
194 }
195 };
196
197 let packet: ScionPacketUdp = match packet.try_into() {
198 Ok(packet) => packet,
199 Err(e) => {
200 tracing::debug!(error = %e, "Received invalid UDP packet, dropping");
201 continue;
202 }
203 };
204 let src_addr = match packet.headers.address.source() {
205 Some(source) => SocketAddr::new(source, packet.src_port()),
206 None => {
207 tracing::debug!("Received packet without source address header, dropping");
208 continue;
209 }
210 };
211
212 tracing::trace!(
213 src = %src_addr,
214 length = packet.datagram.payload.len(),
215 buffer_size = buffer.len(),
216 "received packet",
217 );
218
219 let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
220 buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
221
222 return Ok((packet.datagram.payload.len(), src_addr));
223 }
224 })
225 }
226
227 fn local_addr(&self) -> SocketAddr {
229 self.inner.local_addr()
230 }
231}
232
233pub struct ScmpScionSocket {
235 inner: Box<dyn UnderlaySocket + Sync + Send>,
236}
237
238impl ScmpScionSocket {
239 pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
240 Self { inner: socket }
241 }
242}
243
244impl ScmpScionSocket {
245 pub fn send_to_via<'a>(
247 &'a self,
248 message: ScmpMessage,
249 destination: ScionAddr,
250 path: &Path<&[u8]>,
251 ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
252 let packet = match ScionPacketScmp::new(
253 ByEndpoint {
254 source: self.inner.local_addr().scion_address(),
255 destination,
256 },
257 path.data_plane_path.to_bytes_path(),
258 message,
259 ) {
260 Ok(packet) => packet,
261 Err(e) => {
262 return Box::pin(async move {
263 Err(ScionSocketSendError::InvalidPacket(
264 format!("error encoding packet: {e}").into(),
265 ))
266 });
267 }
268 };
269 let packet = packet.into();
270 Box::pin(async move { self.inner.send(packet).await })
271 }
272
273 #[allow(clippy::type_complexity)]
275 pub fn recv_from_with_path<'a>(
276 &'a self,
277 path_buffer: &'a mut [u8],
278 ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
279 {
280 Box::pin(async move {
281 loop {
282 let packet = self.inner.recv().await?;
283 let packet: ScionPacketScmp = match packet.try_into() {
284 Ok(packet) => packet,
285 Err(e) => {
286 tracing::debug!(error = %e, "Received invalid SCMP packet, dropping");
287 continue;
288 }
289 };
290 let src_addr = match packet.headers.address.source() {
291 Some(source) => source,
292 None => {
293 tracing::debug!("Received packet without source address header, dropping");
294 continue;
295 }
296 };
297
298 if path_buffer.len() < packet.headers.path.raw().len() {
299 return Err(ScionSocketReceiveError::PathBufTooSmall);
300 }
301 let dataplane_path = packet
302 .headers
303 .path
304 .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
305 let path = Path::new(dataplane_path, packet.headers.address.ia, None);
306
307 return Ok((packet.message, src_addr, path));
308 }
309 })
310 }
311
312 pub fn recv_from<'a>(
314 &'a self,
315 ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr), ScionSocketReceiveError>> {
316 Box::pin(async move {
317 loop {
318 let packet = self.inner.recv().await?;
319 let packet: ScionPacketScmp = match packet.try_into() {
320 Ok(packet) => packet,
321 Err(e) => {
322 tracing::debug!(error = %e, "Received invalid SCMP packet, skipping");
323 continue;
324 }
325 };
326 let src_addr = match packet.headers.address.source() {
327 Some(source) => source,
328 None => {
329 tracing::debug!("Received packet without source address header, skipping");
330 continue;
331 }
332 };
333 return Ok((packet.message, src_addr));
334 }
335 })
336 }
337
338 pub fn local_addr(&self) -> SocketAddr {
340 self.inner.local_addr()
341 }
342}
343
344pub struct RawScionSocket {
346 inner: Box<dyn UnderlaySocket>,
347}
348
349impl RawScionSocket {
350 pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
351 Self { inner: socket }
352 }
353}
354
355impl RawScionSocket {
356 pub fn send<'a>(
358 &'a self,
359 packet: ScionPacketRaw,
360 ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
361 self.inner.send(packet)
362 }
363
364 pub fn recv<'a>(&'a self) -> BoxFuture<'a, Result<ScionPacketRaw, ScionSocketReceiveError>> {
366 self.inner.recv()
367 }
368
369 pub fn local_addr(&self) -> SocketAddr {
371 self.inner.local_addr()
372 }
373}
374
375pub trait SendErrorReceiver: Send + Sync {
377 fn report_send_error(&self, error: &ScionSocketSendError);
380}
381
382pub struct UdpScionSocket<P: PathManager = MultiPathManager> {
384 socket: PathUnawareUdpScionSocket,
385 pather: Arc<P>,
386 connect_timeout: Duration,
387 remote_addr: Option<SocketAddr>,
388 send_error_receivers: Subscribers<dyn SendErrorReceiver>,
389}
390
391impl<P: PathManager> std::fmt::Debug for UdpScionSocket<P> {
392 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393 f.debug_struct("UdpScionSocket")
394 .field("local_addr", &self.socket.local_addr())
395 .field("remote_addr", &self.remote_addr)
396 .finish()
397 }
398}
399
400impl<P: PathManager> UdpScionSocket<P> {
401 pub fn new(
403 socket: PathUnawareUdpScionSocket,
404 pather: Arc<P>,
405 connect_timeout: Duration,
406 send_error_receivers: Subscribers<dyn SendErrorReceiver>,
407 ) -> Self {
408 Self {
409 socket,
410 pather,
411 connect_timeout,
412 remote_addr: None,
413 send_error_receivers,
414 }
415 }
416
417 pub async fn connect(self, remote_addr: SocketAddr) -> Result<Self, ScionSocketConnectError> {
423 let _path = self
425 .pather
426 .path_timeout(
427 self.socket.local_addr().isd_asn(),
428 remote_addr.isd_asn(),
429 Utc::now(),
430 self.connect_timeout,
431 )
432 .await
433 .map_err(|e| e.to_string());
434
435 Ok(Self {
436 remote_addr: Some(remote_addr),
437 ..self
438 })
439 }
440
441 pub async fn send(&self, payload: &[u8]) -> Result<(), ScionSocketSendError> {
443 if let Some(remote_addr) = self.remote_addr {
444 self.send_to(payload, remote_addr).await
445 } else {
446 Err(ScionSocketSendError::NotConnected)
447 }
448 }
449
450 pub async fn send_to(
452 &self,
453 payload: &[u8],
454 destination: SocketAddr,
455 ) -> Result<(), ScionSocketSendError> {
456 let path = &self
457 .pather
458 .path_wait(
459 self.socket.local_addr().isd_asn(),
460 destination.isd_asn(),
461 Utc::now(),
462 )
463 .await
464 .map_err(|e| {
465 match e {
466 PathWaitError::FetchFailed(e) => {
467 ScionSocketSendError::PathLookupError(e.into())
468 }
469 PathWaitError::NoPathFound => {
470 ScionSocketSendError::NetworkUnreachable(
471 NetworkError::DestinationUnreachable("No path found".into()),
472 )
473 }
474 }
475 })?;
476 self.socket
477 .send_to_via(payload, destination, &path.to_slice_path())
478 .await
479 }
480
481 pub async fn send_to_via(
483 &self,
484 payload: &[u8],
485 destination: SocketAddr,
486 path: &Path<&[u8]>,
487 ) -> Result<(), ScionSocketSendError> {
488 self.socket
489 .send_to_via(payload, destination, path)
490 .await
491 .inspect_err(|e| {
492 self.send_error_receivers
493 .for_each(|receiver| receiver.report_send_error(e));
494 })
495 }
496
497 pub async fn recv_from_with_path<'a>(
499 &'a self,
500 buffer: &'a mut [u8],
501 path_buffer: &'a mut [u8],
502 ) -> Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError> {
503 let (len, sender_addr, path): (usize, SocketAddr, Path<&mut [u8]>) =
504 self.socket.recv_from_with_path(buffer, path_buffer).await?;
505
506 match path.to_reversed() {
507 Ok(reversed_path) => {
508 self.pather.register_path(
510 self.socket.local_addr().isd_asn(),
511 sender_addr.isd_asn(),
512 Utc::now(),
513 reversed_path,
514 );
515 }
516 Err(e) => {
517 tracing::trace!(error = ?e, "Failed to reverse path for registration")
518 }
519 }
520
521 tracing::trace!(
522 src = %self.socket.local_addr(),
523 dst = %sender_addr,
524 "Registered reverse path",
525 );
526
527 Ok((len, sender_addr, path))
528 }
529
530 pub async fn recv_from(
532 &self,
533 buffer: &mut [u8],
534 ) -> Result<(usize, SocketAddr), ScionSocketReceiveError> {
535 let mut path_buffer = [0u8; 1024]; let (len, sender_addr, _) = self.recv_from_with_path(buffer, &mut path_buffer).await?;
538 Ok((len, sender_addr))
539 }
540
541 pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ScionSocketReceiveError> {
545 if self.remote_addr.is_none() {
546 return Err(ScionSocketReceiveError::NotConnected);
547 }
548 loop {
549 let (len, sender_addr) = self.recv_from(buffer).await?;
550 match self.remote_addr {
551 Some(remote_addr) => {
552 if sender_addr == remote_addr {
553 return Ok(len);
554 }
555 }
556 None => return Err(ScionSocketReceiveError::NotConnected),
557 }
558 }
559 }
560
561 pub fn local_addr(&self) -> SocketAddr {
563 self.socket.local_addr()
564 }
565}