use crate::channel::{from_receiver, Builder, Receiver, Sender};
use crate::loc::WakeMsg;
use crate::msg::Message;
use crate::runtime::execution::ExecutionState;
use crate::runtime::task::TaskId;
use crate::runtime::thread::{self, switch};
use crate::thread::Thread;
use crate::CommunicationModel::LocalOrder;
use crate::TJoin;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::result::Result;
use std::task::{Context, Poll, Waker};
impl std::task::Wake for Sender<WakeMsg> {
fn wake(self: std::sync::Arc<Self>) {
self.send_msg(WakeMsg);
}
}
fn get_bidir_handles() -> (TwoWayCom, TwoWayCom) {
let (sender1, receiver1) = Builder::new().with_comm(LocalOrder).build();
let (sender2, receiver2) = Builder::new().with_comm(LocalOrder).build();
(
TwoWayCom {
sender: sender1,
receiver: receiver2,
},
TwoWayCom {
sender: sender2,
receiver: receiver1,
},
)
}
pub fn spawn<T, F>(fut: F) -> JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Message + 'static,
{
spawn_with_attributes::<T, F>(false, None, fut)
}
pub fn spawn_with_attributes<T, F>(is_daemon: bool, name: Option<String>, fut: F) -> JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Message + 'static,
{
thread::switch();
let stack_size = ExecutionState::with(|s| s.must.borrow().config.stack_size);
let (fut_handles, join_handles) = get_bidir_handles();
let task_id = ExecutionState::spawn_thread(
move || {
let (sender, fut_recv) = Builder::<WakeMsg>::new().build();
let fut_waker = Waker::from(std::sync::Arc::new(sender.clone()));
let mut fut = Box::pin(fut);
let mut res = fut.as_mut().poll(&mut Context::from_waker(&fut_waker));
let mut join_waker: Option<Waker> = None;
let res = loop {
match res {
Poll::Ready(res) => {
if let Some(waker) = join_waker {
waker.wake();
}
break Some(res);
}
Poll::Pending => { }
}
let (msg, ind) = crate::select_val_block(&fut_handles.receiver, &fut_recv);
if ind == 0 {
match msg.as_any().downcast::<PollerMsg>() {
Ok(waker) => match *waker {
PollerMsg::Waker(waker) => {
assert!(ind == 0);
join_waker = Some(waker.clone());
fut_handles.sender.send_msg(PollerMsg::Pending);
}
PollerMsg::Cancel => break None,
_ => unreachable!(),
},
_ => unreachable!(),
}
} else {
assert!(ind == 1);
assert!(msg.as_any().downcast::<WakeMsg>().is_ok());
res = fut.as_mut().poll(&mut Context::from_waker(&fut_waker));
}
};
let val = match res {
Some(result) => {
match fut_handles.receiver.recv_msg_block() {
PollerMsg::Waker(_) => {
fut_handles.sender.send_msg(PollerMsg::Ready);
crate::Val::new(result)
}
PollerMsg::Cancel => {
drop(fut);
crate::Val::new(())
}
_ => unreachable!(),
}
}
None => {
drop(fut);
crate::Val::new(())
}
};
fut_handles.sender.send_msg(PollerMsg::Done);
ExecutionState::with(|state| {
let pos = state.next_pos();
state
.must
.borrow_mut()
.handle_tend(crate::End::new(pos, val));
crate::must::Must::unstuck_joiners(state, pos.thread);
});
},
stack_size,
None,
);
let (thread_id, name) = ExecutionState::with(|state| {
let pos = state.next_pos();
let tid = state.must.borrow().next_thread_id(&pos);
let name = match name {
None => format!("<future-{}>", tid.to_number()),
Some(x) => x,
};
state.must.borrow_mut().handle_tcreate(
tid,
task_id,
None,
pos,
Some(name.clone()),
is_daemon,
);
(tid, Some(name))
});
let thread = Thread {
id: thread_id,
name,
};
thread::switch();
JoinHandle {
task_id,
thread,
com: join_handles,
_p: std::marker::PhantomData,
}
}
pub(crate) fn spawn_receive<T>(recv: &Receiver<T>) -> JoinHandle<T>
where
T: Message + Clone + 'static,
{
thread::switch();
let stack_size = ExecutionState::with(|s| s.must.borrow().config.stack_size);
let (fut_handles, join_handles) = get_bidir_handles();
let recv = recv.clone();
let task_id = ExecutionState::spawn_thread(
move || {
let mut join_waker: Option<Waker> = None;
let res = loop {
let (msg, ind) = crate::select_val_block(&fut_handles.receiver, &recv);
if ind == 0 {
match msg.as_any().downcast::<PollerMsg>() {
Ok(msg) => {
match *msg {
PollerMsg::Waker(waker) => {
join_waker = Some(waker.clone());
fut_handles.sender.send_msg(PollerMsg::Pending);
}
PollerMsg::Cancel => break None,
_ => unreachable!(),
}
}
_ => unreachable!(),
}
} else {
assert!(ind == 1);
match msg.as_any().downcast::<T>() {
Ok(result) => {
if let Some(waker) = join_waker {
waker.wake();
}
break Some(*result);
}
_ => unreachable!(),
}
}
};
let val = match res {
Some(result) => {
match fut_handles.receiver.recv_msg_block() {
PollerMsg::Waker(_) => {
fut_handles.sender.send_msg(PollerMsg::Ready);
crate::Val::new(result)
}
PollerMsg::Cancel => {
from_receiver(recv).send_msg(result);
crate::Val::new(())
}
_ => unreachable!(),
}
}
None => crate::Val::new(()),
};
fut_handles.sender.send_msg(PollerMsg::Done);
ExecutionState::with(|state| {
let pos = state.next_pos();
state
.must
.borrow_mut()
.handle_tend(crate::End::new(pos, val));
crate::must::Must::unstuck_joiners(state, pos.thread);
});
},
stack_size,
None,
);
let (thread_id, name) = ExecutionState::with(|state| {
let pos = state.next_pos();
let tid = state.must.borrow().next_thread_id(&pos);
let name = format!("<async_recv-{}>", tid.to_number());
state.must.borrow_mut().handle_tcreate(
tid,
task_id,
None,
pos,
Some(name.clone()),
false,
);
(tid, Some(name))
});
let thread = Thread {
id: thread_id,
name,
};
thread::switch();
JoinHandle {
task_id,
thread,
com: join_handles,
_p: std::marker::PhantomData,
}
}
#[derive(Debug)]
pub struct JoinHandle<T> {
task_id: TaskId,
thread: Thread,
com: TwoWayCom,
_p: std::marker::PhantomData<T>,
}
#[derive(Clone, Debug)]
pub enum PollerMsg {
Waker(Waker),
Pending,
Cancel,
Done,
Ready,
}
impl PartialEq for PollerMsg {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(PollerMsg::Waker(_), PollerMsg::Waker(_)) => true,
(PollerMsg::Pending, PollerMsg::Pending) => true,
(PollerMsg::Cancel, PollerMsg::Cancel) => true,
(PollerMsg::Ready, PollerMsg::Ready) => true,
(PollerMsg::Done, PollerMsg::Done) => true,
_ => false,
}
}
}
#[derive(Clone, Debug)]
pub struct TwoWayCom {
pub sender: Sender<PollerMsg>,
pub receiver: Receiver<PollerMsg>,
}
impl<T> JoinHandle<T> {
pub fn is_finished(&self) -> bool {
ExecutionState::with(|state| {
let task = state.get(self.task_id);
task.finished()
})
}
pub fn thread(&self) -> &Thread {
&self.thread
}
pub fn abort(&self) {
if ExecutionState::with(|state| state.is_running()) {
self.com.sender.send_msg(PollerMsg::Cancel);
let ack = self.com.receiver.recv_msg_block();
assert!(matches!(ack, PollerMsg::Done));
}
}
}
#[derive(Debug)]
pub enum JoinError {
Cancelled,
}
impl Display for JoinError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
JoinError::Cancelled => write!(f, "task was cancelled"),
}
}
}
impl Error for JoinError {}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
if std::thread::panicking() {
return;
}
self.abort();
}
}
impl<T: Message + 'static> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.com
.sender
.send_msg(PollerMsg::Waker(cx.waker().clone()));
match self.com.receiver.recv_msg_block() {
PollerMsg::Ready => {
loop {
switch();
let val = ExecutionState::with(|s| {
let target_task_id = s.get(self.task_id).id();
let target_id = s.must.borrow().to_thread_id(target_task_id);
let pos = s.next_pos();
s.must.borrow_mut().handle_tjoin(TJoin::new(pos, target_id))
});
if let Some(val) = val {
if val.is_pending() {
ExecutionState::with(|s| s.current_mut().stuck());
} else {
return Poll::Ready(Ok(*val.as_any().downcast().unwrap()));
}
}
ExecutionState::with(|s| s.prev_pos());
}
}
PollerMsg::Pending => Poll::Pending,
_ => unreachable!(),
}
}
}
pub fn block_on<F: Future>(future: F) -> F::Output {
let mut future = Box::pin(future);
let (sender, receiver) = Builder::<WakeMsg>::new().build();
let waker = Waker::from(std::sync::Arc::new(sender.clone()));
let cx = &mut Context::from_waker(&waker);
thread::switch();
loop {
match future.as_mut().poll(cx) {
Poll::Ready(result) => {
break result;
}
Poll::Pending => {
receiver.recv_msg_block();
}
}
thread::switch();
}
}
#[cfg(test)]
mod test {
use crate::{recv_msg_block, send_msg, thread, verify, Config};
use super::block_on;
#[test]
fn test_thread() {
verify(Config::builder().build(), || {
let parent_id = thread::current().id();
let fut = crate::future::spawn(async move {
let i: i32 = recv_msg_block();
send_msg(parent_id, i); 3 });
let fut_tid = fut.thread().id();
println!("Future's thread id is {}", fut.thread().id());
send_msg(fut_tid, 4);
let echoed: i32 = recv_msg_block();
assert_eq!(echoed, 4);
let res = block_on(fut);
println!("Retrieved {:?} from future", &res);
assert_eq!(res.unwrap(), 3);
});
}
}