use std::{
io,
net::{self},
sync::Arc,
task::{Poll, ready},
};
use anyhow::Context;
use bytes::BytesMut;
use futures::future::BoxFuture;
use scion_proto::{
address::{IsdAsn, SocketAddr},
packet::{
ByEndpoint, PacketClassification, ScionPacketRaw, ScionPacketUdp, classify_scion_packet,
},
path::{DataPlanePath, Path, PathInterface},
scmp::SCMP_PROTOCOL_NUMBER,
wire_encoding::{WireDecode as _, WireEncodeVec as _},
};
use tokio::{io::ReadBuf, net::UdpSocket};
use crate::{
scionstack::{
AsyncUdpUnderlaySocket, ScionSocketSendError, UnderlaySocket, scmp_handler::ScmpHandler,
udp_polling::UdpPollHelper,
},
underlays::{discovery::UnderlayDiscovery, source_ip_towards},
};
const UDP_DATAGRAM_BUFFER_SIZE: usize = 65535;
#[async_trait::async_trait]
pub trait LocalIpResolver: Send + Sync {
async fn local_ips(&self) -> Vec<net::IpAddr>;
}
#[async_trait::async_trait]
impl LocalIpResolver for Vec<net::IpAddr> {
async fn local_ips(&self) -> Vec<net::IpAddr> {
self.clone()
}
}
pub(crate) struct TargetAddrLocalIpResolver {
api_socket_address: net::SocketAddr,
}
impl TargetAddrLocalIpResolver {
pub fn new(api_address: url::Url) -> anyhow::Result<Self> {
let socket_addr = api_address
.socket_addrs(|| None)
.context("invalid api address")?
.first()
.ok_or(anyhow::anyhow!("failed to resolve api socket address"))?
.to_owned();
Ok(Self {
api_socket_address: socket_addr,
})
}
}
#[async_trait::async_trait]
impl LocalIpResolver for TargetAddrLocalIpResolver {
async fn local_ips(&self) -> Vec<net::IpAddr> {
match source_ip_towards(self.api_socket_address).await {
Some(ip) => vec![ip],
None => vec![],
}
}
}
pub struct UdpUnderlaySocket {
pub(crate) socket: UdpSocket,
pub(crate) bind_addr: SocketAddr,
pub(crate) underlay_discovery: Arc<dyn UnderlayDiscovery>,
}
impl UdpUnderlaySocket {
pub(crate) fn new(
socket: UdpSocket,
bind_addr: SocketAddr,
underlay_discovery: Arc<dyn UnderlayDiscovery>,
) -> Self {
Self {
socket,
bind_addr,
underlay_discovery,
}
}
fn resolve_local_dispatch_addr(
&self,
packet: &ScionPacketRaw,
) -> Result<net::SocketAddr, ScionSocketSendError> {
let dst_addr = packet
.headers
.address
.destination()
.ok_or(crate::scionstack::ScionSocketSendError::InvalidPacket(
"Packet to local endhost has no destination address".into(),
))?
.local_address()
.ok_or(crate::scionstack::ScionSocketSendError::InvalidPacket(
"Cannot forward packet to local service address".into(),
))?;
let classification = classify_scion_packet(packet.clone()).map_err(|e| {
crate::scionstack::ScionSocketSendError::InvalidPacket(
format!("Cannot classify packet to local endhost: {e:#}").into(),
)
})?;
let dst_port = match classification {
PacketClassification::Udp(udp_packet) => udp_packet.dst_port(),
PacketClassification::ScmpWithDestination(port, _) => port,
PacketClassification::ScmpWithoutDestination(_) | PacketClassification::Other(_) => {
return Err(crate::scionstack::ScionSocketSendError::InvalidPacket(
"Cannot deduce port for packet to local endhost".into(),
));
}
};
Ok(net::SocketAddr::new(dst_addr, dst_port))
}
fn try_dispatch_local(&self, packet: ScionPacketRaw) -> Result<(), ScionSocketSendError> {
let dst_addr = self.resolve_local_dispatch_addr(&packet)?;
let packet_bytes = packet.encode_to_bytes_vec().concat();
self.socket
.try_send_to(&packet_bytes, dst_addr)
.map_err(|e| {
Self::map_send_io_error(e, packet.headers.address.ia.source, 0, dst_addr)
})?;
Ok(())
}
async fn dispatch_local(
&self,
packet: ScionPacketRaw,
) -> Result<(), crate::scionstack::ScionSocketSendError> {
let dst_addr = self.resolve_local_dispatch_addr(&packet)?;
let packet_bytes = packet.encode_to_bytes_vec().concat();
self.socket
.send_to(&packet_bytes, dst_addr)
.await
.map_err(|e| {
Self::map_send_io_error(e, packet.headers.address.ia.source, 0, dst_addr)
})?;
Ok(())
}
fn map_send_io_error(
e: io::Error,
src: IsdAsn,
interface_id: u16,
next_hop: net::SocketAddr,
) -> ScionSocketSendError {
use std::io::ErrorKind::*;
match e.kind() {
HostUnreachable | NetworkUnreachable => {
ScionSocketSendError::UnderlayNextHopUnreachable {
isd_as: src,
interface_id,
address: Some(next_hop),
msg: e.to_string(),
}
}
ConnectionAborted | ConnectionReset | BrokenPipe => ScionSocketSendError::Closed,
_ => ScionSocketSendError::IoError(e),
}
}
}
impl UnderlaySocket for UdpUnderlaySocket {
fn send<'a>(
&'a self,
packet: ScionPacketRaw,
) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
let source_ia = packet.headers.address.ia.source;
if packet.headers.address.ia.destination == source_ia {
return Box::pin(async move {
self.dispatch_local(packet).await?;
Ok(())
});
}
let interface_id = if let DataPlanePath::Standard(standard_path) = &packet.headers.path
&& let Some(interface_id) = standard_path.iter_interfaces().next()
{
interface_id
} else {
return Box::pin(async move {
Err(ScionSocketSendError::InvalidPacket(
"Path does not contain first hop.".into(),
))
});
};
let next_hop = match self
.underlay_discovery
.resolve_udp_underlay_next_hop(PathInterface {
isd_asn: source_ia,
id: interface_id.get(),
})
.ok_or(ScionSocketSendError::UnderlayNextHopUnreachable {
isd_as: source_ia,
interface_id: interface_id.get(),
address: None,
msg: "next hop not found".to_string(),
}) {
Ok(next_hop) => next_hop,
Err(e) => {
return Box::pin(async move { Err(e) });
}
};
let packet_bytes = packet.encode_to_bytes_vec().concat();
Box::pin(async move {
self.socket
.send_to(&packet_bytes, next_hop)
.await
.map_err(|e| {
use std::io::ErrorKind::*;
match e.kind() {
HostUnreachable | NetworkUnreachable => {
ScionSocketSendError::UnderlayNextHopUnreachable {
isd_as: source_ia,
interface_id: interface_id.get(),
address: Some(next_hop),
msg: e.to_string(),
}
}
ConnectionAborted | ConnectionReset | BrokenPipe => {
ScionSocketSendError::Closed
}
_ => ScionSocketSendError::IoError(e),
}
})?;
Ok(())
})
}
fn try_send(&self, packet: ScionPacketRaw) -> Result<(), ScionSocketSendError> {
let source_ia = packet.headers.address.ia.source;
if packet.headers.address.ia.destination == source_ia {
return self.try_dispatch_local(packet);
}
let interface_id = if let DataPlanePath::Standard(standard_path) = &packet.headers.path
&& let Some(interface_id) = standard_path.iter_interfaces().next()
{
interface_id
} else {
return Err(ScionSocketSendError::InvalidPacket(
"Path does not contain first hop.".into(),
));
};
let next_hop = match self
.underlay_discovery
.resolve_udp_underlay_next_hop(PathInterface {
isd_asn: source_ia,
id: interface_id.get(),
})
.ok_or(ScionSocketSendError::UnderlayNextHopUnreachable {
isd_as: source_ia,
interface_id: interface_id.get(),
address: None,
msg: "next hop not found".to_string(),
}) {
Ok(next_hop) => next_hop,
Err(e) => {
return Err(e);
}
};
self.socket
.try_send_to(&packet.encode_to_bytes_vec().concat(), next_hop)
.map_err(|e| Self::map_send_io_error(e, source_ia, interface_id.get(), next_hop))?;
Ok(())
}
fn recv<'a>(
&'a self,
) -> BoxFuture<'a, Result<ScionPacketRaw, crate::scionstack::ScionSocketReceiveError>> {
Box::pin(async move {
let mut buf = [0u8; UDP_DATAGRAM_BUFFER_SIZE];
loop {
let (n, _) = self.socket.recv_from(&mut buf).await?;
let packet = match ScionPacketRaw::decode(&mut BytesMut::from(&buf[..n])) {
Ok(packet) => packet,
Err(e) => {
tracing::error!(error = %e, "Failed to decode SCION packet");
continue;
}
};
let dst = packet.headers.address.destination();
if let Some(dst) = dst
&& dst != self.bind_addr.scion_address()
{
tracing::debug!(destination = ?dst, assigned_addr = %self.bind_addr.scion_address(), "Packet destination does not match assigned address, skipping");
continue;
}
return Ok(packet);
}
})
}
fn local_addr(&self) -> scion_proto::address::SocketAddr {
self.bind_addr
}
fn snap_data_plane(&self) -> Option<net::SocketAddr> {
None
}
}
pub struct UdpAsyncUdpUnderlaySocket {
local_addr: SocketAddr,
discovery: Arc<dyn UnderlayDiscovery>,
inner: UdpSocket,
scmp_handlers: Vec<Box<dyn ScmpHandler>>,
}
impl UdpAsyncUdpUnderlaySocket {
pub(crate) fn new(
local_addr: SocketAddr,
discovery: Arc<dyn UnderlayDiscovery>,
inner: UdpSocket,
scmp_handlers: Vec<Box<dyn ScmpHandler>>,
) -> Self {
Self {
local_addr,
discovery,
inner,
scmp_handlers,
}
}
fn try_dispatch_local(&self, packet: ScionPacketRaw) -> io::Result<()> {
let dst_addr = packet
.headers
.address
.destination()
.ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"Packet to local endhost has no destination address".to_string(),
))?
.local_address()
.ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"Cannot forward packet with service address".to_string(),
))?;
let classification = classify_scion_packet(packet.clone()).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Cannot classify packet to local endhost: {e:#}"),
)
})?;
let dst_port = match classification {
PacketClassification::Udp(udp_packet) => udp_packet.dst_port(),
PacketClassification::ScmpWithDestination(port, _) => port,
PacketClassification::ScmpWithoutDestination(_) | PacketClassification::Other(_) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Cannot deduce port for packet to local endhost",
));
}
};
let packet_bytes = packet.encode_to_bytes_vec().concat();
let dst_addr = net::SocketAddr::new(dst_addr, dst_port);
self.inner.try_send_to(&packet_bytes, dst_addr)?;
Ok(())
}
}
impl AsyncUdpUnderlaySocket for UdpAsyncUdpUnderlaySocket {
fn create_io_poller(
self: Arc<Self>,
) -> std::pin::Pin<Box<dyn crate::scionstack::udp_polling::UdpPoller>> {
Box::pin(UdpPollHelper::new(move || {
let self_clone = self.clone();
async move { self_clone.inner.writable().await }
}))
}
fn try_send(&self, packet: ScionPacketRaw) -> Result<(), std::io::Error> {
let source_ia = packet.headers.address.ia.source;
if packet.headers.address.ia.destination == source_ia {
return self.try_dispatch_local(packet);
}
let interface_id = if let DataPlanePath::Standard(standard_path) = &packet.headers.path
&& let Some(interface_id) = standard_path.iter_interfaces().next()
{
interface_id
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Path does not contain first hop.".to_string(),
));
};
let next_hop = self
.discovery
.resolve_udp_underlay_next_hop(PathInterface {
isd_asn: source_ia,
id: interface_id.get(),
})
.ok_or(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"could not resolve next hop",
))?;
let packet_bytes = packet.encode_to_bytes_vec().concat();
match self.inner.try_send_to(&packet_bytes, next_hop) {
Ok(_) => Ok(()),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Err(e),
Err(e) => {
tracing::warn!(err = ?e, "Error sending packet");
Ok(())
}
}?;
Ok(())
}
fn poll_recv_from_with_path(
&self,
cx: &mut std::task::Context,
) -> Poll<std::io::Result<(SocketAddr, bytes::Bytes, scion_proto::path::Path)>> {
loop {
let mut raw_buf = [0u8; UDP_DATAGRAM_BUFFER_SIZE];
let mut buf = ReadBuf::new(&mut raw_buf);
let _ = ready!(self.inner.poll_recv_from(cx, &mut buf))?;
let packet = match ScionPacketRaw::decode(&mut BytesMut::from(buf.initialized())) {
Ok(packet) => packet,
Err(e) => {
tracing::trace!(error = %e, "Received non SCION packet, dropping");
continue;
}
};
if packet.headers.common.next_header == SCMP_PROTOCOL_NUMBER {
tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
for handler in &self.scmp_handlers {
if let Some(reply) = handler.handle(packet.clone())
&& let Err(e) = self.try_send(reply)
{
tracing::warn!(error = %e, "failed to send SCMP reply");
}
}
continue;
};
let fallible = || {
let src = packet
.headers
.address
.source()
.context("reading source address")?;
let dst = packet
.headers
.address
.destination()
.context("reading destination address")?;
if dst != self.local_addr.scion_address() {
anyhow::bail!(
"Packet destination does not match assigned address, skipping (dst: {}, assigned: {})",
dst,
self.local_addr.scion_address()
);
}
let path = Path::new(
packet.headers.path.clone(),
ByEndpoint {
source: src.isd_asn(),
destination: dst.isd_asn(),
},
None,
);
let packet: ScionPacketUdp = packet.try_into().context("parsing UDP packet")?;
anyhow::Ok((
SocketAddr::new(src, packet.src_port()),
packet.datagram.payload,
path,
))
};
match fallible() {
Ok(result) => return Poll::Ready(Ok(result)),
Err(e) => {
tracing::warn!(error = %e, "Received invalid packet, skipping");
continue;
}
}
}
}
fn local_addr(&self) -> SocketAddr {
self.local_addr
}
fn snap_data_plane(&self) -> Option<net::SocketAddr> {
None
}
}