Skip to main content

telltale_runtime/effects/
interpreter.rs

1// Program interpreter for free algebra choreographic effects
2//
3// This module provides the interpreter that executes effect programs
4// by translating them to concrete handler operations.
5
6use async_recursion::async_recursion;
7use async_trait::async_trait;
8use cfg_if::cfg_if;
9use serde::{de::DeserializeOwned, Serialize};
10
11use crate::effects::algebra::{Effect, InterpretResult, InterpreterState, Program, ProgramMessage};
12use crate::effects::registry::ExtensibleHandler;
13use crate::effects::{ChoreoHandler, ChoreoResult, ChoreographyError, RoleId};
14
15/// Interpret a choreographic program using a concrete handler
16pub async fn interpret<H, R, M>(
17    handler: &mut H,
18    endpoint: &mut H::Endpoint,
19    program: Program<R, M>,
20) -> ChoreoResult<InterpretResult<M>>
21where
22    H: ChoreoHandler<Role = R> + Send,
23    R: RoleId,
24    M: ProgramMessage + Serialize + DeserializeOwned + 'static,
25{
26    let mut interpreter: Interpreter<M, R> = Interpreter::new();
27    interpreter.run(handler, endpoint, &program).await
28}
29
30/// Interpret a choreographic program using an extensible handler
31///
32/// This version supports handlers that implement `ExtensibleHandler` and can
33/// dispatch extension effects to registered handlers.
34pub async fn interpret_extensible<H, R, M>(
35    handler: &mut H,
36    endpoint: &mut H::Endpoint,
37    program: Program<R, M>,
38) -> ChoreoResult<InterpretResult<M>>
39where
40    H: ExtensibleHandler<Role = R> + Send,
41    R: RoleId,
42    M: ProgramMessage + Serialize + DeserializeOwned + 'static,
43{
44    let mut interpreter: ExtensibleInterpreter<M, R> = ExtensibleInterpreter::new();
45    interpreter.run(handler, endpoint, &program).await
46}
47
48/// Internal interpreter state
49struct Interpreter<M, R: RoleId> {
50    received_values: Vec<M>,
51    /// Track the last received label from an Offer effect
52    last_label: Option<<R as RoleId>::Label>,
53}
54
55enum ControlFrame<'a, R: RoleId, M> {
56    Effects {
57        effects: &'a [Effect<R, M>],
58        index: usize,
59    },
60    SequentialPrograms {
61        programs: &'a [Program<R, M>],
62        index: usize,
63    },
64    Repeat {
65        body: &'a Program<R, M>,
66        remaining: usize,
67    },
68    ClearLastLabel,
69}
70
71impl<M, R: RoleId> Interpreter<M, R> {
72    fn new() -> Self {
73        Self {
74            received_values: Vec::new(),
75            last_label: None,
76        }
77    }
78
79    #[async_recursion]
80    async fn run<H>(
81        &mut self,
82        handler: &mut H,
83        endpoint: &mut H::Endpoint,
84        program: &Program<R, M>,
85    ) -> ChoreoResult<InterpretResult<M>>
86    where
87        H: ChoreoHandler<Role = R> + Send,
88        M: ProgramMessage + Serialize + DeserializeOwned + 'static,
89    {
90        let start_len = self.received_values.len();
91        let final_state = match self.execute_program(handler, endpoint, program).await {
92            Ok(()) => InterpreterState::Completed,
93            Err(ChoreographyError::Timeout(_)) => InterpreterState::Timeout,
94            Err(e) => InterpreterState::Failed(e.to_string()),
95        };
96
97        Ok(InterpretResult {
98            received_values: self.received_values_since(start_len),
99            final_state,
100        })
101    }
102
103    async fn execute_program<H>(
104        &mut self,
105        handler: &mut H,
106        endpoint: &mut H::Endpoint,
107        program: &Program<R, M>,
108    ) -> ChoreoResult<()>
109    where
110        H: ChoreoHandler<Role = R> + Send,
111        M: ProgramMessage + Serialize + DeserializeOwned + 'static,
112    {
113        let mut stack = vec![ControlFrame::Effects {
114            effects: program.effects(),
115            index: 0,
116        }];
117
118        while let Some(frame) = stack.pop() {
119            match frame {
120                ControlFrame::Effects { effects, index } => {
121                    if index >= effects.len() {
122                        continue;
123                    }
124
125                    stack.push(ControlFrame::Effects {
126                        effects,
127                        index: index + 1,
128                    });
129                    self.execute_base_effect(handler, endpoint, &effects[index], &mut stack)
130                        .await?;
131                }
132                ControlFrame::SequentialPrograms { programs, index } => {
133                    if let Some(program) = programs.get(index) {
134                        stack.push(ControlFrame::SequentialPrograms {
135                            programs,
136                            index: index + 1,
137                        });
138                        stack.push(ControlFrame::Effects {
139                            effects: program.effects(),
140                            index: 0,
141                        });
142                    }
143                }
144                ControlFrame::Repeat { body, remaining } => {
145                    if remaining > 0 {
146                        stack.push(ControlFrame::Repeat {
147                            body,
148                            remaining: remaining - 1,
149                        });
150                        stack.push(ControlFrame::Effects {
151                            effects: body.effects(),
152                            index: 0,
153                        });
154                    }
155                }
156                ControlFrame::ClearLastLabel => {
157                    self.last_label = None;
158                }
159            }
160        }
161
162        Ok(())
163    }
164
165    #[allow(clippy::too_many_lines)]
166    async fn execute_base_effect<'a, H>(
167        &mut self,
168        handler: &mut H,
169        endpoint: &mut H::Endpoint,
170        effect: &'a Effect<R, M>,
171        stack: &mut Vec<ControlFrame<'a, R, M>>,
172    ) -> ChoreoResult<()>
173    where
174        H: ChoreoHandler<Role = R> + Send,
175        M: ProgramMessage + Serialize + DeserializeOwned + 'static,
176    {
177        match effect {
178            Effect::Send { to, msg } => {
179                handler.send(endpoint, *to, msg).await?;
180            }
181            Effect::Recv { from, msg_tag } => {
182                tracing::debug!(
183                    ?from,
184                    msg_type = msg_tag.type_name(),
185                    "recv effect - type casting required"
186                );
187                let value = self
188                    .try_recv_as_type::<H, M>(handler, endpoint, *from)
189                    .await?;
190                self.received_values.push(value);
191            }
192            Effect::Choose { at, label } => {
193                handler.choose(endpoint, *at, *label).await?;
194                self.last_label = Some(*label);
195            }
196            Effect::Offer { from } => {
197                let label = handler.offer(endpoint, *from).await?;
198                tracing::debug!(?from, ?label, "Received offer label");
199                self.last_label = Some(label);
200            }
201            Effect::Branch {
202                choosing_role,
203                branches,
204            } => {
205                tracing::debug!(
206                    ?choosing_role,
207                    branch_count = branches.len(),
208                    "Executing branch effect"
209                );
210
211                let label = self.last_label.ok_or_else(|| {
212                    ChoreographyError::ProtocolViolation(
213                        "Branch effect requires a preceding Choose or Offer effect".to_string(),
214                    )
215                })?;
216
217                let (_, selected_branch) = branches
218                    .iter()
219                    .find(|(branch_label, _)| branch_label == &label)
220                    .ok_or_else(|| {
221                        ChoreographyError::ProtocolViolation(format!(
222                            "No branch found for label {label:?}"
223                        ))
224                    })?;
225
226                tracing::debug!(selected_label = ?label, "Executing selected branch");
227                stack.push(ControlFrame::ClearLastLabel);
228                stack.push(ControlFrame::Effects {
229                    effects: selected_branch.effects(),
230                    index: 0,
231                });
232            }
233            Effect::Loop { iterations, body } => {
234                tracing::debug!(?iterations, "Executing loop effect");
235                stack.push(ControlFrame::Repeat {
236                    body,
237                    remaining: iterations.unwrap_or(1),
238                });
239            }
240            Effect::Timeout {
241                at,
242                dur,
243                body,
244                on_timeout,
245            } => {
246                tracing::debug!(
247                    ?at,
248                    ?dur,
249                    has_fallback = on_timeout.is_some(),
250                    "Executing timeout effect"
251                );
252
253                cfg_if! {
254                    if #[cfg(target_arch = "wasm32")] {
255                        let timeout_result = {
256                            use futures::future::{select, Either};
257                            use futures::pin_mut;
258                            use wasm_timer::Delay;
259
260                            let body_future = self.run(handler, endpoint, body);
261                            let timeout = Delay::new(*dur);
262                            pin_mut!(body_future);
263                            pin_mut!(timeout);
264
265                            match select(body_future, timeout).await {
266                                Either::Left((result, _)) => Ok(result),
267                                Either::Right(_) => Err(()),
268                            }
269                        };
270                    } else {
271                        let timeout_result =
272                            tokio::time::timeout(*dur, self.run(handler, endpoint, body)).await;
273                    }
274                }
275
276                match timeout_result {
277                    Ok(Ok(result)) => self.propagate_nested_result(result, *dur)?,
278                    Ok(Err(err)) => return Err(err),
279                    Err(_) => {
280                        if let Some(timeout_body) = on_timeout {
281                            tracing::debug!("Timeout fired, executing fallback program");
282                            let result = self.run(handler, endpoint, timeout_body).await?;
283                            self.propagate_nested_result(result, *dur)?;
284                        } else {
285                            return Err(ChoreographyError::Timeout(*dur));
286                        }
287                    }
288                }
289            }
290            Effect::Parallel { programs } => {
291                tracing::debug!(program_count = programs.len(), "Executing parallel effect");
292                stack.push(ControlFrame::SequentialPrograms { programs, index: 0 });
293            }
294            Effect::Extension(ext) => {
295                tracing::debug!(
296                    type_name = ext.type_name(),
297                    type_id = ?ext.type_id(),
298                    "Executing extension effect"
299                );
300                tracing::warn!(
301                    "Extension effect encountered but handler does not support extensions. \
302                     Use interpret_extensible() for handlers with extension support."
303                );
304            }
305            Effect::End => {}
306        }
307
308        Ok(())
309    }
310
311    fn received_values_since(&self, start_len: usize) -> Vec<M>
312    where
313        M: ProgramMessage,
314    {
315        self.received_values[start_len..].to_vec()
316    }
317
318    async fn try_recv_as_type<H, T>(
319        &mut self,
320        handler: &mut H,
321        endpoint: &mut H::Endpoint,
322        from: R,
323    ) -> ChoreoResult<T>
324    where
325        H: ChoreoHandler<Role = R>,
326        T: DeserializeOwned + Send,
327    {
328        handler.recv(endpoint, from).await
329    }
330
331    fn propagate_nested_result(
332        &self,
333        result: InterpretResult<M>,
334        timeout: std::time::Duration,
335    ) -> ChoreoResult<()>
336    where
337        M: ProgramMessage,
338    {
339        match result.final_state {
340            InterpreterState::Completed => Ok(()),
341            InterpreterState::Failed(msg) => Err(ChoreographyError::Transport(msg)),
342            InterpreterState::Timeout => Err(ChoreographyError::Timeout(timeout)),
343        }
344    }
345}
346
347/// Extensible interpreter that supports extension effects
348struct ExtensibleInterpreter<M, R: RoleId> {
349    base: Interpreter<M, R>,
350}
351
352impl<M, R: RoleId> ExtensibleInterpreter<M, R> {
353    fn new() -> Self {
354        Self {
355            base: Interpreter::new(),
356        }
357    }
358
359    #[async_recursion]
360    async fn run<H>(
361        &mut self,
362        handler: &mut H,
363        endpoint: &mut H::Endpoint,
364        program: &Program<R, M>,
365    ) -> ChoreoResult<InterpretResult<M>>
366    where
367        H: ExtensibleHandler<Role = R> + Send,
368        M: ProgramMessage + Serialize + DeserializeOwned + 'static,
369    {
370        let start_len = self.base.received_values.len();
371        let final_state = match self.execute_program(handler, endpoint, program).await {
372            Ok(()) => InterpreterState::Completed,
373            Err(ChoreographyError::Timeout(_)) => InterpreterState::Timeout,
374            Err(e) => InterpreterState::Failed(e.to_string()),
375        };
376
377        Ok(InterpretResult {
378            received_values: self.base.received_values_since(start_len),
379            final_state,
380        })
381    }
382
383    async fn execute_program<H>(
384        &mut self,
385        handler: &mut H,
386        endpoint: &mut H::Endpoint,
387        program: &Program<R, M>,
388    ) -> ChoreoResult<()>
389    where
390        H: ExtensibleHandler<Role = R> + Send,
391        M: ProgramMessage + Serialize + DeserializeOwned + 'static,
392    {
393        let mut stack = vec![ControlFrame::Effects {
394            effects: program.effects(),
395            index: 0,
396        }];
397
398        while let Some(frame) = stack.pop() {
399            match frame {
400                ControlFrame::Effects { effects, index } => {
401                    if index >= effects.len() {
402                        continue;
403                    }
404
405                    stack.push(ControlFrame::Effects {
406                        effects,
407                        index: index + 1,
408                    });
409                    match &effects[index] {
410                        Effect::Extension(ext) => {
411                            tracing::debug!(
412                                type_name = ext.type_name(),
413                                type_id = ?ext.type_id(),
414                                "Dispatching extension effect to handler"
415                            );
416                            handler
417                                .extension_registry()
418                                .handle(endpoint, ext.as_ref())
419                                .await
420                                .map_err(|e| ChoreographyError::Transport(e.to_string()))?;
421                        }
422                        effect => {
423                            self.base
424                                .execute_base_effect(handler, endpoint, effect, &mut stack)
425                                .await?;
426                        }
427                    }
428                }
429                ControlFrame::SequentialPrograms { programs, index } => {
430                    if let Some(program) = programs.get(index) {
431                        stack.push(ControlFrame::SequentialPrograms {
432                            programs,
433                            index: index + 1,
434                        });
435                        stack.push(ControlFrame::Effects {
436                            effects: program.effects(),
437                            index: 0,
438                        });
439                    }
440                }
441                ControlFrame::Repeat { body, remaining } => {
442                    if remaining > 0 {
443                        stack.push(ControlFrame::Repeat {
444                            body,
445                            remaining: remaining - 1,
446                        });
447                        stack.push(ControlFrame::Effects {
448                            effects: body.effects(),
449                            index: 0,
450                        });
451                    }
452                }
453                ControlFrame::ClearLastLabel => {
454                    self.base.last_label = None;
455                }
456            }
457        }
458
459        Ok(())
460    }
461}
462
463/// Extension trait to add program interpretation to handlers
464#[async_trait]
465pub trait ChoreoHandlerExt: ChoreoHandler + Sized {
466    /// Run a choreographic program using this handler
467    async fn run_program<M>(
468        &mut self,
469        endpoint: &mut Self::Endpoint,
470        program: Program<Self::Role, M>,
471    ) -> ChoreoResult<InterpretResult<M>>
472    where
473        M: ProgramMessage + Serialize + DeserializeOwned + 'static,
474        Self: Send,
475    {
476        interpret(self, endpoint, program).await
477    }
478}
479
480// Blanket implementation for all ChoreoHandlers
481impl<T: ChoreoHandler> ChoreoHandlerExt for T {}
482/// Utilities for testing and simulation
483#[path = "interpreter_testing.rs"]
484pub mod testing;
485
486#[cfg(all(test, not(target_arch = "wasm32")))]
487mod tests {
488    include!("../../tests/unit/effects/interpreter_tests.rs");
489}