use std::any::Any;
use std::marker::PhantomData;
use std::sync::mpsc::{Receiver, Sender};
type ResultMsg = (u64, Box<dyn Any + Send>);
#[must_use = "dropping a TaskHandle cancels the spawned task; store it to poll the result"]
pub struct TaskHandle<T> {
pub(crate) id: u64,
cancel: Option<Sender<u64>>,
_marker: PhantomData<fn() -> T>,
}
impl<T> TaskHandle<T> {
pub(crate) fn new(id: u64, cancel: Sender<u64>) -> Self {
Self {
id,
cancel: Some(cancel),
_marker: PhantomData,
}
}
pub(crate) fn id(&self) -> u64 {
self.id
}
}
impl<T> std::fmt::Debug for TaskHandle<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TaskHandle").field("id", &self.id).finish()
}
}
impl<T> Drop for TaskHandle<T> {
fn drop(&mut self) {
if let Some(cancel) = self.cancel.take() {
let _ = cancel.send(self.id);
}
}
}
pub(crate) struct AsyncTasks {
runtime: Option<tokio::runtime::Handle>,
next_id: u64,
joins: std::collections::HashMap<u64, tokio::task::JoinHandle<()>>,
results: std::collections::HashMap<u64, Box<dyn Any + Send>>,
result_tx: Option<Sender<ResultMsg>>,
result_rx: Option<Receiver<ResultMsg>>,
cancel_tx: Sender<u64>,
cancel_rx: Receiver<u64>,
}
impl Default for AsyncTasks {
fn default() -> Self {
let (cancel_tx, cancel_rx) = std::sync::mpsc::channel();
Self {
runtime: None,
next_id: 0,
joins: std::collections::HashMap::new(),
results: std::collections::HashMap::new(),
result_tx: None,
result_rx: None,
cancel_tx,
cancel_rx,
}
}
}
impl AsyncTasks {
pub(crate) fn set_runtime(&mut self, handle: tokio::runtime::Handle) {
self.runtime = Some(handle);
}
pub(crate) fn spawn<T: Send + 'static>(
&mut self,
fut: impl std::future::Future<Output = T> + Send + 'static,
) -> TaskHandle<T> {
let runtime = self.runtime.clone().unwrap_or_else(|| {
panic!(
"Context::spawn requires an active Tokio runtime; call it inside \
run_async() / run_async_with()"
)
});
if self.result_tx.is_none() {
let (tx, rx) = std::sync::mpsc::channel();
self.result_tx = Some(tx);
self.result_rx = Some(rx);
}
let result_tx = self
.result_tx
.clone()
.expect("result_tx wired immediately above");
let id = self.next_id;
self.next_id = self.next_id.wrapping_add(1);
let join = runtime.spawn(async move {
let out = fut.await;
let _ = result_tx.send((id, Box::new(out) as Box<dyn Any + Send>));
});
self.joins.insert(id, join);
TaskHandle::new(id, self.cancel_tx.clone())
}
fn drain(&mut self) {
if let Some(rx) = self.result_rx.as_ref() {
while let Ok((id, value)) = rx.try_recv() {
self.joins.remove(&id);
self.results.insert(id, value);
}
}
while let Ok(id) = self.cancel_rx.try_recv() {
self.cancel(id);
}
}
pub(crate) fn maintain(&mut self) {
self.drain();
}
pub(crate) fn poll<T: 'static>(&mut self, id: u64) -> Option<T> {
self.drain();
let boxed = self.results.remove(&id)?;
match boxed.downcast::<T>() {
Ok(value) => Some(*value),
Err(boxed) => {
self.results.insert(id, boxed);
None
}
}
}
fn cancel(&mut self, id: u64) {
if let Some(join) = self.joins.remove(&id) {
join.abort();
}
self.results.remove(&id);
}
}