use futures::{channel::mpsc::UnboundedReceiver, select_biased, Future, FutureExt, StreamExt};
use std::{sync::OnceLock, task::Poll, time::Duration};
pub use async_std::net::TcpStream;
pub use async_task::Runnable;
pub use futures::Stream;
pub use std::time::Instant;
#[derive(Default, Copy, Clone)]
pub enum MissedTickBehavior {
#[default]
Burst,
Delay,
Skip,
}
static DISPATCHER: OnceLock<&'static dyn Dispatcher> = OnceLock::new();
pub trait Dispatcher: 'static + Send + Sync {
fn dispatch(&self, runnable: Runnable);
fn dispatch_after(&self, duration: Duration, runnable: Runnable);
}
pub fn set_dispatcher(dispatcher: impl Dispatcher) {
let dispatcher = Box::leak(Box::new(dispatcher));
DISPATCHER.set(dispatcher).ok();
}
fn get_dispatcher() -> &'static dyn Dispatcher {
*DISPATCHER.get().expect("The livekit dispatcher requires a call to set_dispatcher()")
}
#[derive(Debug)]
pub struct JoinHandle<T> {
task: Option<async_task::Task<T>>,
}
pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
where
F: Future + 'static + Send,
F::Output: 'static + Send,
{
let dispatcher = get_dispatcher();
let (runnable, task) = async_task::spawn(future, |runnable| dispatcher.dispatch(runnable));
runnable.schedule();
JoinHandle { task: Some(task) }
}
impl<T> Future for JoinHandle<T> {
type Output = T;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
self.task.as_mut().expect("poll should not be called after drop").poll_unpin(cx)
}
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
self.task.take().expect("This is the only place the option is mutated").detach();
}
}
pub struct Sleep {
task: async_task::Task<()>,
}
pub fn sleep(time: Duration) -> Sleep {
let dispatcher = get_dispatcher();
let (runnable, task) =
async_task::spawn(async {}, move |runnable| dispatcher.dispatch_after(time, runnable));
runnable.schedule();
Sleep { task }
}
impl Sleep {
pub fn reset(&mut self, deadline: Instant) {
let duration = deadline.saturating_duration_since(Instant::now());
self.task = sleep(duration).task
}
}
impl Future for Sleep {
type Output = ();
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
self.task.poll_unpin(cx)
}
}
pub struct TimeoutError {}
pub fn timeout<T>(
duration: Duration,
future: T,
) -> impl Future<Output = Result<T::Output, TimeoutError>>
where
T: Future,
{
async move {
select_biased! {
res = future.fuse() => Ok(res),
_ = sleep(duration).fuse() => Err(TimeoutError {}),
}
}
}
pub struct Interval {
duration: Duration,
_timer_loop: JoinHandle<()>,
missed_tick_behavior: MissedTickBehavior,
rx: UnboundedReceiver<Instant>,
}
pub fn interval(duration: Duration) -> Interval {
let (tx, rx) = futures::channel::mpsc::unbounded();
let timer_loop = spawn(async move {
loop {
sleep(duration).await;
tx.unbounded_send(Instant::now()).ok();
}
});
Interval {
duration,
rx,
_timer_loop: timer_loop,
missed_tick_behavior: MissedTickBehavior::default(),
}
}
impl Interval {
pub fn reset(&mut self) {
let missed_tick_behavior = self.missed_tick_behavior;
*self = interval(self.duration);
self.set_missed_tick_behavior(missed_tick_behavior);
}
pub async fn tick(&mut self) -> Instant {
self.rx.next().await.expect("timer loop should always be running")
}
pub fn set_missed_tick_behavior(&mut self, behavior: MissedTickBehavior) {
self.missed_tick_behavior = behavior;
}
}
impl Future for Interval {
type Output = Instant;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
self.rx.next().poll_unpin(cx).map(|option| {
option.expect("join loop should be running for as long as the interval exists")
})
}
}