use std::{net::SocketAddr, sync::Arc};
use scion_proto::{
packet::{ScionPacketRaw, classify_scion_packet},
wire_encoding::WireDecode,
};
use snap_dataplane::dispatcher::Dispatcher;
use crate::network::local::receivers::Receiver;
pub struct RouterSocket<D> {
socket: tokio::net::UdpSocket,
dispatcher: Arc<D>,
}
impl<D> RouterSocket<D> {
pub async fn new(socket: tokio::net::UdpSocket, dispatcher: Arc<D>) -> std::io::Result<Self> {
Ok(Self { socket, dispatcher })
}
pub fn addr(&self) -> SocketAddr {
self.socket.local_addr().expect("socket should be bound")
}
}
impl<D: Dispatcher> RouterSocket<D> {
pub async fn run(&self) -> std::io::Result<()> {
let mut buf = vec![0u8; 65536]; loop {
match self.socket.recv_from(&mut buf).await {
Ok((size, src)) => {
let packet = match ScionPacketRaw::decode(&mut buf[..size].as_ref()) {
Ok(packet) => packet,
Err(e) => {
tracing::error!(error=%e, src=?src, "Failed to decode SCION packet");
continue;
}
};
self.dispatcher.try_dispatch(packet);
}
Err(e) => {
tracing::error!(error=%e, "Failed to receive packet");
}
}
}
}
}
pub struct SharedRouterSocket<D: Dispatcher>(Arc<RouterSocket<D>>);
impl<D: Dispatcher> SharedRouterSocket<D> {
pub fn new(router_socket: RouterSocket<D>) -> Self {
Self(Arc::new(router_socket))
}
pub async fn run(&self) -> std::io::Result<()> {
self.0.run().await
}
}
impl<D: Dispatcher> Clone for SharedRouterSocket<D> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<D: Dispatcher> Receiver for SharedRouterSocket<D> {
fn receive_packet(&self, packet: ScionPacketRaw) {
let classified_packet = match classify_scion_packet(packet) {
Ok(classification) => classification,
Err(e) => {
tracing::error!(error=%e, "Failed to classify SCION packet");
return;
}
};
let dst_addr = match classified_packet.destination() {
Some(addr) => addr,
None => {
tracing::error!("Could not extract destination address from SCION packet");
return;
}
};
let dst_addr = match dst_addr.local_address() {
Some(addr) => addr,
None => {
tracing::error!("SVC address not supported");
return;
}
};
let src_addr = self.0.socket.local_addr().expect("no fail");
tracing::debug!(?dst_addr, ?src_addr, "Router socket dispatching packet");
let raw = classified_packet.encode_to_vec();
if let Err(e) = self.0.socket.try_send_to(&raw, dst_addr) {
tracing::error!(error=%e, "Failed to send packet to {}", dst_addr);
}
}
}