use std::{io, net::Ipv4Addr, ptr, slice::from_raw_parts};
use winapi::ctypes::*;
use winapi::{
shared::{
basetsd::SIZE_T,
minwindef::{DWORD, LPVOID, ULONG},
winerror::{ERROR_BUFFER_OVERFLOW, ERROR_SUCCESS},
ws2def::{AF_INET, SOCKADDR},
},
um::heapapi::{GetProcessHeap, HeapAlloc, HeapFree},
};
use crate::upnp::Network;
const MAX_ADAPTER_ADDRESS_LENGTH: usize = 8;
const WORKING_BUFFER_SIZEL: SIZE_T = 15000;
#[repr(C)]
struct LengthIfIndex {
length: ULONG,
ifindex: DWORD,
}
#[repr(C)]
struct LengthFlags {
length: ULONG,
flags: DWORD,
}
#[repr(C)]
struct SoketAddress {
lp_sockaddr: *mut SOCKADDR,
i_sockaddr_length: c_int,
}
#[repr(C)]
struct IpAdapterPrefix {
aol: LengthIfIndex,
next: *mut IpAdapterPrefix,
address: SoketAddress,
prefix_length: ULONG,
}
#[repr(C)]
struct IpAdapterUnicastAddress {
aol: LengthFlags,
next: *mut IpAdapterUnicastAddress,
address: SoketAddress,
prefix_origin: c_int,
suffix_origin: c_int,
dad_state: c_int,
valid_lifetime: ULONG,
preferred_lifetime: ULONG,
lease_lifetime: ULONG,
on_link_prefix_length: __uint8,
}
#[repr(C)]
struct IpAdapterAddresses {
aol: LengthIfIndex,
next: *mut IpAdapterAddresses,
adapter_name: *mut c_char,
first_unicass_address: *mut IpAdapterUnicastAddress,
first_anycass_address: *const c_void,
first_multicass_address: *const c_void,
first_dns_server_address: *const c_void,
dns_suffix: *mut wchar_t,
description: *mut wchar_t,
friendly_name: *mut wchar_t,
physical_address: [u8; MAX_ADAPTER_ADDRESS_LENGTH],
physical_address_length: DWORD,
flags: DWORD,
mtu: DWORD,
if_type: DWORD,
oper_status: c_int,
ipv6_if_index: DWORD,
zone_indices: [DWORD; 16],
first_prefix: *mut IpAdapterPrefix,
}
#[link(name = "iphlpapi")]
unsafe extern "system" {
fn GetAdaptersAddresses(
family: ULONG,
flags: ULONG,
reserved: *const c_void,
addresses: *mut IpAdapterAddresses,
size: *mut ULONG,
) -> ULONG;
}
pub fn get_local_net_state() -> io::Result<Vec<Network>> {
let mut new_size: ULONG = WORKING_BUFFER_SIZEL as ULONG;
let mut p_adapter: *mut IpAdapterAddresses;
loop {
unsafe {
p_adapter =
HeapAlloc(GetProcessHeap(), 0, WORKING_BUFFER_SIZEL) as *mut IpAdapterAddresses;
if p_adapter.is_null() {
return Err(io::Error::new(io::ErrorKind::Other, "Failed: malloc!"));
}
let res_code = GetAdaptersAddresses(
2, 0x0002 as ULONG, ptr::null(),
p_adapter,
&mut new_size as *mut ULONG,
);
match res_code {
ERROR_SUCCESS => break,
ERROR_BUFFER_OVERFLOW => {
new_size *= 2;
let res = HeapFree(GetProcessHeap(), 0, p_adapter as LPVOID);
if res == 0 {
return Err(io::Error::new(io::ErrorKind::Other, "Failed: HeapFree!"));
}
continue;
}
_ => {
return Err(io::Error::last_os_error());
}
}
}
}
let mut result = Vec::new();
unsafe {
let mut cur_p_adapter = p_adapter;
while !cur_p_adapter.is_null() {
let friendly_name = u16_array_to_string((*cur_p_adapter).friendly_name).to_lowercase();
if friendly_name.starts_with("loopback") || friendly_name.starts_with("docker") {
cur_p_adapter = (*cur_p_adapter).next;
continue;
}
let mut cur_p_addr = (*cur_p_adapter).first_unicass_address;
while !cur_p_addr.is_null() {
if let Some(address) = parse_addr((*cur_p_addr).address.lp_sockaddr) {
result.push(Network {
address,
net_mask: netmask_v4((*cur_p_addr).on_link_prefix_length)
.expect("Invalid subnet mask"),
});
}
cur_p_addr = (*cur_p_addr).next;
}
cur_p_adapter = (*cur_p_adapter).next;
}
}
unsafe {
let res = HeapFree(GetProcessHeap(), 0, p_adapter as LPVOID);
if res == 0 {
return Err(io::Error::new(io::ErrorKind::Other, "Failed: HeapFree!"));
}
}
Ok(result)
}
fn parse_addr(p_sock: *const SOCKADDR) -> Option<Ipv4Addr> {
if p_sock.is_null() {
return None;
}
let addr = unsafe { *p_sock };
match i32::from(addr.sa_family) {
AF_INET => Some(Ipv4Addr::new(
addr.sa_data[2] as u8,
addr.sa_data[3] as u8,
addr.sa_data[4] as u8,
addr.sa_data[5] as u8,
)),
_ => None,
}
}
fn u16_array_to_string(p_array: *const u16) -> String {
use std::char::{REPLACEMENT_CHARACTER, decode_utf16};
unsafe {
if p_array.is_null() {
return String::new();
}
let mut cur = 0usize;
while !p_array.add(cur).is_null() && *p_array.add(cur) != 0u16 {
cur += 1;
}
let u16s = from_raw_parts(p_array, cur);
decode_utf16(u16s.iter().cloned())
.map(|r| r.unwrap_or(REPLACEMENT_CHARACTER))
.collect::<String>()
}
}
fn netmask_v4(bits: u8) -> Option<Ipv4Addr> {
if bits <= 32 {
let mut i = (0..4u8).map(|idx| {
let idx8 = idx << 3;
match (bits > idx8, bits > idx8 + 8) {
(true, true) => 255,
(true, false) => 255u8.wrapping_shl(u32::from(8 - bits % 8)),
_ => 0,
}
});
Some(Ipv4Addr::new(
i.next().unwrap(),
i.next().unwrap(),
i.next().unwrap(),
i.next().unwrap(),
))
} else {
None
}
}
#[cfg(test)]
mod test {
use super::{Ipv4Addr, netmask_v4};
#[test]
fn netmask_v4_test() {
vec![
(0, "0.0.0.0"),
(1, "128.0.0.0"),
(2, "192.0.0.0"),
(3, "224.0.0.0"),
(4, "240.0.0.0"),
(5, "248.0.0.0"),
(6, "252.0.0.0"),
(7, "254.0.0.0"),
(8, "255.0.0.0"),
(9, "255.128.0.0"),
(10, "255.192.0.0"),
(11, "255.224.0.0"),
(12, "255.240.0.0"),
(13, "255.248.0.0"),
(14, "255.252.0.0"),
(15, "255.254.0.0"),
(16, "255.255.0.0"),
(17, "255.255.128.0"),
(18, "255.255.192.0"),
(19, "255.255.224.0"),
(20, "255.255.240.0"),
(21, "255.255.248.0"),
(22, "255.255.252.0"),
(23, "255.255.254.0"),
(24, "255.255.255.0"),
(25, "255.255.255.128"),
(26, "255.255.255.192"),
(27, "255.255.255.224"),
(28, "255.255.255.240"),
(29, "255.255.255.248"),
(30, "255.255.255.252"),
(31, "255.255.255.254"),
(32, "255.255.255.255"),
]
.into_iter()
.for_each(|(i, addr)| assert_eq!(netmask_v4(i), addr.parse::<Ipv4Addr>().ok()))
}
}