use crate::{IfEvent, IpNet, Ipv4Net, Ipv6Net};
use fnv::FnvHashSet;
use futures::stream::{FusedStream, Stream};
use futures::task::AtomicWaker;
use if_addrs::IfAddr;
use std::collections::VecDeque;
use std::ffi::c_void;
use std::io::{Error, ErrorKind, Result};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use windows::Win32::Foundation::HANDLE;
use windows::Win32::NetworkManagement::IpHelper::{
CancelMibChangeNotify2, NotifyIpInterfaceChange, MIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE,
};
use windows::Win32::Networking::WinSock::AF_UNSPEC;
#[cfg(feature = "tokio")]
pub mod tokio {
pub type IfWatcher = super::IfWatcher;
}
#[cfg(feature = "smol")]
pub mod smol {
pub type IfWatcher = super::IfWatcher;
}
#[derive(Debug)]
pub struct IfWatcher {
addrs: FnvHashSet<IpNet>,
queue: VecDeque<IfEvent>,
#[allow(unused)]
notif: IpChangeNotification,
waker: Arc<AtomicWaker>,
resync: Arc<AtomicBool>,
}
impl IfWatcher {
pub fn new() -> Result<Self> {
let resync = Arc::new(AtomicBool::new(true));
let waker = Arc::new(AtomicWaker::new());
Ok(Self {
addrs: Default::default(),
queue: Default::default(),
waker: waker.clone(),
resync: resync.clone(),
notif: IpChangeNotification::new(Box::new(move |_, _| {
resync.store(true, Ordering::Relaxed);
waker.wake();
}))?,
})
}
fn resync(&mut self) -> Result<()> {
let addrs = if_addrs::get_if_addrs()?;
for old_addr in self.addrs.clone() {
if addrs
.iter()
.find(|addr| addr.ip() == old_addr.addr())
.is_none()
{
self.addrs.remove(&old_addr);
self.queue.push_back(IfEvent::Down(old_addr));
}
}
for new_addr in addrs {
let ipnet = ifaddr_to_ipnet(new_addr.addr);
if self.addrs.insert(ipnet) {
self.queue.push_back(IfEvent::Up(ipnet));
}
}
Ok(())
}
pub fn iter(&self) -> impl Iterator<Item = &IpNet> {
self.addrs.iter()
}
pub fn poll_if_event(&mut self, cx: &mut Context) -> Poll<Result<IfEvent>> {
loop {
if let Some(event) = self.queue.pop_front() {
return Poll::Ready(Ok(event));
}
if !self.resync.swap(false, Ordering::Relaxed) {
self.waker.register(cx.waker());
return Poll::Pending;
}
if let Err(error) = self.resync() {
return Poll::Ready(Err(error));
}
}
}
}
impl Stream for IfWatcher {
type Item = Result<IfEvent>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::into_inner(self).poll_if_event(cx).map(Some)
}
}
impl FusedStream for IfWatcher {
fn is_terminated(&self) -> bool {
false
}
}
fn ifaddr_to_ipnet(addr: IfAddr) -> IpNet {
match addr {
IfAddr::V4(ip) => {
let prefix_len = (!u32::from_be_bytes(ip.netmask.octets())).leading_zeros();
IpNet::V4(
Ipv4Net::new(ip.ip, prefix_len as u8).expect("if_addrs returned a valid prefix"),
)
}
IfAddr::V6(ip) => {
let prefix_len = (!u128::from_be_bytes(ip.netmask.octets())).leading_zeros();
IpNet::V6(
Ipv6Net::new(ip.ip, prefix_len as u8).expect("if_addrs returned a valid prefix"),
)
}
}
}
struct IpChangeNotification {
handle: HANDLE,
callback: *mut IpChangeCallback,
}
impl std::fmt::Debug for IpChangeNotification {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "IpChangeNotification")
}
}
type IpChangeCallback = Box<dyn Fn(&MIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) + Send>;
impl IpChangeNotification {
fn new(cb: IpChangeCallback) -> Result<Self> {
unsafe extern "system" fn global_callback(
caller_context: *const c_void,
row: *const MIB_IPINTERFACE_ROW,
notification_type: MIB_NOTIFICATION_TYPE,
) {
(**(caller_context as *const IpChangeCallback))(&*row, notification_type)
}
let mut handle = HANDLE::default();
let callback = Box::into_raw(Box::new(cb));
unsafe {
NotifyIpInterfaceChange(
AF_UNSPEC,
Some(global_callback),
Some(callback as _),
false,
&mut handle as _,
)
.ok()
.map_err(|err| Error::new(ErrorKind::Other, err.to_string()))?;
}
Ok(Self { callback, handle })
}
}
impl Drop for IpChangeNotification {
fn drop(&mut self) {
unsafe {
if let Err(err) = CancelMibChangeNotify2(self.handle).ok() {
log::error!("error deregistering notification: {}", err);
}
drop(Box::from_raw(self.callback));
}
}
}
unsafe impl Send for IpChangeNotification {}