use std::sync::Arc;
use std::sync::mpsc;
use crate::task::TaskPriority;
const CHANNEL_CAPACITY: usize = 1024;
pub struct CrossThreadTask {
callback: Box<dyn FnOnce() + Send>,
priority: TaskPriority,
}
impl CrossThreadTask {
#[inline]
pub fn new(priority: TaskPriority, callback: impl FnOnce() + Send + 'static) -> Self {
Self {
callback: Box::new(callback),
priority,
}
}
#[inline]
#[must_use]
pub fn priority(&self) -> TaskPriority {
self.priority
}
#[inline]
pub fn run(self) {
(self.callback)();
}
}
pub struct WakeSender {
sender: mpsc::SyncSender<CrossThreadTask>,
notify: Option<Arc<dyn Fn() + Send + Sync>>,
}
impl WakeSender {
pub fn set_notify(&mut self, notify: Arc<dyn Fn() + Send + Sync>) {
self.notify = Some(notify);
}
#[inline]
pub fn send(&self, task: CrossThreadTask) -> Result<(), SendError> {
let result = self.sender.send(task).map_err(|_| SendError::Disconnected);
if result.is_ok() {
if let Some(notify) = &self.notify {
notify();
}
}
result
}
#[inline]
pub fn try_send(&self, task: CrossThreadTask) -> Result<(), SendError> {
let result = self.sender.try_send(task).map_err(|e| match e {
mpsc::TrySendError::Full(_) => SendError::Full,
mpsc::TrySendError::Disconnected(_) => SendError::Disconnected,
});
if result.is_ok() {
if let Some(notify) = &self.notify {
notify();
}
}
result
}
#[inline]
pub fn post(&self, callback: impl FnOnce() + Send + 'static) -> Result<(), SendError> {
self.send(CrossThreadTask::new(TaskPriority::Normal, callback))
}
#[inline]
pub fn post_with_priority(
&self,
priority: TaskPriority,
callback: impl FnOnce() + Send + 'static,
) -> Result<(), SendError> {
self.send(CrossThreadTask::new(priority, callback))
}
}
impl Clone for WakeSender {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
notify: self.notify.clone(),
}
}
}
pub struct WakeReceiver {
receiver: mpsc::Receiver<CrossThreadTask>,
_not_send: std::marker::PhantomData<*const ()>,
}
impl WakeReceiver {
#[inline]
#[must_use]
pub fn try_recv(&self) -> Option<CrossThreadTask> {
self.receiver.try_recv().ok()
}
pub fn drain_into(&self, mut f: impl FnMut(CrossThreadTask)) -> usize {
let mut count = 0;
while let Some(task) = self.try_recv() {
f(task);
count += 1;
}
count
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SendError {
Disconnected,
Full,
}
impl std::fmt::Display for SendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SendError::Disconnected => write!(f, "window thread disconnected"),
SendError::Full => write!(f, "cross-thread channel full"),
}
}
}
impl std::error::Error for SendError {}
#[must_use]
pub fn cross_thread_channel() -> (WakeSender, WakeReceiver) {
let (sender, receiver) = mpsc::sync_channel(CHANNEL_CAPACITY);
(
WakeSender {
sender,
notify: None,
},
WakeReceiver {
receiver,
_not_send: std::marker::PhantomData,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
#[test]
fn send_and_receive_task() {
let (sender, receiver) = cross_thread_channel();
let called = Arc::new(AtomicBool::new(false));
let c = called.clone();
sender
.post(move || c.store(true, Ordering::SeqCst))
.unwrap();
let task = receiver.try_recv().unwrap();
task.run();
assert!(called.load(Ordering::SeqCst));
}
#[test]
fn try_recv_empty_returns_none() {
let (_sender, receiver) = cross_thread_channel();
assert!(receiver.try_recv().is_none());
}
#[test]
fn drain_into_runs_all() {
let (sender, receiver) = cross_thread_channel();
let counter = Arc::new(AtomicU32::new(0));
for _ in 0..5 {
let c = counter.clone();
sender
.post(move || {
c.fetch_add(1, Ordering::SeqCst);
})
.unwrap();
}
let executed = receiver.drain_into(|task| task.run());
assert_eq!(executed, 5);
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
#[test]
fn sender_clone_works() {
let (sender, receiver) = cross_thread_channel();
let sender2 = sender.clone();
sender.post(|| {}).unwrap();
sender2.post(|| {}).unwrap();
assert!(receiver.try_recv().is_some());
assert!(receiver.try_recv().is_some());
assert!(receiver.try_recv().is_none());
}
#[test]
fn send_from_another_thread() {
let (sender, receiver) = cross_thread_channel();
let handle = std::thread::spawn(move || {
sender.post(|| {}).unwrap();
});
handle.join().unwrap();
assert!(receiver.try_recv().is_some());
}
#[test]
fn send_to_dropped_receiver_errors() {
let (sender, receiver) = cross_thread_channel();
drop(receiver);
let result = sender.post(|| {});
assert_eq!(result, Err(SendError::Disconnected));
}
#[test]
fn send_error_display() {
assert_eq!(
SendError::Disconnected.to_string(),
"window thread disconnected"
);
assert_eq!(SendError::Full.to_string(), "cross-thread channel full");
}
#[test]
fn cross_thread_task_carries_priority() {
let task = CrossThreadTask::new(TaskPriority::Input, || {});
assert_eq!(task.priority(), TaskPriority::Input);
}
#[test]
fn post_with_priority() {
let (sender, receiver) = cross_thread_channel();
sender
.post_with_priority(TaskPriority::Input, || {})
.unwrap();
let task = receiver.try_recv().unwrap();
assert_eq!(task.priority(), TaskPriority::Input);
}
}