use std::io::{self, Read, Write};
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use crate::Layer;
use crate::configuration::Configuration;
use crate::device::AbstractDevice;
use crate::error::{Error, Result};
use crate::windows::AbstractDeviceExt;
use windows_sys::Win32::NetworkManagement::IpHelper::SetIpInterfaceEntry;
use wintun_bindings::{Adapter, MAX_RING_CAPACITY, Session, load_from_path};
pub struct Device {
pub(crate) tun: Tun,
mtu: u16,
}
impl Device {
pub fn new(config: &Configuration) -> Result<Self> {
let layer = config.layer.unwrap_or(Layer::L3);
if layer == Layer::L3 {
let wintun_file = &config.platform_config.wintun_file;
let wintun = unsafe { load_from_path(wintun_file)? };
let tun_name = config.tun_name.as_deref().unwrap_or("wintun");
let guid = config.platform_config.device_guid;
let adapter = match Adapter::open(&wintun, tun_name) {
Ok(a) => a,
Err(e) => {
log::debug!("failed to open adapter: {e}");
Adapter::create(&wintun, tun_name, tun_name, guid)?
}
};
if let (Some(address), Some(mask)) = (config.address, config.netmask) {
let gateway = config.destination;
let luid_value = unsafe { adapter.get_luid().Value };
match (address, mask) {
(IpAddr::V4(addr), IpAddr::V4(mask_v4)) => {
set_unicast_address(luid_value, addr, mask_v4)?;
if let Some(IpAddr::V4(gw)) = gateway {
set_default_route(luid_value, gw)?;
}
}
_ => {
adapter.set_network_addresses_tuple(address, mask, gateway)?;
}
}
}
if let Some(metric) = config.metric {
let luid = unsafe { adapter.get_luid().Value };
set_interface_metric(luid, metric.into(), false)?;
set_interface_metric(luid, metric.into(), true)?;
}
if let Some(dns_servers) = &config.platform_config.dns_servers {
adapter.set_dns_servers(dns_servers)?;
}
if let Some(mtu) = config.mtu {
adapter.set_mtu(mtu as _)?;
}
let capacity = config.ring_capacity.unwrap_or(MAX_RING_CAPACITY);
let session = adapter.start_session(capacity)?;
let device = Device {
tun: Tun { session },
mtu: adapter.get_mtu()? as u16,
};
Ok(device)
} else if layer == Layer::L2 {
todo!()
} else {
panic!("unknow layer {layer:?}");
}
}
pub fn split(self) -> (Reader, Writer) {
let tun = Arc::new(self.tun);
(Reader(tun.clone()), Writer(tun))
}
pub fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
self.tun.recv(buf)
}
pub fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
self.tun.send(buf)
}
}
impl Read for Device {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.tun.read(buf)
}
}
impl Write for Device {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.tun.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.tun.flush()
}
}
impl AbstractDevice for Device {
fn tun_index(&self) -> Result<i32> {
Ok(self.tun.session.get_adapter().get_adapter_index()? as i32)
}
fn tun_name(&self) -> Result<String> {
Ok(self.tun.session.get_adapter().get_name()?)
}
fn set_tun_name(&mut self, value: &str) -> Result<()> {
Ok(self.tun.session.get_adapter().set_name(value)?)
}
fn enabled(&mut self, _value: bool) -> Result<()> {
Ok(())
}
fn address(&self) -> Result<IpAddr> {
let addresses = self.tun.session.get_adapter().get_addresses()?;
addresses
.iter()
.find_map(|a| match a {
std::net::IpAddr::V4(a) => Some(std::net::IpAddr::V4(*a)),
_ => None,
})
.ok_or(Error::InvalidConfig)
}
fn set_address(&mut self, value: IpAddr) -> Result<()> {
let IpAddr::V4(value) = value else {
unimplemented!("do not support IPv6 yet")
};
Ok(self.tun.session.get_adapter().set_address(value)?)
}
fn destination(&self) -> Result<IpAddr> {
self.tun
.session
.get_adapter()
.get_gateways()?
.iter()
.find_map(|a| match a {
std::net::IpAddr::V4(a) => Some(std::net::IpAddr::V4(*a)),
_ => None,
})
.ok_or(Error::InvalidConfig)
}
fn set_destination(&mut self, value: IpAddr) -> Result<()> {
let IpAddr::V4(value) = value else {
unimplemented!("do not support IPv6 yet")
};
Ok(self.tun.session.get_adapter().set_gateway(Some(value))?)
}
fn broadcast(&self) -> Result<IpAddr> {
Err(Error::NotImplemented)
}
fn set_broadcast(&mut self, value: IpAddr) -> Result<()> {
log::debug!("set_broadcast {value} is not need");
Ok(())
}
fn netmask(&self) -> Result<IpAddr> {
let current_addr = self.address()?;
self.tun
.session
.get_adapter()
.get_netmask_of_address(¤t_addr)
.map_err(Error::WintunError)
}
fn set_netmask(&mut self, value: IpAddr) -> Result<()> {
let IpAddr::V4(value) = value else {
unimplemented!("do not support IPv6 yet")
};
Ok(self.tun.session.get_adapter().set_netmask(value)?)
}
fn mtu(&self) -> Result<u16> {
Ok(self.mtu)
}
fn set_mtu(&mut self, mtu: u16) -> Result<()> {
self.tun.session.get_adapter().set_mtu(mtu as _)?;
self.mtu = mtu;
Ok(())
}
fn packet_information(&self) -> bool {
false
}
}
impl AbstractDeviceExt for Device {
fn tun_luid(&self) -> u64 {
unsafe { self.tun.session.get_adapter().get_luid().Value }
}
}
pub struct Tun {
session: Arc<Session>,
}
impl Tun {
pub fn get_session(&self) -> Arc<Session> {
self.session.clone()
}
fn read_by_ref(&self, mut buf: &mut [u8]) -> std::io::Result<usize> {
use std::io::{Error, ErrorKind::ConnectionAborted};
match self.session.receive_blocking() {
Ok(pkt) => match std::io::copy(&mut pkt.bytes(), &mut buf) {
Ok(n) => Ok(n as usize),
Err(e) => Err(e),
},
Err(e) => Err(Error::new(ConnectionAborted, e)),
}
}
fn write_by_ref(&self, mut buf: &[u8]) -> std::io::Result<usize> {
use std::io::{Error, ErrorKind::OutOfMemory};
let size = buf.len();
match self.session.allocate_send_packet(size as u16) {
Err(e) => Err(Error::new(OutOfMemory, e)),
Ok(mut packet) => match std::io::copy(&mut buf, &mut packet.bytes_mut()) {
Ok(s) => {
self.session.send_packet(packet);
Ok(s as usize)
}
Err(e) => Err(e),
},
}
}
pub fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
self.read_by_ref(buf)
}
pub fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
self.write_by_ref(buf)
}
}
impl Read for Tun {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.read_by_ref(buf)
}
}
impl Write for Tun {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.write_by_ref(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[repr(transparent)]
pub struct Reader(Arc<Tun>);
impl Read for Reader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.0.read_by_ref(buf)
}
}
#[repr(transparent)]
pub struct Writer(Arc<Tun>);
impl Write for Writer {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.write_by_ref(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
fn netmask_to_prefix_len(mask: Ipv4Addr) -> u8 {
let bits = u32::from(mask);
let prefix = bits.leading_ones() as u8;
debug_assert_eq!(
bits,
u32::MAX.checked_shl(32 - prefix as u32).unwrap_or(0),
"non-contiguous netmask"
);
prefix
}
fn set_unicast_address(luid: u64, address: Ipv4Addr, mask: Ipv4Addr) -> io::Result<()> {
use windows_sys::Win32::NetworkManagement::IpHelper::{
CreateUnicastIpAddressEntry, DeleteUnicastIpAddressEntry, MIB_UNICASTIPADDRESS_ROW,
};
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
use windows_sys::Win32::Networking::WinSock::AF_INET;
unsafe {
let mut row: MIB_UNICASTIPADDRESS_ROW = std::mem::zeroed();
row.InterfaceLuid = NET_LUID_LH { Value: luid };
row.Address.si_family = AF_INET;
row.Address.Ipv4.sin_family = AF_INET;
row.Address.Ipv4.sin_addr.S_un.S_addr = u32::from_ne_bytes(address.octets());
row.OnLinkPrefixLength = netmask_to_prefix_len(mask);
row.DadState = 4; row.ValidLifetime = u32::MAX;
row.PreferredLifetime = u32::MAX;
row.PrefixOrigin = 1; row.SuffixOrigin = 1;
let del_status = DeleteUnicastIpAddressEntry(&row);
if del_status != 0 && del_status != 2
{
log::warn!("DeleteUnicastIpAddressEntry failed: {del_status}");
}
let status = CreateUnicastIpAddressEntry(&row);
if status == 0 {
return Ok(());
}
log::error!("CreateUnicastIpAddressEntry failed: {status}");
Err(io::Error::from_raw_os_error(status as i32))
}
}
fn set_default_route(luid: u64, gateway: Ipv4Addr) -> io::Result<()> {
use windows_sys::Win32::NetworkManagement::IpHelper::{
CreateIpForwardEntry2, DeleteIpForwardEntry2, MIB_IPFORWARD_ROW2,
};
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
use windows_sys::Win32::Networking::WinSock::AF_INET;
unsafe {
let mut row: MIB_IPFORWARD_ROW2 = std::mem::zeroed();
row.InterfaceLuid = NET_LUID_LH { Value: luid };
row.DestinationPrefix.Prefix.si_family = AF_INET;
row.DestinationPrefix.Prefix.Ipv4.sin_family = AF_INET;
row.DestinationPrefix.PrefixLength = 0;
row.NextHop.si_family = AF_INET;
row.NextHop.Ipv4.sin_family = AF_INET;
row.NextHop.Ipv4.sin_addr.S_un.S_addr = u32::from_ne_bytes(gateway.octets());
row.Metric = 0;
row.Protocol = 3; row.ValidLifetime = u32::MAX;
row.PreferredLifetime = u32::MAX;
let del_status = DeleteIpForwardEntry2(&row);
if del_status != 0 && del_status != 2
{
log::warn!("DeleteIpForwardEntry2 failed: {del_status}");
}
let status = CreateIpForwardEntry2(&row);
if status != 0 {
log::error!("CreateIpForwardEntry2 failed: {status}");
return Err(io::Error::from_raw_os_error(status as i32));
}
Ok(())
}
}
fn set_interface_metric(luid: u64, metric: u32, ipv6: bool) -> io::Result<()> {
use windows_sys::Win32::NetworkManagement::IpHelper::{
GetIpInterfaceEntry, MIB_IPINTERFACE_ROW,
};
use windows_sys::Win32::NetworkManagement::Ndis::NET_LUID_LH;
use windows_sys::Win32::Networking::WinSock::{AF_INET, AF_INET6};
let luid = NET_LUID_LH { Value: luid };
let family = if ipv6 { AF_INET6 } else { AF_INET };
let mut row = MIB_IPINTERFACE_ROW {
InterfaceLuid: luid,
Family: family,
..Default::default()
};
let status = unsafe { GetIpInterfaceEntry(&mut row) };
if status != 0 {
log::error!("GetIpInterfaceEntry failed with error: {status}");
return Err(io::Error::from_raw_os_error(status as i32));
}
row.SitePrefixLength = 0;
row.Metric = metric;
row.UseAutomaticMetric = false;
let status = unsafe { SetIpInterfaceEntry(&mut row) };
if status != 0 {
log::error!("SetIpInterfaceEntry failed with error: {status}");
return Err(io::Error::from_raw_os_error(status as i32));
}
Ok(())
}