use core::task::Poll;
pub struct SharedStateRef<M>(alloc::rc::Rc<core::cell::RefCell<SharedState<M>>>);
struct SharedState<M> {
incoming_msg: Option<crate::Incoming<M>>,
outgoing_msg: Option<crate::Outgoing<M>>,
wants_recv_msg: bool,
wants_send_msg: bool,
yielded: bool,
}
impl<M> SharedState<M> {
pub fn new() -> Self {
Self {
incoming_msg: None,
outgoing_msg: None,
wants_recv_msg: false,
wants_send_msg: false,
yielded: false,
}
}
}
impl<M> SharedStateRef<M> {
pub fn new() -> Self {
Self(alloc::rc::Rc::new(core::cell::RefCell::new(
SharedState::new(),
)))
}
pub fn can_schedule(&self) -> core::task::Poll<CanSchedule<&Self>> {
let s = self.0.borrow();
let can_poll = !s.wants_recv_msg && !s.wants_send_msg && !s.yielded;
if can_poll {
core::task::Poll::Ready(CanSchedule(self))
} else {
core::task::Poll::Pending
}
}
pub fn protocol_saves_msg_to_be_sent(
&self,
msg: crate::Outgoing<M>,
) -> Result<(), crate::Outgoing<M>> {
let mut s = self.0.borrow_mut();
if s.outgoing_msg.is_some() {
return Err(msg);
}
s.outgoing_msg = Some(msg);
Ok(())
}
pub fn executor_takes_outgoing_msg(&self) -> Option<crate::Outgoing<M>> {
let mut s = self.0.borrow_mut();
if s.wants_send_msg {
debug_assert!(s.outgoing_msg.is_some());
s.wants_send_msg = false;
s.outgoing_msg.take()
} else {
None
}
}
pub fn protocol_wants_more_messages(&self) -> bool {
let s = self.0.borrow();
s.wants_recv_msg
}
pub fn executor_reads_and_resets_yielded_flag(&self) -> bool {
let mut s = self.0.borrow_mut();
let y = s.yielded;
s.yielded = false;
y
}
pub fn executor_received_msg(&self, msg: crate::Incoming<M>) -> Result<(), crate::Incoming<M>> {
let mut s = self.0.borrow_mut();
if s.incoming_msg.is_some() {
return Err(msg);
}
s.incoming_msg = Some(msg);
s.wants_recv_msg = false;
Ok(())
}
}
impl<M> Clone for SharedStateRef<M> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<M> core::fmt::Debug for SharedState<M> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SharedState")
.field("incoming_msg_present", &self.incoming_msg.is_some())
.field("outgoing_msg_present", &self.outgoing_msg.is_some())
.field("wants_recv_msg", &self.wants_recv_msg)
.field("wants_recv_msg", &self.wants_recv_msg)
.field("yielded", &self.yielded)
.finish()
}
}
impl<M> core::fmt::Debug for SharedStateRef<M> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.0.fmt(f)
}
}
pub struct CanSchedule<T>(T);
impl<M> CanSchedule<&SharedStateRef<M>> {
fn borrow_mut(&self) -> core::cell::RefMut<SharedState<M>> {
self.0 .0.borrow_mut()
}
pub fn protocol_flushes_outgoing_msg(self) -> Poll<()> {
let mut s = self.borrow_mut();
if s.outgoing_msg.is_some() {
s.wants_send_msg = true;
Poll::Pending
} else {
Poll::Ready(())
}
}
pub fn protocol_needs_one_more_msg(self) -> Poll<crate::Incoming<M>> {
let mut s = self.borrow_mut();
match s.incoming_msg.take() {
Some(msg) => Poll::Ready(msg),
None => {
s.wants_recv_msg = true;
Poll::Pending
}
}
}
pub fn protocol_yields(self) {
let mut s = self.borrow_mut();
s.yielded = true;
}
}
#[cfg(test)]
mod test {
use core::task::Poll;
use crate::{Incoming, MessageDestination, Outgoing};
use super::SharedStateRef;
#[test]
fn send_msg() {
let shared_state = SharedStateRef::<u32>::new();
let outgoings_state = shared_state.clone();
let executor_state = shared_state;
let msg = Outgoing {
recipient: MessageDestination::AllParties,
msg: 1,
};
outgoings_state
.protocol_saves_msg_to_be_sent(msg)
.expect("msg slot isn't empty");
let Poll::Ready(scheduler) = outgoings_state.can_schedule() else {
panic!("can't schedule");
};
let Poll::Pending = scheduler.protocol_flushes_outgoing_msg() else {
panic!("flushing resolved too early");
};
let msg_actual = executor_state.executor_takes_outgoing_msg().unwrap();
assert_eq!(msg, msg_actual);
let Poll::Ready(scheduler) = outgoings_state.can_schedule() else {
panic!("can't schedule");
};
let Poll::Ready(()) = scheduler.protocol_flushes_outgoing_msg() else {
panic!("flushing must be done at this point");
};
}
#[test]
fn recv_msg() {
let shared_state = SharedStateRef::<&'static str>::new();
let incomings_state = shared_state.clone();
let executor_state = shared_state;
{
let Poll::Ready(scheduler) = incomings_state.can_schedule() else {
panic!("can't schedule");
};
let Poll::Pending = scheduler.protocol_needs_one_more_msg() else {
panic!("unexpected incoming msg");
};
}
assert!(matches!(incomings_state.can_schedule(), Poll::Pending));
let incoming_msg = Incoming {
id: 0,
sender: 1,
msg_type: crate::MessageType::Broadcast,
msg: "hello",
};
executor_state.executor_received_msg(incoming_msg).unwrap();
{
let Poll::Ready(scheduler) = incomings_state.can_schedule() else {
panic!("can't schedule");
};
let Poll::Ready(msg) = scheduler.protocol_needs_one_more_msg() else {
panic!("no incoming msg");
};
assert_eq!(msg, incoming_msg)
}
}
#[test]
fn yielding() {
let shared_state = SharedStateRef::<()>::new();
let runtime_state = shared_state.clone();
let executor_state = shared_state;
{
let Poll::Ready(scheduler) = runtime_state.can_schedule() else {
panic!("can't schedule");
};
scheduler.protocol_yields();
}
assert!(matches!(runtime_state.can_schedule(), Poll::Pending));
{
let yielded = executor_state.executor_reads_and_resets_yielded_flag();
assert!(yielded);
}
assert!(matches!(executor_state.can_schedule(), Poll::Ready(_)));
}
#[test]
fn task_cannot_be_scheduled_when_another_task_is_scheduled() {
let try_obtain_lock_and_fail = |shared_state: &SharedStateRef<u32>| {
let Poll::Pending = shared_state.can_schedule() else {
panic!("lock must not be obtained");
};
};
{
let shared_state = SharedStateRef::new();
shared_state
.protocol_saves_msg_to_be_sent(Outgoing {
recipient: MessageDestination::AllParties,
msg: 1,
})
.expect("msg slot isn't empty");
let Poll::Ready(scheduler) = shared_state.can_schedule() else {
panic!("can't schedule");
};
let Poll::Pending = scheduler.protocol_flushes_outgoing_msg() else {
panic!("flushing resolved too early")
};
try_obtain_lock_and_fail(&shared_state);
}
{
let shared_state = SharedStateRef::new();
let Poll::Ready(scheduler) = shared_state.can_schedule() else {
panic!("can't schedule");
};
let Poll::Pending = scheduler.protocol_needs_one_more_msg() else {
panic!("receiving resolved too early")
};
try_obtain_lock_and_fail(&shared_state);
}
{
let shared_state = SharedStateRef::new();
let Poll::Ready(scheduler) = shared_state.can_schedule() else {
panic!("can't schedule");
};
scheduler.protocol_yields();
try_obtain_lock_and_fail(&shared_state);
}
}
}