use std::collections::BTreeMap;
use std::sync::Mutex;
use crate::coroutine::Value;
use crate::effect::{EffectFailure, EffectHandler, EffectResult};
use crate::engine::{ObsEvent, ProtocolMachine, ProtocolMachineError, StepResult};
use crate::semantic_objects::ProtocolMachineSemanticObjects;
struct SiteRunner {
machine: Mutex<ProtocolMachine>,
handler: Box<dyn EffectHandler>,
}
pub struct NestedProtocolMachineHandler {
sites: BTreeMap<String, SiteRunner>,
max_rounds_per_step: usize,
}
impl NestedProtocolMachineHandler {
#[must_use]
pub fn new() -> Self {
Self {
sites: BTreeMap::new(),
max_rounds_per_step: 1,
}
}
#[must_use]
pub fn with_rounds_per_step(mut self, rounds: usize) -> Self {
self.max_rounds_per_step = rounds.max(1);
self
}
#[must_use]
pub fn rounds_per_step(&self) -> usize {
self.max_rounds_per_step
}
pub fn add_site(
&mut self,
name: impl Into<String>,
machine: ProtocolMachine,
handler: Box<dyn EffectHandler>,
) {
self.sites.insert(
name.into(),
SiteRunner {
machine: Mutex::new(machine),
handler,
},
);
}
#[must_use]
pub fn site_trace(&self, name: &str) -> Option<Vec<ObsEvent>> {
self.sites.get(name).map(|site| {
site.machine
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.trace()
.to_vec()
})
}
#[must_use]
pub fn site_all_done(&self, name: &str) -> Option<bool> {
self.sites.get(name).map(|site| {
site.machine
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.all_done()
})
}
#[must_use]
pub fn site_semantic_objects(&self, name: &str) -> Option<ProtocolMachineSemanticObjects> {
self.sites.get(name).map(|site| {
site.machine
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.semantic_objects()
})
}
fn step_site(&self, name: &str) -> Result<(), String> {
let site = self
.sites
.get(name)
.ok_or_else(|| format!("unknown site: {name}"))?;
let mut machine = site
.machine
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let handler = site.handler.as_ref();
for _ in 0..self.max_rounds_per_step {
match machine.step_round(handler, 1) {
Ok(StepResult::Continue) => {}
Ok(StepResult::AllDone | StepResult::Stuck) => break,
Err(ProtocolMachineError::Fault { fault, .. }) => {
return Err(format!("inner machine fault: {fault}"));
}
Err(e) => return Err(e.to_string()),
}
}
Ok(())
}
}
impl Default for NestedProtocolMachineHandler {
fn default() -> Self {
Self::new()
}
}
impl EffectHandler for NestedProtocolMachineHandler {
fn handle_send(
&self,
role: &str,
_partner: &str,
_label: &str,
_state: &[Value],
) -> EffectResult<Value> {
match self.step_site(role) {
Ok(()) => EffectResult::success(Value::Unit),
Err(message) => EffectResult::failure(EffectFailure::contract_violation(message)),
}
}
fn handle_recv(
&self,
role: &str,
_partner: &str,
_label: &str,
_state: &mut Vec<Value>,
_payload: &Value,
) -> EffectResult<()> {
match self.step_site(role) {
Ok(()) => EffectResult::success(()),
Err(message) => EffectResult::failure(EffectFailure::contract_violation(message)),
}
}
fn handle_choose(
&self,
_role: &str,
_partner: &str,
labels: &[String],
_state: &[Value],
) -> EffectResult<String> {
match labels.first().cloned() {
Some(label) => EffectResult::success(label),
None => EffectResult::failure(EffectFailure::invalid_input("no labels available")),
}
}
fn step(&self, role: &str, _state: &mut Vec<Value>) -> EffectResult<()> {
match self.step_site(role) {
Ok(()) => EffectResult::success(()),
Err(message) => EffectResult::failure(EffectFailure::contract_violation(message)),
}
}
}