use crate::runtime::execution::ExecutionState;
use crate::runtime::task::TaskId;
use crate::runtime::thread;
use std::future::Future;
use std::pin::Pin;
use std::result::Result;
use std::task::{Context, Poll};
pub fn spawn<T, F>(fut: F) -> JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let result = std::sync::Arc::new(std::sync::Mutex::new(None));
let stack_size = ExecutionState::with(|s| s.config.stack_size);
let task_id = ExecutionState::spawn_future(Wrapper::new(fut, std::sync::Arc::clone(&result)), stack_size, None);
thread::switch();
JoinHandle { task_id, result }
}
#[derive(Debug)]
pub struct JoinHandle<T> {
task_id: TaskId,
result: std::sync::Arc<std::sync::Mutex<Option<Result<T, JoinError>>>>,
}
impl<T> JoinHandle<T> {
pub fn abort(&self) {
ExecutionState::try_with(|state| {
if !state.is_finished() {
let task = state.get_mut(self.task_id);
task.detach();
}
});
}
}
#[derive(Debug)]
pub enum JoinError {
Cancelled,
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
self.abort();
}
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(result) = self.result.lock().unwrap().take() {
Poll::Ready(result)
} else {
ExecutionState::with(|state| {
let me = state.current().id();
let r = state.get_mut(self.task_id).set_waiter(me);
assert!(r, "task shouldn't be finished if no result is present");
});
Poll::Pending
}
}
}
struct Wrapper<T, F> {
future: Pin<Box<F>>,
result: std::sync::Arc<std::sync::Mutex<Option<Result<T, JoinError>>>>,
}
impl<T, F> Wrapper<T, F>
where
F: Future<Output = T> + Send + 'static,
{
fn new(future: F, result: std::sync::Arc<std::sync::Mutex<Option<Result<T, JoinError>>>>) -> Self {
Self {
future: Box::pin(future),
result,
}
}
}
impl<T, F> Future for Wrapper<T, F>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.future.as_mut().poll(cx) {
Poll::Ready(result) => {
while let Some(local) = ExecutionState::with(|state| state.current_mut().pop_local()) {
drop(local);
}
*self.result.lock().unwrap() = Some(Ok(result));
ExecutionState::with(|state| {
if let Some(waiter) = state.current_mut().take_waiter() {
if !state.get_mut(waiter).finished() {
state.get_mut(waiter).unblock();
}
}
});
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
pub fn block_on<F: Future>(future: F) -> F::Output {
let mut future = Box::pin(future);
let waker = ExecutionState::with(|state| state.current_mut().waker());
let cx = &mut Context::from_waker(&waker);
thread::switch();
loop {
match future.as_mut().poll(cx) {
Poll::Ready(result) => break result,
Poll::Pending => {
ExecutionState::with(|state| state.current_mut().sleep_unless_woken());
}
}
thread::switch();
}
}
pub async fn yield_now() {
struct YieldNow {
yielded: bool,
}
impl Future for YieldNow {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.yielded {
return Poll::Ready(());
}
self.yielded = true;
cx.waker().wake_by_ref();
ExecutionState::request_yield();
Poll::Pending
}
}
YieldNow { yielded: false }.await
}