use std::net::SocketAddr;
use ana_gotatun::packet::Packet;
use sciparse::{core::view::View, packet::view::ScionPacketView};
use tokio::sync::mpsc::{Receiver, Sender, channel, error::TrySendError};
use crate::{
dispatcher::Dispatcher,
tunnel_gateway::{gateway::PacketPool, metrics::TunnelGatewayDispatcherMetrics},
};
const OUTBOUND_QUEUE_SIZE: usize = 1024;
const BUFFER_POOL_INIT_SIZE: usize = 1024;
#[derive(Clone)]
pub struct TunnelGatewayDispatcher {
metrics: TunnelGatewayDispatcherMetrics,
pool: PacketPool,
outbound_queue: Sender<(SocketAddr, Packet)>,
}
impl TunnelGatewayDispatcher {
pub fn new(metrics: TunnelGatewayDispatcherMetrics) -> (Self, TunnelGatewayDispatcherReceiver) {
let pool = ana_gotatun::packet::PacketBufPool::new(BUFFER_POOL_INIT_SIZE);
let (tx, rx) = channel(OUTBOUND_QUEUE_SIZE);
let myself = Self {
metrics,
pool: pool.clone(),
outbound_queue: tx,
};
let rx = TunnelGatewayDispatcherReceiver {
outbound_queue: rx,
pool,
};
(myself, rx)
}
}
impl Dispatcher for TunnelGatewayDispatcher {
fn try_dispatch(&self, packet: &ScionPacketView) {
let classification = match packet.classify() {
Ok(c) => c,
Err(e) => {
self.metrics.invalid_packets_errors.inc();
tracing::debug!(error=%e, "Failed to classify packet");
return;
}
};
let sock_addr = match classification
.dst_socket_addr()
.and_then(|a| a.socket_addr())
{
Some(addr) => addr,
None => {
self.metrics.invalid_packets_errors.inc();
tracing::debug!("Could not deduce destination socket address from packet");
return;
}
};
let raw_bytes = packet.as_bytes();
let mut pooled_packet = self.pool.get();
pooled_packet.as_mut()[..raw_bytes.len()].copy_from_slice(raw_bytes);
pooled_packet.truncate(raw_bytes.len());
match self.outbound_queue.try_send((sock_addr, pooled_packet)) {
Ok(_) => self.metrics.dispatch_queue_size.inc(),
Err(TrySendError::Closed(_)) => self.metrics.closed_dispatch_queue_errors.inc(),
Err(TrySendError::Full(_)) => self.metrics.full_dispatch_queue_errors.inc(),
}
}
}
pub struct TunnelGatewayDispatcherReceiver {
pub(crate) pool: PacketPool,
pub(crate) outbound_queue: Receiver<(SocketAddr, Packet)>,
}