use futures_util::Future;
use std::{
mem::MaybeUninit,
pin::Pin,
sync::{Arc, atomic::Ordering},
task::{Context, Poll, Waker},
};
use std::sync::atomic::AtomicPtr;
#[cfg(feature = "runtoken-id")]
use std::sync::atomic::AtomicU64;
#[cfg(feature = "ordered-locks")]
use ordered_locks::{L0, LockToken};
#[cfg(feature = "runtoken-id")]
static IDC: AtomicU64 = AtomicU64::new(0);
pub struct IntrusiveList<T> {
first: *mut ListNode<T>,
}
impl<T> Default for IntrusiveList<T> {
fn default() -> Self {
Self {
first: std::ptr::null_mut(),
}
}
}
impl<T> IntrusiveList<T> {
unsafe fn push_back(&mut self, node: *mut ListNode<T>, v: T) {
let n = unsafe { &mut *node };
assert!(n.next.is_null());
n.data.write(v);
if self.first.is_null() {
n.next = n;
n.prev = n;
self.first = n;
} else {
let f = unsafe { &mut *self.first };
n.prev = f.prev;
n.next = self.first;
unsafe {
(*n.prev).next = node;
}
f.prev = node;
}
}
unsafe fn remove(&mut self, node: *mut ListNode<T>) -> T {
let n = unsafe { &mut *node };
assert!(!n.next.is_null());
let v = unsafe { n.data.as_mut_ptr().read() };
if n.next == node {
self.first = std::ptr::null_mut();
} else {
if self.first == node {
self.first = n.next;
}
unsafe {
(*n.next).prev = n.prev;
}
unsafe {
(*n.prev).next = n.next;
}
}
n.next = std::ptr::null_mut();
n.prev = std::ptr::null_mut();
v
}
fn drain(&mut self, v: impl Fn(T)) {
if self.first.is_null() {
return;
}
let mut cur = self.first;
loop {
let c = unsafe { &mut *cur };
let d = unsafe { c.data.as_mut_ptr().read() };
v(d);
let next = c.next;
c.next = std::ptr::null_mut();
c.prev = std::ptr::null_mut();
if next == self.first {
break;
}
cur = next;
}
self.first = std::ptr::null_mut();
}
unsafe fn in_list(&self, node: *mut ListNode<T>) -> bool {
unsafe { !(*node).next.is_null() }
}
}
pub struct ListNode<T> {
prev: *mut ListNode<T>,
next: *mut ListNode<T>,
data: std::mem::MaybeUninit<T>,
_pin: std::marker::PhantomPinned,
}
impl<T> Default for ListNode<T> {
fn default() -> Self {
Self {
prev: std::ptr::null_mut(),
next: std::ptr::null_mut(),
data: MaybeUninit::uninit(),
_pin: Default::default(),
}
}
}
enum State {
Run,
Cancel,
#[cfg(feature = "pause")]
Pause,
}
struct Content {
state: State,
cancel_wakers: IntrusiveList<Waker>,
run_wakers: IntrusiveList<Waker>,
#[cfg(feature = "runtoken-user-data")]
user_data: Option<String>,
}
unsafe impl Send for Content {}
impl Content {
unsafe fn add_cancel_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
let in_list = unsafe { self.cancel_wakers.in_list(node) };
if !in_list {
unsafe { self.cancel_wakers.push_back(node, waker.clone()) }
}
}
#[cfg(feature = "pause")]
unsafe fn add_run_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
let in_list = unsafe { self.run_wakers.in_list(node) };
if !in_list {
unsafe { self.run_wakers.push_back(node, waker.clone()) }
}
}
unsafe fn remove_cancel_waker(&mut self, node: *mut ListNode<Waker>) {
let in_list = unsafe { self.cancel_wakers.in_list(node) };
if in_list {
unsafe { self.cancel_wakers.remove(node) };
}
}
#[cfg(feature = "pause")]
unsafe fn remove_run_waker(&mut self, node: *mut ListNode<Waker>) {
let in_list = unsafe { self.run_wakers.in_list(node) };
if in_list {
unsafe { self.run_wakers.remove(node) };
}
}
}
struct Inner {
cond: std::sync::Condvar,
content: std::sync::Mutex<Content>,
#[cfg(feature = "runtoken-id")]
id: u64,
location_file_line: AtomicPtr<u8>,
}
#[derive(Clone)]
pub struct RunToken(Arc<Inner>);
impl RunToken {
#[cfg(feature = "pause")]
pub fn new_paused() -> Self {
Self(Arc::new(Inner {
cond: std::sync::Condvar::new(),
content: std::sync::Mutex::new(Content {
state: State::Pause,
cancel_wakers: Default::default(),
run_wakers: Default::default(),
#[cfg(feature = "runtoken-user-data")]
user_data: None,
}),
location_file_line: Default::default(),
#[cfg(feature = "runtoken-id")]
id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
}))
}
pub fn new() -> Self {
Self(Arc::new(Inner {
cond: std::sync::Condvar::new(),
content: std::sync::Mutex::new(Content {
state: State::Run,
cancel_wakers: Default::default(),
run_wakers: Default::default(),
#[cfg(feature = "runtoken-user-data")]
user_data: None,
}),
location_file_line: Default::default(),
#[cfg(feature = "runtoken-id")]
id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
}))
}
pub fn cancel(&self) {
let mut content = self.0.content.lock().unwrap();
if matches!(content.state, State::Cancel) {
return;
}
content.state = State::Cancel;
content.run_wakers.drain(|w| w.wake());
content.cancel_wakers.drain(|w| w.wake());
self.0.cond.notify_all();
}
#[cfg(feature = "pause")]
pub fn pause(&self) {
let mut content = self.0.content.lock().unwrap();
if !matches!(content.state, State::Run) {
return;
}
content.state = State::Pause;
}
#[cfg(feature = "pause")]
pub fn resume(&self) {
let mut content = self.0.content.lock().unwrap();
if !matches!(content.state, State::Pause) {
return;
}
content.state = State::Run;
content.run_wakers.drain(|w| w.wake());
self.0.cond.notify_all();
}
pub fn is_cancelled(&self) -> bool {
matches!(self.0.content.lock().unwrap().state, State::Cancel)
}
#[cfg(feature = "pause")]
pub fn is_paused(&self) -> bool {
matches!(self.0.content.lock().unwrap().state, State::Pause)
}
#[cfg(feature = "pause")]
pub fn is_running(&self) -> bool {
matches!(self.0.content.lock().unwrap().state, State::Run)
}
#[cfg(feature = "pause")]
pub fn wait_paused_check_cancelled_sync(&self) -> bool {
let mut content = self.0.content.lock().unwrap();
loop {
match &content.state {
State::Run => return false,
State::Cancel => return true,
State::Pause => {
content = self.0.cond.wait(content).unwrap();
}
}
}
}
#[cfg(feature = "pause")]
pub fn wait_paused_check_cancelled(&self) -> WaitForPauseFuture<'_> {
WaitForPauseFuture {
token: self,
waker: Default::default(),
}
}
pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
WaitForCancellationFuture {
token: self,
waker: Default::default(),
}
}
#[cfg(feature = "ordered-locks")]
pub fn cancelled_checked(
&self,
_lock_token: LockToken<'_, L0>,
) -> WaitForCancellationFuture<'_> {
WaitForCancellationFuture {
token: self,
waker: Default::default(),
}
}
#[inline]
pub fn set_location_file_line(&self, file_line_str: &'static str) {
assert!(file_line_str.ends_with('\0'));
self.0
.location_file_line
.store(file_line_str.as_ptr() as *mut u8, Ordering::Relaxed);
}
pub fn location(&self) -> Option<(&'static str, u32)> {
let location_file_line = self.0.location_file_line.load(Ordering::Relaxed) as *const u8;
if location_file_line.is_null() {
return None;
}
let mut len = 0;
loop {
let l = unsafe { location_file_line.add(len) };
let c = unsafe { *l };
if c == b'\0' {
break;
}
len += 1;
}
let location_file_line = unsafe { std::slice::from_raw_parts(location_file_line, len) };
let location_file_line = unsafe { std::str::from_utf8_unchecked(location_file_line) };
match location_file_line.rsplit_once(":") {
Some((file, line)) => match line.parse() {
Ok(v) => Some((file, v)),
Err(_) => Some((location_file_line, 0)),
},
None => Some((location_file_line, 0)),
}
}
#[cfg(feature = "runtoken-id")]
#[inline]
pub fn id(&self) -> u64 {
self.0.id
}
#[cfg(feature = "runtoken-user-data")]
pub fn set_user_data(&self, data: Option<String>) {
self.0.content.lock().unwrap().user_data = data;
}
#[cfg(feature = "runtoken-user-data")]
pub fn user_data(&self) -> Option<String> {
self.0.content.lock().unwrap().user_data.clone()
}
}
#[macro_export]
macro_rules! set_location {
($run_token: expr) => {
$run_token.set_location_file_line(concat!(file!(), ":", line!(), "\0"));
};
}
impl Default for RunToken {
fn default() -> Self {
Self::new()
}
}
impl core::fmt::Debug for RunToken {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let mut d = f.debug_tuple("RunToken");
match self.0.content.lock().unwrap().state {
State::Run => d.field(&"Running"),
State::Cancel => d.field(&"Canceled"),
#[cfg(feature = "pause")]
State::Pause => d.field(&"Paused"),
};
d.finish()
}
}
#[must_use = "futures do nothing unless polled"]
pub struct WaitForCancellationFuture<'a> {
token: &'a RunToken,
waker: ListNode<Waker>,
}
impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("WaitForCancellationFuture").finish()
}
}
impl<'a> Future for WaitForCancellationFuture<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let mut content = self.token.0.content.lock().unwrap();
match content.state {
State::Cancel => Poll::Ready(()),
State::Run => {
let node = unsafe { &mut Pin::get_unchecked_mut(self).waker };
unsafe { content.add_cancel_waker(node, cx.waker()) };
Poll::Pending
}
#[cfg(feature = "pause")]
State::Pause => {
let node = unsafe { &mut Pin::get_unchecked_mut(self).waker };
unsafe { content.add_cancel_waker(node, cx.waker()) };
Poll::Pending
}
}
}
}
impl<'a> Drop for WaitForCancellationFuture<'a> {
fn drop(&mut self) {
unsafe {
self.token
.0
.content
.lock()
.unwrap()
.remove_cancel_waker(&mut self.waker);
}
}
}
unsafe impl<'a> Send for WaitForCancellationFuture<'a> {}
#[cfg(feature = "pause")]
#[must_use = "futures do nothing unless polled"]
pub struct WaitForPauseFuture<'a> {
token: &'a RunToken,
waker: ListNode<Waker>,
}
#[cfg(feature = "pause")]
impl<'a> core::fmt::Debug for WaitForPauseFuture<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("WaitForPauseFuture").finish()
}
}
#[cfg(feature = "pause")]
impl<'a> Future for WaitForPauseFuture<'a> {
type Output = bool;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> {
let mut content = self.token.0.content.lock().unwrap();
match content.state {
State::Cancel => Poll::Ready(true),
State::Run => Poll::Ready(false),
State::Pause => {
let node = unsafe { &mut Pin::get_unchecked_mut(self).waker };
unsafe { content.add_run_waker(node, cx.waker()) };
Poll::Pending
}
}
}
}
#[cfg(feature = "pause")]
impl<'a> Drop for WaitForPauseFuture<'a> {
fn drop(&mut self) {
unsafe {
self.token
.0
.content
.lock()
.unwrap()
.remove_run_waker(&mut self.waker);
}
}
}
#[cfg(feature = "pause")]
unsafe impl<'a> Send for WaitForPauseFuture<'a> {}