use pin_project_lite::pin_project;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::{futures::Notified, Notify};
#[cfg(feature = "rt")]
use tokio::{
runtime::Handle,
task::{JoinHandle, LocalSet},
};
pub struct TaskTracker {
inner: Arc<TaskTrackerInner>,
}
#[must_use]
#[derive(Debug)]
pub struct TaskTrackerToken {
task_tracker: TaskTracker,
}
struct TaskTrackerInner {
state: AtomicUsize,
on_last_exit: Notify,
}
pin_project! {
#[must_use = "futures do nothing unless polled"]
pub struct TrackedFuture<F> {
#[pin]
future: F,
token: TaskTrackerToken,
}
}
pin_project! {
#[must_use = "futures do nothing unless polled"]
pub struct TaskTrackerWaitFuture<'a> {
#[pin]
future: Notified<'a>,
inner: Option<&'a TaskTrackerInner>,
}
}
impl TaskTrackerInner {
#[inline]
fn new() -> Self {
Self {
state: AtomicUsize::new(0),
on_last_exit: Notify::new(),
}
}
#[inline]
fn is_closed_and_empty(&self) -> bool {
self.state.load(Ordering::Acquire) == 1
}
#[inline]
fn set_closed(&self) -> bool {
let state = self.state.fetch_or(1, Ordering::AcqRel);
if state == 0 {
self.notify_now();
}
(state & 1) == 0
}
#[inline]
fn set_open(&self) -> bool {
let state = self.state.fetch_and(!1, Ordering::AcqRel);
(state & 1) == 1
}
#[inline]
fn add_task(&self) {
self.state.fetch_add(2, Ordering::Relaxed);
}
#[inline]
fn drop_task(&self) {
let state = self.state.fetch_sub(2, Ordering::Release);
if state == 3 {
self.notify_now();
}
}
#[cold]
fn notify_now(&self) {
self.state.load(Ordering::Acquire);
self.on_last_exit.notify_waiters();
}
}
impl TaskTracker {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(TaskTrackerInner::new()),
}
}
#[inline]
pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
TaskTrackerWaitFuture {
future: self.inner.on_last_exit.notified(),
inner: if self.inner.is_closed_and_empty() {
None
} else {
Some(&self.inner)
},
}
}
#[inline]
pub fn close(&self) -> bool {
self.inner.set_closed()
}
#[inline]
pub fn reopen(&self) -> bool {
self.inner.set_open()
}
#[inline]
#[must_use]
pub fn is_closed(&self) -> bool {
(self.inner.state.load(Ordering::Acquire) & 1) != 0
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.inner.state.load(Ordering::Acquire) >> 1
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.state.load(Ordering::Acquire) <= 1
}
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
tokio::task::spawn(self.track_future(task))
}
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_on<F>(&self, task: F, handle: &Handle) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
handle.spawn(self.track_future(task))
}
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_local<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
tokio::task::spawn_local(self.track_future(task))
}
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_local_on<F>(&self, task: F, local_set: &LocalSet) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
local_set.spawn_local(self.track_future(task))
}
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg(not(target_family = "wasm"))]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_blocking<F, T>(&self, task: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
let token = self.token();
tokio::task::spawn_blocking(move || {
let res = task();
drop(token);
res
})
}
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg(not(target_family = "wasm"))]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_blocking_on<F, T>(&self, task: F, handle: &Handle) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
let token = self.token();
handle.spawn_blocking(move || {
let res = task();
drop(token);
res
})
}
#[inline]
pub fn track_future<F: Future>(&self, future: F) -> TrackedFuture<F> {
TrackedFuture {
future,
token: self.token(),
}
}
#[inline]
pub fn token(&self) -> TaskTrackerToken {
self.inner.add_task();
TaskTrackerToken {
task_tracker: self.clone(),
}
}
#[inline]
#[must_use]
pub fn ptr_eq(left: &TaskTracker, right: &TaskTracker) -> bool {
Arc::ptr_eq(&left.inner, &right.inner)
}
}
impl Default for TaskTracker {
#[inline]
fn default() -> TaskTracker {
TaskTracker::new()
}
}
impl Clone for TaskTracker {
#[inline]
fn clone(&self) -> TaskTracker {
Self {
inner: self.inner.clone(),
}
}
}
fn debug_inner(inner: &TaskTrackerInner, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = inner.state.load(Ordering::Acquire);
let is_closed = (state & 1) != 0;
let len = state >> 1;
f.debug_struct("TaskTracker")
.field("len", &len)
.field("is_closed", &is_closed)
.field("inner", &(inner as *const TaskTrackerInner))
.finish()
}
impl fmt::Debug for TaskTracker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
debug_inner(&self.inner, f)
}
}
impl TaskTrackerToken {
#[inline]
#[must_use]
pub fn task_tracker(&self) -> &TaskTracker {
&self.task_tracker
}
}
impl Clone for TaskTrackerToken {
#[inline]
fn clone(&self) -> TaskTrackerToken {
self.task_tracker.token()
}
}
impl Drop for TaskTrackerToken {
#[inline]
fn drop(&mut self) {
self.task_tracker.inner.drop_task();
}
}
impl<F: Future> Future for TrackedFuture<F> {
type Output = F::Output;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
self.project().future.poll(cx)
}
}
impl<F: fmt::Debug> fmt::Debug for TrackedFuture<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TrackedFuture")
.field("future", &self.future)
.field("task_tracker", self.token.task_tracker())
.finish()
}
}
impl<'a> Future for TaskTrackerWaitFuture<'a> {
type Output = ();
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let me = self.project();
let inner = match me.inner.as_ref() {
None => return Poll::Ready(()),
Some(inner) => inner,
};
let ready = inner.is_closed_and_empty() || me.future.poll(cx).is_ready();
if ready {
*me.inner = None;
Poll::Ready(())
} else {
Poll::Pending
}
}
}
impl<'a> fmt::Debug for TaskTrackerWaitFuture<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Helper<'a>(&'a TaskTrackerInner);
impl fmt::Debug for Helper<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
debug_inner(self.0, f)
}
}
f.debug_struct("TaskTrackerWaitFuture")
.field("future", &self.future)
.field("task_tracker", &self.inner.map(Helper))
.finish()
}
}