use alloc::{boxed::Box, vec, vec::Vec};
use smoltcp::{
iface::SocketSet,
phy::{DeviceCapabilities, Medium},
storage::PacketMetadata,
time::Instant,
wire::{IpAddress, IpCidr, IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet, TcpPacket},
};
use crate::{
LISTEN_TABLE,
consts::{SOCKET_BUFFER_SIZE, STANDARD_MTU},
device::Device,
};
#[derive(Debug)]
pub struct Rule {
pub filter: IpCidr,
pub via: Option<IpAddress>,
pub dev: usize,
pub src: IpAddress,
}
impl Rule {
pub fn new(filter: IpCidr, via: Option<IpAddress>, dev: usize, src: IpAddress) -> Self {
Self {
filter,
via,
dev,
src,
}
}
}
type PacketBuffer = smoltcp::storage::PacketBuffer<'static, ()>;
pub struct RouteTable {
rules: Vec<Rule>,
}
impl RouteTable {
pub fn new() -> Self {
Self { rules: Vec::new() }
}
pub fn add_rule(&mut self, rule: Rule) {
let idx = self
.rules
.binary_search_by(|it| rule.filter.prefix_len().cmp(&it.filter.prefix_len()))
.unwrap_or_else(|idx| idx);
self.rules.insert(idx, rule);
}
pub fn lookup(&self, dst: &IpAddress) -> Option<&Rule> {
self.rules
.iter()
.find(|rule| rule.filter.contains_addr(dst))
}
}
pub struct Router {
rx_buffer: PacketBuffer,
tx_buffer: PacketBuffer,
pub(crate) devices: Vec<Box<dyn Device>>,
pub(crate) table: RouteTable,
}
impl Router {
pub fn new() -> Self {
let rx_buffer = PacketBuffer::new(
vec![PacketMetadata::EMPTY; SOCKET_BUFFER_SIZE],
vec![0u8; STANDARD_MTU * SOCKET_BUFFER_SIZE],
);
let tx_buffer = PacketBuffer::new(
vec![PacketMetadata::EMPTY; SOCKET_BUFFER_SIZE],
vec![0u8; STANDARD_MTU * SOCKET_BUFFER_SIZE],
);
Self {
rx_buffer,
tx_buffer,
devices: Vec::new(),
table: RouteTable::new(),
}
}
pub fn add_rule(&mut self, rule: Rule) {
self.table.add_rule(rule);
}
pub fn add_device(&mut self, device: Box<dyn Device>) -> usize {
self.devices.push(device);
self.devices.len() - 1
}
pub fn poll(&mut self, timestamp: Instant) {
for dev in &mut self.devices {
while !self.rx_buffer.is_full() && dev.recv(&mut self.rx_buffer, timestamp) {}
}
}
pub fn dispatch(&mut self, timestamp: Instant) -> bool {
let mut poll_next = false;
while let Ok(((), packet)) = self.tx_buffer.dequeue() {
match IpVersion::of_packet(packet).expect("got invalid IP packet") {
IpVersion::Ipv4 => {
let packet = smoltcp::wire::Ipv4Packet::new_checked(packet)
.expect("got invalid IPv4 packet");
let dst_addr = IpAddress::Ipv4(packet.dst_addr());
if packet.dst_addr().is_broadcast() {
let buf = packet.into_inner();
for dev in &mut self.devices {
poll_next |= dev.send(dst_addr, buf, timestamp);
}
} else {
let Some(rule) = self.table.lookup(&dst_addr) else {
warn!("No route found for destination: {}", dst_addr);
continue;
};
assert_eq!(rule.src, IpAddress::Ipv4(packet.src_addr()));
let next_hop = rule.via.unwrap_or(dst_addr);
let dev = &mut self.devices[rule.dev];
poll_next |= dev.send(next_hop, packet.into_inner(), timestamp);
}
}
IpVersion::Ipv6 => {
let packet = smoltcp::wire::Ipv6Packet::new_checked(packet)
.expect("got invalid IPv6 packet");
let dst_addr = IpAddress::Ipv6(packet.dst_addr());
if packet.dst_addr().is_multicast() {
let buf = packet.into_inner();
for dev in &mut self.devices {
poll_next |= dev.send(dst_addr, buf, timestamp);
}
} else {
let Some(rule) = self.table.lookup(&dst_addr) else {
warn!("No route found for destination: {}", dst_addr);
continue;
};
assert_eq!(rule.src, IpAddress::Ipv6(packet.src_addr()));
let next_hop = rule.via.unwrap_or(dst_addr);
let dev = &mut self.devices[rule.dev];
poll_next |= dev.send(next_hop, packet.into_inner(), timestamp);
}
}
}
}
poll_next
}
}
pub struct TxToken<'a>(&'a mut PacketBuffer);
impl smoltcp::phy::TxToken for TxToken<'_> {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
f(self
.0
.enqueue(len, ())
.expect("This was checked before creating the TxToken"))
}
}
fn snoop_tcp_packet(buf: &[u8], sockets: &mut SocketSet<'_>) {
let (protocol, src_addr, dst_addr, payload) = match IpVersion::of_packet(buf).unwrap() {
IpVersion::Ipv4 => {
let packet = Ipv4Packet::new_unchecked(buf);
(
packet.next_header(),
IpAddress::Ipv4(packet.src_addr()),
IpAddress::Ipv4(packet.dst_addr()),
packet.payload(),
)
}
IpVersion::Ipv6 => {
let packet = Ipv6Packet::new_unchecked(buf);
(
packet.next_header(),
IpAddress::Ipv6(packet.src_addr()),
IpAddress::Ipv6(packet.dst_addr()),
packet.payload(),
)
}
};
if protocol == IpProtocol::Tcp {
let tcp_packet = TcpPacket::new_unchecked(payload);
let src_addr = (src_addr, tcp_packet.src_port()).into();
let dst_addr = (dst_addr, tcp_packet.dst_port()).into();
let is_first = tcp_packet.syn() && !tcp_packet.ack();
if is_first {
LISTEN_TABLE.incoming_tcp_packet(src_addr, dst_addr, sockets);
}
}
}
pub struct RxToken<'a>(&'a [u8]);
impl<'a> smoltcp::phy::RxToken for RxToken<'a> {
fn consume<R, F>(self, f: F) -> R
where
F: FnOnce(&[u8]) -> R,
{
f(self.0)
}
fn preprocess(&self, sockets: &mut SocketSet) {
snoop_tcp_packet(self.0, sockets);
}
}
impl smoltcp::phy::Device for Router {
type RxToken<'a> = RxToken<'a>;
type TxToken<'a> = TxToken<'a>;
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
if self.rx_buffer.is_empty() || self.tx_buffer.is_full() {
None
} else {
Some((
RxToken(self.rx_buffer.dequeue().unwrap().1),
TxToken(&mut self.tx_buffer),
))
}
}
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
if self.tx_buffer.is_full() {
None
} else {
Some(TxToken(&mut self.tx_buffer))
}
}
fn capabilities(&self) -> DeviceCapabilities {
let mut caps = DeviceCapabilities::default();
caps.medium = Medium::Ip;
caps.max_transmission_unit = STANDARD_MTU;
caps.max_burst_size = Some(SOCKET_BUFFER_SIZE);
caps
}
}