1use std::collections::HashMap;
2
3use intent_ir::{CmpOp, IrExpr, Module, Postcondition};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::error::RuntimeError;
8use crate::eval::evaluate;
9use crate::value::EvalContext;
10
11#[derive(Debug, Clone, Deserialize)]
13pub struct ActionRequest {
14 #[serde(default)]
16 pub action: String,
17 pub params: HashMap<String, Value>,
19 pub state: HashMap<String, Vec<Value>>,
21}
22
23#[derive(Debug, Clone, Serialize)]
25pub struct ActionResult {
26 pub ok: bool,
28 pub new_params: HashMap<String, Value>,
30 pub violations: Vec<Violation>,
32}
33
34#[derive(Debug, Clone, Serialize, PartialEq)]
36pub struct Violation {
37 pub kind: ViolationKind,
38 pub message: String,
39}
40
41#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
42#[serde(rename_all = "snake_case")]
43pub enum ViolationKind {
44 PreconditionFailed,
45 PostconditionFailed,
46 InvariantViolated,
47 EdgeGuardTriggered,
48}
49
50pub fn execute_action(
58 module: &Module,
59 request: &ActionRequest,
60) -> Result<ActionResult, RuntimeError> {
61 let func = module
62 .functions
63 .iter()
64 .find(|f| f.name == request.action)
65 .ok_or_else(|| RuntimeError::UnknownFunction(request.action.clone()))?;
66
67 let ctx = EvalContext {
69 bindings: request.params.clone(),
70 old_bindings: None,
71 instances: request.state.clone(),
72 };
73
74 let mut violations = Vec::new();
75
76 for pre in &func.preconditions {
78 match evaluate(&pre.expr, &ctx) {
79 Ok(Value::Bool(true)) => {}
80 Ok(Value::Bool(false)) => {
81 violations.push(Violation {
82 kind: ViolationKind::PreconditionFailed,
83 message: format!("precondition failed: {}", fmt_expr(&pre.expr)),
84 });
85 }
86 Ok(_) => {
87 violations.push(Violation {
88 kind: ViolationKind::PreconditionFailed,
89 message: format!(
90 "precondition did not evaluate to bool: {}",
91 fmt_expr(&pre.expr)
92 ),
93 });
94 }
95 Err(e) => {
96 violations.push(Violation {
97 kind: ViolationKind::PreconditionFailed,
98 message: format!("precondition error: {e}"),
99 });
100 }
101 }
102 }
103
104 if !violations.is_empty() {
106 return Ok(ActionResult {
107 ok: false,
108 new_params: request.params.clone(),
109 violations,
110 });
111 }
112
113 for guard in &module.edge_guards {
115 match evaluate(&guard.condition, &ctx) {
116 Ok(Value::Bool(true)) => {
117 violations.push(Violation {
118 kind: ViolationKind::EdgeGuardTriggered,
119 message: format!(
120 "edge case triggered: {} => {}",
121 fmt_expr(&guard.condition),
122 guard.action
123 ),
124 });
125 }
126 Ok(Value::Bool(false)) | Err(_) => {}
127 Ok(_) => {}
128 }
129 }
130
131 if !violations.is_empty() {
132 return Ok(ActionResult {
133 ok: false,
134 new_params: request.params.clone(),
135 violations,
136 });
137 }
138
139 let old_params = request.params.clone();
141 let mut new_params = request.params.clone();
142
143 for post in &func.postconditions {
145 let expr = match post {
146 Postcondition::Always { expr, .. } => expr,
147 Postcondition::When { guard, expr, .. } => {
148 match evaluate(guard, &ctx) {
150 Ok(Value::Bool(true)) => expr,
151 _ => continue,
152 }
153 }
154 };
155 extract_and_apply_assignments(expr, &old_params, &request.state, &mut new_params)?;
156 }
157
158 let post_ctx = EvalContext {
160 bindings: new_params.clone(),
161 old_bindings: Some(old_params),
162 instances: request.state.clone(),
163 };
164
165 for post in &func.postconditions {
166 let (expr, should_check) = match post {
167 Postcondition::Always { expr, .. } => (expr, true),
168 Postcondition::When { guard, expr, .. } => {
169 let guard_result = evaluate(guard, &post_ctx).unwrap_or(Value::Bool(false));
170 (expr, guard_result == Value::Bool(true))
171 }
172 };
173 if !should_check {
174 continue;
175 }
176 match evaluate(expr, &post_ctx) {
177 Ok(Value::Bool(true)) => {}
178 Ok(Value::Bool(false)) => {
179 violations.push(Violation {
180 kind: ViolationKind::PostconditionFailed,
181 message: format!("postcondition failed: {}", fmt_expr(expr)),
182 });
183 }
184 Ok(_) => {}
185 Err(_) => {
186 }
189 }
190 }
191
192 let inv_ctx = EvalContext {
194 bindings: new_params.clone(),
195 old_bindings: None,
196 instances: request.state.clone(),
197 };
198
199 for inv in &module.invariants {
200 match evaluate(&inv.expr, &inv_ctx) {
201 Ok(Value::Bool(true)) => {}
202 Ok(Value::Bool(false)) => {
203 violations.push(Violation {
204 kind: ViolationKind::InvariantViolated,
205 message: format!("invariant '{}' violated", inv.name),
206 });
207 }
208 Ok(_) | Err(_) => {}
209 }
210 }
211
212 Ok(ActionResult {
213 ok: violations.is_empty(),
214 new_params,
215 violations,
216 })
217}
218
219fn extract_and_apply_assignments(
228 expr: &IrExpr,
229 old_params: &HashMap<String, Value>,
230 instances: &HashMap<String, Vec<Value>>,
231 new_params: &mut HashMap<String, Value>,
232) -> Result<(), RuntimeError> {
233 match expr {
234 IrExpr::Compare {
236 left,
237 op: CmpOp::Eq,
238 right,
239 } => {
240 if let Some((var, field)) = extract_field_path(left)
241 && is_assignable(right)
242 {
243 let old_ctx = EvalContext {
244 bindings: old_params.clone(),
245 old_bindings: Some(old_params.clone()),
246 instances: instances.clone(),
247 };
248 let value = evaluate(right, &old_ctx)?;
249 set_field(new_params, &var, &field, value);
250 }
251 if let Some((var, field)) = extract_field_path(right)
253 && is_assignable(left)
254 {
255 let old_ctx = EvalContext {
256 bindings: old_params.clone(),
257 old_bindings: Some(old_params.clone()),
258 instances: instances.clone(),
259 };
260 let value = evaluate(left, &old_ctx)?;
261 set_field(new_params, &var, &field, value);
262 }
263 }
264 IrExpr::And(left, right) => {
266 extract_and_apply_assignments(left, old_params, instances, new_params)?;
267 extract_and_apply_assignments(right, old_params, instances, new_params)?;
268 }
269 _ => {}
271 }
272 Ok(())
273}
274
275fn extract_field_path(expr: &IrExpr) -> Option<(String, String)> {
277 if let IrExpr::FieldAccess { root, field } = expr
278 && let IrExpr::Var(var) = root.as_ref()
279 {
280 return Some((var.clone(), field.clone()));
281 }
282 None
283}
284
285fn is_assignable(expr: &IrExpr) -> bool {
295 match expr {
296 IrExpr::Literal(_) | IrExpr::Var(_) | IrExpr::Old(_) => true,
297 IrExpr::Arithmetic { .. } => true,
298 IrExpr::FieldAccess { root, .. } => is_assignable(root),
299 IrExpr::Call { .. } => true,
300 IrExpr::Compare { .. }
301 | IrExpr::And(_, _)
302 | IrExpr::Or(_, _)
303 | IrExpr::Implies(_, _)
304 | IrExpr::Not(_)
305 | IrExpr::Forall { .. }
306 | IrExpr::Exists { .. }
307 | IrExpr::List(_) => false,
308 }
309}
310
311fn set_field(params: &mut HashMap<String, Value>, var: &str, field: &str, value: Value) {
313 if let Some(Value::Object(map)) = params.get_mut(var) {
314 map.insert(field.to_string(), value);
315 }
316}
317
318fn fmt_expr(expr: &IrExpr) -> String {
320 match expr {
321 IrExpr::Var(name) => name.clone(),
322 IrExpr::Literal(lit) => format!("{lit:?}"),
323 IrExpr::FieldAccess { root, field } => format!("{}.{field}", fmt_expr(root)),
324 IrExpr::Compare { left, op, right } => {
325 let op_str = match op {
326 CmpOp::Eq => "==",
327 CmpOp::Ne => "!=",
328 CmpOp::Lt => "<",
329 CmpOp::Gt => ">",
330 CmpOp::Le => "<=",
331 CmpOp::Ge => ">=",
332 };
333 format!("{} {op_str} {}", fmt_expr(left), fmt_expr(right))
334 }
335 IrExpr::And(l, r) => format!("{} && {}", fmt_expr(l), fmt_expr(r)),
336 IrExpr::Or(l, r) => format!("{} || {}", fmt_expr(l), fmt_expr(r)),
337 IrExpr::Not(inner) => format!("!{}", fmt_expr(inner)),
338 IrExpr::Old(inner) => format!("old({})", fmt_expr(inner)),
339 _ => "...".into(),
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use intent_ir::*;
347 use intent_parser::ast::Span;
348 use serde_json::json;
349
350 fn empty_trace() -> SourceTrace {
351 SourceTrace {
352 module: String::new(),
353 item: String::new(),
354 part: String::new(),
355 span: Span { start: 0, end: 0 },
356 }
357 }
358
359 fn simple_module() -> Module {
360 Module {
361 name: "Test".into(),
362 structs: vec![Struct {
363 name: "Account".into(),
364 fields: vec![
365 Field {
366 name: "id".into(),
367 ty: IrType::Named("UUID".into()),
368 trace: empty_trace(),
369 },
370 Field {
371 name: "balance".into(),
372 ty: IrType::Decimal(2),
373 trace: empty_trace(),
374 },
375 Field {
376 name: "status".into(),
377 ty: IrType::Union(vec!["Active".into(), "Frozen".into()]),
378 trace: empty_trace(),
379 },
380 ],
381 trace: empty_trace(),
382 }],
383 functions: vec![Function {
384 name: "Transfer".into(),
385 params: vec![
386 Param {
387 name: "from".into(),
388 ty: IrType::Struct("Account".into()),
389 trace: empty_trace(),
390 },
391 Param {
392 name: "to".into(),
393 ty: IrType::Struct("Account".into()),
394 trace: empty_trace(),
395 },
396 Param {
397 name: "amount".into(),
398 ty: IrType::Decimal(2),
399 trace: empty_trace(),
400 },
401 ],
402 preconditions: vec![
403 Condition {
405 expr: IrExpr::Compare {
406 left: Box::new(IrExpr::FieldAccess {
407 root: Box::new(IrExpr::Var("from".into())),
408 field: "status".into(),
409 }),
410 op: CmpOp::Eq,
411 right: Box::new(IrExpr::Var("Active".into())),
412 },
413 trace: empty_trace(),
414 },
415 Condition {
417 expr: IrExpr::Compare {
418 left: Box::new(IrExpr::Var("amount".into())),
419 op: CmpOp::Gt,
420 right: Box::new(IrExpr::Literal(IrLiteral::Int(0))),
421 },
422 trace: empty_trace(),
423 },
424 Condition {
426 expr: IrExpr::Compare {
427 left: Box::new(IrExpr::FieldAccess {
428 root: Box::new(IrExpr::Var("from".into())),
429 field: "balance".into(),
430 }),
431 op: CmpOp::Ge,
432 right: Box::new(IrExpr::Var("amount".into())),
433 },
434 trace: empty_trace(),
435 },
436 ],
437 postconditions: vec![
438 Postcondition::Always {
440 expr: IrExpr::And(
441 Box::new(IrExpr::Compare {
442 left: Box::new(IrExpr::FieldAccess {
443 root: Box::new(IrExpr::Var("from".into())),
444 field: "balance".into(),
445 }),
446 op: CmpOp::Eq,
447 right: Box::new(IrExpr::Arithmetic {
448 left: Box::new(IrExpr::Old(Box::new(IrExpr::FieldAccess {
449 root: Box::new(IrExpr::Var("from".into())),
450 field: "balance".into(),
451 }))),
452 op: ArithOp::Sub,
453 right: Box::new(IrExpr::Var("amount".into())),
454 }),
455 }),
456 Box::new(IrExpr::Compare {
458 left: Box::new(IrExpr::FieldAccess {
459 root: Box::new(IrExpr::Var("to".into())),
460 field: "balance".into(),
461 }),
462 op: CmpOp::Eq,
463 right: Box::new(IrExpr::Arithmetic {
464 left: Box::new(IrExpr::Old(Box::new(IrExpr::FieldAccess {
465 root: Box::new(IrExpr::Var("to".into())),
466 field: "balance".into(),
467 }))),
468 op: ArithOp::Add,
469 right: Box::new(IrExpr::Var("amount".into())),
470 }),
471 }),
472 ),
473 trace: empty_trace(),
474 },
475 ],
476 properties: vec![],
477 trace: empty_trace(),
478 }],
479 invariants: vec![Invariant {
480 name: "NoNegativeBalances".into(),
481 expr: IrExpr::Forall {
482 binding: "a".into(),
483 ty: "Account".into(),
484 body: Box::new(IrExpr::Compare {
485 left: Box::new(IrExpr::FieldAccess {
486 root: Box::new(IrExpr::Var("a".into())),
487 field: "balance".into(),
488 }),
489 op: CmpOp::Ge,
490 right: Box::new(IrExpr::Literal(IrLiteral::Int(0))),
491 }),
492 },
493 trace: empty_trace(),
494 }],
495 edge_guards: vec![],
496 }
497 }
498
499 #[test]
500 fn execute_valid_transfer() {
501 let module = simple_module();
502 let request = ActionRequest {
503 action: "Transfer".into(),
504 params: HashMap::from([
505 (
506 "from".into(),
507 json!({"id": "1", "balance": 1000.0, "status": "Active"}),
508 ),
509 (
510 "to".into(),
511 json!({"id": "2", "balance": 500.0, "status": "Active"}),
512 ),
513 ("amount".into(), json!(200.0)),
514 ]),
515 state: HashMap::from([(
516 "Account".into(),
517 vec![
518 json!({"id": "1", "balance": 1000.0, "status": "Active"}),
519 json!({"id": "2", "balance": 500.0, "status": "Active"}),
520 ],
521 )]),
522 };
523
524 let result = execute_action(&module, &request).unwrap();
525 assert!(result.ok, "violations: {:?}", result.violations);
526 assert_eq!(result.new_params["from"]["balance"], json!(800.0));
527 assert_eq!(result.new_params["to"]["balance"], json!(700.0));
528 }
529
530 #[test]
531 fn precondition_fails_frozen_account() {
532 let module = simple_module();
533 let request = ActionRequest {
534 action: "Transfer".into(),
535 params: HashMap::from([
536 (
537 "from".into(),
538 json!({"id": "1", "balance": 1000.0, "status": "Frozen"}),
539 ),
540 (
541 "to".into(),
542 json!({"id": "2", "balance": 500.0, "status": "Active"}),
543 ),
544 ("amount".into(), json!(200.0)),
545 ]),
546 state: HashMap::new(),
547 };
548
549 let result = execute_action(&module, &request).unwrap();
550 assert!(!result.ok);
551 assert_eq!(result.violations.len(), 1);
552 assert_eq!(result.violations[0].kind, ViolationKind::PreconditionFailed);
553 assert!(result.violations[0].message.contains("from.status"));
554 }
555
556 #[test]
557 fn precondition_fails_insufficient_balance() {
558 let module = simple_module();
559 let request = ActionRequest {
560 action: "Transfer".into(),
561 params: HashMap::from([
562 (
563 "from".into(),
564 json!({"id": "1", "balance": 50.0, "status": "Active"}),
565 ),
566 (
567 "to".into(),
568 json!({"id": "2", "balance": 500.0, "status": "Active"}),
569 ),
570 ("amount".into(), json!(200.0)),
571 ]),
572 state: HashMap::new(),
573 };
574
575 let result = execute_action(&module, &request).unwrap();
576 assert!(!result.ok);
577 assert!(
578 result
579 .violations
580 .iter()
581 .any(|v| v.kind == ViolationKind::PreconditionFailed)
582 );
583 }
584
585 #[test]
586 fn precondition_fails_zero_amount() {
587 let module = simple_module();
588 let request = ActionRequest {
589 action: "Transfer".into(),
590 params: HashMap::from([
591 (
592 "from".into(),
593 json!({"id": "1", "balance": 1000.0, "status": "Active"}),
594 ),
595 (
596 "to".into(),
597 json!({"id": "2", "balance": 500.0, "status": "Active"}),
598 ),
599 ("amount".into(), json!(0)),
600 ]),
601 state: HashMap::new(),
602 };
603
604 let result = execute_action(&module, &request).unwrap();
605 assert!(!result.ok);
606 assert!(
607 result
608 .violations
609 .iter()
610 .any(|v| v.kind == ViolationKind::PreconditionFailed)
611 );
612 }
613
614 #[test]
615 fn invariant_violation_detected() {
616 let module = Module {
618 name: "Test".into(),
619 structs: vec![],
620 functions: vec![Function {
621 name: "Withdraw".into(),
622 params: vec![Param {
623 name: "account".into(),
624 ty: IrType::Struct("Account".into()),
625 trace: empty_trace(),
626 }],
627 preconditions: vec![],
628 postconditions: vec![],
629 properties: vec![],
630 trace: empty_trace(),
631 }],
632 invariants: vec![Invariant {
633 name: "MinBalance".into(),
634 expr: IrExpr::Forall {
635 binding: "a".into(),
636 ty: "Account".into(),
637 body: Box::new(IrExpr::Compare {
638 left: Box::new(IrExpr::FieldAccess {
639 root: Box::new(IrExpr::Var("a".into())),
640 field: "balance".into(),
641 }),
642 op: CmpOp::Ge,
643 right: Box::new(IrExpr::Literal(IrLiteral::Int(100))),
644 }),
645 },
646 trace: empty_trace(),
647 }],
648 edge_guards: vec![],
649 };
650
651 let request = ActionRequest {
652 action: "Withdraw".into(),
653 params: HashMap::from([("account".into(), json!({"balance": 50}))]),
654 state: HashMap::from([("Account".into(), vec![json!({"balance": 50})])]),
655 };
656
657 let result = execute_action(&module, &request).unwrap();
658 assert!(!result.ok);
659 assert!(result.violations.iter().any(|v| {
660 v.kind == ViolationKind::InvariantViolated && v.message.contains("MinBalance")
661 }));
662 }
663
664 #[test]
665 fn unknown_action_error() {
666 let module = simple_module();
667 let request = ActionRequest {
668 action: "NonExistent".into(),
669 params: HashMap::new(),
670 state: HashMap::new(),
671 };
672
673 assert!(matches!(
674 execute_action(&module, &request),
675 Err(RuntimeError::UnknownFunction(_))
676 ));
677 }
678
679 #[test]
680 fn edge_guard_blocks_execution() {
681 let module = Module {
682 name: "Test".into(),
683 structs: vec![],
684 functions: vec![Function {
685 name: "Transfer".into(),
686 params: vec![Param {
687 name: "amount".into(),
688 ty: IrType::Decimal(2),
689 trace: empty_trace(),
690 }],
691 preconditions: vec![],
692 postconditions: vec![],
693 properties: vec![],
694 trace: empty_trace(),
695 }],
696 invariants: vec![],
697 edge_guards: vec![EdgeGuard {
698 condition: IrExpr::Compare {
699 left: Box::new(IrExpr::Var("amount".into())),
700 op: CmpOp::Gt,
701 right: Box::new(IrExpr::Literal(IrLiteral::Int(10000))),
702 },
703 action: "require_approval".into(),
704 args: vec![],
705 trace: empty_trace(),
706 }],
707 };
708
709 let request = ActionRequest {
710 action: "Transfer".into(),
711 params: HashMap::from([("amount".into(), json!(50000))]),
712 state: HashMap::new(),
713 };
714
715 let result = execute_action(&module, &request).unwrap();
716 assert!(!result.ok);
717 assert_eq!(result.violations[0].kind, ViolationKind::EdgeGuardTriggered);
718 }
719
720 #[test]
721 fn when_postcondition_guarded() {
722 let module = Module {
723 name: "Test".into(),
724 structs: vec![],
725 functions: vec![Function {
726 name: "SetStatus".into(),
727 params: vec![
728 Param {
729 name: "account".into(),
730 ty: IrType::Struct("Account".into()),
731 trace: empty_trace(),
732 },
733 Param {
734 name: "freeze".into(),
735 ty: IrType::Named("Bool".into()),
736 trace: empty_trace(),
737 },
738 ],
739 preconditions: vec![],
740 postconditions: vec![Postcondition::When {
741 guard: IrExpr::Var("freeze".into()),
742 expr: IrExpr::Compare {
744 left: Box::new(IrExpr::FieldAccess {
745 root: Box::new(IrExpr::Var("account".into())),
746 field: "status".into(),
747 }),
748 op: CmpOp::Eq,
749 right: Box::new(IrExpr::Var("Frozen".into())),
750 },
751 trace: empty_trace(),
752 }],
753 properties: vec![],
754 trace: empty_trace(),
755 }],
756 invariants: vec![],
757 edge_guards: vec![],
758 };
759
760 let request = ActionRequest {
762 action: "SetStatus".into(),
763 params: HashMap::from([
764 ("account".into(), json!({"status": "Active"})),
765 ("freeze".into(), json!(false)),
766 ]),
767 state: HashMap::new(),
768 };
769 let result = execute_action(&module, &request).unwrap();
770 assert!(result.ok);
771
772 let request2 = ActionRequest {
774 action: "SetStatus".into(),
775 params: HashMap::from([
776 ("account".into(), json!({"status": "Active"})),
777 ("freeze".into(), json!(true)),
778 ]),
779 state: HashMap::new(),
780 };
781 let result2 = execute_action(&module, &request2).unwrap();
782 assert!(result2.ok, "violations: {:?}", result2.violations);
783 assert_eq!(result2.new_params["account"]["status"], json!("Frozen"));
784 }
785}