#![doc = include_str!("../README.md")]
pub mod util;
use std::{
any::Any,
cell::Cell,
future::Future,
mem::{ManuallyDrop, MaybeUninit},
panic::{self, AssertUnwindSafe},
pin::{pin, Pin},
ptr::NonNull,
task::{Context, Poll, Waker},
};
use async_task::Runnable;
use util::{Window, WindowType};
use windows::Win32::{
Foundation::{LPARAM, LRESULT, WPARAM},
UI::WindowsAndMessaging::*,
};
use crate::util::MsgFilterHook;
const MSG_ID_WAKE: u32 = WM_USER;
thread_local! {
static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
if msg.msg == MSG_ID_WAKE {
let runnable = unsafe {
let runnable_ptr = NonNull::new_unchecked(msg.lparam.0 as *mut _);
Runnable::<()>::from_raw(runnable_ptr)
};
if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
PANIC_PAYLOAD.set(Some(panic_payload));
}
Some(LRESULT(0))
} else {
None
}
})
.unwrap();
}
pub struct JoinHandle<T> {
task: ManuallyDrop<async_task::Task<T>>,
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
let task = unsafe { ManuallyDrop::take(&mut self.task) };
task.detach();
}
}
impl<T> Future for JoinHandle<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
pin!(&mut *self.task).poll(cx)
}
}
unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
let (runnable, task) = unsafe {
async_task::spawn_unchecked(future, move |runnable: Runnable| {
let _ = PostMessageW(
Some(hwnd),
MSG_ID_WAKE,
WPARAM(0),
LPARAM(runnable.into_raw().as_ptr() as _),
);
})
};
runnable.schedule();
JoinHandle {
task: ManuallyDrop::new(task),
}
}
pub fn spawn_local<T: 'static>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
unsafe { spawn_unchecked_lifetime(future) }
}
pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
let msg_loop = &MessageLoop::new();
let task = unsafe {
spawn_unchecked_lifetime(async move {
let result = future.await;
msg_loop.quit();
result
})
};
msg_loop.run_loop(|_| FilterResult::Forward);
poll_ready(task).expect("received unexpected quit message")
}
fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
let future = pin!(future);
match future.poll(&mut Context::from_waker(Waker::noop())) {
Poll::Ready(result) => Ok(result),
Poll::Pending => Err(()),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FilterResult {
Forward,
Drop,
}
pub struct MessageLoop {
quit: Cell<bool>,
}
impl MessageLoop {
fn new() -> Self {
Self {
quit: Cell::new(false),
}
}
fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
let executor_hwnd = EXECUTOR_WINDOW.with(|ew| ew.hwnd());
while !self.quit.get() {
unsafe {
let mut msg = MaybeUninit::uninit();
if GetMessageW(msg.as_mut_ptr(), None, 0, 0).0 == 0 {
return;
}
let msg = msg.assume_init();
let is_wake_message = msg.hwnd == executor_hwnd && msg.message == MSG_ID_WAKE;
if is_wake_message || filter(&msg) == FilterResult::Forward {
let _ = TranslateMessage(&msg);
DispatchMessageW(&msg);
}
if let Some(panic_payload) = PANIC_PAYLOAD.take() {
panic::resume_unwind(panic_payload)
}
}
}
}
pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
let msg_loop = MessageLoop::new();
let _hook = unsafe {
MsgFilterHook::register(|msg| {
panic::catch_unwind(AssertUnwindSafe(|| {
let filter_result = filter(&msg_loop, msg);
if msg_loop.quit.get() {
let _ = PostMessageW(Some(msg.hwnd), WM_QUIT, WPARAM(0), LPARAM(0));
}
filter_result == FilterResult::Drop
}))
.unwrap_or_else(|payload| {
PANIC_PAYLOAD.with(|panic_payload| {
panic_payload.set(Some(payload));
});
let _ = PostMessageW(Some(msg.hwnd), WM_QUIT, WPARAM(0), LPARAM(0));
false
})
})
};
msg_loop.run_loop(|msg| filter(&msg_loop, msg));
}
pub fn quit(&self) {
self.quit.set(true);
}
pub fn quit_when_idle(&self) {
unsafe { PostQuitMessage(0) };
}
}
#[cfg(test)]
mod test {
use std::future::poll_fn;
use windows::core::{w, PCWSTR};
use windows::Win32::Foundation::HWND;
use super::*;
fn post_thread_message(msg: u32) {
let _ = unsafe { PostMessageW(None, msg, WPARAM(0), LPARAM(0)) };
}
#[test]
#[should_panic]
fn panic_in_dispatcher() {
post_thread_message(WM_USER);
MessageLoop::run(|_, _| panic!());
}
#[test]
fn message_loop_quit() {
for i in 0..10 {
post_thread_message(WM_USER + i);
}
MessageLoop::run(|msg_loop, msg| {
assert_eq!(msg.message, WM_USER);
msg_loop.quit();
FilterResult::Drop
});
}
#[test]
fn message_loop_quit_when_idle() {
for i in 0..10 {
post_thread_message(WM_USER + i);
}
let expected_msg = Cell::new(0);
MessageLoop::run(|msg_loop, msg| {
assert_eq!(msg.message, WM_USER + expected_msg.get());
expected_msg.set(expected_msg.get() + 1);
msg_loop.quit_when_idle();
FilterResult::Drop
});
assert_eq!(expected_msg.get(), 10);
}
#[test]
fn nested_block_on() {
let count: Cell<usize> = Cell::new(0);
block_on(async {
assert_eq!(count.get(), 0);
count.set(count.get() + 1);
block_on(async {
assert_eq!(count.get(), 1);
count.set(count.get() + 1);
});
assert_eq!(count.get(), 2);
count.set(count.get() + 1);
});
assert_eq!(count.get(), 3);
}
#[test]
#[should_panic]
fn nested_message_loop() {
post_thread_message(WM_USER);
MessageLoop::run(|_, _| {
MessageLoop::run(|_, _| FilterResult::Drop);
FilterResult::Drop
});
}
async fn yield_now() {
let mut yielded = false;
poll_fn(|cx| {
if yielded {
Poll::Ready(())
} else {
yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
})
.await;
}
#[test]
fn nested_message_loop_block_on() {
let inner_executed = Cell::new(false);
post_thread_message(WM_USER);
MessageLoop::run(|msg_loop, _| {
block_on(async {
inner_executed.set(true);
});
msg_loop.quit();
FilterResult::Forward
});
assert!(inner_executed.get());
}
#[test]
fn nested_message_loop_block_on_quit() {
post_thread_message(WM_USER);
MessageLoop::run(|msg_loop, _| {
block_on(async {
msg_loop.quit();
});
FilterResult::Forward
});
}
fn window_by_name(name: PCWSTR) -> HWND {
unsafe { FindWindowW(None, name) }.unwrap_or_default()
}
#[test]
fn running_spawned_with_modal_dialog() {
let window_name = w!("running_spawned_with_modal_dialog");
let task = spawn_local(async move {
while window_by_name(window_name).0.is_null() {
yield_now().await;
}
for _ in 0..10 {
yield_now().await;
}
unsafe {
SendMessageW(window_by_name(window_name), WM_CLOSE, Some(WPARAM(0)), Some(LPARAM(0)));
}
});
block_on(async {
unsafe {
MessageBoxW(
None,
PCWSTR::null(),
window_name,
MESSAGEBOX_STYLE(0),
);
}
task.await;
});
}
#[test]
#[should_panic]
fn reenter_filter_closure_panic() {
let window_name = w!("reenter_filter_closure");
post_thread_message(WM_USER);
let running_filter_closure = Cell::new(false);
MessageLoop::run(|_, msg| {
assert!(
!running_filter_closure.replace(true),
"Filter closure reentered"
);
if msg.hwnd.0.is_null() && msg.message == WM_USER {
unsafe {
MessageBoxW(
None,
PCWSTR::null(),
window_name,
MESSAGEBOX_STYLE(0),
);
}
}
running_filter_closure.set(false);
FilterResult::Forward
});
}
#[test]
fn reenter_filter_closure_quit() {
let window_name = w!("reenter_filter_closure");
post_thread_message(WM_USER);
let running_filter_closure = Cell::new(false);
MessageLoop::run(|msg_loop, msg| {
if running_filter_closure.replace(true) {
msg_loop.quit();
}
if msg.hwnd.0.is_null() && msg.message == WM_USER {
unsafe {
MessageBoxW(
None,
PCWSTR::null(),
window_name,
MESSAGEBOX_STYLE(0),
);
}
}
running_filter_closure.set(false);
FilterResult::Forward
});
}
#[test]
fn message_loop_with_modal_dialog() {
let window_name = w!("message_loop_with_modal_dialog");
spawn_local(async move {
unsafe {
MessageBoxW(
None,
PCWSTR::null(),
window_name,
MESSAGEBOX_STYLE(0),
);
}
});
spawn_local(async move {
assert!(!window_by_name(window_name).0.is_null());
for i in 0..10 {
post_thread_message(WM_USER + i);
yield_now().await;
}
unsafe { SendMessageW(window_by_name(window_name), WM_CLOSE, Some(WPARAM(0)), Some(LPARAM(0))) };
});
let expected_msg = Cell::new(0);
MessageLoop::run(|msg_loop, msg| {
if msg.hwnd.0.is_null() && msg.message >= WM_USER {
assert_eq!(msg.message, WM_USER + expected_msg.get());
expected_msg.set(expected_msg.get() + 1);
msg_loop.quit_when_idle();
FilterResult::Drop
} else {
FilterResult::Forward
}
});
assert_eq!(expected_msg.get(), 10);
}
#[test]
fn reenter_filter_closure_quit_when_idle() {
let window_name = w!("reenter_filter_closure");
post_thread_message(WM_USER);
let running_filter_closure = Cell::new(false);
MessageLoop::run(|msg_loop, msg| {
if running_filter_closure.replace(true) {
msg_loop.quit_when_idle();
}
if msg.hwnd.0.is_null() && msg.message == WM_USER {
unsafe {
MessageBoxW(
None,
PCWSTR::null(),
window_name,
MESSAGEBOX_STYLE(0),
);
}
}
running_filter_closure.set(false);
FilterResult::Forward
});
}
#[test]
fn disallow_wake_message_filtering() {
let msg_loop = MessageLoop::new();
let msg_loop = Box::leak(Box::new(msg_loop));
let custom_wnd = Window::new(WindowType::MessageOnly, (), |_, msg| {
assert_ne!(msg.msg, MSG_ID_WAKE);
None
})
.unwrap();
unsafe {
let _ = PostMessageW(Some(custom_wnd.hwnd()), MSG_ID_WAKE, WPARAM(0), LPARAM(0));
}
spawn_local(async {
yield_now().await;
yield_now().await;
yield_now().await;
msg_loop.quit();
});
msg_loop.run_loop(|msg| {
if msg.message == MSG_ID_WAKE {
assert_ne!(msg.hwnd, EXECUTOR_WINDOW.with(|ew| ew.hwnd()));
FilterResult::Drop
} else {
FilterResult::Forward
}
});
}
}