Skip to main content

telltale_machine/
nested.rs

1//! Nested ProtocolMachine handler for distributed simulation.
2//!
3//! The outer ProtocolMachine schedules site coroutines; each site handler advances an
4//! inner ProtocolMachine that runs site-local protocols.
5
6use std::collections::BTreeMap;
7use std::sync::Mutex;
8
9use crate::coroutine::Value;
10use crate::effect::{EffectFailure, EffectHandler, EffectResult};
11use crate::engine::{ObsEvent, ProtocolMachine, ProtocolMachineError, StepResult};
12use crate::semantic_objects::ProtocolMachineSemanticObjects;
13
14struct SiteRunner {
15    machine: Mutex<ProtocolMachine>,
16    handler: Box<dyn EffectHandler>,
17}
18
19/// Effect handler that dispatches to inner ProtocolMachines keyed by outer role name.
20pub struct NestedProtocolMachineHandler {
21    sites: BTreeMap<String, SiteRunner>,
22    max_rounds_per_step: usize,
23}
24
25impl NestedProtocolMachineHandler {
26    /// Create an empty nested handler.
27    #[must_use]
28    pub fn new() -> Self {
29        Self {
30            sites: BTreeMap::new(),
31            max_rounds_per_step: 1,
32        }
33    }
34
35    /// Set how many inner ProtocolMachine rounds to advance per outer handler call.
36    #[must_use]
37    pub fn with_rounds_per_step(mut self, rounds: usize) -> Self {
38        self.max_rounds_per_step = rounds.max(1);
39        self
40    }
41
42    /// Number of inner ProtocolMachine rounds attempted per outer handler call.
43    #[must_use]
44    pub fn rounds_per_step(&self) -> usize {
45        self.max_rounds_per_step
46    }
47
48    /// Register a site by name with its inner ProtocolMachine and handler.
49    pub fn add_site(
50        &mut self,
51        name: impl Into<String>,
52        machine: ProtocolMachine,
53        handler: Box<dyn EffectHandler>,
54    ) {
55        self.sites.insert(
56            name.into(),
57            SiteRunner {
58                machine: Mutex::new(machine),
59                handler,
60            },
61        );
62    }
63
64    /// Get a copy of the inner ProtocolMachine trace for a site.
65    ///
66    /// # Panics
67    ///
68    /// Panics if the site ProtocolMachine mutex is poisoned.
69    #[must_use]
70    pub fn site_trace(&self, name: &str) -> Option<Vec<ObsEvent>> {
71        self.sites.get(name).map(|site| {
72            site.machine
73                .lock()
74                .unwrap_or_else(|poisoned| poisoned.into_inner())
75                .trace()
76                .to_vec()
77        })
78    }
79
80    /// Check whether all coroutines in a site ProtocolMachine are terminal.
81    ///
82    /// # Panics
83    ///
84    /// Panics if the site ProtocolMachine mutex is poisoned.
85    #[must_use]
86    pub fn site_all_done(&self, name: &str) -> Option<bool> {
87        self.sites.get(name).map(|site| {
88            site.machine
89                .lock()
90                .unwrap_or_else(|poisoned| poisoned.into_inner())
91                .all_done()
92        })
93    }
94
95    /// Get a copy of the canonical semantic-object bundle for a site.
96    ///
97    /// # Panics
98    ///
99    /// Panics if the site ProtocolMachine mutex is poisoned.
100    #[must_use]
101    pub fn site_semantic_objects(&self, name: &str) -> Option<ProtocolMachineSemanticObjects> {
102        self.sites.get(name).map(|site| {
103            site.machine
104                .lock()
105                .unwrap_or_else(|poisoned| poisoned.into_inner())
106                .semantic_objects()
107        })
108    }
109
110    fn step_site(&self, name: &str) -> Result<(), String> {
111        let site = self
112            .sites
113            .get(name)
114            .ok_or_else(|| format!("unknown site: {name}"))?;
115
116        let mut machine = site
117            .machine
118            .lock()
119            .unwrap_or_else(|poisoned| poisoned.into_inner());
120        let handler = site.handler.as_ref();
121
122        for _ in 0..self.max_rounds_per_step {
123            match machine.step_round(handler, 1) {
124                Ok(StepResult::Continue) => {}
125                Ok(StepResult::AllDone | StepResult::Stuck) => break,
126                Err(ProtocolMachineError::Fault { fault, .. }) => {
127                    return Err(format!("inner machine fault: {fault}"));
128                }
129                Err(e) => return Err(e.to_string()),
130            }
131        }
132
133        Ok(())
134    }
135}
136
137impl Default for NestedProtocolMachineHandler {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143impl EffectHandler for NestedProtocolMachineHandler {
144    fn handle_send(
145        &self,
146        role: &str,
147        _partner: &str,
148        _label: &str,
149        _state: &[Value],
150    ) -> EffectResult<Value> {
151        match self.step_site(role) {
152            Ok(()) => EffectResult::success(Value::Unit),
153            Err(message) => EffectResult::failure(EffectFailure::contract_violation(message)),
154        }
155    }
156
157    fn handle_recv(
158        &self,
159        role: &str,
160        _partner: &str,
161        _label: &str,
162        _state: &mut Vec<Value>,
163        _payload: &Value,
164    ) -> EffectResult<()> {
165        match self.step_site(role) {
166            Ok(()) => EffectResult::success(()),
167            Err(message) => EffectResult::failure(EffectFailure::contract_violation(message)),
168        }
169    }
170
171    fn handle_choose(
172        &self,
173        _role: &str,
174        _partner: &str,
175        labels: &[String],
176        _state: &[Value],
177    ) -> EffectResult<String> {
178        match labels.first().cloned() {
179            Some(label) => EffectResult::success(label),
180            None => EffectResult::failure(EffectFailure::invalid_input("no labels available")),
181        }
182    }
183
184    fn step(&self, role: &str, _state: &mut Vec<Value>) -> EffectResult<()> {
185        match self.step_site(role) {
186            Ok(()) => EffectResult::success(()),
187            Err(message) => EffectResult::failure(EffectFailure::contract_violation(message)),
188        }
189    }
190}