use std::{
thread::{ThreadId, current},
time::Duration,
};
use async_task::Runnable;
use flume::Sender;
use util::ResultExt;
use windows::{
System::Threading::{
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
},
Win32::{
Foundation::{LPARAM, WPARAM},
UI::WindowsAndMessaging::PostMessageW,
},
};
use crate::{
HWND, PlatformDispatcher, SafeHwnd, TaskLabel, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
};
pub(crate) struct WindowsDispatcher {
main_sender: Sender<Runnable>,
main_thread_id: ThreadId,
platform_window_handle: SafeHwnd,
validation_number: usize,
}
impl WindowsDispatcher {
pub(crate) fn new(
main_sender: Sender<Runnable>,
platform_window_handle: HWND,
validation_number: usize,
) -> Self {
let main_thread_id = current().id();
let platform_window_handle = platform_window_handle.into();
WindowsDispatcher {
main_sender,
main_thread_id,
platform_window_handle,
validation_number,
}
}
fn dispatch_on_threadpool(&self, runnable: Runnable) {
let handler = {
let mut task_wrapper = Some(runnable);
WorkItemHandler::new(move |_| {
task_wrapper.take().unwrap().run();
Ok(())
})
};
ThreadPool::RunWithPriorityAsync(&handler, WorkItemPriority::High).log_err();
}
fn dispatch_on_threadpool_after(&self, runnable: Runnable, duration: Duration) {
let handler = {
let mut task_wrapper = Some(runnable);
TimerElapsedHandler::new(move |_| {
task_wrapper.take().unwrap().run();
Ok(())
})
};
ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
}
}
impl PlatformDispatcher for WindowsDispatcher {
fn is_main_thread(&self) -> bool {
current().id() == self.main_thread_id
}
fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
self.dispatch_on_threadpool(runnable);
if let Some(label) = label {
log::debug!("TaskLabel: {label:?}");
}
}
fn dispatch_on_main_thread(&self, runnable: Runnable) {
match self.main_sender.send(runnable) {
Ok(_) => unsafe {
PostMessageW(
Some(self.platform_window_handle.as_raw()),
WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
WPARAM(self.validation_number),
LPARAM(0),
)
.log_err();
},
Err(runnable) => {
std::mem::forget(runnable);
}
}
}
fn dispatch_after(&self, duration: Duration, runnable: Runnable) {
self.dispatch_on_threadpool_after(runnable, duration);
}
}