1use std::collections::HashMap;
2
3use intent_ir::{CmpOp, IrExpr, Module, Postcondition};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::error::RuntimeError;
8use crate::eval::evaluate;
9use crate::value::EvalContext;
10
11#[derive(Debug, Clone, Deserialize)]
13pub struct ActionRequest {
14 #[serde(default)]
16 pub action: String,
17 pub params: HashMap<String, Value>,
19 pub state: HashMap<String, Vec<Value>>,
21}
22
23#[derive(Debug, Clone, Serialize)]
25pub struct ActionResult {
26 pub ok: bool,
28 pub new_params: HashMap<String, Value>,
30 pub violations: Vec<Violation>,
32}
33
34#[derive(Debug, Clone, Serialize, PartialEq)]
36pub struct Violation {
37 pub kind: ViolationKind,
38 pub message: String,
39}
40
41#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
42#[serde(rename_all = "snake_case")]
43pub enum ViolationKind {
44 PreconditionFailed,
45 PostconditionFailed,
46 InvariantViolated,
47 EdgeGuardTriggered,
48}
49
50pub fn execute_action(
58 module: &Module,
59 request: &ActionRequest,
60) -> Result<ActionResult, RuntimeError> {
61 let func = module
62 .functions
63 .iter()
64 .find(|f| f.name == request.action)
65 .ok_or_else(|| RuntimeError::UnknownFunction(request.action.clone()))?;
66
67 let ctx = EvalContext {
69 bindings: request.params.clone(),
70 old_bindings: None,
71 instances: request.state.clone(),
72 };
73
74 let mut violations = Vec::new();
75
76 for pre in &func.preconditions {
78 match evaluate(&pre.expr, &ctx) {
79 Ok(Value::Bool(true)) => {}
80 Ok(Value::Bool(false)) => {
81 violations.push(Violation {
82 kind: ViolationKind::PreconditionFailed,
83 message: format!("precondition failed: {}", fmt_expr(&pre.expr)),
84 });
85 }
86 Ok(_) => {
87 violations.push(Violation {
88 kind: ViolationKind::PreconditionFailed,
89 message: format!(
90 "precondition did not evaluate to bool: {}",
91 fmt_expr(&pre.expr)
92 ),
93 });
94 }
95 Err(e) => {
96 violations.push(Violation {
97 kind: ViolationKind::PreconditionFailed,
98 message: format!("precondition error: {e}"),
99 });
100 }
101 }
102 }
103
104 if !violations.is_empty() {
106 return Ok(ActionResult {
107 ok: false,
108 new_params: request.params.clone(),
109 violations,
110 });
111 }
112
113 for guard in &module.edge_guards {
115 match evaluate(&guard.condition, &ctx) {
116 Ok(Value::Bool(true)) => {
117 violations.push(Violation {
118 kind: ViolationKind::EdgeGuardTriggered,
119 message: format!(
120 "edge case triggered: {} => {}",
121 fmt_expr(&guard.condition),
122 guard.action
123 ),
124 });
125 }
126 Ok(Value::Bool(false)) | Err(_) => {}
127 Ok(_) => {}
128 }
129 }
130
131 if !violations.is_empty() {
132 return Ok(ActionResult {
133 ok: false,
134 new_params: request.params.clone(),
135 violations,
136 });
137 }
138
139 let old_params = request.params.clone();
141 let mut new_params = request.params.clone();
142
143 for post in &func.postconditions {
145 let expr = match post {
146 Postcondition::Always { expr, .. } => expr,
147 Postcondition::When { guard, expr, .. } => {
148 match evaluate(guard, &ctx) {
150 Ok(Value::Bool(true)) => expr,
151 _ => continue,
152 }
153 }
154 };
155 extract_and_apply_assignments(expr, &old_params, &request.state, &mut new_params)?;
156 }
157
158 let post_ctx = EvalContext {
160 bindings: new_params.clone(),
161 old_bindings: Some(old_params),
162 instances: request.state.clone(),
163 };
164
165 for post in &func.postconditions {
166 let (expr, should_check) = match post {
167 Postcondition::Always { expr, .. } => (expr, true),
168 Postcondition::When { guard, expr, .. } => {
169 let guard_result = evaluate(guard, &post_ctx).unwrap_or(Value::Bool(false));
170 (expr, guard_result == Value::Bool(true))
171 }
172 };
173 if !should_check {
174 continue;
175 }
176 match evaluate(expr, &post_ctx) {
177 Ok(Value::Bool(true)) => {}
178 Ok(Value::Bool(false)) => {
179 violations.push(Violation {
180 kind: ViolationKind::PostconditionFailed,
181 message: format!("postcondition failed: {}", fmt_expr(expr)),
182 });
183 }
184 Ok(_) => {}
185 Err(_) => {
186 }
189 }
190 }
191
192 let inv_ctx = EvalContext {
194 bindings: new_params.clone(),
195 old_bindings: None,
196 instances: request.state.clone(),
197 };
198
199 for inv in &module.invariants {
200 match evaluate(&inv.expr, &inv_ctx) {
201 Ok(Value::Bool(true)) => {}
202 Ok(Value::Bool(false)) => {
203 violations.push(Violation {
204 kind: ViolationKind::InvariantViolated,
205 message: format!("invariant '{}' violated", inv.name),
206 });
207 }
208 Ok(_) | Err(_) => {}
209 }
210 }
211
212 Ok(ActionResult {
213 ok: violations.is_empty(),
214 new_params,
215 violations,
216 })
217}
218
219fn extract_and_apply_assignments(
225 expr: &IrExpr,
226 old_params: &HashMap<String, Value>,
227 instances: &HashMap<String, Vec<Value>>,
228 new_params: &mut HashMap<String, Value>,
229) -> Result<(), RuntimeError> {
230 match expr {
231 IrExpr::Compare {
233 left,
234 op: CmpOp::Eq,
235 right,
236 } => {
237 if let Some((var, field)) = extract_field_path(left)
238 && contains_old(right)
239 {
240 let old_ctx = EvalContext {
241 bindings: old_params.clone(),
242 old_bindings: Some(old_params.clone()),
243 instances: instances.clone(),
244 };
245 let value = evaluate(right, &old_ctx)?;
246 set_field(new_params, &var, &field, value);
247 }
248 if let Some((var, field)) = extract_field_path(right)
250 && contains_old(left)
251 {
252 let old_ctx = EvalContext {
253 bindings: old_params.clone(),
254 old_bindings: Some(old_params.clone()),
255 instances: instances.clone(),
256 };
257 let value = evaluate(left, &old_ctx)?;
258 set_field(new_params, &var, &field, value);
259 }
260 }
261 IrExpr::And(left, right) => {
263 extract_and_apply_assignments(left, old_params, instances, new_params)?;
264 extract_and_apply_assignments(right, old_params, instances, new_params)?;
265 }
266 _ => {}
268 }
269 Ok(())
270}
271
272fn extract_field_path(expr: &IrExpr) -> Option<(String, String)> {
274 if let IrExpr::FieldAccess { root, field } = expr
275 && let IrExpr::Var(var) = root.as_ref()
276 {
277 return Some((var.clone(), field.clone()));
278 }
279 None
280}
281
282fn contains_old(expr: &IrExpr) -> bool {
284 match expr {
285 IrExpr::Old(_) => true,
286 IrExpr::Compare { left, right, .. }
287 | IrExpr::Arithmetic { left, right, .. }
288 | IrExpr::And(left, right)
289 | IrExpr::Or(left, right)
290 | IrExpr::Implies(left, right) => contains_old(left) || contains_old(right),
291 IrExpr::Not(inner) => contains_old(inner),
292 IrExpr::FieldAccess { root, .. } => contains_old(root),
293 IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => contains_old(body),
294 IrExpr::Call { args, .. } | IrExpr::List(args) => args.iter().any(contains_old),
295 IrExpr::Var(_) | IrExpr::Literal(_) => false,
296 }
297}
298
299fn set_field(params: &mut HashMap<String, Value>, var: &str, field: &str, value: Value) {
301 if let Some(Value::Object(map)) = params.get_mut(var) {
302 map.insert(field.to_string(), value);
303 }
304}
305
306fn fmt_expr(expr: &IrExpr) -> String {
308 match expr {
309 IrExpr::Var(name) => name.clone(),
310 IrExpr::Literal(lit) => format!("{lit:?}"),
311 IrExpr::FieldAccess { root, field } => format!("{}.{field}", fmt_expr(root)),
312 IrExpr::Compare { left, op, right } => {
313 let op_str = match op {
314 CmpOp::Eq => "==",
315 CmpOp::Ne => "!=",
316 CmpOp::Lt => "<",
317 CmpOp::Gt => ">",
318 CmpOp::Le => "<=",
319 CmpOp::Ge => ">=",
320 };
321 format!("{} {op_str} {}", fmt_expr(left), fmt_expr(right))
322 }
323 IrExpr::And(l, r) => format!("{} && {}", fmt_expr(l), fmt_expr(r)),
324 IrExpr::Or(l, r) => format!("{} || {}", fmt_expr(l), fmt_expr(r)),
325 IrExpr::Not(inner) => format!("!{}", fmt_expr(inner)),
326 IrExpr::Old(inner) => format!("old({})", fmt_expr(inner)),
327 _ => "...".into(),
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use intent_ir::*;
335 use intent_parser::ast::Span;
336 use serde_json::json;
337
338 fn empty_trace() -> SourceTrace {
339 SourceTrace {
340 module: String::new(),
341 item: String::new(),
342 part: String::new(),
343 span: Span { start: 0, end: 0 },
344 }
345 }
346
347 fn simple_module() -> Module {
348 Module {
349 name: "Test".into(),
350 structs: vec![Struct {
351 name: "Account".into(),
352 fields: vec![
353 Field {
354 name: "id".into(),
355 ty: IrType::Named("UUID".into()),
356 trace: empty_trace(),
357 },
358 Field {
359 name: "balance".into(),
360 ty: IrType::Decimal(2),
361 trace: empty_trace(),
362 },
363 Field {
364 name: "status".into(),
365 ty: IrType::Union(vec!["Active".into(), "Frozen".into()]),
366 trace: empty_trace(),
367 },
368 ],
369 trace: empty_trace(),
370 }],
371 functions: vec![Function {
372 name: "Transfer".into(),
373 params: vec![
374 Param {
375 name: "from".into(),
376 ty: IrType::Struct("Account".into()),
377 trace: empty_trace(),
378 },
379 Param {
380 name: "to".into(),
381 ty: IrType::Struct("Account".into()),
382 trace: empty_trace(),
383 },
384 Param {
385 name: "amount".into(),
386 ty: IrType::Decimal(2),
387 trace: empty_trace(),
388 },
389 ],
390 preconditions: vec![
391 Condition {
393 expr: IrExpr::Compare {
394 left: Box::new(IrExpr::FieldAccess {
395 root: Box::new(IrExpr::Var("from".into())),
396 field: "status".into(),
397 }),
398 op: CmpOp::Eq,
399 right: Box::new(IrExpr::Var("Active".into())),
400 },
401 trace: empty_trace(),
402 },
403 Condition {
405 expr: IrExpr::Compare {
406 left: Box::new(IrExpr::Var("amount".into())),
407 op: CmpOp::Gt,
408 right: Box::new(IrExpr::Literal(IrLiteral::Int(0))),
409 },
410 trace: empty_trace(),
411 },
412 Condition {
414 expr: IrExpr::Compare {
415 left: Box::new(IrExpr::FieldAccess {
416 root: Box::new(IrExpr::Var("from".into())),
417 field: "balance".into(),
418 }),
419 op: CmpOp::Ge,
420 right: Box::new(IrExpr::Var("amount".into())),
421 },
422 trace: empty_trace(),
423 },
424 ],
425 postconditions: vec![
426 Postcondition::Always {
428 expr: IrExpr::And(
429 Box::new(IrExpr::Compare {
430 left: Box::new(IrExpr::FieldAccess {
431 root: Box::new(IrExpr::Var("from".into())),
432 field: "balance".into(),
433 }),
434 op: CmpOp::Eq,
435 right: Box::new(IrExpr::Arithmetic {
436 left: Box::new(IrExpr::Old(Box::new(IrExpr::FieldAccess {
437 root: Box::new(IrExpr::Var("from".into())),
438 field: "balance".into(),
439 }))),
440 op: ArithOp::Sub,
441 right: Box::new(IrExpr::Var("amount".into())),
442 }),
443 }),
444 Box::new(IrExpr::Compare {
446 left: Box::new(IrExpr::FieldAccess {
447 root: Box::new(IrExpr::Var("to".into())),
448 field: "balance".into(),
449 }),
450 op: CmpOp::Eq,
451 right: Box::new(IrExpr::Arithmetic {
452 left: Box::new(IrExpr::Old(Box::new(IrExpr::FieldAccess {
453 root: Box::new(IrExpr::Var("to".into())),
454 field: "balance".into(),
455 }))),
456 op: ArithOp::Add,
457 right: Box::new(IrExpr::Var("amount".into())),
458 }),
459 }),
460 ),
461 trace: empty_trace(),
462 },
463 ],
464 properties: vec![],
465 trace: empty_trace(),
466 }],
467 invariants: vec![Invariant {
468 name: "NoNegativeBalances".into(),
469 expr: IrExpr::Forall {
470 binding: "a".into(),
471 ty: "Account".into(),
472 body: Box::new(IrExpr::Compare {
473 left: Box::new(IrExpr::FieldAccess {
474 root: Box::new(IrExpr::Var("a".into())),
475 field: "balance".into(),
476 }),
477 op: CmpOp::Ge,
478 right: Box::new(IrExpr::Literal(IrLiteral::Int(0))),
479 }),
480 },
481 trace: empty_trace(),
482 }],
483 edge_guards: vec![],
484 }
485 }
486
487 #[test]
488 fn execute_valid_transfer() {
489 let module = simple_module();
490 let request = ActionRequest {
491 action: "Transfer".into(),
492 params: HashMap::from([
493 (
494 "from".into(),
495 json!({"id": "1", "balance": 1000.0, "status": "Active"}),
496 ),
497 (
498 "to".into(),
499 json!({"id": "2", "balance": 500.0, "status": "Active"}),
500 ),
501 ("amount".into(), json!(200.0)),
502 ]),
503 state: HashMap::from([(
504 "Account".into(),
505 vec![
506 json!({"id": "1", "balance": 1000.0, "status": "Active"}),
507 json!({"id": "2", "balance": 500.0, "status": "Active"}),
508 ],
509 )]),
510 };
511
512 let result = execute_action(&module, &request).unwrap();
513 assert!(result.ok, "violations: {:?}", result.violations);
514 assert_eq!(result.new_params["from"]["balance"], json!(800.0));
515 assert_eq!(result.new_params["to"]["balance"], json!(700.0));
516 }
517
518 #[test]
519 fn precondition_fails_frozen_account() {
520 let module = simple_module();
521 let request = ActionRequest {
522 action: "Transfer".into(),
523 params: HashMap::from([
524 (
525 "from".into(),
526 json!({"id": "1", "balance": 1000.0, "status": "Frozen"}),
527 ),
528 (
529 "to".into(),
530 json!({"id": "2", "balance": 500.0, "status": "Active"}),
531 ),
532 ("amount".into(), json!(200.0)),
533 ]),
534 state: HashMap::new(),
535 };
536
537 let result = execute_action(&module, &request).unwrap();
538 assert!(!result.ok);
539 assert_eq!(result.violations.len(), 1);
540 assert_eq!(result.violations[0].kind, ViolationKind::PreconditionFailed);
541 assert!(result.violations[0].message.contains("from.status"));
542 }
543
544 #[test]
545 fn precondition_fails_insufficient_balance() {
546 let module = simple_module();
547 let request = ActionRequest {
548 action: "Transfer".into(),
549 params: HashMap::from([
550 (
551 "from".into(),
552 json!({"id": "1", "balance": 50.0, "status": "Active"}),
553 ),
554 (
555 "to".into(),
556 json!({"id": "2", "balance": 500.0, "status": "Active"}),
557 ),
558 ("amount".into(), json!(200.0)),
559 ]),
560 state: HashMap::new(),
561 };
562
563 let result = execute_action(&module, &request).unwrap();
564 assert!(!result.ok);
565 assert!(
566 result
567 .violations
568 .iter()
569 .any(|v| v.kind == ViolationKind::PreconditionFailed)
570 );
571 }
572
573 #[test]
574 fn precondition_fails_zero_amount() {
575 let module = simple_module();
576 let request = ActionRequest {
577 action: "Transfer".into(),
578 params: HashMap::from([
579 (
580 "from".into(),
581 json!({"id": "1", "balance": 1000.0, "status": "Active"}),
582 ),
583 (
584 "to".into(),
585 json!({"id": "2", "balance": 500.0, "status": "Active"}),
586 ),
587 ("amount".into(), json!(0)),
588 ]),
589 state: HashMap::new(),
590 };
591
592 let result = execute_action(&module, &request).unwrap();
593 assert!(!result.ok);
594 assert!(
595 result
596 .violations
597 .iter()
598 .any(|v| v.kind == ViolationKind::PreconditionFailed)
599 );
600 }
601
602 #[test]
603 fn invariant_violation_detected() {
604 let module = Module {
606 name: "Test".into(),
607 structs: vec![],
608 functions: vec![Function {
609 name: "Withdraw".into(),
610 params: vec![Param {
611 name: "account".into(),
612 ty: IrType::Struct("Account".into()),
613 trace: empty_trace(),
614 }],
615 preconditions: vec![],
616 postconditions: vec![],
617 properties: vec![],
618 trace: empty_trace(),
619 }],
620 invariants: vec![Invariant {
621 name: "MinBalance".into(),
622 expr: IrExpr::Forall {
623 binding: "a".into(),
624 ty: "Account".into(),
625 body: Box::new(IrExpr::Compare {
626 left: Box::new(IrExpr::FieldAccess {
627 root: Box::new(IrExpr::Var("a".into())),
628 field: "balance".into(),
629 }),
630 op: CmpOp::Ge,
631 right: Box::new(IrExpr::Literal(IrLiteral::Int(100))),
632 }),
633 },
634 trace: empty_trace(),
635 }],
636 edge_guards: vec![],
637 };
638
639 let request = ActionRequest {
640 action: "Withdraw".into(),
641 params: HashMap::from([("account".into(), json!({"balance": 50}))]),
642 state: HashMap::from([("Account".into(), vec![json!({"balance": 50})])]),
643 };
644
645 let result = execute_action(&module, &request).unwrap();
646 assert!(!result.ok);
647 assert!(result.violations.iter().any(|v| {
648 v.kind == ViolationKind::InvariantViolated && v.message.contains("MinBalance")
649 }));
650 }
651
652 #[test]
653 fn unknown_action_error() {
654 let module = simple_module();
655 let request = ActionRequest {
656 action: "NonExistent".into(),
657 params: HashMap::new(),
658 state: HashMap::new(),
659 };
660
661 assert!(matches!(
662 execute_action(&module, &request),
663 Err(RuntimeError::UnknownFunction(_))
664 ));
665 }
666
667 #[test]
668 fn edge_guard_blocks_execution() {
669 let module = Module {
670 name: "Test".into(),
671 structs: vec![],
672 functions: vec![Function {
673 name: "Transfer".into(),
674 params: vec![Param {
675 name: "amount".into(),
676 ty: IrType::Decimal(2),
677 trace: empty_trace(),
678 }],
679 preconditions: vec![],
680 postconditions: vec![],
681 properties: vec![],
682 trace: empty_trace(),
683 }],
684 invariants: vec![],
685 edge_guards: vec![EdgeGuard {
686 condition: IrExpr::Compare {
687 left: Box::new(IrExpr::Var("amount".into())),
688 op: CmpOp::Gt,
689 right: Box::new(IrExpr::Literal(IrLiteral::Int(10000))),
690 },
691 action: "require_approval".into(),
692 args: vec![],
693 trace: empty_trace(),
694 }],
695 };
696
697 let request = ActionRequest {
698 action: "Transfer".into(),
699 params: HashMap::from([("amount".into(), json!(50000))]),
700 state: HashMap::new(),
701 };
702
703 let result = execute_action(&module, &request).unwrap();
704 assert!(!result.ok);
705 assert_eq!(result.violations[0].kind, ViolationKind::EdgeGuardTriggered);
706 }
707
708 #[test]
709 fn when_postcondition_guarded() {
710 let module = Module {
711 name: "Test".into(),
712 structs: vec![],
713 functions: vec![Function {
714 name: "SetStatus".into(),
715 params: vec![
716 Param {
717 name: "account".into(),
718 ty: IrType::Struct("Account".into()),
719 trace: empty_trace(),
720 },
721 Param {
722 name: "freeze".into(),
723 ty: IrType::Named("Bool".into()),
724 trace: empty_trace(),
725 },
726 ],
727 preconditions: vec![],
728 postconditions: vec![Postcondition::When {
729 guard: IrExpr::Var("freeze".into()),
730 expr: IrExpr::Compare {
732 left: Box::new(IrExpr::FieldAccess {
733 root: Box::new(IrExpr::Var("account".into())),
734 field: "status".into(),
735 }),
736 op: CmpOp::Eq,
737 right: Box::new(IrExpr::Var("Frozen".into())),
738 },
739 trace: empty_trace(),
740 }],
741 properties: vec![],
742 trace: empty_trace(),
743 }],
744 invariants: vec![],
745 edge_guards: vec![],
746 };
747
748 let request = ActionRequest {
750 action: "SetStatus".into(),
751 params: HashMap::from([
752 ("account".into(), json!({"status": "Active"})),
753 ("freeze".into(), json!(false)),
754 ]),
755 state: HashMap::new(),
756 };
757 let result = execute_action(&module, &request).unwrap();
758 assert!(result.ok);
759
760 let request2 = ActionRequest {
762 action: "SetStatus".into(),
763 params: HashMap::from([
764 ("account".into(), json!({"status": "Active"})),
765 ("freeze".into(), json!(true)),
766 ]),
767 state: HashMap::new(),
768 };
769 let result2 = execute_action(&module, &request2).unwrap();
770 assert!(!result2.ok);
771 assert_eq!(
772 result2.violations[0].kind,
773 ViolationKind::PostconditionFailed
774 );
775 }
776}