tla_connect/driver.rs
1//! Core abstractions for connecting Rust implementations to TLA+ specs.
2//!
3//! Mirrors quint-connect's `Driver`/`State`/`Step` pattern, adapted for
4//! TLA+ ITF traces produced by Apalache.
5//!
6//! # Example
7//!
8//! ```
9//! use tla_connect::{Driver, State, ExtractState, Step, DriverError, switch};
10//! use serde::Deserialize;
11//!
12//! #[derive(Debug, PartialEq, Deserialize)]
13//! struct CounterState {
14//! counter: i64,
15//! }
16//!
17//! struct CounterDriver {
18//! value: i64,
19//! }
20//!
21//! impl State for CounterState {}
22//!
23//! impl ExtractState<CounterDriver> for CounterState {
24//! fn from_driver(driver: &CounterDriver) -> Result<Self, DriverError> {
25//! Ok(CounterState { counter: driver.value })
26//! }
27//! }
28//!
29//! impl Driver for CounterDriver {
30//! type State = CounterState;
31//!
32//! fn step(&mut self, step: &Step) -> Result<(), DriverError> {
33//! switch!(step {
34//! "init" => { self.value = 0; Ok(()) },
35//! "increment" => { self.value += 1; Ok(()) },
36//! })
37//! }
38//! }
39//! ```
40
41use crate::error::DriverError;
42use serde::de::DeserializeOwned;
43use similar::{ChangeTag, TextDiff};
44use std::fmt::Debug;
45
46/// A single step from an Apalache-generated ITF trace.
47///
48/// Each ITF state record contains the TLA+ variables plus auxiliary MBT
49/// variables (`action_taken`, `nondet_picks`) that identify which action
50/// was taken and any nondeterministic choices.
51#[derive(Debug, Clone)]
52#[non_exhaustive]
53pub struct Step {
54 /// The TLA+ action that was taken (e.g., "request_success", "tick").
55 pub action_taken: String,
56
57 /// Nondeterministic picks made by this step (ITF Value for proper type handling).
58 pub nondet_picks: itf::Value,
59
60 /// Full TLA+ state after this step – an `itf::Value::Record` containing
61 /// all state variables. Used for state comparison via `State::from_spec`.
62 pub state: itf::Value,
63}
64
65/// Core trait for connecting Rust implementations to TLA+ specs.
66///
67/// Implementors hold the Rust type under test and map TLA+ actions
68/// to Rust method calls via `step()`.
69///
70/// # Parallel replay
71///
72/// When using [`replay_traces_parallel`](crate::replay_traces_parallel)
73/// (requires the `parallel` feature), your `Driver` must also implement
74/// `Send`. Each trace is replayed in its own thread, so the driver
75/// factory closure must be `Sync` and the resulting driver must be `Send`.
76pub trait Driver: Sized {
77 /// The state type used for comparing TLA+ spec state with Rust state.
78 type State: State + ExtractState<Self>;
79
80 /// Execute a single step from the TLA+ trace on the Rust implementation.
81 ///
82 /// Use the `switch!` macro to dispatch on `step.action_taken`.
83 fn step(&mut self, step: &Step) -> Result<(), DriverError>;
84}
85
86/// State comparison between TLA+ spec and Rust implementation.
87///
88/// Deserializes from ITF `Value` (spec side). Only include fields that should
89/// be compared – intentionally exclude fields where spec and implementation
90/// have valid semantic differences.
91pub trait State: PartialEq + DeserializeOwned + Debug {
92 /// Deserialize the spec state from an ITF Value.
93 ///
94 /// The default implementation uses serde deserialization via `itf::Value`,
95 /// which transparently handles ITF-specific encodings (`#bigint`, `#set`, etc.).
96 ///
97 /// Note: The default implementation clones the `itf::Value` because serde's
98 /// `DeserializeOwned` trait requires ownership. This clone happens once per
99 /// state per trace and may be significant for large state records. Override
100 /// this method if you need to avoid the clone (e.g., by deserializing
101 /// specific fields manually).
102 fn from_spec(value: &itf::Value) -> Result<Self, DriverError> {
103 Self::deserialize(value.clone()).map_err(|e| DriverError::StateExtraction(e.to_string()))
104 }
105
106 /// Generate a human-readable diff between two states.
107 ///
108 /// The default implementation uses Debug formatting with unified diff.
109 /// Override this for custom diff output (e.g., field-by-field comparison).
110 fn diff(&self, other: &Self) -> String {
111 use std::fmt::Write;
112 let self_str = format!("{self:#?}");
113 let other_str = format!("{other:#?}");
114
115 let mut output = String::new();
116 let self_lines: Vec<&str> = self_str.lines().collect();
117 let other_lines: Vec<&str> = other_str.lines().collect();
118
119 for (i, (a, b)) in self_lines.iter().zip(other_lines.iter()).enumerate() {
120 if a != b {
121 let _ = writeln!(output, " line {}: {} -> {}", i + 1, a.trim(), b.trim());
122 }
123 }
124
125 if self_lines.len() != other_lines.len() {
126 let _ = writeln!(
127 output,
128 " (line count differs: {} vs {})",
129 self_lines.len(),
130 other_lines.len()
131 );
132 }
133
134 if output.is_empty() {
135 output = "(states appear equal but PartialEq returned false)".to_string();
136 }
137
138 output
139 }
140}
141
142/// Extract the comparable state from the Rust driver.
143///
144/// Separated from [`State`] so that `State` does not require a generic
145/// parameter for the driver type, making it easier to use in contexts
146/// that only need deserialization and comparison.
147pub trait ExtractState<D>: State {
148 /// Extract the comparable state from the Rust driver.
149 fn from_driver(driver: &D) -> Result<Self, DriverError>;
150}
151
152/// Produce a unified diff between two strings.
153pub fn unified_diff(left: &str, right: &str) -> String {
154 let diff = TextDiff::from_lines(left, right);
155 let mut output = String::new();
156
157 for change in diff.iter_all_changes() {
158 let sign = match change.tag() {
159 ChangeTag::Delete => "-",
160 ChangeTag::Insert => "+",
161 ChangeTag::Equal => " ",
162 };
163 output.push_str(sign);
164 output.push_str(change.value());
165 if !change.value().ends_with('\n') {
166 output.push('\n');
167 }
168 }
169
170 output
171}
172
173/// Format a state mismatch between spec and driver states for error reporting.
174pub fn format_state_mismatch<S: State>(spec: &S, driver: &S) -> String {
175 let summary = spec.diff(driver);
176 let full = unified_diff(&format!("{spec:#?}"), &format!("{driver:#?}"));
177 format!("State differences:\n{summary}\n--- spec (TLA+)\n+++ driver (Rust)\n{full}")
178}
179
180/// Helper to create a unified diff between two Debug-formatted values.
181///
182/// Useful for implementing custom `State::diff` methods.
183pub fn debug_diff<T: Debug, U: Debug>(left: &T, right: &U) -> String {
184 let left_str = format!("{left:#?}");
185 let right_str = format!("{right:#?}");
186 unified_diff(&left_str, &right_str)
187}
188
189/// Dispatch a TLA+ action to the corresponding Rust code.
190///
191/// Generates a single flat `match` on `step.action_taken`, mapping each
192/// TLA+ action name to the corresponding Rust code block.
193///
194/// # Usage
195///
196/// The first argument must be a variable name (identifier) bound to a `&Step`.
197/// Each arm body must evaluate to `Result<(), DriverError>`.
198///
199/// ```ignore
200/// tla_connect::switch!(step {
201/// "init" => { /* initialization */ Ok(()) },
202/// "request_success" => { self.cb.record_success(); Ok(()) },
203/// "tick" => { let _ = self.cb.allows_request(); Ok(()) },
204/// })
205/// ```
206#[macro_export]
207macro_rules! switch {
208 ($step:ident { $( $action:literal => $body:expr ),+ $(,)? }) => {{
209 let __tla_step: &$crate::Step = $step;
210 match __tla_step.action_taken.as_str() {
211 $( $action => { $body }, )+
212 other => Err($crate::DriverError::UnknownAction(other.to_string())),
213 }
214 }};
215}