use std::{
collections::HashMap,
io,
net::IpAddr,
sync::{Arc, RwLock},
};
use bytes::Bytes;
use scion_proto::{
address::{IsdAsn, ScionAddr, SocketAddr},
packet::{ByEndpoint, NextHeader, ScionPacketRaw, ScionPacketUdp},
path::DataPlanePath,
};
use scion_sdk_quic_scion::socket::{BoxedSocketError, GenericScionUdpSocket};
use tokio::sync::{Mutex, mpsc};
use crate::{
network::{local::receivers::Receiver, scion::routing::ScionNetworkTime},
state::SharedPocketScionState,
};
#[derive(Clone)]
pub struct NetSimStack(Arc<NetSimStackInner>);
struct NetSimStackInner {
state: SharedPocketScionState,
udp_receivers: RwLock<HashMap<u16, mpsc::Sender<ScionPacketUdp>>>,
raw_recevers: RwLock<Vec<mpsc::Sender<ScionPacketRaw>>>,
local_as: IsdAsn,
bind_addr: IpAddr,
rx_queue_size: usize,
}
impl NetSimStack {
pub fn bind(
state: SharedPocketScionState,
local_as: IsdAsn,
bind_addr: IpAddr,
queue_size: usize,
) -> anyhow::Result<Self> {
let this = Self(Arc::new(NetSimStackInner {
state: state.clone(),
udp_receivers: RwLock::new(HashMap::new()),
raw_recevers: RwLock::new(Vec::new()),
local_as,
bind_addr,
rx_queue_size: queue_size,
}));
state.add_sim_receiver(local_as, bind_addr.into(), this.0.clone())?;
Ok(this)
}
pub fn bind_udp(&self, mut port: u16) -> anyhow::Result<NetSimUdpSocket> {
let mut udp_receivers = self.0.udp_receivers.write().unwrap();
if port == 0 {
for check_port in 1024..65535 {
if !udp_receivers.contains_key(&check_port) {
port = check_port;
break;
}
}
if port == 0 {
anyhow::bail!("No available ports");
}
}
if udp_receivers.contains_key(&port) {
anyhow::bail!("Port {} already in use", port);
}
let (socket, receiver) = NetSimUdpSocket::new(self.clone(), self.0.rx_queue_size, port);
udp_receivers.insert(port, receiver);
Ok(socket)
}
pub fn bind_raw(&self) -> NetSimRawSocket {
let mut raw_receivers = self.0.raw_recevers.write().unwrap();
let (socket, receiver) = NetSimRawSocket::new(self.clone(), self.0.rx_queue_size);
raw_receivers.push(receiver);
socket
}
pub fn clean(&self) {
let mut udp_receivers = self.0.udp_receivers.write().unwrap();
udp_receivers.retain(|_, rx| !rx.is_closed());
let mut raw_receivers = self.0.raw_recevers.write().unwrap();
raw_receivers.retain(|rx| !rx.is_closed());
}
fn send(&self, packet: ScionPacketRaw, timestamp: ScionNetworkTime) {
self.0
.state
.dispatch_to_network_sim(self.0.local_as, 0, timestamp, packet);
}
}
impl Receiver for NetSimStackInner {
fn receive_packet(&self, packet: ScionPacketRaw) {
let dest_addr = packet.headers.address.destination();
if !dest_addr
.iter()
.flat_map(|addr| addr.local_address())
.any(|addr| addr == self.bind_addr)
{
tracing::warn!(
packet_destination = ?dest_addr,
local_address = ?self.bind_addr,
"Received packet with destination address that does not match socket's bind address, dropping packet"
);
return;
}
let mut forwarded_once = false;
{
let raw_recv = self.raw_recevers.read().unwrap();
for raw_rx in raw_recv.iter() {
match raw_rx.try_reserve() {
Ok(permit) => {
permit.send(packet.clone());
forwarded_once = true;
}
Err(e) => {
tracing::warn!(
error = ?e,
"Raw socket receiver is full, dropping packet for this receiver"
);
}
}
}
}
if packet.headers.common.next_header == NextHeader::UDP {
let pkt = match ScionPacketUdp::try_from(packet.clone()) {
Ok(pkt) => pkt,
Err(e) => {
tracing::warn!(
error = ?e,
"Failed to parse received packet as SCION UDP, not forwarding to UDP receivers"
);
return;
}
};
let udp_receivers = self.udp_receivers.read().unwrap();
let Some(udp) = udp_receivers.get(&pkt.dst_port()) else {
if !forwarded_once {
tracing::warn!(
port = pkt.dst_port(),
"Received UDP packet for port that has no receiver, and no raw receivers to forward to, dropping packet"
);
}
return;
};
match udp.try_reserve() {
Ok(permit) => permit.send(pkt),
Err(e) => {
tracing::warn!(
error = ?e,
port = pkt.dst_port(),
"UDP socket receiver is full, dropping packet for this receiver"
);
}
}
}
}
}
pub struct NetSimUdpSocket {
stack: NetSimStack,
rx_queue: Mutex<mpsc::Receiver<ScionPacketUdp>>,
port: u16,
}
impl NetSimUdpSocket {
fn new(
stack: NetSimStack,
rx_queue_size: usize,
port: u16,
) -> (Self, mpsc::Sender<ScionPacketUdp>) {
let (rx_queue_sender, rx_queue) = mpsc::channel(rx_queue_size);
(
Self {
stack,
rx_queue: Mutex::new(rx_queue),
port,
},
rx_queue_sender,
)
}
pub fn into_path_aware<P: NetSimPathProvider>(
self,
path_provider: P,
) -> PathAwareNetSimUdpSocket<P> {
PathAwareNetSimUdpSocket::new(self, path_provider)
}
pub fn try_send(
&self,
dst: scion_proto::address::SocketAddr,
path: DataPlanePath<Bytes>,
payload: Bytes,
timestamp: ScionNetworkTime,
) -> io::Result<()> {
let packet = ScionPacketUdp::new(
ByEndpoint {
source: SocketAddr::new(
ScionAddr::new(self.stack.0.local_as, self.stack.0.bind_addr.into()),
self.port,
),
destination: dst,
},
path,
payload,
)
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Failed to construct SCION packet: {e}"),
)
})?;
self.stack.send(packet.into(), timestamp);
Ok(())
}
pub fn try_recv(&self) -> io::Result<ScionPacketUdp> {
match self
.rx_queue
.try_lock()
.map_err(|_| io::Error::new(io::ErrorKind::WouldBlock, "Failed to acquire lock"))?
.try_recv()
{
Ok(p) => Ok(p),
Err(err) => {
match err {
mpsc::error::TryRecvError::Empty => {
Err(io::Error::new(
io::ErrorKind::WouldBlock,
"No packet available",
))
}
mpsc::error::TryRecvError::Disconnected => {
Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Socket receiver disconnected",
))
}
}
}
}
}
pub async fn recv(&self) -> io::Result<ScionPacketUdp> {
match self.rx_queue.lock().await.recv().await {
Some(p) => Ok(p),
None => {
Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Socket receiver disconnected",
))
}
}
}
pub fn socket_addr(&self) -> SocketAddr {
SocketAddr::new(
ScionAddr::new(self.stack.0.local_as, self.stack.0.bind_addr.into()),
self.port,
)
}
}
pub struct NetSimRawSocket {
stack: NetSimStack,
rx_queue: Mutex<mpsc::Receiver<ScionPacketRaw>>,
}
impl NetSimRawSocket {
fn new(stack: NetSimStack, rx_queue_size: usize) -> (Self, mpsc::Sender<ScionPacketRaw>) {
let (rx_queue_sender, rx_queue) = mpsc::channel(rx_queue_size);
(
Self {
stack,
rx_queue: Mutex::new(rx_queue),
},
rx_queue_sender,
)
}
pub fn try_send(&self, packet: ScionPacketRaw, timestamp: ScionNetworkTime) -> io::Result<()> {
self.stack.send(packet, timestamp);
Ok(())
}
pub fn try_recv(&self) -> io::Result<ScionPacketRaw> {
match self
.rx_queue
.try_lock()
.map_err(|_| io::Error::new(io::ErrorKind::WouldBlock, "Failed to acquire lock"))?
.try_recv()
{
Ok(p) => Ok(p),
Err(err) => {
match err {
mpsc::error::TryRecvError::Empty => {
Err(io::Error::new(
io::ErrorKind::WouldBlock,
"No packet available",
))
}
mpsc::error::TryRecvError::Disconnected => {
Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Socket receiver disconnected",
))
}
}
}
}
}
pub async fn recv(&self) -> io::Result<ScionPacketRaw> {
match self.rx_queue.lock().await.recv().await {
Some(p) => Ok(p),
None => {
Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Socket receiver disconnected",
))
}
}
}
pub fn scion_addr(&self) -> ScionAddr {
ScionAddr::new(self.stack.0.local_as, self.stack.0.bind_addr.into())
}
}
pub trait NetSimPathProvider: Send + Sync + 'static {
fn get_path(&self, src_as: IsdAsn, dst_as: IsdAsn) -> Option<DataPlanePath>;
}
pub struct PathAwareNetSimUdpSocket<P: NetSimPathProvider> {
socket: NetSimUdpSocket,
pub path_provider: P,
}
impl<P: NetSimPathProvider> PathAwareNetSimUdpSocket<P> {
pub fn new(socket: NetSimUdpSocket, path_provider: P) -> Self {
Self {
socket,
path_provider,
}
}
pub fn try_send(
&self,
dst: scion_proto::address::SocketAddr,
payload: Bytes,
) -> io::Result<()> {
let path = self
.path_provider
.get_path(self.socket.stack.0.local_as, dst.isd_asn())
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!(
"No path found from AS {} to destination AS {}",
self.socket.stack.0.local_as,
dst.isd_asn()
),
)
})?;
self.socket
.try_send(dst, path, payload, ScionNetworkTime::now())
}
pub fn try_recv(&mut self) -> io::Result<ScionPacketUdp> {
self.socket.try_recv()
}
pub async fn recv(&self) -> io::Result<ScionPacketUdp> {
self.socket.recv().await
}
}
#[async_trait::async_trait]
impl<P: NetSimPathProvider> GenericScionUdpSocket for PathAwareNetSimUdpSocket<P> {
async fn send_to(
&self,
payload: &[u8],
destination: SocketAddr,
) -> Result<(), BoxedSocketError> {
self.try_send(destination, Bytes::copy_from_slice(payload))
.map_err(|e| Box::new(e) as BoxedSocketError)?;
Ok(())
}
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), BoxedSocketError> {
let (pkt, src_addr) = loop {
let pkt = self
.recv()
.await
.map_err(|e| Box::new(e) as BoxedSocketError)?;
let sci_addr = match pkt.headers.address.source() {
Some(addr) => addr,
None => {
tracing::warn!("Received packet with unknown source address, dropping packet");
continue;
}
};
let port = pkt.src_port();
break (pkt, SocketAddr::new(sci_addr, port));
};
let payload = pkt.payload();
let payload_len = std::cmp::min(buf.len(), payload.len());
buf[..payload_len].copy_from_slice(&payload[..payload_len]);
Ok((payload_len, src_addr))
}
fn local_addr(&self) -> SocketAddr {
SocketAddr::new(
ScionAddr::new(
self.socket.stack.0.local_as,
self.socket.stack.0.bind_addr.into(),
),
self.socket.port,
)
}
}
impl SharedPocketScionState {
pub fn bind_sim_network_stack(
&self,
local_as: IsdAsn,
bind_addr: IpAddr,
queue_size: usize,
) -> anyhow::Result<NetSimStack> {
NetSimStack::bind(self.clone(), local_as, bind_addr, queue_size)
}
}
#[cfg(test)]
mod tests {
use std::{net::IpAddr, time::SystemTime};
use bytes::Bytes;
use scion_proto::{
address::{IsdAsn, ScionAddr, SocketAddr},
packet::{ByEndpoint, ScionPacketUdp},
path::DataPlanePath,
};
use tokio::time::{Duration, timeout};
use crate::{
network::scion::{
routing::ScionNetworkTime,
topology::{ScionAs, ScionTopology},
},
state::SharedPocketScionState,
};
fn setup_state(isd_as: IsdAsn) -> SharedPocketScionState {
let mut state = SharedPocketScionState::new(SystemTime::now());
let mut topology = ScionTopology::new();
topology
.add_as(ScionAs::new_core(isd_as))
.expect("failed to add AS");
state.set_topology(topology);
state
}
#[tokio::test]
async fn should_deliver_udp_to_port_and_raw_receiver() {
let local_as: IsdAsn = "1-ff00:0:110".parse().unwrap();
let bind_ip: IpAddr = "10.0.0.1".parse().unwrap();
let queue_size = 8;
let state = setup_state(local_as);
let stack = state
.bind_sim_network_stack(local_as, bind_ip, queue_size)
.expect("bind sim stack");
let udp_socket = stack.bind_udp(40000).expect("bind udp socket");
let raw_socket = stack.bind_raw();
let src_ip: IpAddr = "10.0.0.9".parse().unwrap();
let src = SocketAddr::new(ScionAddr::new(local_as, src_ip.into()), 50000);
let dst = SocketAddr::new(ScionAddr::new(local_as, bind_ip.into()), 40000);
let payload = Bytes::from_static(b"hello");
let packet = ScionPacketUdp::new(
ByEndpoint {
source: src,
destination: dst,
},
DataPlanePath::EmptyPath,
payload.clone(),
)
.expect("build packet");
state.dispatch_to_network_sim(local_as, 0, ScionNetworkTime::now(), packet.clone().into());
let recv_udp = timeout(Duration::from_secs(2), udp_socket.recv())
.await
.expect("udp recv timeout")
.expect("udp recv packet");
assert_eq!(recv_udp.payload(), &payload);
let recv_raw = timeout(Duration::from_secs(2), raw_socket.recv())
.await
.expect("raw recv timeout")
.expect("raw recv packet");
let recv_raw_udp: ScionPacketUdp = recv_raw.try_into().expect("raw packet as UDP");
assert_eq!(recv_raw_udp.payload(), &payload); }
#[tokio::test]
async fn should_send_udp_between_stacks() {
let local_as: IsdAsn = "1-ff00:0:110".parse().unwrap();
let sender_ip: IpAddr = "10.0.0.1".parse().unwrap();
let receiver_ip: IpAddr = "10.0.0.2".parse().unwrap();
let queue_size = 8;
let state = setup_state(local_as);
let sender_stack = state
.bind_sim_network_stack(local_as, sender_ip, queue_size)
.expect("bind sender stack");
let receiver_stack = state
.bind_sim_network_stack(local_as, receiver_ip, queue_size)
.expect("bind receiver stack");
let sender_socket = sender_stack.bind_udp(0).expect("bind sender udp");
let receiver_socket = receiver_stack.bind_udp(41000).expect("bind receiver udp");
let dst = SocketAddr::new(ScionAddr::new(local_as, receiver_ip.into()), 41000);
let payload = Bytes::from_static(b"cross-stack");
sender_socket
.try_send(
dst,
DataPlanePath::EmptyPath,
payload.clone(),
ScionNetworkTime::now(),
)
.expect("send packet");
let recv = timeout(Duration::from_secs(2), receiver_socket.recv())
.await
.expect("recv timeout")
.expect("recv packet");
assert_eq!(recv.payload(), &payload); assert_eq!(
recv.source().expect("source addr"),
sender_socket.socket_addr()
);
}
}