use alloc::{boxed::Box, vec, vec::Vec};
use smoltcp::{
iface::SocketSet,
phy::{DeviceCapabilities, Medium},
storage::PacketMetadata,
time::Instant,
wire::{
IpAddress, IpCidr, IpProtocol, IpVersion, Ipv4Address, Ipv4Cidr, Ipv4Packet, Ipv6Packet,
TcpPacket,
},
};
use crate::{
LISTEN_TABLE,
consts::{SOCKET_BUFFER_SIZE, STANDARD_MTU},
device::{ArpEntry, Device},
};
#[derive(Debug)]
pub struct Rule {
pub filter: IpCidr,
pub via: Option<IpAddress>,
pub dev: usize,
pub src: IpAddress,
}
impl Rule {
pub fn new(filter: IpCidr, via: Option<IpAddress>, dev: usize, src: IpAddress) -> Self {
Self {
filter,
via,
dev,
src,
}
}
}
type PacketBuffer = smoltcp::storage::PacketBuffer<'static, ()>;
pub struct RouteTable {
rules: Vec<Rule>,
}
impl RouteTable {
pub fn new() -> Self {
Self { rules: Vec::new() }
}
pub fn add_rule(&mut self, rule: Rule) {
let idx = self
.rules
.iter()
.position(|it| it.filter.prefix_len() < rule.filter.prefix_len())
.unwrap_or(self.rules.len());
self.rules.insert(idx, rule);
}
pub fn lookup(&self, dst: &IpAddress) -> Option<&Rule> {
self.rules
.iter()
.find(|rule| rule.filter.contains_addr(dst))
}
pub fn remove_ipv4_rules_for_dev(&mut self, dev: usize) {
self.rules.retain(|rule| {
!matches!(
rule.filter,
IpCidr::Ipv4(_) if rule.dev == dev
)
});
}
}
pub struct Router {
rx_buffer: PacketBuffer,
tx_buffer: PacketBuffer,
pub(crate) devices: Vec<Box<dyn Device>>,
pub(crate) table: RouteTable,
}
impl Router {
pub fn new() -> Self {
let rx_buffer = PacketBuffer::new(
vec![PacketMetadata::EMPTY; SOCKET_BUFFER_SIZE],
vec![0u8; STANDARD_MTU * SOCKET_BUFFER_SIZE],
);
let tx_buffer = PacketBuffer::new(
vec![PacketMetadata::EMPTY; SOCKET_BUFFER_SIZE],
vec![0u8; STANDARD_MTU * SOCKET_BUFFER_SIZE],
);
Self {
rx_buffer,
tx_buffer,
devices: Vec::new(),
table: RouteTable::new(),
}
}
pub fn add_rule(&mut self, rule: Rule) {
self.table.add_rule(rule);
}
pub fn add_device(&mut self, device: Box<dyn Device>) -> usize {
self.devices.push(device);
self.devices.len() - 1
}
pub fn set_ipv4_config(
&mut self,
dev: usize,
address: Option<Ipv4Cidr>,
gateway: Option<IpAddress>,
) {
self.table.remove_ipv4_rules_for_dev(dev);
self.devices[dev].set_ipv4_addr(address);
if let Some(address) = address {
self.add_rule(Rule::new(
address.into(),
None,
dev,
address.address().into(),
));
self.add_rule(Rule::new(
Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0).into(),
gateway,
dev,
address.address().into(),
));
}
}
pub fn poll(
&mut self,
timestamp: Instant,
sockets: &mut SocketSet<'_>,
mut snoop: impl FnMut(usize, &[u8]),
) {
for (dev_idx, dev) in self.devices.iter_mut().enumerate() {
let mut packet_snoop = |packet: &[u8]| {
snoop_tcp_packet(packet, sockets);
snoop(dev_idx, packet);
};
while !self.rx_buffer.is_full()
&& dev.recv(&mut self.rx_buffer, timestamp, &mut packet_snoop)
{}
}
}
pub fn send_on_device(
&mut self,
dev: usize,
next_hop: IpAddress,
packet: &[u8],
timestamp: Instant,
) -> bool {
self.devices[dev].send(next_hop, packet, timestamp)
}
pub fn arp_entries(&self, timestamp: Instant) -> Vec<ArpEntry> {
let mut entries = Vec::new();
for device in &self.devices {
entries.extend(device.arp_entries(timestamp));
}
entries
}
pub fn dispatch(&mut self, timestamp: Instant) -> bool {
let mut poll_next = false;
while let Ok(((), packet)) = self.tx_buffer.dequeue() {
match IpVersion::of_packet(packet).expect("got invalid IP packet") {
IpVersion::Ipv4 => {
let packet = smoltcp::wire::Ipv4Packet::new_checked(packet)
.expect("got invalid IPv4 packet");
let dst_addr = IpAddress::Ipv4(packet.dst_addr());
if packet.dst_addr().is_broadcast() {
let buf = packet.into_inner();
for dev in &mut self.devices {
poll_next |= dev.send(dst_addr, buf, timestamp);
}
} else {
let Some(rule) = self.table.lookup(&dst_addr) else {
warn!("No route found for destination: {}", dst_addr);
continue;
};
assert_eq!(rule.src, IpAddress::Ipv4(packet.src_addr()));
let next_hop = rule.via.unwrap_or(dst_addr);
let dev = &mut self.devices[rule.dev];
poll_next |= dev.send(next_hop, packet.into_inner(), timestamp);
}
}
IpVersion::Ipv6 => {
let packet = smoltcp::wire::Ipv6Packet::new_checked(packet)
.expect("got invalid IPv6 packet");
let dst_addr = IpAddress::Ipv6(packet.dst_addr());
if packet.dst_addr().is_multicast() {
let buf = packet.into_inner();
for dev in &mut self.devices {
poll_next |= dev.send(dst_addr, buf, timestamp);
}
} else {
let Some(rule) = self.table.lookup(&dst_addr) else {
warn!("No route found for destination: {}", dst_addr);
continue;
};
assert_eq!(rule.src, IpAddress::Ipv6(packet.src_addr()));
let next_hop = rule.via.unwrap_or(dst_addr);
let dev = &mut self.devices[rule.dev];
poll_next |= dev.send(next_hop, packet.into_inner(), timestamp);
}
}
}
}
poll_next
}
}
pub struct TxToken<'a>(&'a mut PacketBuffer);
impl smoltcp::phy::TxToken for TxToken<'_> {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
f(self
.0
.enqueue(len, ())
.expect("This was checked before creating the TxToken"))
}
}
fn snoop_tcp_packet(buf: &[u8], sockets: &mut SocketSet<'_>) {
let (protocol, src_addr, dst_addr, payload) = match IpVersion::of_packet(buf).unwrap() {
IpVersion::Ipv4 => {
let packet = Ipv4Packet::new_unchecked(buf);
(
packet.next_header(),
IpAddress::Ipv4(packet.src_addr()),
IpAddress::Ipv4(packet.dst_addr()),
packet.payload(),
)
}
IpVersion::Ipv6 => {
let packet = Ipv6Packet::new_unchecked(buf);
(
packet.next_header(),
IpAddress::Ipv6(packet.src_addr()),
IpAddress::Ipv6(packet.dst_addr()),
packet.payload(),
)
}
};
if protocol == IpProtocol::Tcp {
let tcp_packet = TcpPacket::new_unchecked(payload);
let src_addr = (src_addr, tcp_packet.src_port()).into();
let dst_addr = (dst_addr, tcp_packet.dst_port()).into();
let is_first = tcp_packet.syn() && !tcp_packet.ack();
if is_first {
LISTEN_TABLE.incoming_tcp_packet(src_addr, dst_addr, sockets);
}
}
}
pub struct RxToken<'a>(&'a [u8]);
impl<'a> smoltcp::phy::RxToken for RxToken<'a> {
fn consume<R, F>(self, f: F) -> R
where
F: FnOnce(&[u8]) -> R,
{
f(self.0)
}
}
impl smoltcp::phy::Device for Router {
type RxToken<'a> = RxToken<'a>;
type TxToken<'a> = TxToken<'a>;
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
if self.rx_buffer.is_empty() || self.tx_buffer.is_full() {
None
} else {
Some((
RxToken(self.rx_buffer.dequeue().unwrap().1),
TxToken(&mut self.tx_buffer),
))
}
}
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
if self.tx_buffer.is_full() {
None
} else {
Some(TxToken(&mut self.tx_buffer))
}
}
fn capabilities(&self) -> DeviceCapabilities {
let mut caps = DeviceCapabilities::default();
caps.medium = Medium::Ip;
caps.max_transmission_unit = STANDARD_MTU;
caps.max_burst_size = Some(SOCKET_BUFFER_SIZE);
caps
}
}