netsim-embed-router 0.4.5

Implementation of a router for netsim embed.
Documentation
use futures::channel::{mpsc, oneshot};
use futures::future::{poll_fn, FutureExt};
use futures::stream::{FuturesUnordered, StreamExt};
use libpacket::ipv4::Ipv4Packet;
use netsim_embed_core::{Ipv4Route, Plug};
use std::net::Ipv4Addr;
use std::task::Poll;

#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
enum RouterCtrl {
    AddRoute(usize, Plug, Vec<Ipv4Route>),
    RemoveRoute(usize, oneshot::Sender<Option<Plug>>),
    EnableRoute(usize),
    DisableRoute(usize),
}

#[derive(Debug)]
pub struct Ipv4Router {
    #[allow(unused)]
    addr: Ipv4Addr,
    ctrl: mpsc::UnboundedSender<RouterCtrl>,
}

impl Ipv4Router {
    pub fn new(addr: Ipv4Addr) -> Self {
        let (tx, rx) = mpsc::unbounded();
        router(addr, rx);
        Self { addr, ctrl: tx }
    }

    pub fn add_connection(&self, id: usize, plug: Plug, routes: Vec<Ipv4Route>) {
        self.ctrl
            .unbounded_send(RouterCtrl::AddRoute(id, plug, routes))
            .ok();
    }

    pub async fn remove_connection(&self, id: usize) -> Option<Plug> {
        let (tx, rx) = oneshot::channel();
        self.ctrl
            .unbounded_send(RouterCtrl::RemoveRoute(id, tx))
            .unwrap();
        rx.await.unwrap()
    }

    pub fn enable_route(&self, id: usize) {
        self.ctrl
            .unbounded_send(RouterCtrl::EnableRoute(id))
            .unwrap();
    }

    pub fn disable_route(&self, id: usize) {
        self.ctrl
            .unbounded_send(RouterCtrl::DisableRoute(id))
            .unwrap();
    }
}

fn router(addr: Ipv4Addr, mut ctrl: mpsc::UnboundedReceiver<RouterCtrl>) {
    async_global_executor::spawn(async move {
        let mut conns = vec![];
        loop {
            futures::select! {
                ctrl = ctrl.next() => match ctrl {
                    Some(RouterCtrl::AddRoute(id, plug, routes)) => {
                        conns.push((id, plug, routes, true));
                    }
                    Some(RouterCtrl::RemoveRoute(id, ch)) => {
                        let plug = if let Some(idx) = conns.iter().position(|(id2, _, _, _)| *id2 == id) {
                            let (_, plug, _, _) = conns.swap_remove(idx);
                            Some(plug)
                        } else {
                            None
                        };
                        ch.send(plug).ok();
                    }
                    Some(RouterCtrl::EnableRoute(id)) => {
                        if let Some((_, _, _, en)) = conns.iter_mut().find(|(id2, _, _, _)| id == *id2) {
                            *en = true;
                        }
                    }
                    Some(RouterCtrl::DisableRoute(id)) => {
                        if let Some((_, _, _, en)) = conns.iter_mut().find(|(id2, _, _, _)| id == *id2) {
                            *en = false;
                        }
                    }
                    None => break,
                },
                incoming = incoming(&mut conns).fuse() => match incoming {
                    (_, Some(packet)) => forward_packet(addr, &mut conns, packet),
                    (i, None) => { conns.swap_remove(i); }
                }
            }
        }
    }).detach()
}

async fn incoming(conns: &mut [(usize, Plug, Vec<Ipv4Route>, bool)]) -> (usize, Option<Vec<u8>>) {
    let mut futures = conns
        .iter_mut()
        .enumerate()
        .filter(|(_, (_, _, _, en))| *en)
        .map(|(i, (_, plug, _, _))| async move { (i, plug.incoming().await) })
        .collect::<FuturesUnordered<_>>();
    if futures.is_empty() {
        poll_fn(|_| Poll::Pending).await
    } else {
        futures.next().await.unwrap()
    }
}

fn forward_packet(
    addr: Ipv4Addr,
    conns: &mut [(usize, Plug, Vec<Ipv4Route>, bool)],
    bytes: Vec<u8>,
) {
    let packet = if let Some(packet) = Ipv4Packet::new(&bytes) {
        packet
    } else {
        log::info!("router {}: dropping invalid ipv4 packet", addr);
        return;
    };
    let dest = packet.get_destination();
    if dest == addr {
        log::info!("router {}: dropping packet addressed to me", addr);
        return;
    }
    let mut forwarded = false;
    for (_, tx, routes, en) in conns {
        for route in routes {
            if route.dest().contains(dest) || dest.is_broadcast() || dest.is_multicast() {
                if !*en {
                    log::trace!("router {}: route {:?} disabled", addr, route);
                } else {
                    log::trace!("router {}: routing packet on route {:?}", addr, route);
                    let _ = tx.unbounded_send(bytes.clone());
                    forwarded = true;
                }
            }
        }
    }
    if !forwarded {
        let src = packet.get_source();
        log::debug!(
            "router {}: dropping unroutable packet from {} to {}",
            addr,
            src,
            dest
        );
    }
}