use crate::{queue::QueueKey, task::TaskHeader};
use futures::task::AtomicWaker;
use std::any::Any;
use std::{
fmt,
future::Future,
pin::Pin,
sync::{atomic::Ordering, Arc},
task::{Context, Poll},
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PanicError {
message: String,
}
impl PanicError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
pub fn message(&self) -> &str {
&self.message
}
pub fn from_panic_payload(panic_payload: Box<dyn Any + Send>) -> Self {
let message = match panic_payload.downcast::<String>() {
Ok(msg) => format!("Task panicked: {}", msg),
Err(payload) => match payload.downcast::<&'static str>() {
Ok(msg) => format!("Task panicked: {}", msg),
Err(_) => "Task panicked with unknown payload".to_string(),
},
};
Self::new(message)
}
}
impl fmt::Display for PanicError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for PanicError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum JoinError {
Cancelled,
ResultTaken,
Panic(PanicError),
}
#[derive(Debug)]
pub struct JoinState<T> {
done: std::sync::atomic::AtomicBool,
result: std::sync::Mutex<Option<Result<T, JoinError>>>,
waker: AtomicWaker,
}
impl<T> JoinState<T> {
pub fn new() -> Self {
Self {
done: std::sync::atomic::AtomicBool::new(false),
result: std::sync::Mutex::new(None),
waker: AtomicWaker::new(),
}
}
#[inline]
pub fn is_done(&self) -> bool {
self.done.load(Ordering::Acquire)
}
fn try_complete(&self, result: Result<T, JoinError>) -> bool {
let mut guard = self.result.lock().unwrap();
if guard.is_some() {
return false;
}
*guard = Some(result);
drop(guard);
self.done.store(true, Ordering::Release);
self.waker.wake();
true
}
pub fn try_complete_ok(&self, val: T) -> bool {
self.try_complete(Ok(val))
}
pub fn try_complete_cancelled(&self) -> bool {
self.try_complete(Err(JoinError::Cancelled))
}
pub fn try_complete_err(&self, err: JoinError) -> bool {
self.try_complete(Err(err))
}
fn take_result(&self) -> Result<T, JoinError> {
let mut g = self.result.lock().unwrap();
if g.is_none() {
return Err(JoinError::ResultTaken);
}
g.take().unwrap()
}
}
#[derive(Clone, Debug)]
pub struct JoinHandle<T, K: QueueKey> {
header: Arc<TaskHeader<K>>,
join: Arc<JoinState<T>>,
}
impl<T, K: QueueKey> JoinHandle<T, K> {
pub fn new(header: Arc<TaskHeader<K>>, join: Arc<JoinState<T>>) -> Self {
Self { header, join }
}
pub fn abort(&self) {
self.header.cancel();
self.join.try_complete_cancelled();
self.header.enqueue();
}
}
impl<T, K: QueueKey> Future for JoinHandle<T, K> {
type Output = Result<T, JoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.join.is_done() {
return Poll::Ready(self.join.take_result());
}
self.join.waker.register(cx.waker());
if self.join.is_done() {
return Poll::Ready(self.join.take_result());
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_complete_ok() {
let state = JoinState::<i32>::new();
assert!(!state.is_done());
assert!(state.try_complete_ok(42));
assert!(state.is_done());
let result = state.take_result();
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_complete_cancelled() {
let state = JoinState::<i32>::new();
assert!(state.try_complete_cancelled());
assert!(state.is_done());
let result = state.take_result();
assert!(matches!(result, Err(JoinError::Cancelled)));
}
#[test]
fn test_complete_panic() {
let state = JoinState::<i32>::new();
let panic_err = PanicError::new("test panic");
assert!(state.try_complete_err(JoinError::Panic(panic_err)));
assert!(state.is_done());
let result = state.take_result();
match result {
Err(JoinError::Panic(e)) => assert_eq!(e.message(), "test panic"),
other => panic!("Expected Panic error, got {:?}", other),
}
}
#[test]
fn test_only_one_completer_wins() {
let state = JoinState::<i32>::new();
assert!(state.try_complete_ok(1));
assert!(!state.try_complete_ok(2));
assert!(!state.try_complete_cancelled());
assert!(!state.try_complete_err(JoinError::Cancelled));
let result = state.take_result();
assert_eq!(result.unwrap(), 1);
}
#[test]
fn test_result_taken_on_second_take() {
let state = JoinState::<i32>::new();
state.try_complete_ok(42);
let result1 = state.take_result();
assert_eq!(result1.unwrap(), 42);
let result2 = state.take_result();
assert!(matches!(result2, Err(JoinError::ResultTaken)));
}
#[test]
fn test_race_condition_result_visible_when_done() {
for _ in 0..1000 {
let state = Arc::new(JoinState::<i32>::new());
let state_writer = state.clone();
let state_reader = state.clone();
let writer = thread::spawn(move || {
state_writer.try_complete_ok(42);
});
let reader = thread::spawn(move || {
while !state_reader.is_done() {
std::hint::spin_loop();
}
let result = state_reader.take_result();
assert!(
result.is_ok(),
"Result should be Ok(42), got {:?}. Race condition bug!",
result
);
assert_eq!(result.unwrap(), 42);
});
writer.join().unwrap();
reader.join().unwrap();
}
}
}