1use std::{sync::Arc, time::Duration};
17
18use bytes::Bytes;
19use chrono::Utc;
20use futures::future::BoxFuture;
21use scion_proto::{
22 address::{ScionAddr, SocketAddr},
23 datagram::UdpMessage,
24 packet::{ByEndpoint, ScionPacketRaw, ScionPacketScmp, ScionPacketUdp},
25 path::Path,
26 scmp::{SCMP_PROTOCOL_NUMBER, ScmpMessage},
27};
28
29use super::UnderlaySocket;
30use crate::{
31 path::manager::{MultiPathManager, traits::PathManager},
32 scionstack::{
33 ScionSocketConnectError, ScionSocketReceiveError, ScionSocketSendError,
34 scmp_handler::ScmpHandler,
35 },
36 types::Subscribers,
37};
38
39pub struct PathUnawareUdpScionSocket {
41 inner: Box<dyn UnderlaySocket + Sync + Send>,
42 scmp_handlers: Vec<Box<dyn ScmpHandler>>,
44}
45
46impl std::fmt::Debug for PathUnawareUdpScionSocket {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("PathUnawareUdpScionSocket")
49 .field("local_addr", &self.inner.local_addr())
50 .finish()
51 }
52}
53
54impl PathUnawareUdpScionSocket {
55 pub(crate) fn new(
56 socket: Box<dyn UnderlaySocket + Sync + Send>,
57 scmp_handlers: Vec<Box<dyn ScmpHandler>>,
58 ) -> Self {
59 Self {
60 inner: socket,
61 scmp_handlers,
62 }
63 }
64
65 pub fn send_to_via<'a>(
67 &'a self,
68 payload: &[u8],
69 destination: SocketAddr,
70 path: &Path<&[u8]>,
71 ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
72 let packet = match ScionPacketUdp::new(
73 ByEndpoint {
74 source: self.inner.local_addr(),
75 destination,
76 },
77 path.data_plane_path.to_bytes_path(),
78 Bytes::copy_from_slice(payload),
79 ) {
80 Ok(packet) => packet,
81 Err(e) => {
82 return Box::pin(async move {
83 Err(ScionSocketSendError::InvalidPacket(
84 format!("error encoding packet: {e}").into(),
85 ))
86 });
87 }
88 }
89 .into();
90 self.inner.send(packet)
91 }
92
93 #[allow(clippy::type_complexity)]
95 pub fn recv_from_with_path<'a>(
96 &'a self,
97 buffer: &'a mut [u8],
98 path_buffer: &'a mut [u8],
99 ) -> BoxFuture<'a, Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
100 {
101 Box::pin(async move {
102 loop {
103 let packet = self.inner.recv().await?;
104
105 let packet = match packet.headers.common.next_header {
106 UdpMessage::PROTOCOL_NUMBER => packet,
107 SCMP_PROTOCOL_NUMBER => {
108 tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
109 for handler in &self.scmp_handlers {
110 if let Some(reply) = handler.handle(packet.clone())
111 && let Err(e) = self.inner.try_send(reply)
112 {
113 tracing::warn!(error = %e, "failed to send SCMP reply");
114 }
115 }
116 continue;
117 }
118 _ => {
119 tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
120 continue;
121 }
122 };
123
124 let packet: ScionPacketUdp = match packet.try_into() {
125 Ok(packet) => packet,
126 Err(e) => {
127 tracing::debug!(error = %e, "Received invalid UDP packet, skipping");
128 continue;
129 }
130 };
131 let src_addr = match packet.headers.address.source() {
132 Some(source) => SocketAddr::new(source, packet.src_port()),
133 None => {
134 tracing::debug!("Received packet without source address header, skipping");
135 continue;
136 }
137 };
138 tracing::trace!(
139 src = %src_addr,
140 length = packet.datagram.payload.len(),
141 "received packet",
142 );
143
144 let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
145 buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
146
147 if path_buffer.len() < packet.headers.path.raw().len() {
148 return Err(ScionSocketReceiveError::PathBufTooSmall);
149 }
150
151 let dataplane_path = packet
152 .headers
153 .path
154 .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
155
156 let path = Path::new(dataplane_path, packet.headers.address.ia, None);
160
161 return Ok((packet.datagram.payload.len(), src_addr, path));
162 }
163 })
164 }
165
166 pub fn recv_from<'a>(
168 &'a self,
169 buffer: &'a mut [u8],
170 ) -> BoxFuture<'a, Result<(usize, SocketAddr), ScionSocketReceiveError>> {
171 Box::pin(async move {
172 loop {
173 let packet = self.inner.recv().await?;
174
175 let packet = match packet.headers.common.next_header {
176 UdpMessage::PROTOCOL_NUMBER => packet,
177 SCMP_PROTOCOL_NUMBER => {
178 tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
179 for handler in &self.scmp_handlers {
180 if let Some(reply) = handler.handle(packet.clone())
181 && let Err(e) = self.inner.try_send(reply)
182 {
183 tracing::warn!(error = %e, "failed to send SCMP reply");
184 }
185 }
186 continue;
187 }
188 _ => {
189 tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
190 continue;
191 }
192 };
193
194 let packet: ScionPacketUdp = match packet.try_into() {
195 Ok(packet) => packet,
196 Err(e) => {
197 tracing::debug!(error = %e, "Received invalid UDP packet, dropping");
198 continue;
199 }
200 };
201 let src_addr = match packet.headers.address.source() {
202 Some(source) => SocketAddr::new(source, packet.src_port()),
203 None => {
204 tracing::debug!("Received packet without source address header, dropping");
205 continue;
206 }
207 };
208
209 tracing::trace!(
210 src = %src_addr,
211 length = packet.datagram.payload.len(),
212 buffer_size = buffer.len(),
213 "received packet",
214 );
215
216 let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
217 buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
218
219 return Ok((packet.datagram.payload.len(), src_addr));
220 }
221 })
222 }
223
224 fn local_addr(&self) -> SocketAddr {
226 self.inner.local_addr()
227 }
228}
229
230pub struct ScmpScionSocket {
232 inner: Box<dyn UnderlaySocket + Sync + Send>,
233}
234
235impl ScmpScionSocket {
236 pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
237 Self { inner: socket }
238 }
239}
240
241impl ScmpScionSocket {
242 pub fn send_to_via<'a>(
244 &'a self,
245 message: ScmpMessage,
246 destination: ScionAddr,
247 path: &Path<&[u8]>,
248 ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
249 let packet = match ScionPacketScmp::new(
250 ByEndpoint {
251 source: self.inner.local_addr().scion_address(),
252 destination,
253 },
254 path.data_plane_path.to_bytes_path(),
255 message,
256 ) {
257 Ok(packet) => packet,
258 Err(e) => {
259 return Box::pin(async move {
260 Err(ScionSocketSendError::InvalidPacket(
261 format!("error encoding packet: {e}").into(),
262 ))
263 });
264 }
265 };
266 let packet = packet.into();
267 Box::pin(async move { self.inner.send(packet).await })
268 }
269
270 #[allow(clippy::type_complexity)]
272 pub fn recv_from_with_path<'a>(
273 &'a self,
274 path_buffer: &'a mut [u8],
275 ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
276 {
277 Box::pin(async move {
278 loop {
279 let packet = self.inner.recv().await?;
280 let packet: ScionPacketScmp = match packet.try_into() {
281 Ok(packet) => packet,
282 Err(e) => {
283 tracing::debug!(error = %e, "Received invalid SCMP packet, dropping");
284 continue;
285 }
286 };
287 let src_addr = match packet.headers.address.source() {
288 Some(source) => source,
289 None => {
290 tracing::debug!("Received packet without source address header, dropping");
291 continue;
292 }
293 };
294
295 if path_buffer.len() < packet.headers.path.raw().len() {
296 return Err(ScionSocketReceiveError::PathBufTooSmall);
297 }
298 let dataplane_path = packet
299 .headers
300 .path
301 .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
302 let path = Path::new(dataplane_path, packet.headers.address.ia, None);
303
304 return Ok((packet.message, src_addr, path));
305 }
306 })
307 }
308
309 pub fn recv_from<'a>(
311 &'a self,
312 ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr), ScionSocketReceiveError>> {
313 Box::pin(async move {
314 loop {
315 let packet = self.inner.recv().await?;
316 let packet: ScionPacketScmp = match packet.try_into() {
317 Ok(packet) => packet,
318 Err(e) => {
319 tracing::debug!(error = %e, "Received invalid SCMP packet, skipping");
320 continue;
321 }
322 };
323 let src_addr = match packet.headers.address.source() {
324 Some(source) => source,
325 None => {
326 tracing::debug!("Received packet without source address header, skipping");
327 continue;
328 }
329 };
330 return Ok((packet.message, src_addr));
331 }
332 })
333 }
334
335 pub fn local_addr(&self) -> SocketAddr {
337 self.inner.local_addr()
338 }
339}
340
341pub struct RawScionSocket {
343 inner: Box<dyn UnderlaySocket>,
344}
345
346impl RawScionSocket {
347 pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
348 Self { inner: socket }
349 }
350}
351
352impl RawScionSocket {
353 pub fn send<'a>(
355 &'a self,
356 packet: ScionPacketRaw,
357 ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
358 self.inner.send(packet)
359 }
360
361 pub fn recv<'a>(&'a self) -> BoxFuture<'a, Result<ScionPacketRaw, ScionSocketReceiveError>> {
363 self.inner.recv()
364 }
365
366 pub fn local_addr(&self) -> SocketAddr {
368 self.inner.local_addr()
369 }
370}
371
372pub trait SendErrorReceiver: Send + Sync {
374 fn report_send_error(&self, error: &ScionSocketSendError);
377}
378
379pub struct UdpScionSocket<P: PathManager = MultiPathManager> {
381 socket: PathUnawareUdpScionSocket,
382 pather: Arc<P>,
383 connect_timeout: Duration,
384 remote_addr: Option<SocketAddr>,
385 send_error_receivers: Subscribers<dyn SendErrorReceiver>,
386}
387
388impl<P: PathManager> std::fmt::Debug for UdpScionSocket<P> {
389 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 f.debug_struct("UdpScionSocket")
391 .field("local_addr", &self.socket.local_addr())
392 .field("remote_addr", &self.remote_addr)
393 .finish()
394 }
395}
396
397impl<P: PathManager> UdpScionSocket<P> {
398 pub fn new(
400 socket: PathUnawareUdpScionSocket,
401 pather: Arc<P>,
402 connect_timeout: Duration,
403 send_error_receivers: Subscribers<dyn SendErrorReceiver>,
404 ) -> Self {
405 Self {
406 socket,
407 pather,
408 connect_timeout,
409 remote_addr: None,
410 send_error_receivers,
411 }
412 }
413
414 pub async fn connect(self, remote_addr: SocketAddr) -> Result<Self, ScionSocketConnectError> {
420 let _path = self
422 .pather
423 .path_timeout(
424 self.socket.local_addr().isd_asn(),
425 remote_addr.isd_asn(),
426 Utc::now(),
427 self.connect_timeout,
428 )
429 .await?;
430
431 Ok(Self {
432 remote_addr: Some(remote_addr),
433 ..self
434 })
435 }
436
437 pub async fn send(&self, payload: &[u8]) -> Result<(), ScionSocketSendError> {
439 if let Some(remote_addr) = self.remote_addr {
440 self.send_to(payload, remote_addr).await
441 } else {
442 Err(ScionSocketSendError::NotConnected)
443 }
444 }
445
446 pub async fn send_to(
448 &self,
449 payload: &[u8],
450 destination: SocketAddr,
451 ) -> Result<(), ScionSocketSendError> {
452 let path = &self
453 .pather
454 .path_wait(
455 self.socket.local_addr().isd_asn(),
456 destination.isd_asn(),
457 Utc::now(),
458 )
459 .await?;
460 self.socket
461 .send_to_via(payload, destination, &path.to_slice_path())
462 .await
463 }
464
465 pub async fn send_to_via(
467 &self,
468 payload: &[u8],
469 destination: SocketAddr,
470 path: &Path<&[u8]>,
471 ) -> Result<(), ScionSocketSendError> {
472 self.socket
473 .send_to_via(payload, destination, path)
474 .await
475 .inspect_err(|e| {
476 self.send_error_receivers
477 .for_each(|receiver| receiver.report_send_error(e));
478 })
479 }
480
481 pub async fn recv_from_with_path<'a>(
483 &'a self,
484 buffer: &'a mut [u8],
485 path_buffer: &'a mut [u8],
486 ) -> Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError> {
487 let (len, sender_addr, path): (usize, SocketAddr, Path<&mut [u8]>) =
488 self.socket.recv_from_with_path(buffer, path_buffer).await?;
489
490 match path.to_reversed() {
491 Ok(reversed_path) => {
492 self.pather.register_path(
494 self.socket.local_addr().isd_asn(),
495 sender_addr.isd_asn(),
496 Utc::now(),
497 reversed_path,
498 );
499 }
500 Err(e) => {
501 tracing::trace!(error = ?e, "Failed to reverse path for registration")
502 }
503 }
504
505 tracing::trace!(
506 src = %self.socket.local_addr(),
507 dst = %sender_addr,
508 "Registered reverse path",
509 );
510
511 Ok((len, sender_addr, path))
512 }
513
514 pub async fn recv_from(
516 &self,
517 buffer: &mut [u8],
518 ) -> Result<(usize, SocketAddr), ScionSocketReceiveError> {
519 let mut path_buffer = [0u8; 1024]; let (len, sender_addr, _) = self.recv_from_with_path(buffer, &mut path_buffer).await?;
522 Ok((len, sender_addr))
523 }
524
525 pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ScionSocketReceiveError> {
529 if self.remote_addr.is_none() {
530 return Err(ScionSocketReceiveError::NotConnected);
531 }
532 loop {
533 let (len, sender_addr) = self.recv_from(buffer).await?;
534 match self.remote_addr {
535 Some(remote_addr) => {
536 if sender_addr == remote_addr {
537 return Ok(len);
538 }
539 }
540 None => return Err(ScionSocketReceiveError::NotConnected),
541 }
542 }
543 }
544
545 pub fn local_addr(&self) -> SocketAddr {
547 self.socket.local_addr()
548 }
549}