use std::{
cell::Cell,
future::Future,
rc::Rc,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use crate::{DroppableFuture, TaskIdentifier};
#[derive(Debug)]
pub enum TaskState {
Spawn(TaskIdentifier),
Wake(TaskIdentifier),
Tick(TaskIdentifier, f64),
Drop(TaskIdentifier),
}
pub type Task<T> = async_task::Task<T>;
type Payload = (TaskIdentifier, async_task::Runnable);
pub struct SplitTickedAsyncExecutor;
impl SplitTickedAsyncExecutor {
pub fn default() -> (
TickedAsyncExecutorSpawner<fn(TaskState)>,
TickedAsyncExecutorTicker<fn(TaskState)>,
) {
Self::new(|_state| {})
}
pub fn new<O>(observer: O) -> (TickedAsyncExecutorSpawner<O>, TickedAsyncExecutorTicker<O>)
where
O: Fn(TaskState) + Clone + Send + Sync + 'static,
{
let (task_tx, task_rx) = flume::unbounded();
let num_spawned_tasks = Arc::new(AtomicUsize::new(0));
#[cfg(feature = "tick_event")]
let (tick_event_tx, tick_event_rx) = tokio::sync::watch::channel(1.0);
#[cfg(feature = "timer_registration")]
let (timer_registration_tx, timer_registration_rx) = flume::unbounded();
let spawner = TickedAsyncExecutorSpawner {
task_tx,
num_spawned_tasks: num_spawned_tasks.clone(),
observer: observer.clone(),
#[cfg(feature = "tick_event")]
tick_event_rx,
#[cfg(feature = "timer_registration")]
timer_registration_tx,
_not_send: std::marker::PhantomData,
};
let ticker = TickedAsyncExecutorTicker {
task_rx,
num_spawned_tasks,
observer,
delta: Rc::new(0.0.into()),
#[cfg(feature = "tick_event")]
tick_event_tx,
#[cfg(feature = "timer_registration")]
timer_registration_rx,
#[cfg(feature = "timer_registration")]
timers: Vec::new(),
};
(spawner, ticker)
}
}
pub struct TickedAsyncExecutorSpawner<O> {
task_tx: flume::Sender<Payload>,
num_spawned_tasks: Arc<AtomicUsize>,
observer: O,
#[cfg(feature = "tick_event")]
tick_event_rx: tokio::sync::watch::Receiver<f64>,
#[cfg(feature = "timer_registration")]
timer_registration_tx: flume::Sender<(f64, std::task::Waker)>,
_not_send: std::marker::PhantomData<*const ()>,
}
impl<O: Clone> Clone for TickedAsyncExecutorSpawner<O> {
fn clone(&self) -> Self {
Self {
task_tx: self.task_tx.clone(),
num_spawned_tasks: self.num_spawned_tasks.clone(),
observer: self.observer.clone(),
#[cfg(feature = "tick_event")]
tick_event_rx: self.tick_event_rx.clone(),
#[cfg(feature = "timer_registration")]
timer_registration_tx: self.timer_registration_tx.clone(),
_not_send: self._not_send,
}
}
}
impl<O> TickedAsyncExecutorSpawner<O>
where
O: Fn(TaskState) + Clone + Send + Sync + 'static,
{
pub fn spawn_local<T>(
&self,
identifier: impl Into<TaskIdentifier>,
future: impl Future<Output = T> + 'static,
) -> Task<T>
where
T: 'static,
{
let identifier = identifier.into();
let future = self.droppable_future(identifier.clone(), future);
let schedule = self.runnable_schedule_cb(identifier);
let (runnable, task) = async_task::spawn_local(future, schedule);
runnable.schedule();
task
}
#[cfg(feature = "tick_event")]
pub fn create_timer_from_tick_event(&self) -> crate::TickedTimerFromTickEvent {
crate::TickedTimerFromTickEvent::new(self.tick_event_rx.clone())
}
#[cfg(feature = "tick_event")]
pub fn tick_channel(&self) -> tokio::sync::watch::Receiver<f64> {
self.tick_event_rx.clone()
}
#[cfg(feature = "timer_registration")]
pub fn create_timer_from_timer_registration(&self) -> crate::TickedTimerFromTimerRegistration {
crate::TickedTimerFromTimerRegistration::new(self.timer_registration_tx.clone())
}
pub fn num_tasks(&self) -> usize {
self.num_spawned_tasks.load(Ordering::Relaxed)
}
fn droppable_future<F>(
&self,
identifier: TaskIdentifier,
future: F,
) -> DroppableFuture<F, impl Fn() + use<F, O>>
where
F: Future,
{
let observer = self.observer.clone();
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
observer(TaskState::Spawn(identifier.clone()));
let num_spawned_tasks = self.num_spawned_tasks.clone();
DroppableFuture::new(future, move || {
num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
observer(TaskState::Drop(identifier.clone()));
})
}
fn runnable_schedule_cb(
&self,
identifier: TaskIdentifier,
) -> impl Fn(async_task::Runnable) + use<O> {
let task_tx = self.task_tx.clone();
let observer = self.observer.clone();
move |runnable| {
task_tx.send((identifier.clone(), runnable)).unwrap_or(());
observer(TaskState::Wake(identifier.clone()));
}
}
}
#[derive(Clone)]
pub struct TickedAsyncExecutorDelta(Rc<Cell<f64>>);
impl TickedAsyncExecutorDelta {
pub fn get(&self) -> f64 {
self.0.get()
}
pub fn inner(self) -> Rc<Cell<f64>> {
self.0
}
}
pub struct TickedAsyncExecutorTicker<O> {
task_rx: flume::Receiver<Payload>,
num_spawned_tasks: Arc<AtomicUsize>,
observer: O,
delta: Rc<Cell<f64>>,
#[cfg(feature = "tick_event")]
tick_event_tx: tokio::sync::watch::Sender<f64>,
#[cfg(feature = "timer_registration")]
timer_registration_rx: flume::Receiver<(f64, std::task::Waker)>,
#[cfg(feature = "timer_registration")]
timers: Vec<(f64, std::task::Waker)>,
}
impl<O> TickedAsyncExecutorTicker<O>
where
O: Fn(TaskState),
{
pub fn delta(&self) -> TickedAsyncExecutorDelta {
TickedAsyncExecutorDelta(self.delta.clone())
}
pub fn tick(&mut self, delta: f64, limit: Option<usize>) {
self.delta.replace(delta);
#[cfg(feature = "tick_event")]
let _r = self.tick_event_tx.send(delta);
#[cfg(feature = "timer_registration")]
self.timer_registration_tick(delta);
let mut num_woken_tasks = self.task_rx.len();
if let Some(limit) = limit {
num_woken_tasks = num_woken_tasks.min(limit);
}
self.task_rx
.try_iter()
.take(num_woken_tasks)
.for_each(|(identifier, runnable)| {
(self.observer)(TaskState::Tick(identifier, delta));
runnable.run();
});
}
pub fn wait_till_completed(&mut self, constant_delta: f64) {
while self.num_spawned_tasks.load(Ordering::Relaxed) != 0 {
self.tick(constant_delta, None);
}
}
#[cfg(feature = "timer_registration")]
fn timer_registration_tick(&mut self, delta: f64) {
self.timer_registration_rx.try_iter().for_each(|timer| {
self.timers.push(timer);
});
if self.timers.is_empty() {
return;
}
self.timers
.extract_if(.., |(elapsed, _)| {
*elapsed -= delta;
*elapsed <= 0.0
})
.for_each(|(_, waker)| {
waker.wake();
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_ticked_async_executor_spawner_clone() {
let (spawner, _ticker) = SplitTickedAsyncExecutor::default();
let _spawner_clone = spawner.clone();
}
}