use log::warn;
use parking_lot::Mutex;
use std::{pin::Pin, sync::Arc};
use tokio::sync::oneshot::Receiver;
use futures::Future;
type Task = Pin<Box<dyn Future<Output = ()> + Send>>;
#[derive(Clone)]
pub struct DedicatedExecutor {
state: Arc<Mutex<State>>,
}
struct State {
num_threads: usize,
thread_name: String,
requests: Option<std::sync::mpsc::Sender<Task>>,
thread: Option<std::thread::JoinHandle<()>>,
}
const WORKER_PRIORITY: i32 = 10;
impl std::fmt::Debug for DedicatedExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let state = self.state.lock();
let mut d = f.debug_struct("DedicatedExecutor");
d.field("num_threads", &state.num_threads)
.field("thread_name", &state.thread_name);
if state.requests.is_some() {
d.field("requests", &"Some(...)")
} else {
d.field("requests", &"None")
};
if state.thread.is_some() {
d.field("thread", &"Some(...)")
} else {
d.field("thread", &"None")
};
d.finish()
}
}
impl DedicatedExecutor {
pub fn new(thread_name: impl Into<String>, num_threads: usize) -> Self {
let thread_name = thread_name.into();
let name_copy = thread_name.to_string();
let (tx, rx) = std::sync::mpsc::channel();
let thread = std::thread::spawn(move || {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.thread_name(&name_copy)
.worker_threads(num_threads)
.on_thread_start(move || set_current_thread_priority(WORKER_PRIORITY))
.build()
.expect("Creating tokio runtime");
let _guard = runtime.enter();
while let Ok(request) = rx.recv() {
tokio::task::spawn(request);
}
});
let state = State {
num_threads,
thread_name,
requests: Some(tx),
thread: Some(thread),
};
Self {
state: Arc::new(Mutex::new(state)),
}
}
pub fn spawn<T>(&self, task: T) -> Receiver<T::Output>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let (tx, rx) = tokio::sync::oneshot::channel();
let job = Box::pin(async move {
let task_output = task.await;
if tx.send(task_output).is_err() {
warn!("Spawned task output ignored: receiver dropped");
}
});
let mut state = self.state.lock();
if let Some(requests) = &mut state.requests {
requests.send(job).ok();
} else {
warn!("tried to schedule task on an executor that was shutdown");
}
rx
}
#[allow(dead_code)]
pub fn shutdown(&self) {
let mut state = self.state.lock();
state.requests = None;
}
#[allow(dead_code)]
pub fn join(&self) {
self.shutdown();
let thread = {
let mut state = self.state.lock();
state.thread.take()
};
if let Some(thread) = thread {
thread.join().ok();
}
}
}
#[cfg(unix)]
fn set_current_thread_priority(prio: i32) {
unsafe { libc::setpriority(0, 0, prio) };
}
#[cfg(not(unix))]
fn set_current_thread_priority(_prio: i32) {
warn!("Setting worker thread priority not supported on this platform");
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Barrier};
#[cfg(unix)]
fn get_current_thread_priority() -> i32 {
unsafe { libc::getpriority(0, 0) }
}
#[cfg(not(unix))]
fn get_current_thread_priority() -> i32 {
WORKER_PRIORITY
}
#[tokio::test]
async fn basic_test_in_diff_thread() {
let barrier = Arc::new(Barrier::new(2));
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 1);
let dedicated_task = exec.spawn(do_work(42, Arc::clone(&barrier)));
barrier.wait();
assert_eq!(dedicated_task.await.unwrap(), 42);
}
#[tokio::test]
async fn basic_clone() {
let barrier = Arc::new(Barrier::new(2));
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 1);
let dedicated_task = exec.clone().spawn(do_work(42, Arc::clone(&barrier)));
barrier.wait();
assert_eq!(dedicated_task.await.unwrap(), 42);
}
#[tokio::test]
async fn multi_task() {
let barrier = Arc::new(Barrier::new(3));
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 2);
let dedicated_task1 = exec.spawn(do_work(11, Arc::clone(&barrier)));
let dedicated_task2 = exec.spawn(do_work(42, Arc::clone(&barrier)));
barrier.wait();
assert_eq!(dedicated_task1.await.unwrap(), 11);
assert_eq!(dedicated_task2.await.unwrap(), 42);
exec.join();
}
#[tokio::test]
async fn worker_priority() {
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 2);
let dedicated_task = exec.spawn(async move { get_current_thread_priority() });
assert_eq!(dedicated_task.await.unwrap(), WORKER_PRIORITY);
}
#[tokio::test]
async fn tokio_spawn() {
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 2);
let dedicated_task = exec.spawn(async move {
let t1 = tokio::task::spawn(async {
assert_eq!(
std::thread::current().name(),
Some("Test DedicatedExecutor")
);
25usize
});
t1.await.unwrap()
});
assert_eq!(dedicated_task.await.unwrap(), 25);
}
#[tokio::test]
async fn panic_on_executor() {
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 1);
let dedicated_task = exec.spawn(async move {
panic!("At the disco, on the dedicated task scheduler");
});
dedicated_task.await.unwrap_err();
}
#[tokio::test]
#[ignore]
async fn executor_shutdown_while_task_running() {
let barrier = Arc::new(Barrier::new(2));
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 1);
let dedicated_task = exec.spawn(do_work(42, Arc::clone(&barrier)));
exec.shutdown();
barrier.wait();
assert_eq!(dedicated_task.await.unwrap(), 42);
}
#[tokio::test]
async fn executor_submit_task_after_shutdown() {
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 1);
exec.shutdown();
let dedicated_task = exec.spawn(async { 11 });
dedicated_task.await.unwrap_err();
}
#[tokio::test]
async fn executor_submit_task_after_clone_shutdown() {
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 1);
exec.clone().join();
let dedicated_task = exec.spawn(async { 11 });
dedicated_task.await.unwrap_err();
}
#[tokio::test]
async fn executor_join() {
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 1);
exec.join()
}
#[tokio::test]
#[allow(clippy::redundant_clone)]
async fn executor_clone_join() {
let exec = DedicatedExecutor::new("Test DedicatedExecutor", 1);
exec.clone().join();
exec.clone().join();
exec.join();
}
async fn do_work(result: usize, barrier: Arc<Barrier>) -> usize {
barrier.wait();
result
}
}