use std::net::{
Ipv4Addr,
Ipv6Addr
};
use futures::StreamExt;
use ipnetwork::{
Ipv4Network,
Ipv6Network
};
use prefix_trie::PrefixMap;
use tracing::{
info_span,
Span,
Instrument
};
use tokio::sync::watch;
use tokio::task::JoinError;
use crate::control::{
KeyArg,
RouterReq
};
use crate::peer::Address;
use crate::util::*;
use super::{
InterfaceAddress,
InterfaceRoute,
ChaCha20Server
};
use super::interface::{
Interface,
InterfaceLauncher,
Terminated
};
pub struct Router {
inbound_pkt_tx: Sender<PktFrom>,
monitor_tx: watch::Sender<bool>,
outbound_pkt_rx: Receiver<PktTo>,
query_rx: Receiver<RouterReq>,
encryption_key: [u8; 32],
interfaces: Registry<Terminated, Interface>,
tables: Tables,
span: Span
}
pub type RouterMonitor = watch::Receiver<bool>;
pub struct RouterSetup {
pub router: Router,
pub inbound_packet_receiver: Receiver<PktFrom>,
pub router_monitor: RouterMonitor,
pub outbound_packet_sender: Sender<PktTo>,
pub router_query_sender: Sender<RouterReq>
}
#[derive(Debug, Default)]
struct Tables {
ip_v4: PrefixMap<Ipv4Network, Sender<PktTo>>,
ip_v6: PrefixMap<Ipv6Network, Sender<PktTo>>,
dummy: Vec<Sender<PktTo>>
}
impl Tables {
fn route(&self, to: Address) -> Option<&Sender<PktTo>> {
match to {
Address::V4UdpChaCha20(s_addr) => {
let addr_as_net = Ipv4Network::new(
*s_addr.ip(),
Ipv4Addr::BITS as u8
).unwrap();
self.ip_v4.get_lpm(&addr_as_net).map(|(_p, iface)| iface)
},
Address::V6UdpChaCha20(s_addr) => {
let addr_as_net = Ipv6Network::new(
*s_addr.ip(),
Ipv6Addr::BITS as u8
).unwrap();
self.ip_v6.get_lpm(&addr_as_net).map(|(_p, iface)| iface)
},
Address::Dummy(_) => self.dummy.first()
}
}
fn add_routes(&mut self, interface: &Interface) {
let if_addr = interface.get_interface_address();
let Some(sender) = interface.get_sender() else {
error!(%if_addr, "can't add routes for interface without sender");
return;
};
match if_addr.get_routes() {
InterfaceRoute::V4UdpChaCha20(routes) => for &route in routes {
let cloned = sender.clone();
if let Some(_old) = self.ip_v4.insert(route, cloned) {
}
},
InterfaceRoute::V6UdpChaCha20(routes) => for &route in routes {
let cloned = sender.clone();
if let Some(_old) = self.ip_v6.insert(route, cloned) {
}
},
InterfaceRoute::Dummy => {
self.dummy.push(sender.clone());
}
}
}
fn drop_routes(&mut self, interface: &Interface) {
let if_addr = interface.get_interface_address();
let Some(sender) = interface.get_sender() else {
error!(%if_addr, "can't drop routes for interface without sender");
return;
};
match if_addr.get_routes() {
InterfaceRoute::V4UdpChaCha20(routes) => for route in routes {
if self.ip_v4.get(route)
.is_some_and(|s| s.same_channel(sender))
{
self.ip_v4.remove(route);
}
},
InterfaceRoute::V6UdpChaCha20(routes) => for route in routes {
if self.ip_v6.get(route)
.is_some_and(|s| s.same_channel(sender))
{
self.ip_v6.remove(route);
}
},
InterfaceRoute::Dummy => {
self.dummy.retain(|s| !s.same_channel(sender));
}
}
}
}
impl Router {
pub fn setup(encryption_key: KeyArg, parent_span: &Span) -> RouterSetup {
let (inbound_packet_sender, inbound_packet_receiver)
= new_channel::<PktFrom>();
let (outbound_packet_sender, outbound_packet_receiver)
= new_channel::<PktTo>();
let (router_query_sender, query_receiver)
= new_channel::<RouterReq>();
let (monitor_tx, router_monitor)
= watch::channel::<bool>(false);
let interfaces = Registry::new();
let router = Self {
inbound_pkt_tx: inbound_packet_sender,
monitor_tx,
outbound_pkt_rx: outbound_packet_receiver,
query_rx: query_receiver,
encryption_key: encryption_key.get(),
interfaces,
tables: Tables::default(),
span: parent_span.in_scope(|| info_span!("router"))
};
RouterSetup {
router,
inbound_packet_receiver,
router_monitor,
outbound_packet_sender,
router_query_sender
}
}
fn send_packet(&self, pkt_to: PktTo) {
match self.tables.route(pkt_to.0) {
Some(channel) => match channel.send(pkt_to) {
Ok(()) => {},
Err(err) => warn!(?err, "failed to send packet to channel"),
},
None => warn!(
destination=%pkt_to.0,
packet=?pkt_to.1,
"no interface to send packet"
)
}
}
async fn rekey(&mut self, key: KeyArg) {
self.encryption_key = key.get();
let mut addrs_to_restart = Vec::new();
let mut await_stop = Registry::new();
for (mut interface, handle) in std::mem::take(&mut self.interfaces) {
match interface.get_interface_address() {
address @ InterfaceAddress::V4UdpChaCha20 { .. } |
address @ InterfaceAddress::V6UdpChaCha20 { .. } => {
trace!(:self.span, %interface, "terminating");
self.tables.drop_routes(&interface);
interface.shutdown();
addrs_to_restart.push(address);
await_stop.insert(interface, handle);
},
InterfaceAddress::Dummy(_) => self.interfaces
.insert(interface, handle)
}
}
while let Some((iface, reason)) = await_stop.next().await {
self.log_termination(iface, reason);
}
for address in addrs_to_restart {
trace!(:self.span, %address, "restarting");
self.bind_interface(address).await;
}
}
fn terminate_all(&mut self) {
std::mem::take(&mut self.tables);
for interface in self.interfaces.iter_mut() {
trace!(:self.span, %interface, "terminating");
interface.shutdown();
}
}
fn remove(&mut self, addr: Address) -> bool {
if let Some(interface) = self.interfaces.iter_mut()
.find(|interface| interface.get_address() == addr)
{
trace!(:self.span, %interface, "terminating");
self.tables.drop_routes(interface);
interface.shutdown();
true
}
else {false}
}
async fn bind_interface(&mut self, address: InterfaceAddress) -> bool {
let _g = self.span.enter();
let already_bound = self.interfaces
.iter_mut()
.find(|iface| iface.get_address() == address.get_address());
if let Some(interface) = already_bound {
self.tables.drop_routes(interface);
trace!(%interface, %address, "already bound, replacing routes");
interface.replace_address(address);
self.tables.add_routes(interface);
return true;
}
let Some(server) = ChaCha20Server::bind(
&self.encryption_key,
address.clone(),
&self.span
).in_current_span().await else {return false};
let (outbound_packet_sender, outbound_packet_receiver)
= new_channel::<PktTo>();
let server_channel = ServerChannel::combine(
self.inbound_pkt_tx.clone(),
outbound_packet_receiver
);
let interface = Interface::new(address, outbound_packet_sender);
self.tables.add_routes(&interface);
trace!(@self, %interface, "spawned chacha20 task");
self.interfaces.launch(interface, server.launch(server_channel));
true
}
fn launch_interface(&mut self, launcher: InterfaceLauncher) {
let address = launcher.get_address();
let (outbound_packet_sender, outbound_packet_receiver)
= new_channel::<PktTo>();
let interface = Interface::new(address, outbound_packet_sender);
self.tables.add_routes(&interface);
let channel = ServerChannel::combine(
self.inbound_pkt_tx.clone(),
outbound_packet_receiver
);
match launcher {
InterfaceLauncher::Chacha20Udp(server) => {
let address = server.get_address();
self.interfaces.launch(interface, server.launch(channel));
trace!(@self, %address, "spawned chacha20 task");
},
InterfaceLauncher::Dummy(server) => {
let address = server.get_address();
self.interfaces.launch(interface, server.run(channel));
trace!(@self, %address, "spawned dummy server task");
}
}
}
fn log_termination(
&self,
interface: Interface,
reason: Result<Terminated, JoinError>)
{
match reason {
Ok(Terminated::Shutdown) =>
info!(@self, %interface, "shut down"),
Ok(Terminated::Crashed(error_msg)) =>
error!(@self, %interface, %error_msg, "crashed"),
Ok(Terminated::Panic) =>
error!(@self, %interface, "task panicked"),
Err(error) =>
error!(@self, %interface, %error, "crashed"),
}
}
pub async fn run(mut self) {
info!(@self, "started");
let mut is_active = false;
loop {
if is_active == self.interfaces.is_empty() {
if self.interfaces.is_empty() {
debug!(@self, "entering dormant state");
is_active = false;
}
else {
debug!(@self, "entering active state");
is_active = true;
}
let _ = self.monitor_tx.send(is_active);
}
tokio::select! {
q = self.query_rx.recv() => match q {
Some(request) => self.handle_request(request)
.in_current_span()
.await,
None => {
info!(@self, "received shutdown signal");
break
}
},
r = self.interfaces.next(), if is_active => {
if let Some((interface, reason)) = r {
self.tables.drop_routes(&interface);
self.log_termination(interface, reason);
}
},
p = self.outbound_pkt_rx.recv() => match p {
Some(pkt_to) => self.send_packet(pkt_to),
None => break
}
}
}
while let Ok(pkt_to) = self.outbound_pkt_rx.try_recv() {
self.send_packet(pkt_to);
}
std::mem::take(&mut self.tables);
self.terminate_all();
while let Some((addr, reason)) = self.interfaces.next().await {
self.log_termination(addr, reason);
}
info!(@self, "finished");
}
async fn handle_request(&mut self, req: RouterReq) {
match req {
RouterReq::GetInterfaces(req) => {
let addrs = self.interfaces
.iter()
.map(Interface::get_interface_address)
.collect::<Vec<_>>();
req.reply(addrs);
},
RouterReq::CanSend((addr, replier)) => {
replier.reply(self.tables.route(addr).is_some());
},
RouterReq::BindInterface((address, replier)) => {
let worked = self.bind_interface(address)
.in_current_span()
.await;
replier.reply(worked);
},
RouterReq::LaunchInterface((launcher, replier)) => {
self.launch_interface(launcher);
replier.reply(true);
},
RouterReq::Rekey((key, replier)) => {
self.rekey(key).await;
replier.reply(true);
},
RouterReq::RemoveInterface((addr, replier)) => {
replier.reply(self.remove(addr));
}
}
}
}