#![cfg(windows)]
use std::env;
use std::ffi::OsString;
use std::os::windows::ffi::OsStringExt;
use windows::core::PWSTR;
use windows::Win32::Foundation::{ERROR_BUFFER_OVERFLOW, NO_ERROR, WIN32_ERROR};
use windows::Win32::NetworkManagement::IpHelper::{
GetAdaptersAddresses, GAA_FLAG_INCLUDE_GATEWAYS, GET_ADAPTERS_ADDRESSES_FLAGS,
IF_TYPE_SOFTWARE_LOOPBACK, IP_ADAPTER_ADDRESSES_LH,
};
use windows::Win32::NetworkManagement::Ndis::IfOperStatusUp;
use windows::Win32::Networking::WinSock::AF_INET;
use crate::error::{HnsError, HnsResult};
pub const ZLAYER_UPLINK_ENV: &str = "ZLAYER_HCN_UPLINK_ADAPTER";
pub fn find_primary_adapter() -> HnsResult<String> {
if let Ok(override_name) = env::var(ZLAYER_UPLINK_ENV) {
let trimmed = override_name.trim();
if !trimmed.is_empty() {
return Ok(trimmed.to_string());
}
}
let buffer = query_adapters_ipv4_with_gateways()?;
let Some(first) = buffer.first_adapter() else {
return Err(HnsError::Other {
hresult: 0,
message: format!(
"GetAdaptersAddresses returned no adapters; set {ZLAYER_UPLINK_ENV} to override"
),
});
};
let mut cursor: *const IP_ADAPTER_ADDRESSES_LH = first;
while !cursor.is_null() {
let adapter = unsafe { &*cursor };
if adapter_is_candidate(adapter) {
let name = unsafe { pwstr_to_string(adapter.FriendlyName) };
if !name.is_empty() {
return Ok(name);
}
}
cursor = adapter.Next;
}
Err(HnsError::Other {
hresult: 0,
message: format!(
"no up physical adapter with a default IPv4 gateway found; set {ZLAYER_UPLINK_ENV} \
to the adapter friendly name (e.g. \"Ethernet\") to override"
),
})
}
fn adapter_is_candidate(adapter: &IP_ADAPTER_ADDRESSES_LH) -> bool {
if adapter.OperStatus != IfOperStatusUp {
return false;
}
if adapter.IfType == IF_TYPE_SOFTWARE_LOOPBACK {
return false;
}
!adapter.FirstGatewayAddress.is_null()
}
struct AdapterBuffer {
buf: Vec<u8>,
#[allow(dead_code)]
used: usize,
}
impl AdapterBuffer {
fn first_adapter(&self) -> Option<*const IP_ADAPTER_ADDRESSES_LH> {
if self.buf.is_empty() {
return None;
}
Some(self.buf.as_ptr() as *const IP_ADAPTER_ADDRESSES_LH)
}
}
fn query_adapters_ipv4_with_gateways() -> HnsResult<AdapterBuffer> {
let flags: GET_ADAPTERS_ADDRESSES_FLAGS = GAA_FLAG_INCLUDE_GATEWAYS;
let family = u32::from(AF_INET.0);
let mut buf_len: u32 = 15 * 1024;
let mut buffer: Vec<u8> = Vec::new();
for _attempt in 0..3 {
buffer.resize(buf_len as usize, 0);
let ret = unsafe {
GetAdaptersAddresses(
family,
flags,
None,
Some(buffer.as_mut_ptr().cast::<IP_ADAPTER_ADDRESSES_LH>()),
&mut buf_len,
)
};
let code = WIN32_ERROR(ret);
if code == NO_ERROR {
let used = buf_len as usize;
buffer.truncate(used);
return Ok(AdapterBuffer { buf: buffer, used });
}
if code == ERROR_BUFFER_OVERFLOW {
continue;
}
return Err(HnsError::Other {
hresult: ret as i32,
message: format!("GetAdaptersAddresses failed: WIN32_ERROR {ret:#x}"),
});
}
Err(HnsError::Other {
hresult: 0,
message: "GetAdaptersAddresses kept asking for more buffer space".to_string(),
})
}
unsafe fn pwstr_to_string(p: PWSTR) -> String {
if p.0.is_null() {
return String::new();
}
let mut len = 0usize;
while *p.0.add(len) != 0 {
len += 1;
}
let slice = std::slice::from_raw_parts(p.0, len);
OsString::from_wide(slice).to_string_lossy().into_owned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn env_override_wins_when_set() {
let marker = "zlayer-test-uplink-marker";
let previous = env::var(ZLAYER_UPLINK_ENV).ok();
unsafe {
env::set_var(ZLAYER_UPLINK_ENV, marker);
}
let got = find_primary_adapter().unwrap();
assert_eq!(got, marker);
match previous {
Some(v) => unsafe { env::set_var(ZLAYER_UPLINK_ENV, v) },
None => unsafe { env::remove_var(ZLAYER_UPLINK_ENV) },
}
}
#[test]
fn env_override_empty_falls_through() {
let previous = env::var(ZLAYER_UPLINK_ENV).ok();
unsafe {
env::set_var(ZLAYER_UPLINK_ENV, " ");
}
let got = find_primary_adapter();
if let Ok(name) = got {
assert_ne!(name, " ");
} else {
}
match previous {
Some(v) => unsafe { env::set_var(ZLAYER_UPLINK_ENV, v) },
None => unsafe { env::remove_var(ZLAYER_UPLINK_ENV) },
}
}
#[test]
#[ignore = "requires a real Windows host with a default gateway"]
fn auto_detect_returns_adapter_name() {
let previous = env::var(ZLAYER_UPLINK_ENV).ok();
unsafe {
env::remove_var(ZLAYER_UPLINK_ENV);
}
let name = find_primary_adapter().unwrap();
assert!(!name.is_empty());
if let Some(v) = previous {
unsafe {
env::set_var(ZLAYER_UPLINK_ENV, v);
}
}
}
}