use std::sync::{Arc, Mutex};
use bytes::Bytes;
use scion_proto::{address::EndhostAddr, packet::ScionPacketRaw, wire_encoding::WireEncodeVec};
use scion_sdk_token_validator::validator::Token;
use serde::Deserialize;
use snap_tun::server::{AddressAssignmentError, SendPacketError};
use tokio::sync::mpsc::{Receiver, Sender, error::TrySendError};
use tracing::{debug, info, span, trace};
use crate::{
dispatcher::Dispatcher,
tunnel_gateway::{metrics::TunnelGatewayDispatcherMetrics, state::SharedTunnelGatewayState},
};
const DISPATCHER_CHANNEL_SIZE: usize = 10000;
#[derive(Debug, Clone)]
pub struct TunnelGatewayDispatcher<T>
where
T: for<'de> Deserialize<'de> + Token,
{
sender: Sender<ScionPacketRaw>,
receiver: Arc<Mutex<Option<Receiver<ScionPacketRaw>>>>,
state: SharedTunnelGatewayState<T>,
metrics: TunnelGatewayDispatcherMetrics,
}
impl<T> TunnelGatewayDispatcher<T>
where
T: for<'de> Deserialize<'de> + Token,
{
pub fn new(
state: SharedTunnelGatewayState<T>,
metrics: TunnelGatewayDispatcherMetrics,
) -> Self {
let (sender, receiver) =
tokio::sync::mpsc::channel::<ScionPacketRaw>(DISPATCHER_CHANNEL_SIZE);
Self {
sender,
receiver: Arc::new(Mutex::new(Some(receiver))),
state,
metrics,
}
}
pub async fn start_dispatching(&self) -> std::io::Result<()> {
let mut receiver = self.receiver.lock().unwrap().take().unwrap();
while let Some(packet) = receiver.recv().await {
self.metrics.dispatch_queue_size.dec();
let dest_addr = match packet.headers.address.destination() {
Some(addr) => addr,
None => {
self.metrics.invalid_packets_errors.inc();
debug!("Destination address couldn't be decoded.");
continue;
}
};
let dest_addr: EndhostAddr = match EndhostAddr::try_from(dest_addr) {
Ok(addr) => addr,
Err(err) => {
self.metrics.invalid_packets_errors.inc();
debug!(%err, "Destination address is not a valid endhost address");
continue;
}
};
match self.state.get_tunnel(dest_addr) {
Some(tun) => {
span!(tracing::Level::INFO, "connection", remote_underlay_address = %tun.remote_underlay_address(), dest_addr = %dest_addr).in_scope(|| {
let raw: Bytes = packet.encode_to_bytes_vec().concat().into();
trace!(dst=%dest_addr, pkt_len=%raw.len(), "dispatching packet");
if let Err(e) = tun.send(raw) {
match e {
SendPacketError::ConnectionClosed => {
self.metrics.connection_closed_errors.inc()
}
SendPacketError::NewAssignedAddress(_) => {
self.metrics.new_assigned_address_errors.inc()
}
SendPacketError::AddressAssignmentError(
AddressAssignmentError::NoAddressAssigned,
) => self.metrics.no_address_assigned_errors.inc(),
SendPacketError::SendDatagramError(_) => {
self.metrics.send_datagram_errors.inc()
}
}
}
});
}
_ => {
self.metrics.missing_tunnel_errors.inc();
debug!(dest_addr=%dest_addr, "no connection found");
}
}
}
info!("Tunnel gateway dispatcher stopped");
Ok(())
}
}
impl<T> Dispatcher for TunnelGatewayDispatcher<T>
where
T: for<'de> Deserialize<'de> + Token + Clone,
{
fn try_dispatch(&self, packet: ScionPacketRaw) {
match self.sender.try_send(packet) {
Ok(_) => self.metrics.dispatch_queue_size.inc(),
Err(err) => {
match err {
TrySendError::Full(_) => self.metrics.full_dispatch_queue_errors.inc(),
TrySendError::Closed(_) => self.metrics.closed_dispatch_queue_errors.inc(),
}
}
}
}
}