#![allow(private_interfaces)]
pub mod raw_task;
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use crate::scheduler::{RawTask, SchedulerHandle};
pub use crate::scheduler::TaskId;
pub use crate::scheduler::gen_task_id;
#[derive(Clone, Copy, PartialEq, Eq)]
enum TaskState {
Running = 0,
Waiting = 1,
Completed = 2,
Cancelled = 3,
Panicked = 4,
}
impl TaskState {
fn from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(Self::Running),
1 => Some(Self::Waiting),
2 => Some(Self::Completed),
3 => Some(Self::Cancelled),
4 => Some(Self::Panicked),
_ => None,
}
}
fn is_finished(self) -> bool {
matches!(self, Self::Completed | Self::Cancelled | Self::Panicked)
}
}
#[allow(dead_code)]
struct TaskInner<T> {
id: TaskId,
state: AtomicU8,
ref_count: AtomicUsize,
scheduler: SchedulerHandle,
raw_task: AtomicUsize,
output: lock::OptionalCell<T>,
waiter: futures::task::AtomicWaker,
}
mod lock {
use std::mem::MaybeUninit;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU8, Ordering};
pub(super) struct OptionalCell<T> {
inner: Mutex<MaybeUninit<T>>,
initialized: AtomicU8,
}
impl<T> OptionalCell<T> {
#[allow(dead_code)]
pub(super) fn new() -> Self {
Self {
inner: Mutex::new(MaybeUninit::uninit()),
initialized: AtomicU8::new(0),
}
}
#[allow(dead_code)]
pub(super) fn set(&self, value: T) {
let mut inner = self.inner.lock().unwrap();
*inner = MaybeUninit::new(value);
self.initialized.store(1, Ordering::Release);
}
#[allow(dead_code)]
pub(super) unsafe fn get(&self) -> Option<T> {
if self.initialized.load(Ordering::Acquire) == 1 {
let inner = self.inner.lock().unwrap();
Some(inner.assume_init_read())
} else {
None
}
}
}
unsafe impl<T: Send> Send for OptionalCell<T> {}
unsafe impl<T: Send> Sync for OptionalCell<T> {}
impl<T> Drop for OptionalCell<T> {
fn drop(&mut self) {
if self.initialized.load(Ordering::Acquire) == 1 {
let mut inner = self.inner.lock().unwrap();
unsafe {
std::ptr::drop_in_place(inner.as_mut_ptr());
}
}
}
}
}
#[allow(dead_code)]
pub struct Task<T> {
inner: Arc<TaskInner<T>>,
}
impl<T> Task<T> {
#[allow(dead_code)]
fn new<F>(_future: F, id: TaskId, scheduler: SchedulerHandle) -> (Self, RawTask)
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let inner = Arc::new(TaskInner {
id,
state: AtomicU8::new(TaskState::Running as u8),
ref_count: AtomicUsize::new(2), scheduler,
raw_task: AtomicUsize::new(0),
output: lock::OptionalCell::new(),
waiter: futures::task::AtomicWaker::new(),
});
let raw_task = Arc::into_raw(inner.clone()) as RawTask;
inner.raw_task.store(raw_task as usize, Ordering::Release);
let task = Task { inner };
(task, raw_task)
}
#[must_use]
pub fn id(&self) -> TaskId {
self.inner.id
}
#[allow(dead_code)]
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<T> {
Poll::Pending
}
}
use std::pin::Pin;
impl<T> Drop for Task<T> {
fn drop(&mut self) {
self.inner.raw_task.store(0, Ordering::Release);
}
}
#[allow(dead_code)]
fn task_waker(inner: &Arc<TaskInner<()>>) -> Waker {
let cloned = inner.clone();
let data = Arc::into_raw(cloned) as *const ();
unsafe { Waker::from_raw(RawWaker::new(data, &RAW_WAKER_VTABLE)) }
}
#[allow(dead_code)]
static RAW_WAKER_VTABLE: RawWakerVTable =
RawWakerVTable::new(raw_waker_clone, raw_waker_wake, raw_waker_wake_by_ref, raw_waker_drop);
#[allow(dead_code)]
unsafe fn raw_waker_clone(data: *const ()) -> RawWaker {
let inner = &*(data as *const TaskInner<()>);
inner.ref_count.fetch_add(1, Ordering::Relaxed);
RawWaker::new(data, &RAW_WAKER_VTABLE)
}
#[allow(dead_code)]
unsafe fn raw_waker_wake(data: *const ()) {
raw_waker_wake_by_ref(data);
raw_waker_drop(data);
}
#[allow(dead_code)]
unsafe fn raw_waker_wake_by_ref(data: *const ()) {
let inner = &*(data as *const TaskInner<()>);
if inner
.state
.compare_exchange(
TaskState::Waiting as u8,
TaskState::Running as u8,
Ordering::Release,
Ordering::Relaxed,
)
.is_err()
{
return; }
let raw_task = inner.raw_task.load(Ordering::Acquire) as RawTask;
if raw_task as usize != 0 {
let _ = inner.scheduler.submit(raw_task);
}
}
#[allow(dead_code)]
unsafe fn raw_waker_drop(data: *const ()) {
let inner = &*(data as *const TaskInner<()>);
if inner.ref_count.fetch_sub(1, Ordering::Release) == 1 {
}
}
pub struct JoinHandle<T> {
inner: Option<Arc<TaskInner<T>>>,
raw_core: Option<raw_task::TaskRef>,
}
impl<T> JoinHandle<T> {
#[must_use]
pub fn id(&self) -> TaskId {
if let Some(refs) = &self.raw_core
&& let Some(core) = refs.core()
{
return core.id();
}
self.inner.as_ref().map_or(0, |i| i.id)
}
#[must_use]
pub fn is_finished(&self) -> bool {
if let Some(refs) = &self.raw_core
&& let Some(core) = refs.core()
{
return core.is_completed();
}
self.inner
.as_ref()
.and_then(|i| TaskState::from_u8(i.state.load(Ordering::Acquire)))
.is_some_and(TaskState::is_finished)
}
pub async fn wait(self) -> Result<T, JoinError> {
if let Some(refs) = &self.raw_core
&& let Some(core) = refs.core()
{
std::future::poll_fn(|cx| {
if core.is_completed() {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
})
.await;
return unsafe { raw_task::read_output::<T>(core) }.ok_or(JoinError::TaskCancelled);
}
if let Some(inner) = self.inner {
return WaitForTask::new(inner).await;
}
Err(JoinError::TaskCancelled)
}
}
struct WaitForTask<T> {
inner: Option<Arc<TaskInner<T>>>,
}
impl<T> WaitForTask<T> {
fn new(inner: Arc<TaskInner<T>>) -> Self {
Self { inner: Some(inner) }
}
}
impl<T> Future for WaitForTask<T> {
type Output = Result<T, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.inner.as_ref().unwrap();
inner.waiter.register(cx.waker());
let state = TaskState::from_u8(inner.state.load(Ordering::Acquire));
match state {
Some(TaskState::Completed) => {
let output = unsafe { inner.output.get() };
if let Some(result) = output {
self.inner = None;
Poll::Ready(Ok(result))
} else {
Poll::Ready(Err(JoinError::TaskCancelled))
}
},
Some(TaskState::Cancelled) => {
self.inner = None;
Poll::Ready(Err(JoinError::TaskCancelled))
},
Some(TaskState::Panicked) => {
self.inner = None;
Poll::Ready(Err(JoinError::TaskPanic))
},
Some(TaskState::Running | TaskState::Waiting) => {
Poll::Pending
},
None => Poll::Ready(Err(JoinError::TaskCancelled)),
}
}
}
impl<T> Drop for WaitForTask<T> {
fn drop(&mut self) {
self.inner = None;
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum JoinError {
TaskCancelled,
TaskPanic,
}
impl std::fmt::Display for JoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TaskCancelled => write!(f, "Task was cancelled"),
Self::TaskPanic => write!(f, "Task panicked"),
}
}
}
impl std::error::Error for JoinError {}
pub fn spawn<F, T>(future: F) -> JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
if let Some(handle) = crate::runtime::Handle::try_current() {
let (raw_task, task_ref) = raw_task::allocate_task(future, handle.scheduler().clone());
let id = task_ref.core().map_or(0, raw_task::TaskCore::id);
let _ = handle.scheduler().submit(raw_task);
return JoinHandle {
inner: Some(Arc::new(TaskInner {
id,
state: AtomicU8::new(TaskState::Running as u8),
ref_count: AtomicUsize::new(1),
scheduler: handle.scheduler().clone(),
raw_task: AtomicUsize::new(0),
output: lock::OptionalCell::new(),
waiter: futures::task::AtomicWaker::new(),
})),
raw_core: Some(task_ref),
};
}
let id = gen_task_id();
let inner = Arc::new(TaskInner {
id,
state: AtomicU8::new(TaskState::Running as u8),
ref_count: AtomicUsize::new(1),
scheduler: SchedulerHandle::new_default(),
raw_task: AtomicUsize::new(0),
output: lock::OptionalCell::new(),
waiter: futures::task::AtomicWaker::new(),
});
let inner_clone = inner.clone();
std::thread::spawn(move || {
let mut future = Box::pin(future);
let waker = Waker::noop();
let mut context = Context::from_waker(waker);
let result = loop {
match Pin::new(&mut future).poll(&mut context) {
Poll::Ready(value) => break value,
Poll::Pending => {
std::thread::sleep(std::time::Duration::from_millis(1));
},
}
};
inner_clone.output.set(result);
inner_clone
.state
.store(TaskState::Completed as u8, Ordering::Release);
inner_clone.waiter.wake();
});
JoinHandle {
inner: Some(inner),
raw_core: None,
}
}
pub fn block_on<F, T>(future: F) -> T
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
use std::pin::Pin;
use std::sync::mpsc;
use std::task::{Context, Poll, RawWaker, Waker};
use std::{ptr, thread};
let (sender, receiver) = mpsc::channel();
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_RAW_WAKER_VTABLE)) };
thread::spawn(move || {
let mut future = Box::pin(future);
let mut cx = Context::from_waker(&waker);
loop {
match Pin::as_mut(&mut future).poll(&mut cx) {
Poll::Ready(result) => {
let _ = sender.send(result);
break;
},
Poll::Pending => {
thread::sleep(std::time::Duration::from_millis(1));
},
}
}
});
receiver
.recv()
.unwrap_or_else(|_| panic!("block_on: Failed to receive result from executor"))
}
const NOOP_RAW_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| RawWaker::new(std::ptr::null(), &NOOP_RAW_WAKER_VTABLE), |_| {}, |_| {}, |_| {}, );
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_id_generation() {
let id1 = gen_task_id();
let id2 = gen_task_id();
assert!(id2 > id1);
}
#[test]
fn test_task_state() {
assert_eq!(TaskState::Running as u8, 0);
assert_eq!(TaskState::Completed as u8, 2);
assert!(TaskState::Completed.is_finished());
assert!(!TaskState::Running.is_finished());
}
#[test]
fn test_join_error_display() {
assert_eq!(format!("{}", JoinError::TaskCancelled), "Task was cancelled");
assert_eq!(format!("{}", JoinError::TaskPanic), "Task panicked");
}
#[test]
fn test_join_error_equality() {
assert_eq!(JoinError::TaskCancelled, JoinError::TaskCancelled);
assert_eq!(JoinError::TaskPanic, JoinError::TaskPanic);
assert_ne!(JoinError::TaskCancelled, JoinError::TaskPanic);
}
#[test]
fn test_join_error_is_std_error() {
let err: Box<dyn std::error::Error> = Box::new(JoinError::TaskCancelled);
assert_eq!(err.to_string(), "Task was cancelled");
let err: Box<dyn std::error::Error> = Box::new(JoinError::TaskPanic);
assert_eq!(err.to_string(), "Task panicked");
}
#[test]
fn test_block_on_free_function() {
let result = block_on(async { 42i32 });
assert_eq!(result, 42);
}
#[test]
fn test_block_on_free_function_string() {
let result = block_on(async { String::from("hiver") });
assert_eq!(result, "hiver");
}
#[test]
fn test_block_on_free_function_unit() {
block_on(async {});
}
#[test]
fn test_block_on_free_function_complex() {
let result = block_on(async {
let a = 10;
let b = 20;
a + b
});
assert_eq!(result, 30);
}
#[test]
fn test_task_id_uniqueness() {
use std::collections::HashSet;
let ids: HashSet<_> = (0..100).map(|_| gen_task_id()).collect();
assert_eq!(ids.len(), 100, "all generated task IDs should be unique");
}
#[test]
fn test_task_state_is_finished() {
assert!(TaskState::Completed.is_finished());
assert!(TaskState::Cancelled.is_finished());
assert!(TaskState::Panicked.is_finished());
assert!(!TaskState::Running.is_finished());
assert!(!TaskState::Waiting.is_finished());
}
#[test]
fn test_task_state_from_u8_roundtrip() {
let states = [
TaskState::Running,
TaskState::Waiting,
TaskState::Completed,
TaskState::Cancelled,
TaskState::Panicked,
];
for state in states {
let byte = state as u8;
let parsed = TaskState::from_u8(byte);
assert!(parsed == Some(state));
}
assert!(TaskState::from_u8(255).is_none());
}
}