use crate::monitor::{ApiError, ApiListener};
use std::pin::Pin;
use std::sync::mpsc;
use std::task::{Context, Poll};
use tokio::sync::mpsc as tokio_mpsc;
use tokio_stream::Stream;
use windows::Win32::Foundation::{HANDLE, NO_ERROR, WIN32_ERROR};
use windows::Win32::NetworkManagement::IpHelper::{
CancelMibChangeNotify2, MIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE, NotifyIpInterfaceChange,
};
use windows::Win32::Networking::WinSock::AF_UNSPEC;
#[derive(Debug, Default)]
pub struct WindowsApiListener {
_private: (),
}
impl WindowsApiListener {
pub const fn new() -> Result<Self, ApiError> {
Ok(Self { _private: () })
}
}
impl ApiListener for WindowsApiListener {
type Stream = WindowsApiStream;
fn into_stream(self) -> Self::Stream {
WindowsApiStream::new()
}
}
pub struct WindowsApiStream {
receiver: tokio_mpsc::UnboundedReceiver<Result<(), ApiError>>,
#[allow(dead_code)]
handle: Option<NotificationHandle>,
terminated: bool,
}
impl std::fmt::Debug for WindowsApiStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WindowsApiStream")
.field("terminated", &self.terminated)
.field("has_handle", &self.handle.is_some())
.finish_non_exhaustive()
}
}
struct NotificationHandle {
handle: HANDLE,
context_ptr: *mut CallbackContext,
}
impl Drop for NotificationHandle {
fn drop(&mut self) {
let _ = unsafe { CancelMibChangeNotify2(self.handle) };
drop(unsafe { Box::from_raw(self.context_ptr) });
}
}
unsafe impl Send for NotificationHandle {}
struct CallbackContext {
sender: mpsc::Sender<()>,
}
impl WindowsApiStream {
fn new() -> Self {
let (sync_tx, sync_rx) = mpsc::channel::<()>();
let (async_tx, async_rx) = tokio_mpsc::unbounded_channel();
let bridge_tx = async_tx.clone();
std::thread::spawn(move || {
while sync_rx.recv().is_ok() {
if bridge_tx.send(Ok(())).is_err() {
break;
}
}
});
let (handle, terminated) = match register_notification(sync_tx) {
Ok((h, ctx_ptr)) => (
Some(NotificationHandle {
handle: h,
context_ptr: ctx_ptr,
}),
false,
),
Err(e) => {
let _ = async_tx.send(Err(e));
(None, true)
}
};
Self {
receiver: async_rx,
handle,
terminated,
}
}
}
impl Stream for WindowsApiStream {
type Item = Result<(), ApiError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.terminated {
return Poll::Ready(None);
}
match Pin::new(&mut self.receiver).poll_recv(cx) {
Poll::Ready(Some(Ok(()))) => Poll::Ready(Some(Ok(()))),
Poll::Ready(Some(Err(e))) => {
self.terminated = true;
Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => {
self.terminated = true;
Poll::Ready(Some(Err(ApiError::Stopped)))
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(not(tarpaulin_include))]
fn register_notification(
sender: mpsc::Sender<()>,
) -> Result<(HANDLE, *mut CallbackContext), ApiError> {
let context_ptr = Box::into_raw(Box::new(CallbackContext { sender }));
let void_ptr = context_ptr.cast::<std::ffi::c_void>();
let mut handle = HANDLE::default();
let result = unsafe {
NotifyIpInterfaceChange(
AF_UNSPEC,
Some(ip_interface_change_callback),
Some(void_ptr),
false, &raw mut handle,
)
};
if result != NO_ERROR {
drop(unsafe { Box::from_raw(context_ptr) });
return Err(windows::core::Error::from(WIN32_ERROR(result.0)).into());
}
Ok((handle, context_ptr))
}
#[cfg(not(tarpaulin_include))]
unsafe extern "system" fn ip_interface_change_callback(
caller_context: *const std::ffi::c_void,
_row: *const MIB_IPINTERFACE_ROW,
_notification_type: MIB_NOTIFICATION_TYPE,
) {
if caller_context.is_null() {
return;
}
let context = unsafe { &*(caller_context.cast::<CallbackContext>()) };
let _ = context.sender.send(());
}