use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum TaskState {
Ready = 0,
Running = 1,
Waiting = 2,
Completed = 3,
Cancelled = 4,
Panicked = 5,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Poll<T> {
Ready(T),
Pending,
}
impl<T> Poll<T> {
pub fn is_ready(&self) -> bool {
matches!(self, Poll::Ready(_))
}
pub fn is_pending(&self) -> bool {
matches!(self, Poll::Pending)
}
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> Poll<U> {
match self {
Poll::Ready(t) => Poll::Ready(f(t)),
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Clone)]
pub struct Waker {
task_id: u64,
executor: Arc<ExecutorInner>,
}
impl Waker {
pub fn new(task_id: u64, executor: Arc<ExecutorInner>) -> Self {
Self { task_id, executor }
}
pub fn wake(&self) {
self.executor.wake_task(self.task_id);
}
pub fn wake_by_ref(&self) {
self.wake();
}
}
pub struct Context<'a> {
waker: &'a Waker,
}
impl<'a> Context<'a> {
pub fn new(waker: &'a Waker) -> Self {
Self { waker }
}
pub fn waker(&self) -> &Waker {
self.waker
}
}
pub trait Future {
type Output;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Self::Output>;
}
pub type TaskId = u64;
pub struct Task {
pub id: TaskId,
pub state: AtomicUsize,
pub priority: u32,
pub future_ptr: *mut (),
pub drop_fn: Option<fn(*mut ())>,
pub poll_fn: Option<fn(*mut (), &mut Context<'_>) -> Poll<()>>,
pub on_complete: Option<fn(TaskId)>,
pub parent: Option<TaskId>,
pub children: Mutex<Vec<TaskId>>,
pub created_at: Instant,
pub waiters: Mutex<Vec<Waker>>,
}
impl Task {
pub fn new(id: TaskId, priority: u32) -> Self {
Self {
id,
state: AtomicUsize::new(TaskState::Ready as usize),
priority,
future_ptr: std::ptr::null_mut(),
drop_fn: None,
poll_fn: None,
on_complete: None,
parent: None,
children: Mutex::new(Vec::new()),
created_at: Instant::now(),
waiters: Mutex::new(Vec::new()),
}
}
pub fn get_state(&self) -> TaskState {
match self.state.load(Ordering::Acquire) {
0 => TaskState::Ready,
1 => TaskState::Running,
2 => TaskState::Waiting,
3 => TaskState::Completed,
4 => TaskState::Cancelled,
_ => TaskState::Panicked,
}
}
pub fn set_state(&self, state: TaskState) {
self.state.store(state as usize, Ordering::Release);
}
pub fn poll(&self, cx: &mut Context<'_>) -> Poll<()> {
if let Some(poll_fn) = self.poll_fn {
poll_fn(self.future_ptr, cx)
} else {
Poll::Ready(())
}
}
pub fn add_child(&self, child_id: TaskId) {
self.children.lock().unwrap().push(child_id);
}
pub fn notify_waiters(&self) {
let waiters = std::mem::take(&mut *self.waiters.lock().unwrap());
for waker in waiters {
waker.wake();
}
}
}
impl Drop for Task {
fn drop(&mut self) {
if let Some(drop_fn) = self.drop_fn {
if !self.future_ptr.is_null() {
drop_fn(self.future_ptr);
}
}
}
}
unsafe impl Send for Task {}
unsafe impl Sync for Task {}
pub struct WorkStealingDeque {
queue: Mutex<VecDeque<Arc<Task>>>,
len: AtomicUsize,
}
impl WorkStealingDeque {
pub fn new() -> Self {
Self {
queue: Mutex::new(VecDeque::new()),
len: AtomicUsize::new(0),
}
}
pub fn push(&self, task: Arc<Task>) {
let mut queue = self.queue.lock().unwrap();
queue.push_back(task);
self.len.fetch_add(1, Ordering::Release);
}
pub fn pop(&self) -> Option<Arc<Task>> {
let mut queue = self.queue.lock().unwrap();
if let Some(task) = queue.pop_back() {
self.len.fetch_sub(1, Ordering::Release);
Some(task)
} else {
None
}
}
pub fn steal(&self) -> Option<Arc<Task>> {
let mut queue = self.queue.lock().unwrap();
if let Some(task) = queue.pop_front() {
self.len.fetch_sub(1, Ordering::Release);
Some(task)
} else {
None
}
}
pub fn steal_batch(&self, max: usize) -> Vec<Arc<Task>> {
let mut queue = self.queue.lock().unwrap();
let steal_count = std::cmp::min(queue.len() / 2, max);
let mut stolen = Vec::with_capacity(steal_count);
for _ in 0..steal_count {
if let Some(task) = queue.pop_front() {
stolen.push(task);
}
}
self.len.fetch_sub(stolen.len(), Ordering::Release);
stolen
}
pub fn is_empty(&self) -> bool {
self.len.load(Ordering::Acquire) == 0
}
pub fn len(&self) -> usize {
self.len.load(Ordering::Acquire)
}
}
impl Default for WorkStealingDeque {
fn default() -> Self {
Self::new()
}
}
pub struct Worker {
pub id: usize,
pub local_queue: WorkStealingDeque,
pub tasks_executed: AtomicU64,
pub tasks_stolen: AtomicU64,
pub active: AtomicBool,
rng_state: AtomicU64,
}
impl Worker {
pub fn new(id: usize) -> Self {
Self {
id,
local_queue: WorkStealingDeque::new(),
tasks_executed: AtomicU64::new(0),
tasks_stolen: AtomicU64::new(0),
active: AtomicBool::new(true),
rng_state: AtomicU64::new(id as u64 ^ 0x517cc1b727220a95),
}
}
fn next_random(&self) -> u64 {
let mut x = self.rng_state.load(Ordering::Relaxed);
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.rng_state.store(x, Ordering::Relaxed);
x
}
pub fn select_victim(&self, num_workers: usize) -> usize {
let mut victim = (self.next_random() as usize) % num_workers;
if victim == self.id {
victim = (victim + 1) % num_workers;
}
victim
}
}
pub struct ExecutorInner {
pub workers: Vec<Arc<Worker>>,
pub global_queue: Mutex<VecDeque<Arc<Task>>>,
pub tasks: Mutex<std::collections::HashMap<TaskId, Arc<Task>>>,
pub next_task_id: AtomicU64,
pub running: AtomicBool,
pub shutdown: AtomicBool,
pub park_condvar: Condvar,
pub park_mutex: Mutex<()>,
pub active_workers: AtomicUsize,
}
impl ExecutorInner {
pub fn wake_task(&self, task_id: TaskId) {
let task = {
let tasks = self.tasks.lock().unwrap();
if let Some(task) = tasks.get(&task_id) {
let state = task.get_state();
if state == TaskState::Waiting {
task.set_state(TaskState::Ready);
Some(task.clone())
} else {
None
}
} else {
None
}
};
if let Some(task) = task {
self.global_queue.lock().unwrap().push_back(task);
self.park_condvar.notify_one();
}
}
pub fn alloc_task_id(&self) -> TaskId {
self.next_task_id.fetch_add(1, Ordering::Relaxed)
}
pub fn get_task(&self, task_id: TaskId) -> Option<Arc<Task>> {
self.tasks.lock().unwrap().get(&task_id).cloned()
}
pub fn unpark_workers(&self) {
self.park_condvar.notify_all();
}
}
#[derive(Debug, Clone)]
pub struct ExecutorConfig {
pub num_workers: usize,
pub global_queue_batch: usize,
pub work_stealing: bool,
pub stack_size: usize,
}
impl Default for ExecutorConfig {
fn default() -> Self {
Self {
num_workers: num_cpus(),
global_queue_batch: 32,
work_stealing: true,
stack_size: 2 * 1024 * 1024, }
}
}
fn num_cpus() -> usize {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
}
pub struct Executor {
inner: Arc<ExecutorInner>,
config: ExecutorConfig,
}
impl Executor {
pub fn new() -> Self {
Self::with_config(ExecutorConfig::default())
}
pub fn with_config(config: ExecutorConfig) -> Self {
let workers: Vec<Arc<Worker>> = (0..config.num_workers)
.map(|id| Arc::new(Worker::new(id)))
.collect();
let inner = Arc::new(ExecutorInner {
workers,
global_queue: Mutex::new(VecDeque::new()),
tasks: Mutex::new(std::collections::HashMap::new()),
next_task_id: AtomicU64::new(1),
running: AtomicBool::new(false),
shutdown: AtomicBool::new(false),
park_condvar: Condvar::new(),
park_mutex: Mutex::new(()),
active_workers: AtomicUsize::new(0),
});
Self { inner, config }
}
pub fn spawn(&self, priority: u32) -> TaskId {
let task_id = self.inner.alloc_task_id();
let task = Arc::new(Task::new(task_id, priority));
self.inner
.tasks
.lock()
.unwrap()
.insert(task_id, task.clone());
self.inner.global_queue.lock().unwrap().push_back(task);
self.inner.park_condvar.notify_one();
task_id
}
pub fn spawn_child(&self, parent_id: TaskId, priority: u32) -> Option<TaskId> {
let parent = self.inner.get_task(parent_id)?;
let task_id = self.spawn(priority);
if let Some(_task) = self.inner.get_task(task_id) {
parent.add_child(task_id);
}
Some(task_id)
}
pub fn cancel(&self, task_id: TaskId) -> bool {
if let Some(task) = self.inner.get_task(task_id) {
task.set_state(TaskState::Cancelled);
let children = task.children.lock().unwrap().clone();
for child_id in children {
self.cancel(child_id);
}
task.notify_waiters();
true
} else {
false
}
}
pub fn is_complete(&self, task_id: TaskId) -> bool {
self.inner
.get_task(task_id)
.map(|t| {
matches!(
t.get_state(),
TaskState::Completed | TaskState::Cancelled | TaskState::Panicked
)
})
.unwrap_or(true)
}
pub fn block_on(&self, task_id: TaskId) {
while !self.is_complete(task_id) {
if let Some(task) = self.inner.global_queue.lock().unwrap().pop_front() {
self.run_task(&task);
} else {
std::thread::yield_now();
}
}
}
fn run_task(&self, task: &Arc<Task>) {
let state = task.get_state();
if state != TaskState::Ready {
return;
}
task.set_state(TaskState::Running);
let waker = Waker::new(task.id, self.inner.clone());
let mut cx = Context::new(&waker);
match task.poll(&mut cx) {
Poll::Ready(()) => {
task.set_state(TaskState::Completed);
if let Some(on_complete) = task.on_complete {
on_complete(task.id);
}
task.notify_waiters();
}
Poll::Pending => {
task.set_state(TaskState::Waiting);
}
}
}
pub fn stats(&self) -> ExecutorStats {
let mut total_executed = 0;
let mut total_stolen = 0;
let mut queue_lengths = Vec::new();
for worker in &self.inner.workers {
total_executed += worker.tasks_executed.load(Ordering::Relaxed);
total_stolen += worker.tasks_stolen.load(Ordering::Relaxed);
queue_lengths.push(worker.local_queue.len());
}
ExecutorStats {
total_tasks_executed: total_executed,
total_tasks_stolen: total_stolen,
global_queue_len: self.inner.global_queue.lock().unwrap().len(),
worker_queue_lengths: queue_lengths,
active_workers: self.inner.active_workers.load(Ordering::Relaxed),
}
}
pub fn shutdown(&self) {
self.inner.shutdown.store(true, Ordering::Release);
self.inner.running.store(false, Ordering::Release);
self.inner.unpark_workers();
}
pub fn inner(&self) -> &Arc<ExecutorInner> {
&self.inner
}
}
impl Default for Executor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ExecutorStats {
pub total_tasks_executed: u64,
pub total_tasks_stolen: u64,
pub global_queue_len: usize,
pub worker_queue_lengths: Vec<usize>,
pub active_workers: usize,
}
#[derive(Debug)]
pub struct TimerEntry {
pub deadline: Instant,
pub task_id: TaskId,
pub cancelled: AtomicBool,
}
pub struct TimerWheel {
entries: Mutex<Vec<Arc<TimerEntry>>>,
executor: Arc<ExecutorInner>,
}
impl TimerWheel {
pub fn new(executor: Arc<ExecutorInner>) -> Self {
Self {
entries: Mutex::new(Vec::new()),
executor,
}
}
pub fn register(&self, deadline: Instant, task_id: TaskId) -> Arc<TimerEntry> {
let entry = Arc::new(TimerEntry {
deadline,
task_id,
cancelled: AtomicBool::new(false),
});
let mut entries = self.entries.lock().unwrap();
let pos = entries
.binary_search_by(|e| e.deadline.cmp(&deadline))
.unwrap_or_else(|e| e);
entries.insert(pos, entry.clone());
entry
}
pub fn register_delay(&self, delay: Duration, task_id: TaskId) -> Arc<TimerEntry> {
self.register(Instant::now() + delay, task_id)
}
pub fn cancel(&self, entry: &TimerEntry) {
entry.cancelled.store(true, Ordering::Release);
}
pub fn process(&self) -> Option<Duration> {
let now = Instant::now();
let mut entries = self.entries.lock().unwrap();
while let Some(entry) = entries.first() {
if entry.deadline <= now {
let entry = entries.remove(0);
if !entry.cancelled.load(Ordering::Acquire) {
self.executor.wake_task(entry.task_id);
}
} else {
return Some(entry.deadline - now);
}
}
None
}
}
pub struct Channel<T> {
buffer: Mutex<VecDeque<T>>,
capacity: usize,
send_waiters: Mutex<Vec<Waker>>,
recv_waiters: Mutex<Vec<Waker>>,
closed: AtomicBool,
}
impl<T> Channel<T> {
pub fn new(capacity: usize) -> Arc<Self> {
Arc::new(Self {
buffer: Mutex::new(VecDeque::with_capacity(capacity)),
capacity,
send_waiters: Mutex::new(Vec::new()),
recv_waiters: Mutex::new(Vec::new()),
closed: AtomicBool::new(false),
})
}
pub fn try_send(&self, value: T) -> Result<(), ChannelError<T>> {
if self.closed.load(Ordering::Acquire) {
return Err(ChannelError::Closed(value));
}
let mut buffer = self.buffer.lock().unwrap();
if buffer.len() >= self.capacity {
Err(ChannelError::Full(value))
} else {
buffer.push_back(value);
if let Some(waker) = self.recv_waiters.lock().unwrap().pop() {
waker.wake();
}
Ok(())
}
}
pub fn try_recv(&self) -> Result<T, ChannelError<()>> {
let mut buffer = self.buffer.lock().unwrap();
if let Some(value) = buffer.pop_front() {
if let Some(waker) = self.send_waiters.lock().unwrap().pop() {
waker.wake();
}
Ok(value)
} else if self.closed.load(Ordering::Acquire) {
Err(ChannelError::Closed(()))
} else {
Err(ChannelError::Empty(()))
}
}
pub fn register_send_wait(&self, waker: Waker) {
self.send_waiters.lock().unwrap().push(waker);
}
pub fn register_recv_wait(&self, waker: Waker) {
self.recv_waiters.lock().unwrap().push(waker);
}
pub fn close(&self) {
self.closed.store(true, Ordering::Release);
for waker in self.send_waiters.lock().unwrap().drain(..) {
waker.wake();
}
for waker in self.recv_waiters.lock().unwrap().drain(..) {
waker.wake();
}
}
pub fn is_closed(&self) -> bool {
self.closed.load(Ordering::Acquire)
}
pub fn len(&self) -> usize {
self.buffer.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug)]
pub enum ChannelError<T> {
Full(T),
Empty(T),
Closed(T),
}
pub struct Semaphore {
permits: AtomicUsize,
max_permits: usize,
waiters: Mutex<Vec<Waker>>,
}
impl Semaphore {
pub fn new(permits: usize) -> Self {
Self {
permits: AtomicUsize::new(permits),
max_permits: permits,
waiters: Mutex::new(Vec::new()),
}
}
pub fn try_acquire(&self) -> bool {
loop {
let current = self.permits.load(Ordering::Acquire);
if current == 0 {
return false;
}
if self
.permits
.compare_exchange_weak(current, current - 1, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return true;
}
}
}
pub fn register_wait(&self, waker: Waker) {
self.waiters.lock().unwrap().push(waker);
}
pub fn release(&self) {
let prev = self.permits.fetch_add(1, Ordering::AcqRel);
debug_assert!(prev < self.max_permits, "released more than acquired");
if let Some(waker) = self.waiters.lock().unwrap().pop() {
waker.wake();
}
}
pub fn available(&self) -> usize {
self.permits.load(Ordering::Acquire)
}
}
pub struct Oneshot<T> {
value: Mutex<Option<T>>,
completed: AtomicBool,
waiter: Mutex<Option<Waker>>,
}
impl<T> Oneshot<T> {
pub fn new() -> Arc<Self> {
Arc::new(Self {
value: Mutex::new(None),
completed: AtomicBool::new(false),
waiter: Mutex::new(None),
})
}
pub fn send(&self, value: T) -> Result<(), T> {
if self.completed.load(Ordering::Acquire) {
return Err(value);
}
*self.value.lock().unwrap() = Some(value);
self.completed.store(true, Ordering::Release);
if let Some(waker) = self.waiter.lock().unwrap().take() {
waker.wake();
}
Ok(())
}
pub fn try_recv(&self) -> Option<T> {
if self.completed.load(Ordering::Acquire) {
self.value.lock().unwrap().take()
} else {
None
}
}
pub fn register_wait(&self, waker: Waker) {
*self.waiter.lock().unwrap() = Some(waker);
}
pub fn is_completed(&self) -> bool {
self.completed.load(Ordering::Acquire)
}
}
impl<T> Default for Oneshot<T> {
fn default() -> Self {
Self {
value: Mutex::new(None),
completed: AtomicBool::new(false),
waiter: Mutex::new(None),
}
}
}
pub struct JoinHandle {
task_id: TaskId,
executor: Arc<ExecutorInner>,
}
impl JoinHandle {
pub fn new(task_id: TaskId, executor: Arc<ExecutorInner>) -> Self {
Self { task_id, executor }
}
pub fn is_finished(&self) -> bool {
self.executor
.get_task(self.task_id)
.map(|t| {
matches!(
t.get_state(),
TaskState::Completed | TaskState::Cancelled | TaskState::Panicked
)
})
.unwrap_or(true)
}
pub fn register_wait(&self, waker: Waker) {
if let Some(task) = self.executor.get_task(self.task_id) {
task.waiters.lock().unwrap().push(waker);
}
}
pub fn task_id(&self) -> TaskId {
self.task_id
}
}
#[derive(Debug, Clone)]
pub struct AsyncFnMetadata {
pub name: String,
pub await_points: usize,
pub state_size: usize,
pub captures: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AsyncState {
Start,
AwaitPoint(u32),
Complete,
}
impl AsyncState {
pub const START_VALUE: u32 = 0;
pub const COMPLETE_VALUE: u32 = u32::MAX;
pub fn from_u32(v: u32) -> Self {
match v {
Self::START_VALUE => AsyncState::Start,
Self::COMPLETE_VALUE => AsyncState::Complete,
n => AsyncState::AwaitPoint(n),
}
}
pub fn to_u32(self) -> u32 {
match self {
AsyncState::Start => Self::START_VALUE,
AsyncState::Complete => Self::COMPLETE_VALUE,
AsyncState::AwaitPoint(n) => n,
}
}
}
pub fn async_state_machine_layout(
name: &str,
state_fields: &[(String, String)], ) -> String {
let mut layout = format!("struct {}State {{\n", name);
layout.push_str(" __state: u32,\n");
for (field_name, field_type) in state_fields {
layout.push_str(&format!(" {}: {},\n", field_name, field_type));
}
layout.push_str("}\n");
layout
}
use std::cell::RefCell;
thread_local! {
static CURRENT_TASK: RefCell<Option<TaskId>> = const { RefCell::new(None) };
}
pub fn current_task_id() -> Option<TaskId> {
CURRENT_TASK.with(|t| *t.borrow())
}
pub fn set_current_task_id(id: Option<TaskId>) {
CURRENT_TASK.with(|t| *t.borrow_mut() = id);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_work_stealing_deque() {
let deque = WorkStealingDeque::new();
let task1 = Arc::new(Task::new(1, 0));
let task2 = Arc::new(Task::new(2, 0));
let task3 = Arc::new(Task::new(3, 0));
deque.push(task1.clone());
deque.push(task2.clone());
deque.push(task3.clone());
assert_eq!(deque.len(), 3);
assert_eq!(deque.pop().unwrap().id, 3);
assert_eq!(deque.steal().unwrap().id, 1);
assert_eq!(deque.len(), 1);
}
#[test]
fn test_executor_spawn() {
let executor = Executor::new();
let task_id = executor.spawn(0);
assert!(task_id > 0);
assert!(!executor.is_complete(task_id));
}
#[test]
fn test_channel() {
let channel: Arc<Channel<i32>> = Channel::new(2);
assert!(channel.try_send(1).is_ok());
assert!(channel.try_send(2).is_ok());
assert!(matches!(channel.try_send(3), Err(ChannelError::Full(3))));
assert_eq!(channel.try_recv().unwrap(), 1);
assert_eq!(channel.try_recv().unwrap(), 2);
assert!(matches!(channel.try_recv(), Err(ChannelError::Empty(()))));
}
#[test]
fn test_semaphore() {
let sem = Semaphore::new(2);
assert!(sem.try_acquire());
assert!(sem.try_acquire());
assert!(!sem.try_acquire());
sem.release();
assert!(sem.try_acquire());
}
#[test]
fn test_oneshot() {
let oneshot: Arc<Oneshot<i32>> = Oneshot::new();
assert!(!oneshot.is_completed());
assert!(oneshot.try_recv().is_none());
assert!(oneshot.send(42).is_ok());
assert!(oneshot.is_completed());
assert_eq!(oneshot.try_recv(), Some(42));
assert!(oneshot.try_recv().is_none());
}
#[test]
fn test_poll() {
let ready: Poll<i32> = Poll::Ready(42);
let pending: Poll<i32> = Poll::Pending;
assert!(ready.is_ready());
assert!(!ready.is_pending());
assert!(!pending.is_ready());
assert!(pending.is_pending());
let mapped = Poll::Ready(10).map(|x| x * 2);
assert!(matches!(mapped, Poll::Ready(20)));
}
}