use crate::Error;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use windows_sys::{
Win32::{
Foundation::{
ERROR_ADDRESS_NOT_ASSOCIATED, ERROR_BUFFER_OVERFLOW, ERROR_INSUFFICIENT_BUFFER, ERROR_INVALID_PARAMETER,
ERROR_NO_DATA, ERROR_NOT_ENOUGH_MEMORY, ERROR_SUCCESS, GetLastError, LocalFree, NO_ERROR, WIN32_ERROR,
},
NetworkManagement::{
IpHelper::{
DNS_INTERFACE_SETTINGS, DNS_INTERFACE_SETTINGS_VERSION1, DNS_SETTING_NAMESERVER,
GAA_FLAG_INCLUDE_GATEWAYS, GAA_FLAG_INCLUDE_PREFIX, GetAdaptersAddresses, GetInterfaceInfo,
GetIpInterfaceEntry, IF_TYPE_ETHERNET_CSMACD, IF_TYPE_IEEE80211, IP_ADAPTER_ADDRESSES_LH,
IP_ADAPTER_INDEX_MAP, IP_INTERFACE_INFO, MIB_IPINTERFACE_ROW, SetIpInterfaceEntry,
},
Ndis::{IfOperStatusUp, NET_LUID_LH},
},
Networking::WinSock::{AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6, SOCKET_ADDRESS},
System::{
Com::StringFromGUID2,
Diagnostics::Debug::{FORMAT_MESSAGE_ALLOCATE_BUFFER, FORMAT_MESSAGE_FROM_SYSTEM, FormatMessageW},
SystemServices::{LANG_NEUTRAL, SUBLANG_DEFAULT},
},
},
core::GUID,
};
pub fn get_wintun_bin_pattern_path() -> std::io::Result<std::path::PathBuf> {
let dll_path = if cfg!(target_arch = "x86") {
"wintun/bin/x86/wintun.dll"
} else if cfg!(target_arch = "x86_64") {
"wintun/bin/amd64/wintun.dll"
} else if cfg!(target_arch = "arm") {
"wintun/bin/arm/wintun.dll"
} else if cfg!(target_arch = "aarch64") {
"wintun/bin/arm64/wintun.dll"
} else {
return Err(std::io::Error::other("Unsupported architecture"));
};
Ok(dll_path.into())
}
crate::define_fn_dynamic_load!(
RtlGetNtVersionNumbersDeclare,
unsafe extern "system" fn(*mut u32, *mut u32, *mut u32),
RTL_GET_NT_VERSION_NUMBERS,
RtlGetNtVersionNumbers,
"ntdll",
"RtlGetNtVersionNumbers"
);
pub(crate) fn get_windows_version() -> Result<(u32, u32, u32), Error> {
let func = RtlGetNtVersionNumbers().ok_or("Failed to load function RtlGetNtVersionNumbers")?;
let (mut major, mut minor, mut build) = (0, 0, 0);
unsafe { func(&mut major, &mut minor, &mut build) };
Ok((major, minor, build))
}
pub(crate) const fn win_guid_to_u128(guid: &GUID) -> u128 {
let data4_u64 = u64::from_be_bytes(guid.data4);
((guid.data1 as u128) << 96) | ((guid.data2 as u128) << 80) | ((guid.data3 as u128) << 64) | (data4_u64 as u128)
}
pub(crate) unsafe fn win_pstr_to_string(pstr: ::windows_sys::core::PSTR) -> Result<String, Error> {
Ok(unsafe {
std::ffi::CStr::from_ptr(pstr as *const std::ffi::c_char)
.to_str()
.map_err(|e| format!("Invalid UTF-8 sequence: {e}"))
}?
.to_owned())
}
pub(crate) unsafe fn win_pwstr_to_string(pwstr: ::windows_sys::core::PWSTR) -> Result<String, Error> {
if pwstr.is_null() {
return Err("Null pointer received".into());
}
let mut len = 0;
while unsafe { *pwstr.add(len) } != 0 {
len += 1;
}
let slice = unsafe { std::slice::from_raw_parts(pwstr, len) };
use std::os::windows::ffi::OsStringExt;
let os_string = std::ffi::OsString::from_wide(slice);
os_string
.into_string()
.map_err(|e| format!("Invalid UTF-8 sequence: {:?}", e).into())
}
pub(crate) fn guid_to_win_style_string(guid: &GUID) -> Result<String, Error> {
let mut buffer = [0u16; 40];
unsafe { StringFromGUID2(guid, &mut buffer as *mut u16, buffer.len() as i32) };
let guid = unsafe { win_pwstr_to_string(buffer.as_ptr() as _)? };
Ok(guid)
}
pub(crate) fn ipv6_netmask_for_prefix(prefix: u8) -> Result<Ipv6Addr, &'static str> {
if prefix > 128 {
return Err("Prefix value must be between 0 and 128.");
}
let mut mask: [u16; 8] = [0; 8];
let mut i = 0;
let mut remaining = prefix;
while remaining >= 16 {
mask[i] = 0xffff;
remaining -= 16;
i += 1;
}
if remaining > 0 {
mask[i] = 0xffff << (16 - remaining);
}
Ok(Ipv6Addr::new(
mask[0], mask[1], mask[2], mask[3], mask[4], mask[5], mask[6], mask[7],
))
}
pub fn get_active_network_interface_gateways() -> std::io::Result<Vec<IpAddr>> {
let mut addrs = vec![];
get_adapters_addresses(|adapter| {
if adapter.OperStatus == IfOperStatusUp
&& [IF_TYPE_IEEE80211, IF_TYPE_ETHERNET_CSMACD].contains(&adapter.IfType)
{
let mut current_gateway = adapter.FirstGatewayAddress;
while !current_gateway.is_null() {
let gateway = unsafe { &*current_gateway };
{
let sockaddr_ptr = gateway.Address.lpSockaddr;
let sockaddr = unsafe { &*(sockaddr_ptr as *const SOCKADDR) };
match unsafe { sockaddr_to_socket_addr(sockaddr) } {
Ok(a) => addrs.push(a.ip()),
Err(e) => {
log::error!("Failed to convert sockaddr to socket address: {}", e);
return false;
}
}
}
current_gateway = gateway.Next;
}
}
true
})?;
Ok(addrs)
}
crate::define_fn_dynamic_load!(
SetInterfaceDnsSettingsDeclare,
unsafe extern "system" fn(GUID, *const DNS_INTERFACE_SETTINGS) -> WIN32_ERROR,
SET_INTERFACE_DNS_SETTINGS,
SetInterfaceDnsSettings,
"iphlpapi.dll",
"SetInterfaceDnsSettings"
);
pub(crate) fn set_interface_dns_servers(interface: GUID, dns: &[IpAddr]) -> crate::Result<()> {
let func = SetInterfaceDnsSettings().ok_or("Failed to load function SetInterfaceDnsSettings")?;
let dns = dns.iter().map(|ip| ip.to_string()).collect::<Vec<_>>().join(",");
let dns = dns.encode_utf16().chain(std::iter::once(0)).collect::<Vec<_>>();
let settings = DNS_INTERFACE_SETTINGS {
Version: DNS_INTERFACE_SETTINGS_VERSION1,
Flags: DNS_SETTING_NAMESERVER as _,
NameServer: dns.as_ptr() as _,
Domain: std::ptr::null_mut(),
SearchList: std::ptr::null_mut(),
RegistrationEnabled: 0,
RegisterAdapterName: 0,
EnableLLMNR: 0,
QueryAdapterName: 0,
ProfileNameServer: std::ptr::null_mut(),
};
match unsafe { func(interface, &settings as *const _) } {
0 => Ok(()),
e => Err(std::io::Error::from_raw_os_error(e as i32).into()),
}
}
pub(crate) fn set_interface_dns_servers_via_cmd(adapter: &str, dns: &[IpAddr]) -> crate::Result<()> {
if dns.is_empty() {
return Ok(());
}
let ip_str = if dns[0].is_ipv4() { "ipv4" } else { "ipv6" };
let name = format!("name=\"{}\"", adapter);
let addr = format!("address=\"{}\"", dns[0]);
let args = vec!["interface", ip_str, "set", "dns", &name, "source=\"static\"", &addr];
run_command("netsh", &args)?;
for (index, dns) in (2..).zip(dns.iter().skip(1)) {
let ip_str = if dns.is_ipv4() { "ipv4" } else { "ipv6" };
let addr = format!("address=\"{}\"", dns);
let idx = format!("index={}", index);
let args = vec!["interface", ip_str, "add", "dns", &name, &idx, &addr];
run_command("netsh", &args)?;
}
Ok(())
}
pub(crate) fn retrieve_ipaddr_from_socket_address(address: &SOCKET_ADDRESS) -> Result<IpAddr, Error> {
unsafe { Ok(sockaddr_to_socket_addr(address.lpSockaddr)?.ip()) }
}
pub(crate) unsafe fn sockaddr_to_socket_addr(sock_addr: *const SOCKADDR) -> std::io::Result<SocketAddr> {
let address = match (unsafe { *sock_addr }).sa_family {
AF_INET => unsafe { sockaddr_in_to_socket_addr(&*(sock_addr as *const SOCKADDR_IN)) },
AF_INET6 => unsafe { sockaddr_in6_to_socket_addr(&*(sock_addr as *const SOCKADDR_IN6)) },
_ => return Err(std::io::Error::other("Unsupported address type")),
};
Ok(address)
}
pub(crate) unsafe fn sockaddr_in_to_socket_addr(sockaddr_in: &SOCKADDR_IN) -> SocketAddr {
let ip_bytes = unsafe { sockaddr_in.sin_addr.S_un.S_addr.to_ne_bytes() };
let ip = std::net::IpAddr::from(ip_bytes);
let port = u16::from_be(sockaddr_in.sin_port);
SocketAddr::new(ip, port)
}
pub(crate) unsafe fn sockaddr_in6_to_socket_addr(sockaddr_in6: &SOCKADDR_IN6) -> SocketAddr {
let ip = std::net::IpAddr::from(unsafe { sockaddr_in6.sin6_addr.u.Byte });
let port = u16::from_be(sockaddr_in6.sin6_port);
SocketAddr::new(ip, port)
}
pub(crate) fn get_adapters_addresses<F>(mut callback: F) -> Result<(), Error>
where
F: FnMut(IP_ADAPTER_ADDRESSES_LH) -> bool,
{
let mut size = 0;
let flags = GAA_FLAG_INCLUDE_PREFIX | GAA_FLAG_INCLUDE_GATEWAYS;
let family = AF_UNSPEC as u32;
let result = unsafe { GetAdaptersAddresses(family, flags, std::ptr::null_mut(), std::ptr::null_mut(), &mut size) };
if result != ERROR_BUFFER_OVERFLOW {
return Err(format!("GetAdaptersAddresses first attemp failed: {}", format_message(result)?).into());
}
let mut addresses: Vec<u8> = vec![0; (size + 4) as usize];
let result = unsafe {
let addr = addresses.as_mut_ptr() as *mut IP_ADAPTER_ADDRESSES_LH;
GetAdaptersAddresses(family, flags, std::ptr::null_mut(), addr, &mut size)
};
if ERROR_SUCCESS != result {
let err_msg = match result {
ERROR_ADDRESS_NOT_ASSOCIATED => "ERROR_ADDRESS_NOT_ASSOCIATED".into(),
ERROR_BUFFER_OVERFLOW => "ERROR_BUFFER_OVERFLOW".into(),
ERROR_INVALID_PARAMETER => "ERROR_INVALID_PARAMETER".into(),
ERROR_NOT_ENOUGH_MEMORY => "ERROR_NOT_ENOUGH_MEMORY".into(),
ERROR_NO_DATA => "ERROR_NO_DATA".into(),
_ => format_message(result)?,
};
return Err(format!("GetAdaptersAddresses second attemp failed: {err_msg}").into());
}
let mut current_addresses = addresses.as_ptr() as *const IP_ADAPTER_ADDRESSES_LH;
while !current_addresses.is_null() {
unsafe {
if !callback(*current_addresses) {
break;
}
current_addresses = (*current_addresses).Next;
}
}
Ok(())
}
fn get_interface_info_sys<F>(mut callback: F) -> Result<(), Error>
where
F: FnMut(IP_ADAPTER_INDEX_MAP) -> bool,
{
let mut buf_len: u32 = 0;
let result = unsafe { GetInterfaceInfo(std::ptr::null_mut(), &mut buf_len as *mut u32) };
if result != NO_ERROR && result != ERROR_INSUFFICIENT_BUFFER {
let err_msg = format_message(result).map_err(Error::from)?;
log::error!("Failed to get interface info: {}", err_msg);
return Err(format!("GetInterfaceInfo failed: {}", err_msg).into());
}
let buf_elements = buf_len as usize / std::mem::size_of::<u32>() + 1;
let mut buf: Vec<u32> = vec![0; buf_elements];
let buf_bytes = buf.len() * std::mem::size_of::<u32>();
assert!(buf_bytes >= buf_len as usize);
let mut final_buf_len: u32 = buf_len;
let result = unsafe {
GetInterfaceInfo(
buf.as_mut_ptr() as *mut IP_INTERFACE_INFO,
&mut final_buf_len as *mut u32,
)
};
if result != NO_ERROR {
let err_msg = format_message(result).map_err(Error::from)?;
log::error!(
"Failed to get interface info a second time: {}. Original len: {}, final len: {}",
err_msg,
buf_len,
final_buf_len
);
return Err(format!("GetInterfaceInfo failed a second time: {}", err_msg).into());
}
let info = buf.as_mut_ptr() as *const IP_INTERFACE_INFO;
let adapter_base = unsafe { &*info };
let adapter_count = adapter_base.NumAdapters;
let first_adapter = &adapter_base.Adapter as *const IP_ADAPTER_INDEX_MAP;
let interfaces = unsafe { std::slice::from_raw_parts(first_adapter, adapter_count as usize) };
for interface in interfaces {
if !callback(*interface) {
break;
}
}
Ok(())
}
#[allow(dead_code)]
pub(crate) fn get_interface_info() -> Result<Vec<(u32, String)>, Error> {
let mut v = vec![];
get_interface_info_sys(|mut interface| {
let name = match unsafe { win_pwstr_to_string(&mut interface.Name as _) } {
Ok(name) => name,
Err(e) => {
log::error!("Failed to convert interface name: {}", e);
return false;
}
};
match name.split('{').nth(1).and_then(|s| s.split('}').next()) {
Some(guid) => v.push((interface.Index, guid.to_string())),
None => {
log::error!("Failed to extract GUID from interface name: {}", name);
return false;
}
}
true
})?;
Ok(v)
}
#[allow(non_snake_case)]
#[inline]
fn MAKELANGID(p: u32, s: u32) -> u32 {
((s & 0x0000ffff) << 10) | (p & 0x0000ffff)
}
pub fn format_message(error_code: u32) -> std::io::Result<String> {
let buf: *mut u16 = std::ptr::null_mut();
let chars_written = unsafe {
FormatMessageW(
FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER,
std::ptr::null_mut(),
error_code,
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
&buf as *const windows_sys::core::PWSTR as *mut u16,
0,
std::ptr::null_mut(),
)
};
if chars_written == 0 {
return get_last_error();
}
let result = unsafe { win_pwstr_to_string(buf)? };
if unsafe { !LocalFree(buf as *mut _).is_null() } {
log::trace!("LocalFree failed: {:?}", get_last_error());
}
Ok(result)
}
pub(crate) fn get_last_error() -> std::io::Result<String> {
get_os_error_from_id(unsafe { GetLastError() as _ })?;
Ok("No error".to_string())
}
pub(crate) fn get_os_error_from_id(id: i32) -> std::io::Result<()> {
match id {
0 => Ok(()),
e => Err(std::io::Error::from_raw_os_error(e)),
}
}
pub fn set_adapter_mtu(luid: &NET_LUID_LH, mtu: usize, is_ipv6: bool) -> std::io::Result<()> {
let mut row = get_ip_interface(luid, is_ipv6)?;
row.SitePrefixLength = 0;
row.NlMtu = mtu as u32;
let status = unsafe { SetIpInterfaceEntry(&mut row) };
if status != NO_ERROR {
return Err(std::io::Error::from_raw_os_error(status as i32));
}
Ok(())
}
pub(crate) fn get_adapter_mtu(luid: &NET_LUID_LH, is_ipv6: bool) -> std::io::Result<u32> {
get_ip_interface(luid, is_ipv6).map(|row| row.NlMtu)
}
fn get_ip_interface(luid: &NET_LUID_LH, is_ipv6: bool) -> std::io::Result<MIB_IPINTERFACE_ROW> {
let mut row = MIB_IPINTERFACE_ROW {
Family: if is_ipv6 { AF_INET6 } else { AF_INET },
InterfaceLuid: *luid,
..Default::default()
};
let status = unsafe { GetIpInterfaceEntry(&mut row) };
if status != NO_ERROR {
return Err(std::io::Error::from_raw_os_error(status as i32));
}
Ok(row)
}
pub fn run_command(command: &str, args: &[&str]) -> std::io::Result<Vec<u8>> {
let full_cmd = format!("{} {}", command, args.join(" "));
log::debug!("Running command: \"{full_cmd}\"...");
let out = match std::process::Command::new(command).args(args).output() {
Ok(out) => out,
Err(e) => {
let e2 = e.to_string().trim().to_string();
log::error!("Run command: \"{full_cmd}\" failed with: \"{e2}\"");
return Err(e);
}
};
if !out.status.success() {
let err = String::from_utf8_lossy(if out.stderr.is_empty() {
&out.stdout
} else {
&out.stderr
});
let info = format!("Run command: \"{full_cmd}\" not success with \"{}\"", err.trim());
log::error!("{}", info);
return Err(std::io::Error::other(info));
}
Ok(out.stdout)
}
pub fn decode_utf16(string: &[u16]) -> String {
let end = string.iter().position(|b| *b == 0).unwrap_or(string.len());
String::from_utf16_lossy(&string[..end])
}
#[repr(C, align(1))]
#[derive(c2rust_bitfields::BitfieldStruct)]
#[allow(non_snake_case)]
#[allow(non_camel_case_types)]
struct _NET_LUID_LH_INFO {
#[bitfield(name = "Reserved", ty = "u64", bits = "0..=23")]
#[bitfield(name = "NetLuidIndex", ty = "u64", bits = "24..=47")]
#[bitfield(name = "IfType", ty = "u64", bits = "48..=63")]
_Value: [u8; 8],
}