Skip to main content

intent_runtime/
test_runner.rs

1//! Test runner for IntentLang spec-level tests.
2//!
3//! Converts `test` blocks from the AST into runtime `ActionRequest`s,
4//! executes them, and checks assertions against the results.
5
6use std::collections::HashMap;
7
8use intent_ir::Module;
9use intent_parser::ast::{self, ExprKind, GivenValue, Literal, ThenClause};
10use serde_json::Value;
11
12use crate::contract::{ActionRequest, ViolationKind, execute_action};
13use crate::error::RuntimeError;
14use crate::eval::evaluate;
15use crate::value::EvalContext;
16
17/// Result of running a single test.
18#[derive(Debug)]
19pub struct TestResult {
20    pub name: String,
21    pub passed: bool,
22    pub message: Option<String>,
23}
24
25/// Run all test declarations against a compiled IR module.
26pub fn run_tests(module: &Module, tests: &[&ast::TestDecl]) -> Vec<TestResult> {
27    tests.iter().map(|t| run_single_test(module, t)).collect()
28}
29
30fn run_single_test(module: &Module, test: &ast::TestDecl) -> TestResult {
31    match run_single_test_inner(module, test) {
32        Ok(result) => result,
33        Err(e) => TestResult {
34            name: test.name.clone(),
35            passed: false,
36            message: Some(format!("runtime error: {e}")),
37        },
38    }
39}
40
41fn run_single_test_inner(
42    module: &Module,
43    test: &ast::TestDecl,
44) -> Result<TestResult, RuntimeError> {
45    // 1. Evaluate given bindings to concrete values.
46    let mut bindings: HashMap<String, Value> = HashMap::new();
47    let mut state: HashMap<String, Vec<Value>> = HashMap::new();
48
49    for binding in &test.given {
50        match &binding.value {
51            GivenValue::EntityConstructor { type_name, fields } => {
52                let obj = fields_to_json(fields, &bindings)?;
53                bindings.insert(binding.name.clone(), obj.clone());
54                state.entry(type_name.clone()).or_default().push(obj);
55            }
56            GivenValue::Expr(expr) => {
57                let val = ast_expr_to_value(expr, &bindings)?;
58                bindings.insert(binding.name.clone(), val);
59            }
60        }
61    }
62
63    // 2. Build ActionRequest from when block.
64    let mut params: HashMap<String, Value> = HashMap::new();
65    for arg in &test.when_action.args {
66        let val = ast_expr_to_value(&arg.value, &bindings)?;
67        params.insert(arg.name.clone(), val);
68    }
69
70    let request = ActionRequest {
71        action: test.when_action.action_name.clone(),
72        params,
73        state,
74    };
75
76    // 3. Execute.
77    let result = execute_action(module, &request)?;
78
79    // 4. Check then clause.
80    match &test.then {
81        ThenClause::Fails(kind_filter, _) => {
82            if result.ok {
83                return Ok(TestResult {
84                    name: test.name.clone(),
85                    passed: false,
86                    message: Some("expected action to fail, but it succeeded".into()),
87                });
88            }
89            // If a specific kind is requested, check for it.
90            if let Some(kind_str) = kind_filter {
91                let expected_kind = match kind_str.as_str() {
92                    "precondition" => Some(ViolationKind::PreconditionFailed),
93                    "postcondition" => Some(ViolationKind::PostconditionFailed),
94                    "invariant" => Some(ViolationKind::InvariantViolated),
95                    "edge_guard" => Some(ViolationKind::EdgeGuardTriggered),
96                    _ => None,
97                };
98                if let Some(kind) = expected_kind
99                    && !result.violations.iter().any(|v| v.kind == kind)
100                {
101                    return Ok(TestResult {
102                        name: test.name.clone(),
103                        passed: false,
104                        message: Some(format!(
105                            "expected {kind_str} violation, got: {}",
106                            result
107                                .violations
108                                .iter()
109                                .map(|v| format!("{:?}", v.kind))
110                                .collect::<Vec<_>>()
111                                .join(", ")
112                        )),
113                    });
114                }
115            }
116            Ok(TestResult {
117                name: test.name.clone(),
118                passed: true,
119                message: None,
120            })
121        }
122        ThenClause::Asserts(exprs, _) => {
123            if !result.ok {
124                let msgs: Vec<_> = result
125                    .violations
126                    .iter()
127                    .map(|v| v.message.clone())
128                    .collect();
129                return Ok(TestResult {
130                    name: test.name.clone(),
131                    passed: false,
132                    message: Some(format!("action failed: {}", msgs.join("; "))),
133                });
134            }
135
136            // Merge new_params over given bindings for assertion context.
137            let mut assert_bindings = bindings;
138            for (k, v) in &result.new_params {
139                assert_bindings.insert(k.clone(), v.clone());
140            }
141
142            // Lower assertion expressions to IR and evaluate.
143            for expr in exprs {
144                let ir_expr = intent_ir::lower_expr(expr);
145                let ctx = EvalContext {
146                    bindings: assert_bindings.clone(),
147                    old_bindings: None,
148                    instances: request.state.clone(),
149                };
150                match evaluate(&ir_expr, &ctx) {
151                    Ok(Value::Bool(true)) => {}
152                    Ok(Value::Bool(false)) => {
153                        return Ok(TestResult {
154                            name: test.name.clone(),
155                            passed: false,
156                            message: Some(format!("assertion failed: {}", fmt_ast_expr(expr))),
157                        });
158                    }
159                    Ok(other) => {
160                        return Ok(TestResult {
161                            name: test.name.clone(),
162                            passed: false,
163                            message: Some(format!(
164                                "assertion did not evaluate to bool: {} (got {other:?})",
165                                fmt_ast_expr(expr)
166                            )),
167                        });
168                    }
169                    Err(e) => {
170                        return Ok(TestResult {
171                            name: test.name.clone(),
172                            passed: false,
173                            message: Some(format!(
174                                "assertion error: {} ({})",
175                                fmt_ast_expr(expr),
176                                e
177                            )),
178                        });
179                    }
180                }
181            }
182
183            Ok(TestResult {
184                name: test.name.clone(),
185                passed: true,
186                message: None,
187            })
188        }
189    }
190}
191
192/// Convert entity constructor fields to a JSON object.
193fn fields_to_json(
194    fields: &[ast::ConstructorField],
195    bindings: &HashMap<String, Value>,
196) -> Result<Value, RuntimeError> {
197    let mut map = serde_json::Map::new();
198    for field in fields {
199        let val = ast_expr_to_value(&field.value, bindings)?;
200        map.insert(field.name.clone(), val);
201    }
202    Ok(Value::Object(map))
203}
204
205/// Convert a concrete AST expression to a JSON value.
206///
207/// Only handles literal values and identifier references — no complex
208/// expressions (quantifiers, old(), etc.) are valid in test given/when blocks.
209fn ast_expr_to_value(
210    expr: &ast::Expr,
211    bindings: &HashMap<String, Value>,
212) -> Result<Value, RuntimeError> {
213    match &expr.kind {
214        ExprKind::Literal(lit) => match lit {
215            Literal::Null => Ok(Value::Null),
216            Literal::Bool(b) => Ok(Value::Bool(*b)),
217            Literal::Int(n) => Ok(serde_json::json!(*n as f64)),
218            Literal::Decimal(s) => {
219                let n: f64 = s
220                    .parse()
221                    .map_err(|_| RuntimeError::DecimalError(s.clone()))?;
222                Ok(serde_json::json!(n))
223            }
224            Literal::String(s) => Ok(Value::String(s.clone())),
225        },
226        ExprKind::Ident(name) => {
227            // Check bindings first (references to given variables).
228            if let Some(val) = bindings.get(name) {
229                return Ok(val.clone());
230            }
231            // Uppercase identifiers are union variant labels.
232            if name.starts_with(|c: char| c.is_uppercase()) {
233                return Ok(Value::String(name.clone()));
234            }
235            Err(RuntimeError::UnboundVariable(name.clone()))
236        }
237        ExprKind::List(items) => {
238            let vals: Result<Vec<Value>, _> = items
239                .iter()
240                .map(|e| ast_expr_to_value(e, bindings))
241                .collect();
242            Ok(Value::Array(vals?))
243        }
244        ExprKind::Arithmetic { left, op, right } => {
245            let l = ast_expr_to_value(left, bindings)?;
246            let r = ast_expr_to_value(right, bindings)?;
247            let lf = as_f64(&l)?;
248            let rf = as_f64(&r)?;
249            let result = match op {
250                ast::ArithOp::Add => lf + rf,
251                ast::ArithOp::Sub => lf - rf,
252            };
253            Ok(serde_json::json!(result))
254        }
255        _ => Err(RuntimeError::UnboundVariable(
256            "<unsupported expression in test>".into(),
257        )),
258    }
259}
260
261fn as_f64(val: &Value) -> Result<f64, RuntimeError> {
262    val.as_f64().ok_or(RuntimeError::TypeError {
263        expected: "number".into(),
264        got: format!("{val:?}"),
265    })
266}
267
268/// Simple AST expression formatter for error messages.
269fn fmt_ast_expr(expr: &ast::Expr) -> String {
270    match &expr.kind {
271        ExprKind::Ident(name) => name.clone(),
272        ExprKind::Literal(lit) => match lit {
273            Literal::Null => "null".into(),
274            Literal::Bool(b) => b.to_string(),
275            Literal::Int(n) => n.to_string(),
276            Literal::Decimal(s) => s.clone(),
277            Literal::String(s) => format!("\"{s}\""),
278        },
279        ExprKind::FieldAccess { root, fields } => {
280            format!("{}.{}", fmt_ast_expr(root), fields.join("."))
281        }
282        ExprKind::Compare { left, op, right } => {
283            let op_str = match op {
284                ast::CmpOp::Eq => "==",
285                ast::CmpOp::Ne => "!=",
286                ast::CmpOp::Lt => "<",
287                ast::CmpOp::Gt => ">",
288                ast::CmpOp::Le => "<=",
289                ast::CmpOp::Ge => ">=",
290            };
291            format!("{} {op_str} {}", fmt_ast_expr(left), fmt_ast_expr(right))
292        }
293        _ => "...".into(),
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_runner_pass_via_parse() {
303        let src = r#"module Test
304
305entity Account {
306  id: UUID
307  balance: Decimal(precision: 2)
308  status: Active | Frozen
309}
310
311action Transfer {
312  from: Account
313  to: Account
314  amount: Decimal(precision: 2)
315
316  requires {
317    from.status == Active
318    amount > 0
319    from.balance >= amount
320  }
321
322  ensures {
323    from.balance == old(from.balance) - amount
324    to.balance == old(to.balance) + amount
325  }
326}
327
328test "successful transfer" {
329  given {
330    from = Account { id: "1", balance: 1000.0, status: Active }
331    to = Account { id: "2", balance: 500.0, status: Active }
332  }
333  when Transfer {
334    from: from,
335    to: to,
336    amount: 200.0
337  }
338  then {
339    from.balance == 800.0
340    to.balance == 700.0
341  }
342}
343
344test "insufficient funds" {
345  given {
346    from = Account { id: "1", balance: 50.0, status: Active }
347    to = Account { id: "2", balance: 500.0, status: Active }
348  }
349  when Transfer {
350    from: from,
351    to: to,
352    amount: 200.0
353  }
354  then fails
355}
356"#;
357        let file = intent_parser::parse_file(src).unwrap();
358        let ir = intent_ir::lower_file(&file);
359
360        let tests: Vec<_> = file
361            .items
362            .iter()
363            .filter_map(|i| {
364                if let ast::TopLevelItem::Test(t) = i {
365                    Some(t)
366                } else {
367                    None
368                }
369            })
370            .collect();
371
372        let results = run_tests(&ir, &tests);
373        assert_eq!(results.len(), 2);
374        assert!(results[0].passed, "test 0 failed: {:?}", results[0].message);
375        assert!(results[1].passed, "test 1 failed: {:?}", results[1].message);
376    }
377
378    #[test]
379    fn test_runner_then_fails_precondition() {
380        let src = r#"module Test
381
382entity Account {
383  id: UUID
384  balance: Decimal(precision: 2)
385  status: Active | Frozen
386}
387
388action Transfer {
389  from: Account
390  to: Account
391  amount: Decimal(precision: 2)
392
393  requires {
394    from.status == Active
395    amount > 0
396    from.balance >= amount
397  }
398
399  ensures {
400    from.balance == old(from.balance) - amount
401    to.balance == old(to.balance) + amount
402  }
403}
404
405test "frozen account" {
406  given {
407    from = Account { id: "1", balance: 1000.0, status: Frozen }
408    to = Account { id: "2", balance: 500.0, status: Active }
409  }
410  when Transfer {
411    from: from,
412    to: to,
413    amount: 200.0
414  }
415  then fails precondition
416}
417"#;
418        let file = intent_parser::parse_file(src).unwrap();
419        let ir = intent_ir::lower_file(&file);
420
421        let tests: Vec<_> = file
422            .items
423            .iter()
424            .filter_map(|i| {
425                if let ast::TopLevelItem::Test(t) = i {
426                    Some(t)
427                } else {
428                    None
429                }
430            })
431            .collect();
432
433        let results = run_tests(&ir, &tests);
434        assert_eq!(results.len(), 1);
435        assert!(results[0].passed, "test 0 failed: {:?}", results[0].message);
436    }
437}