use std::collections::HashSet;
use std::io;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use ipnet::{Ipv4Net, Ipv6Net};
use smallvec_wrapper::SmallVec;
use windows_sys::Win32::NetworkManagement::IpHelper::*;
use windows_sys::Win32::Networking::WinSock::*;
use super::{sockaddr_to_ipaddr, IpRoute, Ipv4Route, Ipv6Route, NO_ERROR};
const ERROR_NOT_FOUND: i32 = 1168;
const ERROR_NOT_SUPPORTED: i32 = 50;
struct ForwardTable {
ptr: *const MIB_IPFORWARD_TABLE2,
}
impl ForwardTable {
fn fetch(family: u16) -> io::Result<Self> {
let mut ptr = std::ptr::null_mut();
let result = unsafe { GetIpForwardTable2(family, &mut ptr) };
if result != NO_ERROR {
return Err(io::Error::from_raw_os_error(result as i32));
}
Ok(Self { ptr })
}
fn rows(&self) -> &[MIB_IPFORWARD_ROW2] {
if self.ptr.is_null() {
return &[];
}
unsafe {
let table = &*self.ptr;
core::slice::from_raw_parts(
&table.Table as *const _ as *const MIB_IPFORWARD_ROW2,
table.NumEntries as usize,
)
}
}
}
impl Drop for ForwardTable {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { FreeMibTable(self.ptr as *mut _) };
}
}
}
unsafe fn directed_broadcast_set() -> HashSet<(u32, Ipv4Addr)> {
let mut out: HashSet<(u32, Ipv4Addr)> = HashSet::new();
let mut ptr: *mut MIB_UNICASTIPADDRESS_TABLE = std::ptr::null_mut();
if GetUnicastIpAddressTable(AF_INET, &mut ptr) != NO_ERROR {
return out;
}
if ptr.is_null() {
return out;
}
let table = &*ptr;
let rows = core::slice::from_raw_parts(
&table.Table as *const _ as *const MIB_UNICASTIPADDRESS_ROW,
table.NumEntries as usize,
);
for r in rows {
if r.Address.si_family != AF_INET {
continue;
}
let prefix = r.OnLinkPrefixLength;
if prefix == 0 || prefix >= 31 {
continue;
}
let v4 = r.Address.Ipv4;
let raw = v4.sin_addr.S_un.S_addr;
let bytes = raw.to_ne_bytes();
let addr = Ipv4Addr::from(bytes);
let host_mask: u32 = !((!0u32) << (32 - prefix));
let addr_u32 = u32::from(addr);
let broadcast = Ipv4Addr::from(addr_u32 | host_mask);
out.insert((r.InterfaceIndex, broadcast));
}
FreeMibTable(ptr as *mut _);
out
}
#[inline]
fn build_routev4(
row: &MIB_IPFORWARD_ROW2,
broadcasts: &HashSet<(u32, Ipv4Addr)>,
) -> Option<Ipv4Route> {
if row.ValidLifetime == 0 {
return None;
}
let prefix = row.DestinationPrefix.Prefix;
let dst_ip = sockaddr_to_ipaddr(AF_UNSPEC, &prefix as *const _ as *const SOCKADDR)?;
let dst_v4 = match dst_ip {
IpAddr::V4(ip) => ip,
_ => return None,
};
if dst_v4.is_multicast() || dst_v4.is_broadcast() {
return None;
}
let gw_ip = sockaddr_to_ipaddr(AF_UNSPEC, &row.NextHop as *const _ as *const SOCKADDR);
let gw = match gw_ip {
Some(IpAddr::V4(g)) if g != Ipv4Addr::UNSPECIFIED => Some(g),
_ => None,
};
if row.DestinationPrefix.PrefixLength == 32
&& gw.is_none()
&& broadcasts.contains(&(row.InterfaceIndex, dst_v4))
{
return None;
}
let net = Ipv4Net::new(dst_v4, row.DestinationPrefix.PrefixLength).ok()?;
Some(Ipv4Route::new(row.InterfaceIndex, net, gw))
}
#[inline]
fn build_routev6(row: &MIB_IPFORWARD_ROW2) -> Option<Ipv6Route> {
if row.ValidLifetime == 0 {
return None;
}
let prefix = row.DestinationPrefix.Prefix;
let dst_ip = sockaddr_to_ipaddr(AF_UNSPEC, &prefix as *const _ as *const SOCKADDR)?;
let dst_v6 = match dst_ip {
IpAddr::V6(ip) => ip,
_ => return None,
};
if dst_v6.is_multicast() {
return None;
}
let net = Ipv6Net::new(dst_v6, row.DestinationPrefix.PrefixLength).ok()?;
let gw_ip = sockaddr_to_ipaddr(AF_UNSPEC, &row.NextHop as *const _ as *const SOCKADDR);
let gw = match gw_ip {
Some(IpAddr::V6(g)) if g != Ipv6Addr::UNSPECIFIED => Some(g),
_ => None,
};
Some(Ipv6Route::new(row.InterfaceIndex, net, gw))
}
fn fetch_family(family: u16) -> io::Result<Option<ForwardTable>> {
match ForwardTable::fetch(family) {
Ok(table) => Ok(Some(table)),
Err(e)
if matches!(
e.raw_os_error(),
Some(ERROR_NOT_FOUND) | Some(ERROR_NOT_SUPPORTED)
) =>
{
Ok(None)
}
Err(e) => Err(e),
}
}
pub(crate) fn route_table_by_filter<F>(mut f: F) -> io::Result<SmallVec<IpRoute>>
where
F: FnMut(&IpRoute) -> bool,
{
let mut out: SmallVec<IpRoute> = SmallVec::new();
if let Some(table_v4) = fetch_family(AF_INET)? {
let broadcasts = unsafe { directed_broadcast_set() };
for row in table_v4.rows() {
if let Some(r) = build_routev4(row, &broadcasts) {
let r = IpRoute::V4(r);
if f(&r) {
out.push(r);
}
}
}
}
if let Some(table_v6) = fetch_family(AF_INET6)? {
for row in table_v6.rows() {
if let Some(r) = build_routev6(row) {
let r = IpRoute::V6(r);
if f(&r) {
out.push(r);
}
}
}
}
Ok(out)
}
pub(crate) fn route_ipv4_table_by_filter<F>(mut f: F) -> io::Result<SmallVec<Ipv4Route>>
where
F: FnMut(&Ipv4Route) -> bool,
{
let mut out: SmallVec<Ipv4Route> = SmallVec::new();
if let Some(table) = fetch_family(AF_INET)? {
let broadcasts = unsafe { directed_broadcast_set() };
for row in table.rows() {
if let Some(r) = build_routev4(row, &broadcasts) {
if f(&r) {
out.push(r);
}
}
}
}
Ok(out)
}
pub(crate) fn route_ipv6_table_by_filter<F>(mut f: F) -> io::Result<SmallVec<Ipv6Route>>
where
F: FnMut(&Ipv6Route) -> bool,
{
let mut out: SmallVec<Ipv6Route> = SmallVec::new();
if let Some(table) = fetch_family(AF_INET6)? {
for row in table.rows() {
if let Some(r) = build_routev6(row) {
if f(&r) {
out.push(r);
}
}
}
}
Ok(out)
}