use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::FutureExt;
use futures::StreamExt;
pub struct JoinHandle<T> {
handle: Option<async_std::task::JoinHandle<T>>,
is_done: Arc<AtomicBool>,
}
impl<T> Debug for JoinHandle<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JoinHandle")
.field("name", &self.is_done.load(Ordering::Relaxed))
.field("handle", &self.handle.is_some())
.finish()
}
}
impl<T> JoinHandle<T> {
pub fn is_finished(&self) -> bool {
self.handle.is_none() || self.is_done.load(Ordering::Relaxed)
}
pub fn abort(&mut self) {
if let Some(handle) = self.handle.take() {
let f = handle.cancel();
drop(f);
}
}
}
impl<T> async_std::future::Future for JoinHandle<T> {
type Output = Result<T, ()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mutself = self.get_mut();
let inner_polled_value = if let Some(inner) = mutself.handle.as_mut() {
inner.poll_unpin(cx)
} else {
return Poll::Ready(Err(()));
};
match inner_polled_value {
Poll::Pending => Poll::Pending,
Poll::Ready(v) => {
mutself.handle = None;
Poll::Ready(Ok(v))
}
}
}
}
pub type Duration = std::time::Duration;
pub type Instant = std::time::Instant;
#[derive(Debug, Clone)]
pub struct Interval {
dur: Duration,
next_tick: Instant,
}
impl Interval {
pub async fn tick(&mut self) {
let now = Instant::now();
if self.next_tick > now {
sleep(self.next_tick - now).await;
}
self.next_tick += self.dur;
}
}
pub fn interval(dur: Duration) -> Interval {
Interval {
dur,
next_tick: Instant::now(),
}
}
#[derive(Default)]
pub struct JoinSet<T> {
set: FuturesUnordered<BoxFuture<'static, T>>,
}
impl<T> Debug for JoinSet<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JoinSet")
.field("size", &self.set.len())
.finish()
}
}
impl<T> JoinSet<T> {
pub fn new() -> JoinSet<T> {
Self {
set: FuturesUnordered::new(),
}
}
pub fn spawn<F: Future<Output = T> + Send + 'static>(&mut self, f: F) {
self.set.push(f.boxed());
}
pub async fn join_next(&mut self) -> Option<Result<T, ()>> {
self.set.next().await.map(|item| Ok(item))
}
pub fn len(&self) -> usize {
self.set.len()
}
pub fn is_empty(&self) -> bool {
self.set.is_empty()
}
}
pub async fn sleep(dur: super::Duration) {
async_std::task::sleep(dur).await;
}
pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
spawn_named(None, future)
}
pub fn spawn_local<F>(future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
{
let signal = Arc::new(AtomicBool::new(false));
let inner_signal = signal.clone();
let jh = async_std::task::spawn_local(async move {
let r = future.await;
inner_signal.fetch_or(true, Ordering::Relaxed);
r
});
JoinHandle {
handle: Some(jh),
is_done: signal,
}
}
pub fn spawn_named<F>(name: Option<&str>, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
if let Some(name) = name {
let signal = Arc::new(AtomicBool::new(false));
let inner_signal = signal.clone();
let jh = async_std::task::Builder::new()
.name(name.to_string())
.spawn(async move {
let r = future.await;
inner_signal.fetch_or(true, Ordering::Relaxed);
r
})
.unwrap();
JoinHandle {
handle: Some(jh),
is_done: signal,
}
} else {
let signal = Arc::new(AtomicBool::new(false));
let inner_signal = signal.clone();
let jh = async_std::task::spawn(async move {
let r = future.await;
inner_signal.fetch_or(true, Ordering::Relaxed);
r
});
JoinHandle {
handle: Some(jh),
is_done: signal,
}
}
}
pub async fn timeout<F, T>(dur: super::Duration, future: F) -> Result<T, super::Timeout>
where
F: Future<Output = T>,
{
async_std::future::timeout(dur, future)
.await
.map_err(|_| super::Timeout)
}
pub use async_std::test;
pub use futures::select_biased as select;
#[cfg(test)]
mod async_std_primitive_tests {
use super::*;
use crate::common_test::periodic_check;
#[super::test]
async fn join_handle_aborts() {
let mut jh = spawn(async {
sleep(Duration::from_millis(1000)).await;
});
jh.abort();
assert!(jh.is_finished());
}
#[super::test]
async fn join_handle_finishes() {
let jh = spawn(async {
sleep(Duration::from_millis(5)).await;
println!("done.");
});
periodic_check(|| jh.is_finished(), Duration::from_millis(1000)).await;
}
#[super::test]
async fn test_spawn_named() {
let jh = spawn_named(Some("something"), async {
sleep(Duration::from_millis(5)).await;
println!("done.");
});
periodic_check(|| jh.is_finished(), Duration::from_millis(1000)).await;
}
}