use std::{
io::{self, Read, Write},
net::IpAddr,
num::NonZeroI32,
os::fd::{AsRawFd, RawFd},
};
use libc::{
AF_NETLINK, NETLINK_ROUTE, RTNLGRP_IPV4_ROUTE, RTNLGRP_IPV6_ROUTE, RTNLGRP_MPLS_ROUTE,
SOCK_CLOEXEC, SOCK_RAW, sockaddr_nl,
};
use netlink_packet_core::{
NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST, NetlinkHeader, NetlinkMessage,
NetlinkPayload,
};
use netlink_packet_route::{
AddressFamily, RouteNetlinkMessage,
route::{
RouteAddress, RouteAttribute, RouteHeader, RouteMessage, RouteProtocol, RouteScope,
RouteType,
},
};
use crate::{Route, RouteAction, RouteChange, syscall};
pub struct RouteSock(RawFd);
impl AsRawFd for RouteSock {
fn as_raw_fd(&self) -> RawFd {
self.0
}
}
impl RouteSock {
pub fn new() -> io::Result<Self> {
let fd = syscall!(socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, NETLINK_ROUTE))?;
Ok(RouteSock(fd))
}
fn bind(&mut self, local: sockaddr_nl) -> io::Result<()> {
syscall!(bind(
self.as_raw_fd(),
&local as *const sockaddr_nl as *const _,
std::mem::size_of::<sockaddr_nl>() as _
))?;
Ok(())
}
pub fn subscribe(&mut self) -> io::Result<()> {
let mut local = unsafe { std::mem::zeroed::<sockaddr_nl>() };
local.nl_family = AF_NETLINK as u16;
local.nl_groups =
nl_mgrp(RTNLGRP_IPV4_ROUTE) | nl_mgrp(RTNLGRP_IPV6_ROUTE) | nl_mgrp(RTNLGRP_MPLS_ROUTE);
self.bind(local)
}
pub fn new_buf() -> [u8; 16384] {
[0u8; 16384]
}
}
impl Write for RouteSock {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let n = syscall!(write(self.as_raw_fd(), buf.as_ptr() as *const _, buf.len()))?;
Ok(n as usize)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl Read for RouteSock {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = syscall!(read(
self.as_raw_fd(),
buf.as_mut_ptr() as *mut _,
buf.len()
))?;
Ok(n as usize)
}
}
impl Drop for RouteSock {
fn drop(&mut self) {
let _ = syscall!(close(self.0));
}
}
impl RouteAction for RouteSock {
fn add(&mut self, route: &Route) -> io::Result<()> {
route.validate()?;
let mut nl_hdr = NetlinkHeader::default();
nl_hdr.flags = NLM_F_REQUEST | NLM_F_EXCL | NLM_F_CREATE | NLM_F_ACK;
nl_hdr.sequence_number = 1;
let mut rt_msg = route_change_message(route);
rt_msg.header.table = RouteHeader::RT_TABLE_MAIN;
rt_msg.header.protocol = RouteProtocol::Boot;
rt_msg.header.scope = RouteScope::Universe;
rt_msg.header.kind = RouteType::Unicast;
if let Some(gateway) = route.gateway {
rt_msg
.attributes
.push(RouteAttribute::Gateway(route_address(gateway)));
}
if let Some(index) = route.ifindex {
if route.gateway.is_none() {
rt_msg.header.scope = RouteScope::Link;
}
rt_msg.attributes.push(RouteAttribute::Oif(index));
}
let mut req = NetlinkMessage::new(
nl_hdr,
NetlinkPayload::from(RouteNetlinkMessage::NewRoute(rt_msg)),
);
req.finalize();
let mut buf = [0u8; 4096];
req.serialize(&mut buf[..req.buffer_len()]);
self.write_all(&buf[..req.buffer_len()])?;
self.recv_ack()
}
fn delete(&mut self, route: &Route) -> io::Result<()> {
route.validate()?;
let mut nl_hdr = NetlinkHeader::default();
nl_hdr.flags = NLM_F_REQUEST | NLM_F_ACK;
nl_hdr.sequence_number = 1;
let mut rt_msg = route_change_message(route);
rt_msg.header.table = RouteHeader::RT_TABLE_MAIN;
rt_msg.header.scope = RouteScope::NoWhere;
let mut req = NetlinkMessage::new(
nl_hdr,
NetlinkPayload::from(RouteNetlinkMessage::DelRoute(rt_msg)),
);
req.finalize();
let mut buf = [0u8; 4096];
req.serialize(&mut buf[..req.buffer_len()]);
self.write_all(&buf[..req.buffer_len()])?;
self.recv_ack()
}
fn get(&mut self, route: &Route) -> io::Result<Route> {
route.validate()?;
let mut nl_hdr = NetlinkHeader::default();
nl_hdr.flags = NLM_F_REQUEST;
nl_hdr.sequence_number = 1;
let rt_msg = route_lookup_message(route);
let mut req = NetlinkMessage::new(
nl_hdr,
NetlinkPayload::from(RouteNetlinkMessage::GetRoute(rt_msg)),
);
req.finalize();
let mut buf = [0u8; 4096];
req.serialize(&mut buf[..req.buffer_len()]);
self.write_all(&buf[..req.buffer_len()])?;
self.recv_route_response()
}
fn monitor(&mut self, buf: &mut [u8]) -> io::Result<(RouteChange, Route)> {
let n = self.read(buf)?;
let nlmsg = parse_nlmsg(&buf[..n])?;
if let NetlinkPayload::InnerMessage(rtnl_msg) = nlmsg.payload {
match rtnl_msg {
RouteNetlinkMessage::NewRoute(rtmsg) => Ok((
RouteChange::ADD,
route_from_message(&rtmsg)?
.ok_or_else(|| io::Error::other("unsupported route address"))?,
)),
RouteNetlinkMessage::DelRoute(rtmsg) => Ok((
RouteChange::DELETE,
route_from_message(&rtmsg)?
.ok_or_else(|| io::Error::other("unsupported route address"))?,
)),
_ => Err(io::Error::other(format!(
"Unexpected rtnl message: {:?}",
rtnl_msg
))),
}
} else {
Err(io::Error::other(format!("not rtnl message: {:?}", nlmsg)))
}
}
}
impl RouteSock {
fn recv_route_response(&mut self) -> io::Result<Route> {
loop {
let mut rbuf = [0u8; 16384];
let n = self.read(&mut rbuf)?;
let mut offset = 0;
while offset < n {
let nlmsg = parse_nlmsg(&rbuf[offset..n])?;
let length = nlmsg.header.length as usize;
if length == 0 {
return Err(io::Error::other("zero-length netlink message"));
}
match nlmsg.payload {
NetlinkPayload::Done(_) => {
return Err(io::Error::new(io::ErrorKind::NotFound, "route not found"));
}
NetlinkPayload::Error(e) => {
if let Some(e) = e.code {
return Err(netlink_error(e));
}
return Err(io::Error::new(io::ErrorKind::NotFound, "route not found"));
}
NetlinkPayload::InnerMessage(RouteNetlinkMessage::NewRoute(rt_msg)) => {
if let Some(candidate) = route_from_message(&rt_msg)? {
return Ok(candidate);
}
}
_ => {}
}
offset += nlmsg_align(length);
}
}
}
fn recv_ack(&mut self) -> io::Result<()> {
let mut rbuf = [0u8; 4096];
let n = self.read(&mut rbuf)?;
let nlmsg = parse_nlmsg(&rbuf[..n])?;
if let NetlinkPayload::Error(e) = nlmsg.payload {
if let Some(e) = e.code {
return Err(netlink_error(e));
}
}
Ok(())
}
}
fn netlink_error(code: NonZeroI32) -> io::Error {
let raw = code.get();
io::Error::from_raw_os_error(if raw < 0 { -raw } else { raw })
}
fn route_change_message(route: &Route) -> RouteMessage {
let mut rt_msg = RouteMessage::default();
rt_msg.header.address_family = address_family(route.destination);
rt_msg.header.destination_prefix_length = route.prefix;
rt_msg
.attributes
.push(RouteAttribute::Destination(route_address(
route.destination,
)));
rt_msg
}
fn route_lookup_message(route: &Route) -> RouteMessage {
let mut rt_msg = RouteMessage::default();
rt_msg.header.address_family = address_family(route.destination);
rt_msg.header.destination_prefix_length = route.prefix;
rt_msg
.attributes
.push(RouteAttribute::Destination(route_address(
route.destination,
)));
if let Some(index) = route.ifindex {
rt_msg.attributes.push(RouteAttribute::Oif(index));
}
rt_msg
}
fn route_from_message(rt_msg: &RouteMessage) -> io::Result<Option<Route>> {
let mut route = Route::new(
unspecified(rt_msg.header.address_family),
rt_msg.header.destination_prefix_length,
);
for attr in &rt_msg.attributes {
match attr {
RouteAttribute::Destination(addr) => {
let Some(destination) = ip_addr(addr) else {
return Ok(None);
};
route.destination = destination;
}
RouteAttribute::Gateway(addr) => {
let Some(gateway) = ip_addr(addr) else {
return Ok(None);
};
route.gateway = Some(gateway);
}
RouteAttribute::Oif(index) => route.ifindex = Some(*index),
_ => {}
}
}
Ok(Some(route))
}
fn parse_nlmsg(bytes: &[u8]) -> io::Result<NetlinkMessage<RouteNetlinkMessage>> {
<NetlinkMessage<RouteNetlinkMessage>>::deserialize(bytes)
.map_err(|e| io::Error::other(format!("{e:?}")))
}
fn nlmsg_align(len: usize) -> usize {
(len + 3) & !3
}
fn address_family(addr: IpAddr) -> AddressFamily {
match addr {
IpAddr::V4(_) => AddressFamily::Inet,
IpAddr::V6(_) => AddressFamily::Inet6,
}
}
fn unspecified(family: AddressFamily) -> IpAddr {
match family {
AddressFamily::Inet6 => IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
_ => IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
}
}
fn route_address(addr: IpAddr) -> RouteAddress {
match addr {
IpAddr::V4(addr) => RouteAddress::Inet(addr),
IpAddr::V6(addr) => RouteAddress::Inet6(addr),
}
}
fn ip_addr(addr: &RouteAddress) -> Option<IpAddr> {
match addr {
RouteAddress::Inet(addr) => Some(IpAddr::V4(*addr)),
RouteAddress::Inet6(addr) => Some(IpAddr::V6(*addr)),
_ => None,
}
}
const fn nl_mgrp(group: u32) -> u32 {
if group > 31 {
panic!("Use setsockopt NETLINK_ADD_MEMBERSHIP for this group");
}
if group == 0 { 0 } else { 1 << (group - 1) }
}