use futures::task::{self, Task};
use std::fmt;
use std::cell::UnsafeCell;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{Acquire, Release};
pub struct AtomicTask {
state: AtomicUsize,
task: UnsafeCell<Option<Task>>,
}
const WAITING: usize = 2;
const LOCKED_WRITE: usize = 0;
const LOCKED_WRITE_NOTIFIED: usize = 1;
#[allow(dead_code)]
const LOCKED_READ: usize = 3;
impl AtomicTask {
pub fn new() -> AtomicTask {
trait AssertSync: Sync {}
impl AssertSync for Task {}
AtomicTask {
state: AtomicUsize::new(WAITING),
task: UnsafeCell::new(None),
}
}
pub fn register(&self) {
self.register_task(task::current());
}
pub fn register_task(&self, task: Task) {
match self.state.compare_and_swap(WAITING, LOCKED_WRITE, Acquire) {
WAITING => {
unsafe {
*self.task.get() = Some(task);
if LOCKED_WRITE_NOTIFIED == self.state.swap(WAITING, Release) {
(*self.task.get()).as_ref().unwrap().notify();
}
}
}
LOCKED_WRITE | LOCKED_WRITE_NOTIFIED => {
}
state => {
debug_assert!(state != LOCKED_WRITE, "unexpected state LOCKED_WRITE");
debug_assert!(state != LOCKED_WRITE_NOTIFIED, "unexpected state LOCKED_WRITE_NOTIFIED");
task.notify();
}
}
}
pub fn notify(&self) {
let mut curr = WAITING;
loop {
if curr == LOCKED_WRITE {
let actual = self.state.compare_and_swap(LOCKED_WRITE, LOCKED_WRITE_NOTIFIED, Release);
if curr == actual {
return;
}
curr = actual;
} else if curr == LOCKED_WRITE_NOTIFIED {
return;
} else {
let actual = self.state.compare_and_swap(curr, curr + 1, Acquire);
if actual == curr {
unsafe {
if let Some(ref task) = *self.task.get() {
task.notify();
}
}
self.state.fetch_sub(1, Release);
return;
}
curr = actual;
}
}
}
}
impl Default for AtomicTask {
fn default() -> Self {
AtomicTask::new()
}
}
impl fmt::Debug for AtomicTask {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "AtomicTask")
}
}
unsafe impl Send for AtomicTask {}
unsafe impl Sync for AtomicTask {}