#![allow(unsafe_code, clippy::borrow_as_ptr)]
use std::net::IpAddr;
use async_trait::async_trait;
use tokio::task;
use windows::core::PCWSTR;
use windows::Win32::Foundation::{ERROR_OBJECT_ALREADY_EXISTS, NO_ERROR};
use windows::Win32::NetworkManagement::IpHelper::{
ConvertInterfaceAliasToLuid, CreateIpForwardEntry2, CreateUnicastIpAddressEntry,
InitializeIpForwardEntry, InitializeUnicastIpAddressEntry, MIB_IPFORWARD_ROW2,
MIB_UNICASTIPADDRESS_ROW,
};
use windows::Win32::NetworkManagement::Ndis::NET_LUID_LH;
use windows::Win32::Networking::WinSock::{
ADDRESS_FAMILY, AF_INET, AF_INET6, IN6_ADDR, IN6_ADDR_0, IN_ADDR, IN_ADDR_0, SOCKADDR_IN,
SOCKADDR_IN6, SOCKADDR_IN6_0, SOCKADDR_INET,
};
use crate::interface::InterfaceOps;
use crate::OverlayError;
const MIB_IPPROTO_NETMGMT: i32 = 3;
pub(crate) struct WindowsIpHelperOps;
impl WindowsIpHelperOps {
pub(crate) fn new() -> Self {
Self
}
}
#[async_trait]
impl InterfaceOps for WindowsIpHelperOps {
async fn link_exists(&self, name: &str) -> Result<bool, OverlayError> {
let owned = name.to_string();
task::spawn_blocking(move || Ok(luid_for_name(&owned).is_ok()))
.await
.map_err(|e| OverlayError::NetworkConfig(format!("join error: {e}")))?
}
async fn delete_link(&self, _name: &str) -> Result<(), OverlayError> {
Ok(())
}
async fn set_link_up(&self, _name: &str) -> Result<(), OverlayError> {
Ok(())
}
async fn add_address(
&self,
name: &str,
addr: IpAddr,
prefix_len: u8,
) -> Result<(), OverlayError> {
let owned_name = name.to_string();
task::spawn_blocking(move || add_address_blocking(&owned_name, addr, prefix_len))
.await
.map_err(|e| OverlayError::NetworkConfig(format!("join error: {e}")))?
}
async fn add_route_via_dev(
&self,
dest: IpAddr,
prefix_len: u8,
name: &str,
) -> Result<(), OverlayError> {
let owned_name = name.to_string();
task::spawn_blocking(move || add_route_blocking(&owned_name, dest, prefix_len))
.await
.map_err(|e| OverlayError::NetworkConfig(format!("join error: {e}")))?
}
}
fn luid_for_name(name: &str) -> Result<NET_LUID_LH, OverlayError> {
let wide: Vec<u16> = name.encode_utf16().chain(std::iter::once(0)).collect();
let mut luid = NET_LUID_LH::default();
let rc = unsafe { ConvertInterfaceAliasToLuid(PCWSTR::from_raw(wide.as_ptr()), &mut luid) };
if rc == NO_ERROR {
Ok(luid)
} else {
Err(OverlayError::InterfaceNotFound(name.to_string()))
}
}
fn sockaddr_inet_from(addr: IpAddr) -> SOCKADDR_INET {
let mut sa = SOCKADDR_INET::default();
match addr {
IpAddr::V4(v4) => {
let octets = v4.octets();
let packed = u32::from_ne_bytes(octets);
sa.Ipv4 = SOCKADDR_IN {
sin_family: AF_INET,
sin_port: 0,
sin_addr: IN_ADDR {
S_un: IN_ADDR_0 { S_addr: packed },
},
sin_zero: [0; 8],
};
}
IpAddr::V6(v6) => {
sa.Ipv6 = SOCKADDR_IN6 {
sin6_family: AF_INET6,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_addr: IN6_ADDR {
u: IN6_ADDR_0 { Byte: v6.octets() },
},
Anonymous: SOCKADDR_IN6_0 { sin6_scope_id: 0 },
};
}
}
sa
}
fn address_family(addr: IpAddr) -> ADDRESS_FAMILY {
match addr {
IpAddr::V4(_) => AF_INET,
IpAddr::V6(_) => AF_INET6,
}
}
fn add_address_blocking(name: &str, addr: IpAddr, prefix_len: u8) -> Result<(), OverlayError> {
let luid = luid_for_name(name)?;
let mut row = MIB_UNICASTIPADDRESS_ROW::default();
unsafe { InitializeUnicastIpAddressEntry(&mut row) };
row.InterfaceLuid = luid;
row.Address = sockaddr_inet_from(addr);
row.OnLinkPrefixLength = prefix_len;
let rc = unsafe { CreateUnicastIpAddressEntry(&row) };
if rc == NO_ERROR || rc == ERROR_OBJECT_ALREADY_EXISTS {
Ok(())
} else {
Err(OverlayError::NetworkConfig(format!(
"CreateUnicastIpAddressEntry failed for {addr}/{prefix_len} on {name}: WIN32 error {}",
rc.0
)))
}
}
fn add_route_blocking(name: &str, dest: IpAddr, prefix_len: u8) -> Result<(), OverlayError> {
let luid = luid_for_name(name)?;
let mut row = MIB_IPFORWARD_ROW2::default();
unsafe { InitializeIpForwardEntry(&mut row) };
row.InterfaceLuid = luid;
row.DestinationPrefix.Prefix = sockaddr_inet_from(dest);
row.DestinationPrefix.PrefixLength = prefix_len;
row.NextHop = SOCKADDR_INET::default();
row.NextHop.si_family = address_family(dest);
row.Metric = 256;
row.Protocol.0 = MIB_IPPROTO_NETMGMT;
let rc = unsafe { CreateIpForwardEntry2(&row) };
if rc == NO_ERROR || rc == ERROR_OBJECT_ALREADY_EXISTS {
Ok(())
} else {
Err(OverlayError::NetworkConfig(format!(
"CreateIpForwardEntry2 failed for {dest}/{prefix_len} via {name}: WIN32 error {}",
rc.0
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn sockaddr_inet_round_trips_ipv4() {
let sa = sockaddr_inet_from(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 5)));
let family = unsafe { sa.si_family };
assert_eq!(family, AF_INET);
let packed = unsafe { sa.Ipv4.sin_addr.S_un.S_addr };
let octets = packed.to_ne_bytes();
assert_eq!(octets, [10, 200, 0, 5]);
}
#[test]
fn sockaddr_inet_round_trips_ipv6() {
let ip = Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1);
let sa = sockaddr_inet_from(IpAddr::V6(ip));
let family = unsafe { sa.si_family };
assert_eq!(family, AF_INET6);
let bytes = unsafe { sa.Ipv6.sin6_addr.u.Byte };
assert_eq!(bytes, ip.octets());
}
#[test]
fn address_family_matches() {
assert_eq!(address_family(IpAddr::V4(Ipv4Addr::LOCALHOST)), AF_INET);
assert_eq!(address_family(IpAddr::V6(Ipv6Addr::LOCALHOST)), AF_INET6);
}
#[tokio::test]
#[ignore = "Requires Administrator + Wintun adapter"]
async fn add_address_requires_existing_adapter() {
let ops = WindowsIpHelperOps::new();
let res = ops
.add_address(
"this-adapter-definitely-does-not-exist",
IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)),
24,
)
.await;
assert!(res.is_err());
}
}