use std::{future::Future, time::Duration};
use super::handle::async_handle::{channel::RecvError, message::Message};
pub trait IsFinished {
fn is_finished(&self) -> bool;
}
pub trait Executor<const N: usize>: Send + Sync + 'static {
type JoinError;
type JoinHandle: Future<Output = Result<(), Self::JoinError>> + IsFinished;
const VALID: () = assert!(N > 0, "executor must support at leat 1 task");
fn block_on<T>(&self, loop_fn: impl Future<Output = T>) -> T;
fn spawn_local(future: impl Future<Output = ()> + 'static) -> Self::JoinHandle;
fn yield_now() -> impl Future<Output = ()>;
fn timeout(
duration: Duration,
future: impl Future<Output = Result<Message, RecvError>>,
) -> impl Future<Output = Option<Result<Message, RecvError>>>;
}
#[cfg(feature = "tokio-rt")]
pub mod tokio_exec {
use std::sync::Arc;
use tokio::{
runtime::Builder,
task::{JoinError, JoinHandle, LocalSet},
time::timeout,
};
use super::*;
pub type TokioCallback = Arc<dyn Fn() + Send + Sync>;
#[derive(Clone)]
pub struct Tokio<const N: usize> {
#[allow(dead_code)]
enable_io: bool,
on_thread_park: Option<TokioCallback>,
on_thread_unpark: Option<TokioCallback>,
}
impl<const N: usize> Tokio<N> {
pub const fn new(enable_io: bool) -> Self {
Tokio {
enable_io,
on_thread_park: None,
on_thread_unpark: None,
}
}
pub fn on_thread_park(
&mut self,
on_thread_park: impl Fn() + Send + Sync + 'static,
) -> &mut Self {
self.on_thread_park = Some(Arc::new(on_thread_park));
self
}
pub fn on_thread_unpark(
&mut self,
on_thread_unpark: impl Fn() + Send + Sync + 'static,
) -> &mut Self {
self.on_thread_unpark = Some(Arc::new(on_thread_unpark));
self
}
}
impl IsFinished for JoinHandle<()> {
fn is_finished(&self) -> bool {
self.is_finished()
}
}
impl<const N: usize> Executor<N> for Tokio<N> {
type JoinError = JoinError;
type JoinHandle = JoinHandle<()>;
#[inline]
fn block_on<T>(&self, loop_fn: impl Future<Output = T>) -> T {
let mut builder = Builder::new_current_thread();
builder.enable_time();
if let Some(ref on_thread_park) = self.on_thread_park {
let on_thread_park = on_thread_park.clone();
builder.on_thread_park(move || on_thread_park.as_ref()());
}
if let Some(ref on_thread_unpark) = self.on_thread_unpark {
let on_thread_unpark = on_thread_unpark.clone();
builder.on_thread_unpark(move || on_thread_unpark.as_ref()());
}
#[cfg(feature = "tokio-net")]
if self.enable_io {
builder.enable_io();
}
let runtime = builder.build().expect("unable to build tokio runtime");
let local_set = LocalSet::new();
local_set.block_on(&runtime, loop_fn)
}
#[inline]
fn spawn_local(future: impl Future<Output = ()> + 'static) -> Self::JoinHandle {
tokio::task::spawn_local(future)
}
#[inline]
fn yield_now() -> impl Future<Output = ()> {
tokio::task::yield_now()
}
async fn timeout(
duration: Duration,
future: impl Future<Output = Result<Message, RecvError>>,
) -> Option<Result<Message, RecvError>> {
timeout(duration, future).await.ok()
}
}
}