#![cfg(target_os = "windows")]
use windows_sys::{
Win32::{
Foundation::{ERROR_BUFFER_OVERFLOW, ERROR_SUCCESS, NO_ERROR, WIN32_ERROR},
NetworkManagement::{
IpHelper::{
ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToGuid, DNS_INTERFACE_SETTINGS, DNS_INTERFACE_SETTINGS_VERSION1,
DNS_SETTING_NAMESERVER, GAA_FLAG_INCLUDE_GATEWAYS, GAA_FLAG_INCLUDE_PREFIX, GetAdaptersAddresses, IF_TYPE_ETHERNET_CSMACD,
IF_TYPE_IEEE80211, IP_ADAPTER_ADDRESSES_LH,
},
Ndis::{IfOperStatusUp, NET_LUID_LH},
},
Networking::WinSock::{AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6},
},
core::GUID,
};
use crate::{TproxyArgs, TproxyStateInner, run_command};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
pub(crate) async fn _tproxy_setup(tproxy_args: &TproxyArgs) -> std::io::Result<TproxyStateInner> {
log::trace!("Setting up transparent proxy...");
flush_dns_cache()?;
log::trace!("Route all traffic to the gateway of adapter \"{}\"...", tproxy_args.tun_name);
let unspecified = if tproxy_args.tun_gateway.is_ipv4() {
Ipv4Addr::UNSPECIFIED.to_string()
} else {
std::net::Ipv6Addr::UNSPECIFIED.to_string()
};
let gateway = tproxy_args.tun_gateway.to_string();
let args = &["add", &unspecified, "mask", &unspecified, &gateway, "metric", "6"];
run_command("route", args)?;
log::trace!("Get default gateway...");
let original_gateway = get_default_gateway_ip()?;
log::trace!("Setting bypass IPs...");
for bypass_ip in tproxy_args.bypass_ips.iter() {
do_bypass_ip(*bypass_ip, original_gateway)?;
}
if tproxy_args.bypass_ips.is_empty() && !crate::is_private_ip(tproxy_args.proxy_addr.ip()) {
let cidr = cidr::IpCidr::new_host(tproxy_args.proxy_addr.ip());
do_bypass_ip(cidr, original_gateway)?;
}
log::trace!("Setting \"{}\"'s DNS to {}...", tproxy_args.tun_name, tproxy_args.tun_gateway);
set_dns_server(&tproxy_args.tun_name, tproxy_args.tun_gateway)?;
log::trace!("Transparent proxy setup done");
let state = TproxyStateInner {
tproxy_args: Some(tproxy_args.clone()),
original_dns_servers: None,
gateway: Some(original_gateway),
gw_scope: None,
umount_resolvconf: false,
restore_resolvconf_content: None,
tproxy_removed_done: false,
};
#[cfg(all(feature = "unsafe-state-file", any(target_os = "macos", target_os = "windows")))]
crate::store_intermediate_state(&state)?;
Ok(state)
}
fn do_bypass_ip(bypass_ip: cidr::IpCidr, original_gateway: IpAddr) -> std::io::Result<()> {
let args = &["add", &bypass_ip.to_string(), &original_gateway.to_string(), "metric", "1"];
run_command("route", args)?;
Ok(())
}
pub(crate) async fn _tproxy_remove(state: &mut TproxyStateInner) -> std::io::Result<()> {
if state.tproxy_removed_done {
return Ok(());
}
state.tproxy_removed_done = true;
let err = std::io::Error::new(std::io::ErrorKind::InvalidData, "tproxy_args is None");
let tproxy_args = state.tproxy_args.as_ref().ok_or(err)?;
let err = std::io::Error::other("No default gateway found");
let original_gateway = state.gateway.take().ok_or(err)?;
let unspecified = Ipv4Addr::UNSPECIFIED.to_string();
let gateway = tproxy_args.tun_gateway.to_string();
let args = &["-p", "delete", &unspecified, "mask", &unspecified, &gateway];
if let Err(_err) = run_command("route", args) {
log::debug!("command \"route {args:?}\" error: {_err}");
}
for bypass_ip in tproxy_args.bypass_ips.iter() {
let args = &["delete", &bypass_ip.to_string()];
if let Err(_err) = run_command("route", args) {
log::debug!("command \"route {args:?}\" error: {_err}");
}
}
if tproxy_args.bypass_ips.is_empty() && !crate::is_private_ip(tproxy_args.proxy_addr.ip()) {
let bypass_ip = cidr::IpCidr::new_host(tproxy_args.proxy_addr.ip());
let args = &["delete", &bypass_ip.to_string()];
if let Err(_err) = run_command("route", args) {
log::debug!("command \"route {args:?}\" error: {_err}");
}
}
let args = &["delete", &unspecified, "mask", &unspecified];
if let Err(_err) = run_command("route", args) {
log::debug!("command \"route {args:?}\" error: {_err}");
}
let original_gateway = original_gateway.to_string();
let args = &["add", &unspecified, "mask", &unspecified, &original_gateway, "metric", "200"];
if let Err(_err) = run_command("route", args) {
log::debug!("command \"route {args:?}\" error: {_err}");
}
#[cfg(all(feature = "unsafe-state-file", any(target_os = "macos", target_os = "windows")))]
let _ = std::fs::remove_file(crate::get_state_file_path());
flush_dns_cache()?;
Ok(())
}
crate::define_fn_dynamic_load!(
SetInterfaceDnsSettingsDeclare,
unsafe extern "system" fn(interface: GUID, settings: *const DNS_INTERFACE_SETTINGS) -> WIN32_ERROR,
SET_INTERFACE_DNS_SETTINGS,
SetInterfaceDnsSettings,
"iphlpapi.dll",
"SetInterfaceDnsSettings"
);
pub(crate) fn set_dns_server(iface: &str, dns_server: IpAddr) -> std::io::Result<()> {
let Some(set_dns_fn) = SetInterfaceDnsSettings() else {
let tun_name = format!("\"{iface}\"");
let args = &["interface", "ip", "set", "dns", &tun_name, "static", &dns_server.to_string()];
run_command("netsh", args)?;
return Ok(());
};
let svr: Vec<u16> = dns_server.to_string().encode_utf16().chain(std::iter::once(0)).collect();
let settings = DNS_INTERFACE_SETTINGS {
Version: DNS_INTERFACE_SETTINGS_VERSION1,
Flags: DNS_SETTING_NAMESERVER as _,
NameServer: svr.as_ptr() as _,
Domain: std::ptr::null_mut(),
SearchList: std::ptr::null_mut(),
RegistrationEnabled: 0,
RegisterAdapterName: 0,
EnableLLMNR: 0,
QueryAdapterName: 0,
ProfileNameServer: std::ptr::null_mut(),
};
let luid = alias_to_luid(iface)?;
let guid = luid_to_guid(&luid)?;
let ret = unsafe { set_dns_fn(guid, &settings) };
if ret != NO_ERROR {
let err = std::io::Error::from_raw_os_error(ret as _);
return Err(err);
}
Ok(())
}
#[allow(dead_code)]
pub(crate) fn get_default_gateway() -> std::io::Result<(IpAddr, String)> {
let addr = get_default_gateway_ip()?;
let iface = get_default_gateway_interface()?;
Ok((addr, iface))
}
pub(crate) fn get_default_gateway_ip() -> std::io::Result<IpAddr> {
match get_active_network_interface_gateways().map(|gateways| gateways[0]) {
Ok(gateway) => Ok(gateway),
Err(e) => {
log::debug!("Failed to get default gateway by GetAdaptersAddresses: {e}");
get_default_gateway_ip_by_cmd()
}
}
}
pub(crate) fn get_default_gateway_ip_by_cmd() -> std::io::Result<IpAddr> {
let cmd = "Get-WmiObject -Class Win32_NetworkAdapterConfiguration -Filter IPEnabled=TRUE | ForEach-Object { $_.DefaultIPGateway }";
let gateways = match run_command("powershell", &["-Command", cmd]) {
Ok(gateways) => gateways,
Err(e) => {
let str = format!("Command \"powershell -Command {cmd}\" error: {e}");
let err = std::io::Error::other(str);
return Err(err);
}
};
let stdout = String::from_utf8_lossy(&gateways).into_owned();
let lines: Vec<&str> = stdout.lines().collect();
let mut ipv4_gateway = None;
let mut ipv6_gateway = None;
for line in lines {
if let Ok(ip) = <IpAddr as std::str::FromStr>::from_str(line) {
match ip {
IpAddr::V4(_) => {
ipv4_gateway = Some(ip);
break;
}
IpAddr::V6(_) => {
ipv6_gateway = Some(ip);
}
}
}
}
let err = std::io::Error::other("No default gateway found");
ipv4_gateway.or(ipv6_gateway).ok_or(err)
}
#[allow(dead_code)]
pub(crate) fn get_default_gateway_interface() -> std::io::Result<String> {
let cmd = "Get-WmiObject -Class Win32_NetworkAdapter | Where-Object { $_.NetConnectionStatus -eq 2 } | Select-Object -First 1 -ExpandProperty NetConnectionID";
let iface = match run_command("powershell", &["-Command", cmd]) {
Ok(iface) => iface,
Err(e) => {
let str = format!("Command \"powershell -Command {cmd}\" error: {e}");
let err = std::io::Error::other(str);
return Err(err);
}
};
let stdout = String::from_utf8_lossy(&iface).into_owned();
let iface = stdout.trim().to_string();
Ok(iface)
}
pub(crate) fn flush_dns_cache() -> std::io::Result<()> {
run_command("ipconfig", &["/flushdns"])?;
Ok(())
}
pub fn alias_to_luid(alias: &str) -> std::io::Result<NET_LUID_LH> {
let alias = alias.encode_utf16().chain(std::iter::once(0)).collect::<Vec<_>>();
let mut luid = unsafe { std::mem::zeroed() };
match unsafe { ConvertInterfaceAliasToLuid(alias.as_ptr(), &mut luid) } {
0 => Ok(luid),
err => Err(std::io::Error::from_raw_os_error(err as _)),
}
}
pub fn luid_to_guid(luid: &NET_LUID_LH) -> std::io::Result<GUID> {
let mut guid = unsafe { std::mem::zeroed() };
match unsafe { ConvertInterfaceLuidToGuid(luid, &mut guid) } {
0 => Ok(guid),
err => Err(std::io::Error::from_raw_os_error(err as _)),
}
}
pub fn get_active_network_interface_gateways() -> std::io::Result<Vec<IpAddr>> {
let mut addrs = vec![];
get_adapters_addresses(|adapter| {
if adapter.OperStatus == IfOperStatusUp && [IF_TYPE_IEEE80211, IF_TYPE_ETHERNET_CSMACD].contains(&adapter.IfType) {
let mut current_gateway = adapter.FirstGatewayAddress;
while !current_gateway.is_null() {
let gateway = unsafe { &*current_gateway };
{
let sockaddr_ptr = gateway.Address.lpSockaddr;
let sockaddr = unsafe { &*(sockaddr_ptr as *const SOCKADDR) };
let a = unsafe { sockaddr_to_socket_addr(sockaddr) }?;
addrs.insert(0, a.ip());
}
current_gateway = gateway.Next;
}
}
Ok(())
})?;
if addrs.is_empty() {
Err(std::io::Error::other("No gateway found"))
} else {
Ok(addrs)
}
}
pub(crate) fn get_adapters_addresses<F>(mut callback: F) -> std::io::Result<()>
where
F: FnMut(IP_ADAPTER_ADDRESSES_LH) -> std::io::Result<()>,
{
let mut size = 0;
let flags = GAA_FLAG_INCLUDE_PREFIX | GAA_FLAG_INCLUDE_GATEWAYS;
let family = AF_UNSPEC as u32;
let result = unsafe { GetAdaptersAddresses(family, flags, std::ptr::null_mut(), std::ptr::null_mut(), &mut size) };
if result != ERROR_BUFFER_OVERFLOW {
return Err(std::io::Error::from_raw_os_error(result as _));
}
let mut addresses: Vec<u8> = vec![0; (size + 4) as usize];
let addrs = addresses.as_mut_ptr() as *mut IP_ADAPTER_ADDRESSES_LH;
let result = unsafe { GetAdaptersAddresses(family, flags, std::ptr::null_mut(), addrs, &mut size) };
if ERROR_SUCCESS != result {
return Err(std::io::Error::from_raw_os_error(result as _));
}
let mut current_addresses = addresses.as_ptr() as *const IP_ADAPTER_ADDRESSES_LH;
while !current_addresses.is_null() {
unsafe {
callback(*current_addresses)?;
current_addresses = (*current_addresses).Next;
}
}
Ok(())
}
pub(crate) unsafe fn sockaddr_to_socket_addr(sock_addr: *const SOCKADDR) -> std::io::Result<SocketAddr> {
let address = match (unsafe { *sock_addr }).sa_family {
AF_INET => unsafe { sockaddr_in_to_socket_addr(&*(sock_addr as *const SOCKADDR_IN)) },
AF_INET6 => unsafe { sockaddr_in6_to_socket_addr(&*(sock_addr as *const SOCKADDR_IN6)) },
_ => return Err(std::io::Error::other("Unsupported address type")),
};
Ok(address)
}
pub(crate) unsafe fn sockaddr_in_to_socket_addr(sockaddr_in: &SOCKADDR_IN) -> SocketAddr {
let ip_bytes = unsafe { sockaddr_in.sin_addr.S_un.S_addr.to_ne_bytes() };
let ip = std::net::IpAddr::from(ip_bytes);
let port = u16::from_be(sockaddr_in.sin_port);
SocketAddr::new(ip, port)
}
pub(crate) unsafe fn sockaddr_in6_to_socket_addr(sockaddr_in6: &SOCKADDR_IN6) -> SocketAddr {
let ip = std::net::IpAddr::from(unsafe { sockaddr_in6.sin6_addr.u.Byte });
let port = u16::from_be(sockaddr_in6.sin6_port);
SocketAddr::new(ip, port)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_default_gateway() {
let (addr, iface) = get_default_gateway().unwrap();
println!("addr: {addr:?}, iface: {iface}");
let gw = get_active_network_interface_gateways().unwrap();
println!("gateways: {gw:?}");
}
}