use crate::runtime::execution::ExecutionState;
use crate::runtime::task::TaskId;
use crate::runtime::thread;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::result::Result;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
pub mod batch_semaphore;
fn spawn_inner<F>(fut: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
let stack_size = ExecutionState::with(|s| s.config.stack_size);
let inner = Arc::new(std::sync::Mutex::new(JoinHandleInner::default()));
let task_id = ExecutionState::spawn_future(Wrapper::new(fut, inner.clone()), stack_size, None);
thread::switch();
JoinHandle { task_id, inner }
}
pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
spawn_inner(fut)
}
pub fn spawn_local<F>(fut: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
spawn_inner(fut)
}
#[derive(Debug, Clone)]
pub struct AbortHandle {
task_id: TaskId,
}
impl AbortHandle {
pub fn abort(&self) {
ExecutionState::try_with(|state| {
if !state.is_finished() {
let task = state.get_mut(self.task_id);
task.abort();
}
});
}
pub fn is_finished(&self) -> bool {
ExecutionState::with(|state| {
let task = state.get(self.task_id);
task.finished()
})
}
}
unsafe impl Send for AbortHandle {}
unsafe impl Sync for AbortHandle {}
#[derive(Debug)]
pub struct JoinHandle<T> {
task_id: TaskId,
inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<T>>>,
}
#[derive(Debug)]
struct JoinHandleInner<T> {
result: Option<Result<T, JoinError>>,
waker: Option<Waker>,
}
impl<T> Default for JoinHandleInner<T> {
fn default() -> Self {
JoinHandleInner {
result: None,
waker: None,
}
}
}
impl<T> JoinHandle<T> {
pub fn abort(&self) {
ExecutionState::try_with(|state| {
if !state.is_finished() {
let task = state.get_mut(self.task_id);
task.abort();
}
});
}
pub fn is_finished(&self) -> bool {
ExecutionState::with(|state| {
let task = state.get(self.task_id);
task.finished()
})
}
pub fn abort_handle(&self) -> AbortHandle {
AbortHandle { task_id: self.task_id }
}
}
#[derive(Debug)]
pub enum JoinError {
Cancelled,
}
impl Display for JoinError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
JoinError::Cancelled => write!(f, "task was cancelled"),
}
}
}
impl Error for JoinError {}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
self.abort();
}
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut lock = self.inner.lock().unwrap();
if let Some(result) = lock.result.take() {
Poll::Ready(result)
} else {
lock.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
struct Wrapper<F: Future> {
future: Pin<Box<F>>,
inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<F::Output>>>,
}
impl<F> Wrapper<F>
where
F: Future + 'static,
F::Output: 'static,
{
fn new(future: F, inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<F::Output>>>) -> Self {
Self {
future: Box::pin(future),
inner,
}
}
}
impl<F> Future for Wrapper<F>
where
F: Future + 'static,
F::Output: 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.future.as_mut().poll(cx) {
Poll::Ready(result) => {
if ExecutionState::try_with(|state| state.is_finished()).unwrap_or(true) {
return Poll::Ready(());
}
while let Some(local) = ExecutionState::with(|state| state.current_mut().pop_local()) {
drop(local);
}
let mut lock = self.inner.lock().unwrap();
lock.result = Some(Ok(result));
if let Some(waker) = lock.waker.take() {
waker.wake();
}
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
pub fn block_on<F: Future>(future: F) -> F::Output {
let mut future = Box::pin(future);
let waker = ExecutionState::with(|state| state.current_mut().waker());
let cx = &mut Context::from_waker(&waker);
loop {
match future.as_mut().poll(cx) {
Poll::Ready(result) => break result,
Poll::Pending => {
ExecutionState::with(|state| state.current_mut().sleep_unless_woken());
}
}
thread::switch();
}
}
pub async fn yield_now() {
struct YieldNow {
yielded: bool,
}
impl Future for YieldNow {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.yielded {
return Poll::Ready(());
}
self.yielded = true;
cx.waker().wake_by_ref();
ExecutionState::request_yield();
Poll::Pending
}
}
YieldNow { yielded: false }.await
}