telltale-machine 17.0.0

Protocol machine for choreographic session type protocols
Documentation
//! WebAssembly bindings for the protocol machine guest runtime.

use std::collections::BTreeMap;

use serde::Deserialize;
use wasm_bindgen::prelude::*;

use telltale_types::{GlobalType, LocalTypeR};

use crate::coroutine::Value;
use crate::driver::WasmCooperativeDriver;
use crate::effect::{EffectFailure, EffectHandler, EffectResult};
use crate::engine::{ObsEvent, ProtocolMachineConfig, RunStatus, StepResult};
use crate::loader::CodeImage;
use crate::trace::{normalize_trace, strict_trace};

#[derive(Debug, Deserialize)]
struct WasmChoreoSpec {
    local_types: BTreeMap<String, LocalTypeR>,
    global_type: GlobalType,
}

struct NoOpHandler;

impl EffectHandler for NoOpHandler {
    fn handle_send(
        &self,
        _role: &str,
        _partner: &str,
        _label: &str,
        _state: &[Value],
    ) -> EffectResult<Value> {
        EffectResult::success(Value::Unit)
    }

    fn handle_recv(
        &self,
        _role: &str,
        _partner: &str,
        _label: &str,
        _state: &mut Vec<Value>,
        _payload: &Value,
    ) -> EffectResult<()> {
        EffectResult::success(())
    }

    fn handle_choose(
        &self,
        _role: &str,
        _partner: &str,
        labels: &[String],
        _state: &[Value],
    ) -> EffectResult<String> {
        match labels.first().cloned() {
            Some(label) => EffectResult::success(label),
            None => EffectResult::failure(EffectFailure::invalid_input("no labels available")),
        }
    }

    fn step(&self, _role: &str, _state: &mut Vec<Value>) -> EffectResult<()> {
        EffectResult::success(())
    }
}

/// Wasm wrapper for the protocol machine guest runtime.
#[wasm_bindgen]
pub struct WasmProtocolMachine {
    inner: WasmCooperativeDriver,
}

#[wasm_bindgen]
impl WasmProtocolMachine {
    /// Create a new guest runtime with default configuration.
    #[wasm_bindgen(constructor)]
    pub fn new() -> WasmProtocolMachine {
        WasmProtocolMachine {
            inner: WasmCooperativeDriver::new(ProtocolMachineConfig::default()),
        }
    }

    /// Load a choreography from JSON.
    ///
    /// The JSON format is `{ "local_types": { ... }, "global_type": { ... } }`.
    pub fn load_choreography_json(&mut self, json: &str) -> Result<usize, JsValue> {
        let spec: WasmChoreoSpec = serde_json::from_str(json)
            .map_err(|e| JsValue::from_str(&format!("invalid json: {e}")))?;
        let image = CodeImage::from_local_types(&spec.local_types, &spec.global_type);
        let owned = self
            .inner
            .load_choreography_owned(&image, "wasm/host")
            .map_err(|e| JsValue::from_str(&e.to_string()))?;
        Ok(owned.session_id())
    }

    /// Execute one scheduler round with concurrency `n`.
    pub fn step_round(&mut self, n: usize) -> Result<String, JsValue> {
        let handler = NoOpHandler;
        let result = self
            .inner
            .step_round(&handler, n)
            .map_err(|e| JsValue::from_str(&e.to_string()))?;
        let label = match result {
            StepResult::Continue => "continue",
            StepResult::Stuck => "stuck",
            StepResult::AllDone => "all_done",
        };
        Ok(label.to_string())
    }

    /// Run the guest runtime for at most `max_rounds` with concurrency `n`.
    pub fn run(&mut self, max_rounds: usize, n: usize) -> Result<String, JsValue> {
        let handler = NoOpHandler;
        let status = self
            .inner
            .run(&handler, max_rounds, n)
            .map_err(|e| JsValue::from_str(&e.to_string()))?;
        let label = match status {
            RunStatus::AllDone => "all_done",
            RunStatus::Stuck => "stuck",
            RunStatus::MaxRoundsExceeded => "max_rounds_exceeded",
        };
        Ok(label.to_string())
    }

    /// Get the raw observable trace as JSON.
    pub fn trace_json(&self) -> Result<String, JsValue> {
        let trace: Vec<ObsEvent> = strict_trace(self.inner.trace());
        serde_json::to_string(&trace).map_err(|e| JsValue::from_str(&e.to_string()))
    }

    /// Get the session-local normalized trace as JSON.
    pub fn trace_normalized_json(&self) -> Result<String, JsValue> {
        let trace = normalize_trace(self.inner.trace());
        serde_json::to_string(&trace).map_err(|e| JsValue::from_str(&e.to_string()))
    }

    /// Get canonical semantic objects as JSON.
    pub fn semantic_objects_json(&self) -> Result<String, JsValue> {
        serde_json::to_string(&self.inner.machine().semantic_objects())
            .map_err(|e| JsValue::from_str(&e.to_string()))
    }
}