use async_trait::async_trait;
use cfg_if::cfg_if;
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::StreamExt;
use serde::{de::DeserializeOwned, Serialize};
use std::collections::BTreeMap;
use std::time::Duration;
use crate::effects::contract::{
DeliveryModel, DocumentedHandlerContract, ExtensionDispatchContract, ExtensionDispatchMode,
HandlerContractProfile, HandlerContractTier, ProtocolSemanticContract, RetryPolicy,
TimeoutPolicy, TransportPolicyContract,
};
use crate::effects::{ChoreoHandler, ChoreoResult, ChoreographyError, RoleId};
use crate::RoleName;
type MessageChannelPair = (UnboundedSender<Vec<u8>>, UnboundedReceiver<Vec<u8>>);
type ChoiceChannelPair<L> = (UnboundedSender<L>, UnboundedReceiver<L>);
type MessageChannelMap =
std::sync::Arc<std::sync::Mutex<BTreeMap<(RoleKey, RoleKey), MessageChannelPair>>>;
type ChoiceChannelMap<L> =
std::sync::Arc<std::sync::Mutex<BTreeMap<(RoleKey, RoleKey), ChoiceChannelPair<L>>>>;
#[doc(hidden)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RoleKey {
name: RoleName,
index: Option<u32>,
}
impl RoleKey {
fn from_role<R: RoleId>(role: R) -> Self {
Self {
name: role.role_name(),
index: role.role_index(),
}
}
}
pub struct InMemoryHandler<R: RoleId> {
role: R,
channels: MessageChannelMap,
choice_channels: ChoiceChannelMap<R::Label>,
}
impl<R: RoleId> InMemoryHandler<R> {
pub fn new(role: R) -> Self {
Self {
role,
channels: std::sync::Arc::new(std::sync::Mutex::new(BTreeMap::new())),
choice_channels: std::sync::Arc::new(std::sync::Mutex::new(BTreeMap::new())),
}
}
pub fn with_channels(
role: R,
channels: MessageChannelMap,
choice_channels: ChoiceChannelMap<R::Label>,
) -> Self {
Self {
role,
channels,
choice_channels,
}
}
fn get_or_create_channel(&self, from: R, to: R) -> UnboundedSender<Vec<u8>> {
let mut channels = self
.channels
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let key = (RoleKey::from_role(from), RoleKey::from_role(to));
channels.entry(key).or_insert_with(unbounded).0.clone()
}
fn get_receiver(&self, from: R, to: R) -> Option<UnboundedReceiver<Vec<u8>>> {
let mut channels = self
.channels
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let key = (RoleKey::from_role(from), RoleKey::from_role(to));
channels.remove(&key).map(|(_, rx)| rx)
}
fn get_or_create_choice_channel(&self, from: R, to: R) -> UnboundedSender<R::Label> {
let mut channels = self
.choice_channels
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let key = (RoleKey::from_role(from), RoleKey::from_role(to));
channels.entry(key).or_insert_with(unbounded).0.clone()
}
fn get_choice_receiver(&self, from: R, to: R) -> Option<UnboundedReceiver<R::Label>> {
let mut channels = self
.choice_channels
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let key = (RoleKey::from_role(from), RoleKey::from_role(to));
channels.remove(&key).map(|(_, rx)| rx)
}
}
impl<R: RoleId> DocumentedHandlerContract for InMemoryHandler<R> {
fn contract_profile() -> HandlerContractProfile {
HandlerContractProfile {
handler_name: std::any::type_name::<Self>(),
tier: HandlerContractTier::FullProtocol,
semantics: ProtocolSemanticContract {
typed_send_recv_roundtrip: true,
exact_choice_label_preservation: true,
fail_closed_transport_errors: true,
timeouts_scoped_to_enforcing_role: true,
deterministic_for_regression: true,
can_materialize_values: true,
},
transport: TransportPolicyContract {
delivery_model: DeliveryModel::InMemoryChannels,
retry_policy: RetryPolicy::None,
timeout_policy: TimeoutPolicy::EnforcingRoleOnly,
},
extension_dispatch: ExtensionDispatchContract {
mode: ExtensionDispatchMode::Unsupported,
fail_closed_when_unregistered: false,
type_exact_before_side_effects: false,
},
notes: vec![
"intended for deterministic local testing rather than remote transport",
"role-pair channels are reinserted after each recv/offer operation",
],
}
}
}
#[async_trait]
impl<R: RoleId + 'static> ChoreoHandler for InMemoryHandler<R> {
type Role = R;
type Endpoint = ();
async fn send<M: Serialize + Send + Sync>(
&mut self,
_ep: &mut Self::Endpoint,
to: Self::Role,
msg: &M,
) -> ChoreoResult<()> {
let bytes =
bincode::serialize(msg).map_err(|e| ChoreographyError::Serialization(e.to_string()))?;
let sender = self.get_or_create_channel(self.role, to);
sender.unbounded_send(bytes).map_err(|_| {
ChoreographyError::Transport(format!(
"Failed to send message from {:?} to {:?}",
self.role, to
))
})?;
tracing::trace!(?to, "InMemoryHandler: send success");
Ok(())
}
async fn recv<M: DeserializeOwned + Send>(
&mut self,
_ep: &mut Self::Endpoint,
from: Self::Role,
) -> ChoreoResult<M> {
tracing::trace!(?from, "InMemoryHandler: recv start");
let mut receiver = self.get_receiver(from, self.role).ok_or_else(|| {
ChoreographyError::Transport(format!("No channel from {:?} to {:?}", from, self.role))
})?;
let bytes = receiver.next().await.ok_or_else(|| {
ChoreographyError::Transport("Channel closed while waiting for message".into())
})?;
{
let mut channels = self
.channels
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let key = (RoleKey::from_role(from), RoleKey::from_role(self.role));
if let Some((tx, _)) = channels.remove(&key) {
channels.insert(key, (tx, receiver));
}
}
let msg = bincode::deserialize(&bytes)
.map_err(|e| ChoreographyError::Serialization(e.to_string()))?;
tracing::trace!(?from, "InMemoryHandler: recv success");
Ok(msg)
}
async fn choose(
&mut self,
_ep: &mut Self::Endpoint,
who: Self::Role,
label: <Self::Role as RoleId>::Label,
) -> ChoreoResult<()> {
let sender = self.get_or_create_choice_channel(self.role, who);
sender.unbounded_send(label).map_err(|_| {
ChoreographyError::Transport(format!(
"Failed to send choice from {:?} to {:?}",
self.role, who
))
})?;
tracing::trace!(?who, ?label, "InMemoryHandler: sent choice");
Ok(())
}
async fn offer(
&mut self,
_ep: &mut Self::Endpoint,
from: Self::Role,
) -> ChoreoResult<<Self::Role as RoleId>::Label> {
tracing::trace!(?from, "InMemoryHandler: waiting for choice");
let mut receiver = self.get_choice_receiver(from, self.role).ok_or_else(|| {
ChoreographyError::Transport(format!(
"No choice channel from {:?} to {:?}",
from, self.role
))
})?;
let label = receiver.next().await.ok_or_else(|| {
ChoreographyError::Transport("Choice channel closed while waiting for label".into())
})?;
{
let mut channels = self
.choice_channels
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let key = (RoleKey::from_role(from), RoleKey::from_role(self.role));
if let Some((tx, _)) = channels.remove(&key) {
channels.insert(key, (tx, receiver));
}
}
tracing::trace!(?from, ?label, "InMemoryHandler: received choice");
Ok(label)
}
async fn with_timeout<F, T>(
&mut self,
_ep: &mut Self::Endpoint,
at: Self::Role,
dur: Duration,
body: F,
) -> ChoreoResult<T>
where
F: std::future::Future<Output = ChoreoResult<T>> + Send,
{
if at == self.role {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
use futures::future::{select, Either};
use futures::pin_mut;
use wasm_timer::Delay;
let timeout = Delay::new(dur);
pin_mut!(body);
pin_mut!(timeout);
match select(body, timeout).await {
Either::Left((result, _)) => result,
Either::Right(_) => Err(ChoreographyError::Timeout(dur)),
}
} else {
match tokio::time::timeout(dur, body).await {
Ok(result) => result,
Err(_) => Err(ChoreographyError::Timeout(dur)),
}
}
}
} else {
body.await
}
}
}