use futures::StreamExt;
use tracing::{
info_span,
Span,
Instrument
};
use tokio::sync::watch;
use tokio::task::JoinError;
use crate::control::{
KeyArg,
RouterReq
};
use crate::peer::Address;
use crate::util::*;
use super::{
InterfaceAddress,
ChaCha20Server
};
use super::interface::{
Interface,
InterfaceLauncher,
Terminated
};
pub struct Router {
inbound_pkt_tx: Sender<PktFrom>,
monitor_tx: watch::Sender<bool>,
outbound_pkt_rx: Receiver<PktTo>,
query_rx: Receiver<RouterReq>,
encryption_key: [u8; 32],
interfaces: Registry<Terminated, Interface>,
span: Span
}
pub type RouterMonitor = watch::Receiver<bool>;
pub struct RouterSetup {
pub router: Router,
pub inbound_packet_receiver: Receiver<PktFrom>,
pub router_monitor: RouterMonitor,
pub outbound_packet_sender: Sender<PktTo>,
pub router_query_sender: Sender<RouterReq>
}
impl Router {
pub fn setup(encryption_key: KeyArg, parent_span: &Span) -> RouterSetup {
let (inbound_packet_sender, inbound_packet_receiver)
= new_channel::<PktFrom>();
let (outbound_packet_sender, outbound_packet_receiver)
= new_channel::<PktTo>();
let (router_query_sender, query_receiver)
= new_channel::<RouterReq>();
let (monitor_tx, router_monitor)
= watch::channel::<bool>(false);
let interfaces = Registry::new();
let router = Self {
inbound_pkt_tx: inbound_packet_sender,
monitor_tx,
outbound_pkt_rx: outbound_packet_receiver,
query_rx: query_receiver,
encryption_key: encryption_key.get(),
interfaces,
span: parent_span.in_scope(|| info_span!("router"))
};
RouterSetup {
router,
inbound_packet_receiver,
router_monitor,
outbound_packet_sender,
router_query_sender
}
}
#[allow(clippy::option_map_unit_fn)]
fn send_packet(&self, pkt_to: PktTo) {
self.interfaces
.iter()
.find(|&iface| iface.can_send(pkt_to.0))
.map(|iface| iface.send(pkt_to));
}
fn terminate_all(&mut self) {
for interface in self.interfaces.iter_mut() {
trace!(:self.span, %interface, "terminating");
interface.shutdown();
}
}
fn remove(&mut self, addr: Address) -> bool {
self.interfaces
.iter_mut()
.filter(|interface| interface.get_address() == addr)
.map(|interface| {
trace!(:self.span, %interface, "terminating");
interface.shutdown();
})
.count() > 0
}
async fn bind_interface(&mut self, address: InterfaceAddress) -> bool {
let _g = self.span.enter();
let already_bound = self.interfaces
.iter_mut()
.find(|iface| iface.get_address() == address.get_address());
if let Some(interface) = already_bound {
trace!(%interface, %address, "already bound, replacing routes");
interface.replace_address(address);
return true;
}
let Some(server) = ChaCha20Server::bind(
&self.encryption_key,
address.clone(),
&self.span
).in_current_span().await else {return false};
let (outbound_packet_sender, outbound_packet_receiver)
= new_channel::<PktTo>();
let server_channel = ServerChannel::combine(
self.inbound_pkt_tx.clone(),
outbound_packet_receiver
);
let interface = Interface::new(address, outbound_packet_sender);
trace!(@self, %interface, "spawned chacha20 task");
self.interfaces.launch(interface, server.launch(server_channel));
true
}
fn launch_interface(&mut self, launcher: InterfaceLauncher) {
let address = launcher.get_address();
let (outbound_packet_sender, outbound_packet_receiver)
= new_channel::<PktTo>();
let interface = Interface::new(address, outbound_packet_sender);
let channel = ServerChannel::combine(
self.inbound_pkt_tx.clone(),
outbound_packet_receiver
);
match launcher {
InterfaceLauncher::Chacha20Udp(server) => {
let address = server.get_address();
self.interfaces.launch(interface, server.launch(channel));
trace!(@self, %address, "spawned chacha20 task");
},
InterfaceLauncher::Dummy(server) => {
let address = server.get_address();
self.interfaces.launch(interface, server.run(channel));
trace!(@self, %address, "spawned dummy server task");
}
}
}
fn log_termination(
&self,
interface: Interface,
reason: Result<Terminated, JoinError>)
{
match reason {
Ok(Terminated::Shutdown) =>
info!(@self, %interface, "shut down"),
Ok(Terminated::Crashed(error_msg)) =>
error!(@self, %interface, %error_msg, "crashed"),
Ok(Terminated::Panic) =>
error!(@self, %interface, "task panicked"),
Err(error) =>
error!(@self, %interface, %error, "crashed"),
}
}
pub async fn run(mut self) {
info!(@self, "started");
let mut is_active = false;
loop {
if is_active == self.interfaces.is_empty() {
if self.interfaces.is_empty() {
debug!(@self, "entering dormant state");
is_active = false;
}
else {
debug!(@self, "entering active state");
is_active = true;
}
let _ = self.monitor_tx.send(is_active);
}
tokio::select! {
q = self.query_rx.recv() => match q {
Some(request) => self.handle_request(request)
.in_current_span()
.await,
None => {
info!(@self, "received shutdown signal");
break
}
},
r = self.interfaces.next(), if is_active => {
if let Some((addr, reason)) = r {
self.log_termination(addr, reason);
}
},
p = self.outbound_pkt_rx.recv() => match p {
Some(pkt_to) => self.send_packet(pkt_to),
None => break
}
}
}
while let Ok(pkt_to) = self.outbound_pkt_rx.try_recv() {
self.send_packet(pkt_to);
}
self.terminate_all();
while let Some((addr, reason)) = self.interfaces.next().await {
self.log_termination(addr, reason);
}
info!(@self, "finished");
}
async fn handle_request(&mut self, req: RouterReq) {
match req {
RouterReq::GetInterfaces(req) => {
let addrs = self.interfaces
.iter()
.map(Interface::get_interface_address)
.collect::<Vec<_>>();
req.reply(addrs);
},
RouterReq::CanSend(req) => {
let can_send = self.interfaces
.iter()
.any(|iface| iface.can_send(req.0));
req.reply(can_send);
},
RouterReq::BindInterface((address, replier)) => {
let worked = self.bind_interface(address)
.in_current_span()
.await;
replier.reply(worked);
},
RouterReq::LaunchInterface((launcher, replier)) => {
self.launch_interface(launcher);
replier.reply(true);
},
RouterReq::RemoveInterface((addr, replier)) => {
replier.reply(self.remove(addr));
}
}
}
}