#![doc = include_str!("../README.md")]
pub mod util;
use std::{
any::Any,
cell::Cell,
future::Future,
mem::{ManuallyDrop, MaybeUninit},
panic::{self, AssertUnwindSafe},
pin::{pin, Pin},
ptr::{self, NonNull},
task::{Context, Poll, Waker},
};
use async_task::Runnable;
use util::{Window, WindowType};
use windows_sys::Win32::UI::WindowsAndMessaging::*;
use crate::util::MsgFilterHook;
const MSG_ID_WAKE: u32 = WM_USER;
thread_local! {
static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
if msg.msg == MSG_ID_WAKE {
let runnable = unsafe {
let runnable_ptr = NonNull::new_unchecked(msg.lparam as *mut _);
Runnable::<()>::from_raw(runnable_ptr)
};
if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
PANIC_PAYLOAD.set(Some(panic_payload));
}
Some(0)
} else {
None
}
})
.unwrap();
}
pub struct JoinHandle<T> {
task: ManuallyDrop<async_task::Task<T>>,
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
let task = unsafe { ManuallyDrop::take(&mut self.task) };
task.detach();
}
}
impl<T> Future for JoinHandle<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
pin!(&mut *self.task).poll(cx)
}
}
unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
let (runnable, task) = unsafe {
async_task::spawn_unchecked(future, move |runnable: Runnable| {
PostMessageA(hwnd, MSG_ID_WAKE, 0, runnable.into_raw().as_ptr() as _);
})
};
runnable.schedule();
JoinHandle {
task: ManuallyDrop::new(task),
}
}
pub fn spawn_local<T: 'static>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
unsafe { spawn_unchecked_lifetime(future) }
}
pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
let msg_loop = &MessageLoop::new();
let task = unsafe {
spawn_unchecked_lifetime(async move {
let result = future.await;
msg_loop.quit();
result
})
};
msg_loop.run_loop(|_| FilterResult::Forward);
poll_ready(task).expect("received unexpected quit message")
}
fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
let future = pin!(future);
match future.poll(&mut Context::from_waker(Waker::noop())) {
Poll::Ready(result) => Ok(result),
Poll::Pending => Err(()),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FilterResult {
Forward,
Drop,
}
pub struct MessageLoop {
quit: Cell<bool>,
}
impl MessageLoop {
fn new() -> Self {
Self {
quit: Cell::new(false),
}
}
fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
let executor_hwnd = EXECUTOR_WINDOW.with(|ew| ew.hwnd());
while !self.quit.get() {
unsafe {
let mut msg = MaybeUninit::uninit();
if GetMessageA(msg.as_mut_ptr(), ptr::null_mut(), 0, 0) == 0 {
return;
}
let msg = msg.assume_init();
let is_wake_message = msg.hwnd == executor_hwnd && msg.message == MSG_ID_WAKE;
if is_wake_message || filter(&msg) == FilterResult::Forward {
TranslateMessage(&msg);
DispatchMessageA(&msg);
}
if let Some(panic_payload) = PANIC_PAYLOAD.take() {
panic::resume_unwind(panic_payload)
}
}
}
}
pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
let msg_loop = MessageLoop::new();
let _hook = unsafe {
MsgFilterHook::register(|msg| {
panic::catch_unwind(AssertUnwindSafe(|| {
let filter_result = filter(&msg_loop, msg);
if msg_loop.quit.get() {
PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
}
filter_result == FilterResult::Drop
}))
.unwrap_or_else(|payload| {
PANIC_PAYLOAD.with(|panic_payload| {
panic_payload.set(Some(payload));
});
PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
false
})
})
};
msg_loop.run_loop(|msg| filter(&msg_loop, msg));
}
pub fn quit(&self) {
self.quit.set(true);
}
pub fn quit_when_idle(&self) {
unsafe { PostQuitMessage(0) };
}
}
#[cfg(test)]
mod test {
use std::{ffi::CStr, future::poll_fn};
use windows_sys::Win32::Foundation::HWND;
use super::*;
fn post_thread_message(msg: u32) {
unsafe { PostMessageA(ptr::null_mut(), msg, 0, 0) };
}
#[test]
#[should_panic]
fn panic_in_dispatcher() {
post_thread_message(WM_USER);
MessageLoop::run(|_, _| panic!());
}
#[test]
fn message_loop_quit() {
for i in 0..10 {
post_thread_message(WM_USER + i);
}
MessageLoop::run(|msg_loop, msg| {
assert_eq!(msg.message, WM_USER);
msg_loop.quit();
FilterResult::Drop
});
}
#[test]
fn message_loop_quit_when_idle() {
for i in 0..10 {
post_thread_message(WM_USER + i);
}
let expected_msg = Cell::new(0);
MessageLoop::run(|msg_loop, msg| {
assert_eq!(msg.message, WM_USER + expected_msg.get());
expected_msg.set(expected_msg.get() + 1);
msg_loop.quit_when_idle();
FilterResult::Drop
});
assert_eq!(expected_msg.get(), 10);
}
#[test]
fn nested_block_on() {
let count: Cell<usize> = Cell::new(0);
block_on(async {
assert_eq!(count.get(), 0);
count.set(count.get() + 1);
block_on(async {
assert_eq!(count.get(), 1);
count.set(count.get() + 1);
});
assert_eq!(count.get(), 2);
count.set(count.get() + 1);
});
assert_eq!(count.get(), 3);
}
#[test]
#[should_panic]
fn nested_message_loop() {
post_thread_message(WM_USER);
MessageLoop::run(|_, _| {
MessageLoop::run(|_, _| FilterResult::Drop);
FilterResult::Drop
});
}
async fn yield_now() {
let mut yielded = false;
poll_fn(|cx| {
if yielded {
Poll::Ready(())
} else {
yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
})
.await;
}
#[test]
fn nested_message_loop_block_on() {
let inner_executed = Cell::new(false);
post_thread_message(WM_USER);
MessageLoop::run(|msg_loop, _| {
block_on(async {
inner_executed.set(true);
});
msg_loop.quit();
FilterResult::Forward
});
assert!(inner_executed.get());
}
#[test]
fn nested_message_loop_block_on_quit() {
post_thread_message(WM_USER);
MessageLoop::run(|msg_loop, _| {
block_on(async {
msg_loop.quit();
});
FilterResult::Forward
});
}
fn window_by_name(name: &CStr) -> HWND {
unsafe { FindWindowA(ptr::null_mut(), name.as_ptr() as _) }
}
#[test]
fn running_spawned_with_modal_dialog() {
let window_name = c"running_spawned_with_modal_dialog";
let task = spawn_local(async {
while window_by_name(window_name).is_null() {
yield_now().await;
}
for _ in 0..10 {
yield_now().await;
}
unsafe {
SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0);
}
});
block_on(async {
unsafe {
MessageBoxA(
ptr::null_mut(),
ptr::null_mut(),
window_name.as_ptr() as _,
0,
);
}
task.await;
});
}
#[test]
#[should_panic]
fn reenter_filter_closure_panic() {
let window_name = c"reenter_filter_closure";
post_thread_message(WM_USER);
let running_filter_closure = Cell::new(false);
MessageLoop::run(|_, msg| {
assert!(
!running_filter_closure.replace(true),
"Filter closure reentered"
);
if msg.hwnd.is_null() && msg.message == WM_USER {
unsafe {
MessageBoxA(
ptr::null_mut(),
ptr::null_mut(),
window_name.as_ptr() as _,
0,
);
}
}
running_filter_closure.set(false);
FilterResult::Forward
});
}
#[test]
fn reenter_filter_closure_quit() {
let window_name = c"reenter_filter_closure";
post_thread_message(WM_USER);
let running_filter_closure = Cell::new(false);
MessageLoop::run(|msg_loop, msg| {
if running_filter_closure.replace(true) {
msg_loop.quit();
}
if msg.hwnd.is_null() && msg.message == WM_USER {
unsafe {
MessageBoxA(
ptr::null_mut(),
ptr::null_mut(),
window_name.as_ptr() as _,
0,
);
}
}
running_filter_closure.set(false);
FilterResult::Forward
});
}
#[test]
fn message_loop_with_modal_dialog() {
let window_name = c"message_loop_with_modal_dialog";
spawn_local(async {
unsafe {
MessageBoxA(
ptr::null_mut(),
ptr::null_mut(),
window_name.as_ptr() as _,
0,
);
}
});
spawn_local(async {
assert!(!window_by_name(window_name).is_null());
for i in 0..10 {
post_thread_message(WM_USER + i);
yield_now().await;
}
unsafe { SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0) };
});
let expected_msg = Cell::new(0);
MessageLoop::run(|msg_loop, msg| {
if msg.hwnd.is_null() && msg.message >= WM_USER {
assert_eq!(msg.message, WM_USER + expected_msg.get());
expected_msg.set(expected_msg.get() + 1);
msg_loop.quit_when_idle();
FilterResult::Drop
} else {
FilterResult::Forward
}
});
assert_eq!(expected_msg.get(), 10);
}
#[test]
fn reenter_filter_closure_quit_when_idle() {
let window_name = c"reenter_filter_closure";
post_thread_message(WM_USER);
let running_filter_closure = Cell::new(false);
MessageLoop::run(|msg_loop, msg| {
if running_filter_closure.replace(true) {
msg_loop.quit_when_idle();
}
if msg.hwnd.is_null() && msg.message == WM_USER {
unsafe {
MessageBoxA(
ptr::null_mut(),
ptr::null_mut(),
window_name.as_ptr() as _,
0,
);
}
}
running_filter_closure.set(false);
FilterResult::Forward
});
}
#[test]
fn disallow_wake_message_filtering() {
let msg_loop = MessageLoop::new();
let msg_loop = Box::leak(Box::new(msg_loop));
let custom_wnd = Window::new(WindowType::MessageOnly, (), |_, msg| {
assert_ne!(msg.msg, MSG_ID_WAKE);
None
})
.unwrap();
unsafe {
PostMessageA(custom_wnd.hwnd(), MSG_ID_WAKE, 0, 0);
}
spawn_local(async {
yield_now().await;
yield_now().await;
yield_now().await;
msg_loop.quit();
});
msg_loop.run_loop(|msg| {
if msg.message == MSG_ID_WAKE {
assert_ne!(msg.hwnd, EXECUTOR_WINDOW.with(|ew| ew.hwnd()));
FilterResult::Drop
} else {
FilterResult::Forward
}
});
}
}