use std::any::{Any, TypeId};
use std::cell::{Cell, RefCell, UnsafeCell};
use std::collections::{HashMap, VecDeque};
use std::{future::Future, io, os::fd, sync::Arc, thread, time::Duration};
use async_task::{Runnable, Task};
use crossbeam_queue::SegQueue;
use swap_buffer_queue::{Queue, buffer::ArrayBuffer, error::TryEnqueueError};
use crate::driver::{Driver, DriverApi, DriverType, Handler, NotifyHandle, PollResult};
use crate::pool::ThreadPool;
scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
pub type JoinHandle<T> = oneshot::Receiver<Result<T, Box<dyn Any + Send>>>;
pub struct Runtime {
driver: Driver,
queue: Arc<RunnableQueue>,
pub(crate) pool: ThreadPool,
values: RefCell<HashMap<TypeId, Box<dyn Any>, foldhash::fast::RandomState>>,
}
impl Runtime {
pub fn new() -> io::Result<Self> {
Self::builder().build()
}
pub fn builder() -> RuntimeBuilder {
RuntimeBuilder::new()
}
#[allow(clippy::arc_with_non_send_sync)]
fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
let driver = builder.build_driver()?;
let queue = Arc::new(RunnableQueue::new(builder.event_interval, driver.handle()));
Ok(Self {
queue,
driver,
pool: ThreadPool::new(builder.pool_limit, builder.pool_recv_timeout),
values: RefCell::new(HashMap::default()),
})
}
pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
#[cold]
fn not_in_neon_runtime() -> ! {
panic!("not in a neon runtime")
}
if CURRENT_RUNTIME.is_set() {
CURRENT_RUNTIME.with(f)
} else {
not_in_neon_runtime()
}
}
#[inline]
pub fn handle(&self) -> Handle {
Handle {
queue: self.queue.clone(),
}
}
#[doc(hidden)]
pub fn driver(&self) -> &Driver {
&self.driver
}
#[inline]
pub fn driver_type(&self) -> DriverType {
self.driver.tp()
}
#[inline]
pub fn register_handler<F>(&self, f: F)
where
F: FnOnce(DriverApi) -> Box<dyn Handler>,
{
self.driver.register(f)
}
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
CURRENT_RUNTIME.set(self, || {
let mut result = None;
unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
self.driver
.poll(|| {
if let Some(result) = result.take() {
PollResult::Ready(result)
} else if self.queue.run() {
PollResult::HasTasks
} else {
PollResult::Pending
}
})
.expect("Failed to poll driver")
})
}
pub fn spawn<F: Future + 'static>(&self, future: F) -> Task<F::Output> {
unsafe { self.spawn_unchecked(future) }
}
pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
let queue = self.queue.clone();
let schedule = move |runnable| {
queue.schedule(runnable);
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
}
pub fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = oneshot::channel();
crate::driver::spawn_blocking(
self,
&self.driver,
Box::new(move || {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
let _ = tx.send(result);
}),
);
rx
}
pub fn value<T, F>(f: F) -> T
where
T: Clone + 'static,
F: FnOnce(&Runtime) -> T,
{
Runtime::with_current(|rt| {
let val = rt
.values
.borrow()
.get(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_ref().cloned());
if let Some(val) = val {
val
} else {
let val = f(rt);
rt.values
.borrow_mut()
.insert(TypeId::of::<T>(), Box::new(val.clone()));
val
}
})
}
}
impl fd::AsRawFd for Runtime {
fn as_raw_fd(&self) -> fd::RawFd {
self.driver.as_raw_fd()
}
}
impl Drop for Runtime {
fn drop(&mut self) {
CURRENT_RUNTIME.set(self, || {
self.queue.clear();
self.driver.clear();
})
}
}
#[derive(Debug)]
pub struct Handle {
queue: Arc<RunnableQueue>,
}
impl Handle {
pub fn current() -> Handle {
Runtime::with_current(|rt| rt.handle())
}
pub fn notify(&self) -> io::Result<()> {
self.queue.driver.notify()
}
pub fn spawn<F: Future + Send + 'static>(&self, future: F) -> Task<F::Output> {
let queue = self.queue.clone();
let schedule = move |runnable| {
queue.schedule(runnable);
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
}
}
impl Clone for Handle {
fn clone(&self) -> Self {
Self {
queue: self.queue.clone(),
}
}
}
#[derive(Debug)]
struct RunnableQueue {
id: thread::ThreadId,
idle: Cell<bool>,
driver: NotifyHandle,
event_interval: usize,
local_queue: UnsafeCell<VecDeque<Runnable>>,
sync_fixed_queue: Queue<ArrayBuffer<Runnable, 128>>,
sync_queue: SegQueue<Runnable>,
}
unsafe impl Send for RunnableQueue {}
unsafe impl Sync for RunnableQueue {}
impl RunnableQueue {
fn new(event_interval: usize, driver: NotifyHandle) -> Self {
Self {
driver,
event_interval,
id: thread::current().id(),
idle: Cell::new(true),
local_queue: UnsafeCell::new(VecDeque::new()),
sync_fixed_queue: Queue::default(),
sync_queue: SegQueue::new(),
}
}
fn schedule(&self, runnable: Runnable) {
if self.id == thread::current().id() {
unsafe { (*self.local_queue.get()).push_back(runnable) };
if self.idle.get() {
self.idle.set(false);
self.driver.notify().ok();
}
} else {
let result = self.sync_fixed_queue.try_enqueue([runnable]);
if let Err(TryEnqueueError::InsufficientCapacity([runnable])) = result {
self.sync_queue.push(runnable);
}
self.driver.notify().ok();
}
}
fn run(&self) -> bool {
self.idle.set(false);
for _ in 0..self.event_interval {
let task = unsafe { (*self.local_queue.get()).pop_front() };
if let Some(task) = task {
task.run();
} else {
break;
}
}
if let Ok(buf) = self.sync_fixed_queue.try_dequeue() {
for task in buf {
task.run();
}
}
for _ in 0..self.event_interval {
if !self.sync_queue.is_empty() {
if let Some(task) = self.sync_queue.pop() {
task.run();
continue;
}
}
break;
}
self.idle.set(true);
!unsafe { (*self.local_queue.get()).is_empty() }
|| !self.sync_fixed_queue.is_empty()
|| !self.sync_queue.is_empty()
}
fn clear(&self) {
while self.sync_queue.pop().is_some() {}
while self.sync_fixed_queue.try_dequeue().is_ok() {}
unsafe { (*self.local_queue.get()).clear() };
}
}
#[derive(Debug, Clone)]
pub struct RuntimeBuilder {
event_interval: usize,
pool_limit: usize,
pool_recv_timeout: Duration,
io_queue_capacity: u32,
}
impl Default for RuntimeBuilder {
fn default() -> Self {
Self::new()
}
}
impl RuntimeBuilder {
pub fn new() -> Self {
Self {
event_interval: 61,
pool_limit: 256,
pool_recv_timeout: Duration::from_secs(60),
io_queue_capacity: 2048,
}
}
pub fn event_interval(&mut self, val: usize) -> &mut Self {
self.event_interval = val;
self
}
pub fn io_queue_capacity(&mut self, capacity: u32) -> &mut Self {
self.io_queue_capacity = capacity;
self
}
pub fn thread_pool_limit(&mut self, value: usize) -> &mut Self {
self.pool_limit = value;
self
}
pub fn thread_pool_recv_timeout(&mut self, timeout: Duration) -> &mut Self {
self.pool_recv_timeout = timeout;
self
}
pub fn build(&self) -> io::Result<Runtime> {
Runtime::with_builder(self)
}
fn build_driver(&self) -> io::Result<Driver> {
Driver::new(self.io_queue_capacity)
}
}
pub fn spawn<F: Future + 'static>(future: F) -> Task<F::Output> {
Runtime::with_current(|r| r.spawn(future))
}
pub fn spawn_blocking<T: Send + 'static>(
f: impl (FnOnce() -> T) + Send + 'static,
) -> JoinHandle<T> {
Runtime::with_current(|r| r.spawn_blocking(f))
}