#![forbid(unsafe_code)]
use core::cell::Cell;
use core::future::Future;
use core::pin::Pin;
use core::task::Poll;
pub use safina_sync::Receiver;
pub use safina_threadpool::NewThreadPoolError;
use safina_threadpool::ThreadPool;
use std::sync::mpsc::SyncSender;
use std::sync::{Arc, Mutex, Weak};
thread_local! {
static EXECUTOR: Cell<Weak<Executor>> = Cell::new(Weak::new());
}
#[must_use]
pub fn get_thread_executor() -> Option<Arc<Executor>> {
EXECUTOR.with(|cell| {
let weak = cell.take();
let result = weak.upgrade();
cell.set(weak);
result
})
}
pub struct ThreadExecutorGuard;
impl Drop for ThreadExecutorGuard {
fn drop(&mut self) {
EXECUTOR.with(std::cell::Cell::take);
}
}
#[must_use]
pub fn set_thread_executor(executor: Weak<Executor>) -> ThreadExecutorGuard {
EXECUTOR.with(|cell| cell.set(executor));
ThreadExecutorGuard {}
}
pub struct Executor {
async_pool: ThreadPool,
blocking_pool: ThreadPool,
}
impl Executor {
#[must_use]
pub fn default() -> Arc<Self> {
Self::new(4, 4).unwrap()
}
pub fn new(
num_async_threads: usize,
num_blocking_threads: usize,
) -> Result<Arc<Self>, NewThreadPoolError> {
Self::with_name("async", num_async_threads, "blocking", num_blocking_threads)
}
pub fn with_name(
async_threads_name: &'static str,
num_async_threads: usize,
blocking_threads_name: &'static str,
num_blocking_threads: usize,
) -> Result<Arc<Self>, NewThreadPoolError> {
Ok(Arc::new(Self {
async_pool: ThreadPool::new(async_threads_name, num_async_threads)?,
blocking_pool: ThreadPool::new(blocking_threads_name, num_blocking_threads)?,
}))
}
pub fn schedule_blocking<T, F>(self: &Arc<Self>, func: F) -> Receiver<T>
where
T: Send + 'static,
F: (FnOnce() -> T) + Send + 'static,
{
let (sender, receiver) = safina_sync::oneshot();
let weak_self = Arc::downgrade(self);
self.blocking_pool.schedule(move || {
let _guard = set_thread_executor(weak_self);
let _result = sender.send(func());
});
receiver
}
pub fn spawn(self: &Arc<Self>, fut: impl (Future<Output = ()>) + Send + 'static) {
self.spawn_unpin(Box::pin(fut));
}
pub fn spawn_unpin(self: &Arc<Self>, fut: impl (Future<Output = ()>) + Send + Unpin + 'static) {
let task: Arc<Mutex<Option<Box<dyn Future<Output = ()> + Send + Unpin>>>> =
Arc::new(Mutex::new(Some(Box::new(fut))));
let weak_self = Arc::downgrade(self);
self.async_pool.schedule(move || poll_task(task, weak_self));
}
pub fn block_on<R>(self: &Arc<Self>, fut: impl (Future<Output = R>) + 'static) -> R {
self.block_on_unpin(Box::pin(fut))
}
pub fn block_on_unpin<R>(
self: &Arc<Self>,
fut: impl (Future<Output = R>) + Unpin + 'static,
) -> R {
let _guard = set_thread_executor(Arc::downgrade(self));
block_on_unpin(fut)
}
}
impl Default for Executor {
fn default() -> Self {
Arc::try_unwrap(Executor::default()).unwrap_or_else(|_| unreachable!())
}
}
pub fn schedule_blocking<T, F>(func: F) -> Receiver<T>
where
T: Send + 'static,
F: (FnOnce() -> T) + Send + 'static,
{
if let Some(executor) = get_thread_executor() {
executor.schedule_blocking(func)
} else {
panic!(
"called from outside a task; check for duplicate safina-executor crate: cargo tree -d"
);
}
}
#[allow(clippy::needless_pass_by_value)]
fn poll_task(
task: Arc<Mutex<Option<Box<dyn Future<Output = ()> + Send + Unpin>>>>,
executor: Weak<Executor>,
) {
if executor.strong_count() > 0 {
let waker =
std::task::Waker::from(Arc::new(TaskWaker::new(task.clone(), executor.clone())));
let mut cx = std::task::Context::from_waker(&waker);
let mut opt_fut_guard = task.lock().unwrap();
if let Some(fut) = opt_fut_guard.as_mut() {
let _guard = set_thread_executor(executor);
match Pin::new(&mut *fut).poll(&mut cx) {
Poll::Ready(()) => {
opt_fut_guard.take();
}
Poll::Pending => {}
}
}
}
}
struct TaskWaker {
task: Arc<Mutex<Option<Box<dyn Future<Output = ()> + Send + Unpin>>>>,
executor: Weak<Executor>,
}
impl TaskWaker {
pub fn new(
task: Arc<Mutex<Option<Box<dyn Future<Output = ()> + Send + Unpin>>>>,
executor: Weak<Executor>,
) -> Self {
Self { task, executor }
}
}
impl std::task::Wake for TaskWaker {
fn wake(self: Arc<Self>) {
if let Some(ref executor) = self.executor.upgrade() {
let task_clone = self.task.clone();
let executor_weak = Arc::downgrade(executor);
executor
.async_pool
.schedule(move || poll_task(task_clone, executor_weak));
}
}
}
pub fn spawn(fut: impl (Future<Output = ()>) + Send + 'static) {
spawn_unpin(Box::pin(fut));
}
pub fn spawn_unpin(fut: impl (Future<Output = ()>) + Send + Unpin + 'static) {
if let Some(executor) = get_thread_executor() {
let task: Arc<Mutex<Option<Box<dyn Future<Output = ()> + Send + Unpin>>>> =
Arc::new(Mutex::new(Some(Box::new(fut))));
let executor_weak = Arc::downgrade(&executor);
executor
.async_pool
.schedule(move || poll_task(task, executor_weak));
} else {
panic!(
"called from outside a task; check for duplicate safina-executor crate: cargo tree -d"
);
}
}
pub fn block_on<R>(fut: impl (Future<Output = R>) + 'static) -> R {
block_on_unpin(Box::pin(fut))
}
pub fn block_on_unpin<R>(mut fut: impl (Future<Output = R>) + Unpin + 'static) -> R {
struct BlockOnTaskWaker(Mutex<Option<SyncSender<()>>>);
impl std::task::Wake for BlockOnTaskWaker {
fn wake(self: Arc<Self>) {
if let Some(sender) = self.0.lock().unwrap().take() {
let _ = sender.send(());
}
}
}
loop {
let (sender, receiver) = std::sync::mpsc::sync_channel(1);
let waker = std::task::Waker::from(Arc::new(BlockOnTaskWaker(Mutex::new(Some(sender)))));
let mut cx = std::task::Context::from_waker(&waker);
if let Poll::Ready(result) = Pin::new(&mut fut).poll(&mut cx) {
return result;
}
receiver.recv().unwrap();
}
}