use crate::network::{AdapterKind, AdapterSnapshot, AddressFetcher, FetchError};
use std::net::{Ipv4Addr, Ipv6Addr};
use windows::Win32::Foundation::WIN32_ERROR;
use windows::Win32::NetworkManagement::IpHelper::{
GAA_FLAG_SKIP_ANYCAST, GAA_FLAG_SKIP_DNS_SERVER, GAA_FLAG_SKIP_MULTICAST, GetAdaptersAddresses,
IF_TYPE_ETHERNET_CSMACD, IF_TYPE_IEEE80211, IF_TYPE_SOFTWARE_LOOPBACK, IP_ADAPTER_ADDRESSES_LH,
};
use windows::Win32::Networking::WinSock::{
AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR_IN, SOCKADDR_IN6,
};
const IF_TYPE_PPP: u32 = 23;
const IF_TYPE_TUNNEL: u32 = 131;
const INITIAL_BUFFER_SIZE: u32 = 16384;
#[derive(Debug, Clone, Default)]
pub struct WindowsFetcher {
_private: (),
}
impl WindowsFetcher {
#[must_use]
pub const fn new() -> Self {
Self { _private: () }
}
}
impl AddressFetcher for WindowsFetcher {
fn fetch(&self) -> Result<Vec<AdapterSnapshot>, FetchError> {
fetch_adapters()
}
}
fn fetch_adapters() -> Result<Vec<AdapterSnapshot>, FetchError> {
let raw_adapters = get_adapter_addresses()?;
let mut adapters = Vec::new();
#[allow(clippy::cast_ptr_alignment)]
let mut current = raw_adapters.as_ptr().cast::<IP_ADAPTER_ADDRESSES_LH>();
while !current.is_null() {
let adapter = unsafe { &*current };
if let Some(snapshot) = parse_adapter(adapter) {
adapters.push(snapshot);
}
current = adapter.Next;
}
Ok(adapters)
}
fn get_adapter_addresses() -> Result<Vec<u8>, FetchError> {
let flags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER;
let family = u32::from(AF_UNSPEC.0);
let mut buffer: Vec<u8> = vec![0u8; INITIAL_BUFFER_SIZE as usize];
let mut size = INITIAL_BUFFER_SIZE;
let result = unsafe {
GetAdaptersAddresses(
family,
flags,
None,
Some(buffer.as_mut_ptr().cast()),
&raw mut size,
)
};
handle_api_result(result, &mut buffer, &mut size, flags, family)?;
Ok(buffer)
}
#[cfg(not(tarpaulin_include))]
fn handle_api_result(
result: u32,
buffer: &mut Vec<u8>,
size: &mut u32,
flags: windows::Win32::NetworkManagement::IpHelper::GET_ADAPTERS_ADDRESSES_FLAGS,
family: u32,
) -> Result<(), FetchError> {
use windows::Win32::Foundation::{ERROR_BUFFER_OVERFLOW, NO_ERROR};
if result == ERROR_BUFFER_OVERFLOW.0 {
buffer.resize(*size as usize, 0);
let result = unsafe {
GetAdaptersAddresses(
family,
flags,
None,
Some(buffer.as_mut_ptr().cast()),
&raw mut *size,
)
};
if result != NO_ERROR.0 {
return Err(windows::core::Error::from(WIN32_ERROR(result)).into());
}
} else if result != NO_ERROR.0 {
return Err(windows::core::Error::from(WIN32_ERROR(result)).into());
}
Ok(())
}
fn parse_adapter(adapter: &IP_ADAPTER_ADDRESSES_LH) -> Option<AdapterSnapshot> {
let name = unsafe { adapter.FriendlyName.to_string().ok()? };
let kind = map_adapter_type(adapter.IfType);
let (ipv4_addresses, ipv6_addresses) = collect_addresses(adapter);
Some(AdapterSnapshot::new(
name,
kind,
ipv4_addresses,
ipv6_addresses,
))
}
const fn map_adapter_type(if_type: u32) -> AdapterKind {
match if_type {
IF_TYPE_ETHERNET_CSMACD => AdapterKind::Ethernet,
IF_TYPE_IEEE80211 => AdapterKind::Wireless,
IF_TYPE_SOFTWARE_LOOPBACK => AdapterKind::Loopback,
IF_TYPE_TUNNEL | IF_TYPE_PPP => AdapterKind::Virtual,
other => AdapterKind::Other(other),
}
}
#[allow(clippy::cast_ptr_alignment)]
fn collect_addresses(adapter: &IP_ADAPTER_ADDRESSES_LH) -> (Vec<Ipv4Addr>, Vec<Ipv6Addr>) {
let mut ipv4_addresses = Vec::new();
let mut ipv6_addresses = Vec::new();
let mut unicast = adapter.FirstUnicastAddress;
while !unicast.is_null() {
let addr_entry = unsafe { &*unicast };
if let Some(sockaddr) = unsafe { addr_entry.Address.lpSockaddr.as_ref() } {
match sockaddr.sa_family {
f if f == AF_INET => {
let sockaddr_in =
unsafe { &*(std::ptr::from_ref(sockaddr).cast::<SOCKADDR_IN>()) };
let octets = unsafe { sockaddr_in.sin_addr.S_un.S_un_b };
let addr = Ipv4Addr::new(octets.s_b1, octets.s_b2, octets.s_b3, octets.s_b4);
ipv4_addresses.push(addr);
}
f if f == AF_INET6 => {
let sockaddr_in6 =
unsafe { &*(std::ptr::from_ref(sockaddr).cast::<SOCKADDR_IN6>()) };
let octets = unsafe { sockaddr_in6.sin6_addr.u.Byte };
let addr = Ipv6Addr::from(octets);
ipv6_addresses.push(addr);
}
_ => {}
}
}
unicast = unsafe { (*unicast).Next };
}
(ipv4_addresses, ipv6_addresses)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn map_adapter_type_ethernet() {
assert_eq!(
map_adapter_type(IF_TYPE_ETHERNET_CSMACD),
AdapterKind::Ethernet
);
}
#[test]
fn map_adapter_type_wireless() {
assert_eq!(map_adapter_type(IF_TYPE_IEEE80211), AdapterKind::Wireless);
}
#[test]
fn map_adapter_type_loopback() {
assert_eq!(
map_adapter_type(IF_TYPE_SOFTWARE_LOOPBACK),
AdapterKind::Loopback
);
}
#[test]
fn map_adapter_type_tunnel_is_virtual() {
assert_eq!(map_adapter_type(IF_TYPE_TUNNEL), AdapterKind::Virtual);
}
#[test]
fn map_adapter_type_ppp_is_virtual() {
assert_eq!(map_adapter_type(IF_TYPE_PPP), AdapterKind::Virtual);
}
#[test]
fn map_adapter_type_unknown_preserves_code() {
assert_eq!(map_adapter_type(999), AdapterKind::Other(999));
}
#[test]
fn windows_fetcher_new_creates_instance() {
let _fetcher = WindowsFetcher::new();
}
#[test]
fn windows_fetcher_default_creates_instance() {
let _fetcher = WindowsFetcher::default();
}
#[test]
fn fetch_adapters_returns_at_least_loopback() {
let fetcher = WindowsFetcher::new();
let result = fetcher.fetch();
assert!(result.is_ok(), "fetch() failed: {:?}", result.err());
let adapters = result.unwrap();
let has_loopback_addr = adapters.iter().any(|a| {
a.ipv4_addresses.contains(&Ipv4Addr::LOCALHOST)
|| a.ipv6_addresses.contains(&Ipv6Addr::LOCALHOST)
});
assert!(
has_loopback_addr,
"Expected at least loopback address, got adapters: {adapters:?}"
);
}
#[test]
fn fetch_adapters_names_are_not_empty() {
let fetcher = WindowsFetcher::new();
let adapters = fetcher.fetch().expect("fetch() failed");
for adapter in &adapters {
assert!(
!adapter.name.is_empty(),
"Adapter name should not be empty: {adapter:?}"
);
}
}
}