1use std::collections::HashMap;
8
9use intent_parser::ast::{self, CmpOp, ExprKind, Literal, TypeKind};
10
11use crate::test_harness::slugify;
12use crate::to_snake_case;
13
14struct FieldMeta {
18 simple_type: Option<String>,
20 union_info: Option<(String, Vec<String>)>,
22}
23
24struct EntityInfo {
26 fields: HashMap<String, FieldMeta>,
27}
28
29pub fn generate(file: &ast::File) -> String {
33 let tests: Vec<_> = file
34 .items
35 .iter()
36 .filter_map(|i| match i {
37 ast::TopLevelItem::Test(t) => Some(t),
38 _ => None,
39 })
40 .collect();
41
42 if tests.is_empty() {
43 return String::new();
44 }
45
46 let entities = collect_entities(file);
47 let mut out = String::new();
48
49 out.push_str("#[cfg(test)]\n");
50 out.push_str("mod contract_tests {\n");
51 out.push_str(" use super::*;\n");
52 if uses_type(file, "Decimal") {
53 out.push_str(" use std::str::FromStr;\n");
54 }
55 out.push('\n');
56
57 for test in &tests {
58 generate_spec_test(&mut out, test, file, &entities);
59 }
60
61 out.push_str("}\n");
62 out
63}
64
65fn collect_entities(file: &ast::File) -> HashMap<String, EntityInfo> {
68 let mut map = HashMap::new();
69
70 for item in &file.items {
71 if let ast::TopLevelItem::Entity(entity) = item {
72 let mut fields = HashMap::new();
73 for field in &entity.fields {
74 let simple_type = simple_type_name(&field.ty.ty);
75 let union_info = if let TypeKind::Union(variants) = &field.ty.ty {
76 let enum_name = format!("{}{}", entity.name, capitalize(&field.name));
77 let names: Vec<String> = variants
78 .iter()
79 .filter_map(|v| {
80 if let TypeKind::Simple(n) = v {
81 Some(n.clone())
82 } else {
83 None
84 }
85 })
86 .collect();
87 Some((enum_name, names))
88 } else {
89 None
90 };
91 fields.insert(
92 field.name.clone(),
93 FieldMeta {
94 simple_type,
95 union_info,
96 },
97 );
98 }
99 map.insert(entity.name.clone(), EntityInfo { fields });
100 }
101 }
102
103 map
104}
105
106fn simple_type_name(kind: &TypeKind) -> Option<String> {
107 match kind {
108 TypeKind::Simple(name) => Some(name.clone()),
109 TypeKind::Parameterized { name, .. } => Some(name.clone()),
110 _ => None,
111 }
112}
113
114fn uses_type(file: &ast::File, target: &str) -> bool {
115 for item in &file.items {
116 match item {
117 ast::TopLevelItem::Entity(e) => {
118 for f in &e.fields {
119 if type_matches(&f.ty.ty, target) {
120 return true;
121 }
122 }
123 }
124 ast::TopLevelItem::Action(a) => {
125 for p in &a.params {
126 if type_matches(&p.ty.ty, target) {
127 return true;
128 }
129 }
130 }
131 _ => {}
132 }
133 }
134 false
135}
136
137fn type_matches(kind: &TypeKind, target: &str) -> bool {
138 match kind {
139 TypeKind::Simple(n) | TypeKind::Parameterized { name: n, .. } => n == target,
140 _ => false,
141 }
142}
143
144fn generate_spec_test(
147 out: &mut String,
148 test: &ast::TestDecl,
149 file: &ast::File,
150 entities: &HashMap<String, EntityInfo>,
151) {
152 let test_name = slugify(&test.name);
153 let mut given_types: HashMap<String, String> = HashMap::new();
154
155 out.push_str(&format!(" /// Spec test: \"{}\"\n", test.name));
156 out.push_str(" #[test]\n");
157 out.push_str(&format!(" fn test_{test_name}() {{\n"));
158
159 for binding in &test.given {
161 generate_binding(out, binding, entities, &mut given_types);
162 }
163 out.push('\n');
164
165 let action = file.items.iter().find_map(|i| match i {
167 ast::TopLevelItem::Action(a) if a.name == test.when_action.action_name => Some(a),
168 _ => None,
169 });
170
171 generate_call(out, &test.when_action, action, &given_types, entities);
172
173 match &test.then {
175 ast::ThenClause::Asserts(exprs, _) => {
176 out.push_str(" assert!(result.is_ok(), \"expected action to succeed\");\n");
177 for expr in exprs {
178 generate_assertion(out, expr, &given_types, entities);
179 }
180 }
181 ast::ThenClause::Fails(kind, _) => {
182 let msg = match kind {
183 Some(k) => format!("expected action to fail: {k}"),
184 None => "expected action to fail".to_string(),
185 };
186 out.push_str(&format!(" assert!(result.is_err(), \"{msg}\");\n"));
187 }
188 }
189
190 out.push_str(" }\n\n");
191}
192
193fn generate_binding(
196 out: &mut String,
197 binding: &ast::GivenBinding,
198 entities: &HashMap<String, EntityInfo>,
199 given_types: &mut HashMap<String, String>,
200) {
201 match &binding.value {
202 ast::GivenValue::EntityConstructor { type_name, fields } => {
203 given_types.insert(binding.name.clone(), type_name.clone());
204 let entity = entities.get(type_name);
205
206 out.push_str(&format!(
207 " let mut {} = {} {{\n",
208 binding.name, type_name
209 ));
210 for field in fields {
211 let field_meta = entity.and_then(|e| e.fields.get(&field.name));
212 let value = field_value_to_rust(&field.value, field_meta);
213 out.push_str(&format!(
214 " {}: {},\n",
215 safe_field(&field.name),
216 value
217 ));
218 }
219 out.push_str(" };\n");
220 }
221 ast::GivenValue::Expr(expr) => {
222 let value = expr_to_rust(expr);
223 out.push_str(&format!(" let {} = {};\n", binding.name, value));
224 }
225 }
226}
227
228fn field_value_to_rust(expr: &ast::Expr, meta: Option<&FieldMeta>) -> String {
230 if let ExprKind::Ident(name) = &expr.kind
232 && let Some(meta) = meta
233 && let Some((enum_name, variants)) = &meta.union_info
234 && variants.contains(name)
235 {
236 return format!("{enum_name}::{name}");
237 }
238
239 let type_hint = meta.and_then(|m| m.simple_type.as_deref());
240 match &expr.kind {
241 ExprKind::Literal(lit) => literal_to_rust(lit, type_hint),
242 ExprKind::Ident(name) => name.clone(),
243 _ => expr_to_rust(expr),
244 }
245}
246
247fn generate_call(
250 out: &mut String,
251 when: &ast::WhenAction,
252 action: Option<&ast::ActionDecl>,
253 given_types: &HashMap<String, String>,
254 entities: &HashMap<String, EntityInfo>,
255) {
256 let fn_name = to_snake_case(&when.action_name);
257 let mut args = Vec::new();
258
259 if let Some(action) = action {
260 for param in &action.params {
262 let when_arg = when.args.iter().find(|a| a.name == param.name);
263 if let Some(arg) = when_arg {
264 let param_type = simple_type_name(¶m.ty.ty).unwrap_or_default();
265 let is_entity = entities.contains_key(¶m_type);
266
267 match &arg.value.kind {
268 ExprKind::Ident(name) if given_types.contains_key(name) => {
269 if is_entity {
270 args.push(format!("&mut {name}"));
271 } else {
272 args.push(name.clone());
273 }
274 }
275 _ => {
276 let hint = simple_type_name(¶m.ty.ty);
277 let value = param_value_to_rust(&arg.value, hint.as_deref());
278 args.push(value);
279 }
280 }
281 }
282 }
283 } else {
284 for arg in &when.args {
286 args.push(expr_to_rust(&arg.value));
287 }
288 }
289
290 out.push_str(&format!(
291 " let result = {}({});\n",
292 fn_name,
293 args.join(", ")
294 ));
295}
296
297fn param_value_to_rust(expr: &ast::Expr, type_hint: Option<&str>) -> String {
299 match &expr.kind {
300 ExprKind::Literal(lit) => literal_to_rust(lit, type_hint),
301 _ => expr_to_rust(expr),
302 }
303}
304
305fn generate_assertion(
308 out: &mut String,
309 expr: &ast::Expr,
310 given_types: &HashMap<String, String>,
311 entities: &HashMap<String, EntityInfo>,
312) {
313 if let ExprKind::Compare { left, op, right } = &expr.kind {
314 let lhs = expr_to_rust(left);
315
316 let type_hint = resolve_field_type(left, given_types, entities);
318 let union_ctx = resolve_union_context(left, given_types, entities);
319
320 let rhs = match &right.kind {
321 ExprKind::Ident(name) if union_ctx.is_some() => {
322 format!("{}::{name}", union_ctx.unwrap())
323 }
324 ExprKind::Literal(lit) => literal_to_rust(lit, type_hint.as_deref()),
325 _ => expr_to_rust(right),
326 };
327
328 match op {
329 CmpOp::Eq => out.push_str(&format!(" assert_eq!({lhs}, {rhs});\n")),
330 CmpOp::Ne => out.push_str(&format!(" assert_ne!({lhs}, {rhs});\n")),
331 _ => {
332 let op_str = match op {
333 CmpOp::Lt => "<",
334 CmpOp::Gt => ">",
335 CmpOp::Le => "<=",
336 CmpOp::Ge => ">=",
337 _ => unreachable!(),
338 };
339 out.push_str(&format!(" assert!({lhs} {op_str} {rhs});\n"));
340 }
341 }
342 } else {
343 out.push_str(&format!(" assert!({});\n", expr_to_rust(expr)));
344 }
345}
346
347fn resolve_field_type(
349 expr: &ast::Expr,
350 given_types: &HashMap<String, String>,
351 entities: &HashMap<String, EntityInfo>,
352) -> Option<String> {
353 if let ExprKind::FieldAccess { root, fields } = &expr.kind
354 && let ExprKind::Ident(var) = &root.kind
355 {
356 let entity_name = given_types.get(var)?;
357 let entity = entities.get(entity_name)?;
358 let field_name = fields.first()?;
359 let meta = entity.fields.get(field_name)?;
360 return meta.simple_type.clone();
361 }
362 None
363}
364
365fn resolve_union_context(
367 expr: &ast::Expr,
368 given_types: &HashMap<String, String>,
369 entities: &HashMap<String, EntityInfo>,
370) -> Option<String> {
371 if let ExprKind::FieldAccess { root, fields } = &expr.kind
372 && let ExprKind::Ident(var) = &root.kind
373 {
374 let entity_name = given_types.get(var)?;
375 let entity = entities.get(entity_name)?;
376 let field_name = fields.first()?;
377 let meta = entity.fields.get(field_name)?;
378 return meta.union_info.as_ref().map(|(name, _)| name.clone());
379 }
380 None
381}
382
383fn literal_to_rust(lit: &Literal, type_hint: Option<&str>) -> String {
387 match (lit, type_hint) {
388 (Literal::String(_), Some("UUID")) => "Uuid::new_v4()".to_string(),
390 (Literal::String(_), Some("DateTime")) => "Utc::now()".to_string(),
392 (Literal::String(s), _) => format!("\"{s}\".to_string()"),
394 (Literal::Decimal(s), _) => format!("Decimal::from_str(\"{s}\").unwrap()"),
396 (Literal::Int(n), Some("Decimal")) => format!("Decimal::from({n}_i64)"),
398 (Literal::Int(n), _) => format!("{n}"),
399 (Literal::Bool(b), _) => format!("{b}"),
400 (Literal::Null, _) => "None".to_string(),
401 }
402}
403
404fn expr_to_rust(expr: &ast::Expr) -> String {
406 match &expr.kind {
407 ExprKind::Literal(lit) => literal_to_rust(lit, None),
408 ExprKind::Ident(name) => name.clone(),
409 ExprKind::FieldAccess { root, fields } => {
410 format!("{}.{}", expr_to_rust(root), fields.join("."))
411 }
412 ExprKind::Compare { left, op, right } => {
413 let op_str = match op {
414 CmpOp::Eq => "==",
415 CmpOp::Ne => "!=",
416 CmpOp::Lt => "<",
417 CmpOp::Gt => ">",
418 CmpOp::Le => "<=",
419 CmpOp::Ge => ">=",
420 };
421 format!("{} {op_str} {}", expr_to_rust(left), expr_to_rust(right))
422 }
423 ExprKind::Arithmetic { left, op, right } => {
424 let op_str = match op {
425 ast::ArithOp::Add => "+",
426 ast::ArithOp::Sub => "-",
427 };
428 format!("{} {op_str} {}", expr_to_rust(left), expr_to_rust(right))
429 }
430 ExprKind::And(l, r) => format!("{} && {}", expr_to_rust(l), expr_to_rust(r)),
431 ExprKind::Or(l, r) => format!("{} || {}", expr_to_rust(l), expr_to_rust(r)),
432 ExprKind::Not(e) => format!("!{}", expr_to_rust(e)),
433 ExprKind::Old(e) => format!("/* old */ {}", expr_to_rust(e)),
434 ExprKind::Call { name, args } => {
435 let args_str: Vec<String> = args
436 .iter()
437 .map(|a| match a {
438 ast::CallArg::Named { value, .. } => expr_to_rust(value),
439 ast::CallArg::Positional(e) => expr_to_rust(e),
440 })
441 .collect();
442 format!("{name}({})", args_str.join(", "))
443 }
444 ExprKind::Implies(l, r) => {
445 format!("!({}) || ({})", expr_to_rust(l), expr_to_rust(r))
446 }
447 ExprKind::Quantifier { .. } => "true /* quantifier */".to_string(),
448 ExprKind::List(items) => {
449 let inner: Vec<String> = items.iter().map(expr_to_rust).collect();
450 format!("vec![{}]", inner.join(", "))
451 }
452 }
453}
454
455fn safe_field(name: &str) -> String {
458 let snake = to_snake_case(name);
459 const KEYWORDS: &[&str] = &[
460 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
461 "extern", "false", "fn", "for", "gen", "if", "impl", "in", "let", "loop", "match", "mod",
462 "move", "mut", "pub", "ref", "return", "self", "static", "struct", "super", "trait",
463 "true", "type", "unsafe", "use", "where", "while", "yield",
464 ];
465 if KEYWORDS.contains(&snake.as_str()) {
466 format!("r#{snake}")
467 } else {
468 snake
469 }
470}
471
472fn capitalize(s: &str) -> String {
473 let mut chars = s.chars();
474 match chars.next() {
475 None => String::new(),
476 Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
477 }
478}
479
480#[cfg(test)]
483mod tests {
484 use super::*;
485
486 fn parse(src: &str) -> ast::File {
487 intent_parser::parse_file(src).expect("parse failed")
488 }
489
490 #[test]
491 fn test_generate_empty_for_no_tests() {
492 let src = "module Test\n\nentity Foo { id: UUID }\n";
493 assert!(generate(&parse(src)).is_empty());
494 }
495
496 #[test]
497 fn test_generate_failure_test() {
498 let src = r#"module Test
499
500entity Acc {
501 id: UUID
502 balance: Int
503}
504
505action Withdraw {
506 account: Acc
507 amount: Int
508
509 requires {
510 amount > 0
511 account.balance >= amount
512 }
513}
514
515test "overdraft rejected" {
516 given {
517 acc = Acc { id: "x", balance: 50 }
518 }
519 when Withdraw { account: acc, amount: 100 }
520 then fails precondition
521}
522"#;
523 let harness = generate(&parse(src));
524 assert!(harness.contains("#[cfg(test)]"));
525 assert!(harness.contains("mod contract_tests"));
526 assert!(harness.contains("fn test_overdraft_rejected()"));
527 assert!(harness.contains("assert!(result.is_err()"));
528 assert!(harness.contains("&mut acc"));
529 }
530
531 #[test]
532 fn test_generate_success_test_with_assertions() {
533 let src = r#"module Test
534
535entity Counter {
536 id: UUID
537 value: Int
538}
539
540action Increment {
541 counter: Counter
542
543 ensures {
544 counter.value == old(counter.value) + 1
545 }
546}
547
548test "increment works" {
549 given {
550 c = Counter { id: "c1", value: 5 }
551 }
552 when Increment { counter: c }
553 then {
554 c.value == 6
555 }
556}
557"#;
558 let harness = generate(&parse(src));
559 assert!(harness.contains("fn test_increment_works()"));
560 assert!(harness.contains("assert!(result.is_ok()"));
561 assert!(harness.contains("assert_eq!(c.value, 6)"));
562 assert!(harness.contains("&mut c"));
563 }
564
565 #[test]
566 fn test_union_enum_in_given() {
567 let src = r#"module Test
568
569entity Acc {
570 id: UUID
571 status: Active | Frozen
572}
573
574action Freeze {
575 account: Acc
576}
577
578test "freeze active" {
579 given {
580 a = Acc { id: "a1", status: Active }
581 }
582 when Freeze { account: a }
583 then {
584 a.status == Frozen
585 }
586}
587"#;
588 let harness = generate(&parse(src));
589 assert!(harness.contains("AccStatus::Active"));
591 assert!(harness.contains("AccStatus::Frozen"));
593 }
594
595 #[test]
596 fn test_decimal_values() {
597 let src = r#"module Test
598
599entity Acc {
600 id: UUID
601 balance: Decimal(precision: 2)
602}
603
604action Deposit {
605 account: Acc
606 amount: Decimal(precision: 2)
607}
608
609test "deposit adds" {
610 given {
611 a = Acc { id: "a1", balance: 100.00 }
612 }
613 when Deposit { account: a, amount: 50.00 }
614 then {
615 a.balance == 150.00
616 }
617}
618"#;
619 let harness = generate(&parse(src));
620 assert!(harness.contains("use std::str::FromStr"));
621 assert!(harness.contains("Decimal::from_str(\"100.00\").unwrap()"));
622 assert!(harness.contains("Decimal::from_str(\"50.00\").unwrap()"));
623 assert!(harness.contains("Decimal::from_str(\"150.00\").unwrap()"));
624 }
625
626 #[test]
627 fn test_non_entity_params() {
628 let src = r#"module Test
629
630entity Item { id: UUID }
631
632action SetPrice {
633 item: Item
634 price: Int
635}
636
637test "set price" {
638 given {
639 i = Item { id: "i1" }
640 p = 42
641 }
642 when SetPrice { item: i, price: p }
643 then {
644 p == 42
645 }
646}
647"#;
648 let harness = generate(&parse(src));
649 assert!(harness.contains("&mut i"));
651 assert!(harness.contains(", p)"));
653 }
654
655 #[test]
656 fn test_multiple_tests() {
657 let src = r#"module Test
658
659entity X { id: UUID }
660
661action DoIt { x: X }
662
663test "first" {
664 given { x = X { id: "1" } }
665 when DoIt { x: x }
666 then fails
667}
668
669test "second" {
670 given { x = X { id: "2" } }
671 when DoIt { x: x }
672 then fails
673}
674"#;
675 let harness = generate(&parse(src));
676 assert!(harness.contains("fn test_first()"));
677 assert!(harness.contains("fn test_second()"));
678 }
679
680 #[test]
681 fn test_transfer_example() {
682 let src = r#"module TransferFunds
683
684entity Account {
685 id: UUID
686 owner: String
687 balance: Decimal(precision: 2)
688 currency: CurrencyCode
689 status: Active | Frozen | Closed
690 created_at: DateTime
691}
692
693action Transfer {
694 from: Account
695 to: Account
696 amount: Decimal(precision: 2)
697 request_id: UUID
698
699 requires {
700 from.status == Active
701 to.status == Active
702 amount > 0
703 from.balance >= amount
704 }
705
706 ensures {
707 from.balance == old(from.balance) - amount
708 to.balance == old(to.balance) + amount
709 }
710}
711
712test "successful transfer" {
713 given {
714 from = Account { id: "acc-1", owner: "Alice", balance: 1000.00, currency: "USD", status: Active, created_at: "2024-01-01" }
715 to = Account { id: "acc-2", owner: "Bob", balance: 500.00, currency: "USD", status: Active, created_at: "2024-01-01" }
716 }
717 when Transfer {
718 from: from,
719 to: to,
720 amount: 200.00,
721 request_id: "req-1"
722 }
723 then {
724 from.balance == 800.00
725 to.balance == 700.00
726 }
727}
728
729test "insufficient funds" {
730 given {
731 from = Account { id: "acc-1", owner: "Alice", balance: 50.00, currency: "USD", status: Active, created_at: "2024-01-01" }
732 to = Account { id: "acc-2", owner: "Bob", balance: 500.00, currency: "USD", status: Active, created_at: "2024-01-01" }
733 }
734 when Transfer {
735 from: from,
736 to: to,
737 amount: 200.00,
738 request_id: "req-2"
739 }
740 then fails precondition
741}
742
743test "frozen account rejected" {
744 given {
745 from = Account { id: "acc-1", owner: "Alice", balance: 1000.00, currency: "USD", status: Frozen, created_at: "2024-01-01" }
746 to = Account { id: "acc-2", owner: "Bob", balance: 500.00, currency: "USD", status: Active, created_at: "2024-01-01" }
747 }
748 when Transfer {
749 from: from,
750 to: to,
751 amount: 100.00,
752 request_id: "req-3"
753 }
754 then fails precondition
755}
756"#;
757 let harness = generate(&parse(src));
758
759 assert!(harness.contains("fn test_successful_transfer()"));
761 assert!(harness.contains("fn test_insufficient_funds()"));
762 assert!(harness.contains("fn test_frozen_account_rejected()"));
763
764 assert!(harness.contains("&mut from"));
766 assert!(harness.contains("&mut to"));
767
768 assert!(harness.contains("Decimal::from_str(\"1000.00\").unwrap()"));
770 assert!(harness.contains("Decimal::from_str(\"200.00\").unwrap()"));
771
772 assert!(harness.contains("Uuid::new_v4()"));
774
775 assert!(harness.contains("Utc::now()"));
777
778 assert!(harness.contains("AccountStatus::Active"));
780 assert!(harness.contains("AccountStatus::Frozen"));
781
782 assert!(harness.contains("\"Alice\".to_string()"));
784 assert!(harness.contains("\"USD\".to_string()"));
785
786 assert!(
788 harness.contains("assert_eq!(from.balance, Decimal::from_str(\"800.00\").unwrap())")
789 );
790 assert!(harness.contains("assert_eq!(to.balance, Decimal::from_str(\"700.00\").unwrap())"));
791
792 assert!(harness.contains("assert!(result.is_err()"));
794 }
795}