1use async_recursion::async_recursion;
7use async_trait::async_trait;
8use cfg_if::cfg_if;
9use serde::{de::DeserializeOwned, Serialize};
10
11use crate::effects::algebra::{Effect, InterpretResult, InterpreterState, Program, ProgramMessage};
12use crate::effects::registry::ExtensibleHandler;
13use crate::effects::{ChoreoHandler, ChoreoResult, ChoreographyError, RoleId};
14
15pub async fn interpret<H, R, M>(
17 handler: &mut H,
18 endpoint: &mut H::Endpoint,
19 program: Program<R, M>,
20) -> ChoreoResult<InterpretResult<M>>
21where
22 H: ChoreoHandler<Role = R> + Send,
23 R: RoleId,
24 M: ProgramMessage + Serialize + DeserializeOwned + 'static,
25{
26 let mut interpreter: Interpreter<M, R> = Interpreter::new();
27 interpreter.run(handler, endpoint, &program).await
28}
29
30pub async fn interpret_extensible<H, R, M>(
35 handler: &mut H,
36 endpoint: &mut H::Endpoint,
37 program: Program<R, M>,
38) -> ChoreoResult<InterpretResult<M>>
39where
40 H: ExtensibleHandler<Role = R> + Send,
41 R: RoleId,
42 M: ProgramMessage + Serialize + DeserializeOwned + 'static,
43{
44 let mut interpreter: ExtensibleInterpreter<M, R> = ExtensibleInterpreter::new();
45 interpreter.run(handler, endpoint, &program).await
46}
47
48struct Interpreter<M, R: RoleId> {
50 received_values: Vec<M>,
51 last_label: Option<<R as RoleId>::Label>,
53}
54
55enum ControlFrame<'a, R: RoleId, M> {
56 Effects {
57 effects: &'a [Effect<R, M>],
58 index: usize,
59 },
60 SequentialPrograms {
61 programs: &'a [Program<R, M>],
62 index: usize,
63 },
64 Repeat {
65 body: &'a Program<R, M>,
66 remaining: usize,
67 },
68 ClearLastLabel,
69}
70
71impl<M, R: RoleId> Interpreter<M, R> {
72 fn new() -> Self {
73 Self {
74 received_values: Vec::new(),
75 last_label: None,
76 }
77 }
78
79 #[async_recursion]
80 async fn run<H>(
81 &mut self,
82 handler: &mut H,
83 endpoint: &mut H::Endpoint,
84 program: &Program<R, M>,
85 ) -> ChoreoResult<InterpretResult<M>>
86 where
87 H: ChoreoHandler<Role = R> + Send,
88 M: ProgramMessage + Serialize + DeserializeOwned + 'static,
89 {
90 let start_len = self.received_values.len();
91 let final_state = match self.execute_program(handler, endpoint, program).await {
92 Ok(()) => InterpreterState::Completed,
93 Err(ChoreographyError::Timeout(_)) => InterpreterState::Timeout,
94 Err(e) => InterpreterState::Failed(e.to_string()),
95 };
96
97 Ok(InterpretResult {
98 received_values: self.received_values_since(start_len),
99 final_state,
100 })
101 }
102
103 async fn execute_program<H>(
104 &mut self,
105 handler: &mut H,
106 endpoint: &mut H::Endpoint,
107 program: &Program<R, M>,
108 ) -> ChoreoResult<()>
109 where
110 H: ChoreoHandler<Role = R> + Send,
111 M: ProgramMessage + Serialize + DeserializeOwned + 'static,
112 {
113 let mut stack = vec![ControlFrame::Effects {
114 effects: program.effects(),
115 index: 0,
116 }];
117
118 while let Some(frame) = stack.pop() {
119 match frame {
120 ControlFrame::Effects { effects, index } => {
121 if index >= effects.len() {
122 continue;
123 }
124
125 stack.push(ControlFrame::Effects {
126 effects,
127 index: index + 1,
128 });
129 self.execute_base_effect(handler, endpoint, &effects[index], &mut stack)
130 .await?;
131 }
132 ControlFrame::SequentialPrograms { programs, index } => {
133 if let Some(program) = programs.get(index) {
134 stack.push(ControlFrame::SequentialPrograms {
135 programs,
136 index: index + 1,
137 });
138 stack.push(ControlFrame::Effects {
139 effects: program.effects(),
140 index: 0,
141 });
142 }
143 }
144 ControlFrame::Repeat { body, remaining } => {
145 if remaining > 0 {
146 stack.push(ControlFrame::Repeat {
147 body,
148 remaining: remaining - 1,
149 });
150 stack.push(ControlFrame::Effects {
151 effects: body.effects(),
152 index: 0,
153 });
154 }
155 }
156 ControlFrame::ClearLastLabel => {
157 self.last_label = None;
158 }
159 }
160 }
161
162 Ok(())
163 }
164
165 #[allow(clippy::too_many_lines)]
166 async fn execute_base_effect<'a, H>(
167 &mut self,
168 handler: &mut H,
169 endpoint: &mut H::Endpoint,
170 effect: &'a Effect<R, M>,
171 stack: &mut Vec<ControlFrame<'a, R, M>>,
172 ) -> ChoreoResult<()>
173 where
174 H: ChoreoHandler<Role = R> + Send,
175 M: ProgramMessage + Serialize + DeserializeOwned + 'static,
176 {
177 match effect {
178 Effect::Send { to, msg } => {
179 handler.send(endpoint, *to, msg).await?;
180 }
181 Effect::Recv { from, msg_tag } => {
182 tracing::debug!(
183 ?from,
184 msg_type = msg_tag.type_name(),
185 "recv effect - type casting required"
186 );
187 let value = self
188 .try_recv_as_type::<H, M>(handler, endpoint, *from)
189 .await?;
190 self.received_values.push(value);
191 }
192 Effect::Choose { at, label } => {
193 handler.choose(endpoint, *at, *label).await?;
194 self.last_label = Some(*label);
195 }
196 Effect::Offer { from } => {
197 let label = handler.offer(endpoint, *from).await?;
198 tracing::debug!(?from, ?label, "Received offer label");
199 self.last_label = Some(label);
200 }
201 Effect::Branch {
202 choosing_role,
203 branches,
204 } => {
205 tracing::debug!(
206 ?choosing_role,
207 branch_count = branches.len(),
208 "Executing branch effect"
209 );
210
211 let label = self.last_label.ok_or_else(|| {
212 ChoreographyError::ProtocolViolation(
213 "Branch effect requires a preceding Choose or Offer effect".to_string(),
214 )
215 })?;
216
217 let (_, selected_branch) = branches
218 .iter()
219 .find(|(branch_label, _)| branch_label == &label)
220 .ok_or_else(|| {
221 ChoreographyError::ProtocolViolation(format!(
222 "No branch found for label {label:?}"
223 ))
224 })?;
225
226 tracing::debug!(selected_label = ?label, "Executing selected branch");
227 stack.push(ControlFrame::ClearLastLabel);
228 stack.push(ControlFrame::Effects {
229 effects: selected_branch.effects(),
230 index: 0,
231 });
232 }
233 Effect::Loop { iterations, body } => {
234 tracing::debug!(?iterations, "Executing loop effect");
235 stack.push(ControlFrame::Repeat {
236 body,
237 remaining: iterations.unwrap_or(1),
238 });
239 }
240 Effect::Timeout {
241 at,
242 dur,
243 body,
244 on_timeout,
245 } => {
246 tracing::debug!(
247 ?at,
248 ?dur,
249 has_fallback = on_timeout.is_some(),
250 "Executing timeout effect"
251 );
252
253 cfg_if! {
254 if #[cfg(target_arch = "wasm32")] {
255 let timeout_result = {
256 use futures::future::{select, Either};
257 use futures::pin_mut;
258 use wasm_timer::Delay;
259
260 let body_future = self.run(handler, endpoint, body);
261 let timeout = Delay::new(*dur);
262 pin_mut!(body_future);
263 pin_mut!(timeout);
264
265 match select(body_future, timeout).await {
266 Either::Left((result, _)) => Ok(result),
267 Either::Right(_) => Err(()),
268 }
269 };
270 } else {
271 let timeout_result =
272 tokio::time::timeout(*dur, self.run(handler, endpoint, body)).await;
273 }
274 }
275
276 match timeout_result {
277 Ok(Ok(result)) => self.propagate_nested_result(result, *dur)?,
278 Ok(Err(err)) => return Err(err),
279 Err(_) => {
280 if let Some(timeout_body) = on_timeout {
281 tracing::debug!("Timeout fired, executing fallback program");
282 let result = self.run(handler, endpoint, timeout_body).await?;
283 self.propagate_nested_result(result, *dur)?;
284 } else {
285 return Err(ChoreographyError::Timeout(*dur));
286 }
287 }
288 }
289 }
290 Effect::Parallel { programs } => {
291 tracing::debug!(program_count = programs.len(), "Executing parallel effect");
292 stack.push(ControlFrame::SequentialPrograms { programs, index: 0 });
293 }
294 Effect::Extension(ext) => {
295 tracing::debug!(
296 type_name = ext.type_name(),
297 type_id = ?ext.type_id(),
298 "Executing extension effect"
299 );
300 tracing::warn!(
301 "Extension effect encountered but handler does not support extensions. \
302 Use interpret_extensible() for handlers with extension support."
303 );
304 }
305 Effect::End => {}
306 }
307
308 Ok(())
309 }
310
311 fn received_values_since(&self, start_len: usize) -> Vec<M>
312 where
313 M: ProgramMessage,
314 {
315 self.received_values[start_len..].to_vec()
316 }
317
318 async fn try_recv_as_type<H, T>(
319 &mut self,
320 handler: &mut H,
321 endpoint: &mut H::Endpoint,
322 from: R,
323 ) -> ChoreoResult<T>
324 where
325 H: ChoreoHandler<Role = R>,
326 T: DeserializeOwned + Send,
327 {
328 handler.recv(endpoint, from).await
329 }
330
331 fn propagate_nested_result(
332 &self,
333 result: InterpretResult<M>,
334 timeout: std::time::Duration,
335 ) -> ChoreoResult<()>
336 where
337 M: ProgramMessage,
338 {
339 match result.final_state {
340 InterpreterState::Completed => Ok(()),
341 InterpreterState::Failed(msg) => Err(ChoreographyError::Transport(msg)),
342 InterpreterState::Timeout => Err(ChoreographyError::Timeout(timeout)),
343 }
344 }
345}
346
347struct ExtensibleInterpreter<M, R: RoleId> {
349 base: Interpreter<M, R>,
350}
351
352impl<M, R: RoleId> ExtensibleInterpreter<M, R> {
353 fn new() -> Self {
354 Self {
355 base: Interpreter::new(),
356 }
357 }
358
359 #[async_recursion]
360 async fn run<H>(
361 &mut self,
362 handler: &mut H,
363 endpoint: &mut H::Endpoint,
364 program: &Program<R, M>,
365 ) -> ChoreoResult<InterpretResult<M>>
366 where
367 H: ExtensibleHandler<Role = R> + Send,
368 M: ProgramMessage + Serialize + DeserializeOwned + 'static,
369 {
370 let start_len = self.base.received_values.len();
371 let final_state = match self.execute_program(handler, endpoint, program).await {
372 Ok(()) => InterpreterState::Completed,
373 Err(ChoreographyError::Timeout(_)) => InterpreterState::Timeout,
374 Err(e) => InterpreterState::Failed(e.to_string()),
375 };
376
377 Ok(InterpretResult {
378 received_values: self.base.received_values_since(start_len),
379 final_state,
380 })
381 }
382
383 async fn execute_program<H>(
384 &mut self,
385 handler: &mut H,
386 endpoint: &mut H::Endpoint,
387 program: &Program<R, M>,
388 ) -> ChoreoResult<()>
389 where
390 H: ExtensibleHandler<Role = R> + Send,
391 M: ProgramMessage + Serialize + DeserializeOwned + 'static,
392 {
393 let mut stack = vec![ControlFrame::Effects {
394 effects: program.effects(),
395 index: 0,
396 }];
397
398 while let Some(frame) = stack.pop() {
399 match frame {
400 ControlFrame::Effects { effects, index } => {
401 if index >= effects.len() {
402 continue;
403 }
404
405 stack.push(ControlFrame::Effects {
406 effects,
407 index: index + 1,
408 });
409 match &effects[index] {
410 Effect::Extension(ext) => {
411 tracing::debug!(
412 type_name = ext.type_name(),
413 type_id = ?ext.type_id(),
414 "Dispatching extension effect to handler"
415 );
416 handler
417 .extension_registry()
418 .handle(endpoint, ext.as_ref())
419 .await
420 .map_err(|e| ChoreographyError::Transport(e.to_string()))?;
421 }
422 effect => {
423 self.base
424 .execute_base_effect(handler, endpoint, effect, &mut stack)
425 .await?;
426 }
427 }
428 }
429 ControlFrame::SequentialPrograms { programs, index } => {
430 if let Some(program) = programs.get(index) {
431 stack.push(ControlFrame::SequentialPrograms {
432 programs,
433 index: index + 1,
434 });
435 stack.push(ControlFrame::Effects {
436 effects: program.effects(),
437 index: 0,
438 });
439 }
440 }
441 ControlFrame::Repeat { body, remaining } => {
442 if remaining > 0 {
443 stack.push(ControlFrame::Repeat {
444 body,
445 remaining: remaining - 1,
446 });
447 stack.push(ControlFrame::Effects {
448 effects: body.effects(),
449 index: 0,
450 });
451 }
452 }
453 ControlFrame::ClearLastLabel => {
454 self.base.last_label = None;
455 }
456 }
457 }
458
459 Ok(())
460 }
461}
462
463#[async_trait]
465pub trait ChoreoHandlerExt: ChoreoHandler + Sized {
466 async fn run_program<M>(
468 &mut self,
469 endpoint: &mut Self::Endpoint,
470 program: Program<Self::Role, M>,
471 ) -> ChoreoResult<InterpretResult<M>>
472 where
473 M: ProgramMessage + Serialize + DeserializeOwned + 'static,
474 Self: Send,
475 {
476 interpret(self, endpoint, program).await
477 }
478}
479
480impl<T: ChoreoHandler> ChoreoHandlerExt for T {}
482#[path = "interpreter_testing.rs"]
484pub mod testing;
485
486#[cfg(all(test, not(target_arch = "wasm32")))]
487mod tests {
488 include!("../../tests/unit/effects/interpreter_tests.rs");
489}