1use std::collections::{BTreeSet, HashMap, HashSet};
2
3use crate::ast::{
4 Expr, FnDef, MatchArm, Stmt, StrPart, TopLevel, VerifyBlock, VerifyKind, VerifyLaw,
5};
6use crate::types::Type;
7
8pub type FnSigMap = HashMap<String, (Vec<Type>, Type, Vec<String>)>;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct NamedLawFunction {
12 pub name: String,
13 pub is_pure: bool,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct VerifyLawSpecRef {
18 pub spec_fn_name: String,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct MissingHelperLawHint {
23 pub line: usize,
24 pub fn_name: String,
25 pub law_name: String,
26 pub missing_helpers: Vec<String>,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct ContextualHelperLawHint {
31 pub line: usize,
32 pub fn_name: String,
33 pub law_name: String,
34 pub missing_helpers: Vec<String>,
35}
36
37pub fn named_law_function(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<NamedLawFunction> {
38 let (_, _, effects) = fn_sigs.get(&law.name)?;
39 Some(NamedLawFunction {
40 name: law.name.clone(),
41 is_pure: effects.is_empty(),
42 })
43}
44
45pub fn declared_spec_ref(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<VerifyLawSpecRef> {
46 let named = named_law_function(law, fn_sigs)?;
47 named.is_pure.then_some(VerifyLawSpecRef {
48 spec_fn_name: named.name,
49 })
50}
51
52pub fn law_spec_ref(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<VerifyLawSpecRef> {
53 let spec = declared_spec_ref(law, fn_sigs)?;
54 law_calls_function(law, &spec.spec_fn_name).then_some(spec)
55}
56
57pub fn canonical_spec_ref(
58 fn_name: &str,
59 law: &VerifyLaw,
60 fn_sigs: &FnSigMap,
61) -> Option<VerifyLawSpecRef> {
62 let spec = law_spec_ref(law, fn_sigs)?;
63 canonical_spec_shape(fn_name, law, &spec.spec_fn_name).then_some(spec)
64}
65
66pub fn law_calls_function(law: &VerifyLaw, fn_name: &str) -> bool {
67 expr_calls_function(&law.lhs, fn_name) || expr_calls_function(&law.rhs, fn_name)
68}
69
70pub fn canonical_spec_shape(fn_name: &str, law: &VerifyLaw, spec_fn_name: &str) -> bool {
71 let try_side = |impl_side: &Expr, spec_side: &Expr| -> bool {
72 let Some((impl_callee, impl_args)) = direct_call(impl_side) else {
73 return false;
74 };
75 let Some((spec_callee, spec_args)) = direct_call(spec_side) else {
76 return false;
77 };
78 impl_callee == fn_name && spec_callee == spec_fn_name && impl_args == spec_args
79 };
80
81 try_side(&law.lhs, &law.rhs) || try_side(&law.rhs, &law.lhs)
82}
83
84pub fn collect_missing_helper_law_hints(
85 items: &[TopLevel],
86 fn_sigs: &FnSigMap,
87) -> Vec<MissingHelperLawHint> {
88 let fn_defs = items
89 .iter()
90 .filter_map(|item| {
91 if let TopLevel::FnDef(fd) = item {
92 Some((fd.name.clone(), fd))
93 } else {
94 None
95 }
96 })
97 .collect::<HashMap<_, _>>();
98 let verified_law_functions = items
99 .iter()
100 .filter_map(|item| {
101 let TopLevel::Verify(vb) = item else {
102 return None;
103 };
104 let VerifyKind::Law(law) = &vb.kind else {
105 return None;
106 };
107 let mut covered = BTreeSet::new();
108 covered.insert(vb.fn_name.clone());
109 collect_direct_pure_user_calls(&law.lhs, &fn_defs, fn_sigs, &mut covered);
110 collect_direct_pure_user_calls(&law.rhs, &fn_defs, fn_sigs, &mut covered);
111 Some(covered)
112 })
113 .flatten()
114 .collect::<HashSet<_>>();
115
116 items
117 .iter()
118 .filter_map(|item| {
119 let TopLevel::Verify(vb) = item else {
120 return None;
121 };
122 let VerifyKind::Law(law) = &vb.kind else {
123 return None;
124 };
125 missing_helper_law_hint_for_block(vb, law, &fn_defs, &verified_law_functions, fn_sigs)
126 })
127 .collect()
128}
129
130pub fn missing_helper_law_message(hint: &MissingHelperLawHint) -> String {
131 format!(
132 "verify law '{}.{}' uses helper functions without their own verify law: {}; add layered `verify ... law ...` blocks for those helpers before expecting a universal auto-proof",
133 hint.fn_name,
134 hint.law_name,
135 hint.missing_helpers.join(", ")
136 )
137}
138
139pub fn collect_contextual_helper_law_hints(
140 items: &[TopLevel],
141 fn_sigs: &FnSigMap,
142) -> Vec<ContextualHelperLawHint> {
143 let fn_defs = items
144 .iter()
145 .filter_map(|item| {
146 if let TopLevel::FnDef(fd) = item {
147 Some((fd.name.clone(), fd))
148 } else {
149 None
150 }
151 })
152 .collect::<HashMap<_, _>>();
153 let contextual_law_targets = items
154 .iter()
155 .filter_map(|item| {
156 let TopLevel::Verify(vb) = item else {
157 return None;
158 };
159 let VerifyKind::Law(law) = &vb.kind else {
160 return None;
161 };
162 top_level_direct_pure_call_in_law(law, &fn_defs, fn_sigs)
163 })
164 .collect::<HashSet<_>>();
165
166 items
167 .iter()
168 .filter_map(|item| {
169 let TopLevel::Verify(vb) = item else {
170 return None;
171 };
172 let VerifyKind::Law(law) = &vb.kind else {
173 return None;
174 };
175 contextual_helper_law_hint_for_block(
176 vb,
177 law,
178 &fn_defs,
179 &contextual_law_targets,
180 fn_sigs,
181 )
182 })
183 .collect()
184}
185
186pub fn contextual_helper_law_message(hint: &ContextualHelperLawHint) -> String {
187 format!(
188 "verify law '{}.{}' still lacks analogous `verify ... law ...` coverage for contextual helpers: {}; universal auto-proof will likely stop at those helper boundaries",
189 hint.fn_name,
190 hint.law_name,
191 hint.missing_helpers.join(", ")
192 )
193}
194
195fn missing_helper_law_hint_for_block(
196 vb: &VerifyBlock,
197 law: &VerifyLaw,
198 fn_defs: &HashMap<String, &FnDef>,
199 verified_law_functions: &HashSet<String>,
200 fn_sigs: &FnSigMap,
201) -> Option<MissingHelperLawHint> {
202 if law.when.is_none() || law.givens.len() != 1 {
203 return None;
204 }
205
206 let root_calls = direct_pure_user_calls_in_law(law, fn_defs, fn_sigs);
207 if root_calls.is_empty() {
208 return None;
209 }
210
211 let mut missing_helpers = BTreeSet::new();
212 for root in root_calls {
213 for helper in frontier_helper_calls(&root, fn_defs, fn_sigs) {
214 if helper != vb.fn_name && !verified_law_functions.contains(&helper) {
215 missing_helpers.insert(helper);
216 }
217 }
218 }
219
220 if missing_helpers.is_empty() {
221 return None;
222 }
223
224 Some(MissingHelperLawHint {
225 line: vb.line,
226 fn_name: vb.fn_name.clone(),
227 law_name: law.name.clone(),
228 missing_helpers: missing_helpers.into_iter().collect(),
229 })
230}
231
232fn contextual_helper_law_hint_for_block(
233 vb: &VerifyBlock,
234 law: &VerifyLaw,
235 fn_defs: &HashMap<String, &FnDef>,
236 contextual_law_targets: &HashSet<String>,
237 fn_sigs: &FnSigMap,
238) -> Option<ContextualHelperLawHint> {
239 let parser_name = contextual_roundtrip_parser_name(law, fn_defs, fn_sigs)?;
240 let root_parser_name = wrapper_dispatch_root(&parser_name, fn_defs, fn_sigs)
241 .unwrap_or_else(|| parser_name.clone());
242 if root_parser_name != parser_name {
243 return None;
244 }
245
246 let missing_helpers = frontier_helper_calls(&root_parser_name, fn_defs, fn_sigs)
247 .into_iter()
248 .filter(|helper| helper != &vb.fn_name && !contextual_law_targets.contains(helper))
249 .collect::<BTreeSet<_>>();
250
251 if missing_helpers.is_empty() {
252 return None;
253 }
254
255 Some(ContextualHelperLawHint {
256 line: vb.line,
257 fn_name: vb.fn_name.clone(),
258 law_name: law.name.clone(),
259 missing_helpers: missing_helpers.into_iter().collect(),
260 })
261}
262
263fn direct_pure_user_calls_in_law(
264 law: &VerifyLaw,
265 fn_defs: &HashMap<String, &FnDef>,
266 fn_sigs: &FnSigMap,
267) -> BTreeSet<String> {
268 let mut out = BTreeSet::new();
269 collect_direct_pure_user_calls(&law.lhs, fn_defs, fn_sigs, &mut out);
270 collect_direct_pure_user_calls(&law.rhs, fn_defs, fn_sigs, &mut out);
271 out
272}
273
274fn top_level_direct_pure_call_in_law(
275 law: &VerifyLaw,
276 fn_defs: &HashMap<String, &FnDef>,
277 fn_sigs: &FnSigMap,
278) -> Option<String> {
279 direct_pure_user_call_name(&law.lhs, fn_defs, fn_sigs)
280 .or_else(|| direct_pure_user_call_name(&law.rhs, fn_defs, fn_sigs))
281}
282
283fn contextual_roundtrip_parser_name(
284 law: &VerifyLaw,
285 fn_defs: &HashMap<String, &FnDef>,
286 fn_sigs: &FnSigMap,
287) -> Option<String> {
288 let given = law.givens.first()?;
289 detect_roundtrip_layers(law, &given.name, fn_defs, fn_sigs).map(|(parser_name, _)| parser_name)
290}
291
292fn frontier_helper_calls(
293 root_name: &str,
294 fn_defs: &HashMap<String, &FnDef>,
295 fn_sigs: &FnSigMap,
296) -> BTreeSet<String> {
297 let mut current =
298 wrapper_dispatch_root(root_name, fn_defs, fn_sigs).unwrap_or_else(|| root_name.to_string());
299 let mut visited = BTreeSet::new();
300
301 for _ in 0..2 {
302 if !visited.insert(current.clone()) {
303 break;
304 }
305 let direct = direct_pure_fn_callees_matching_return(¤t, fn_defs, fn_sigs);
306 if direct.is_empty() {
307 return BTreeSet::new();
308 }
309 if direct.len() == 1 {
310 current = direct.iter().next().cloned().unwrap_or_default();
311 continue;
312 }
313 return direct;
314 }
315
316 direct_pure_fn_callees_matching_return(¤t, fn_defs, fn_sigs)
317}
318
319fn wrapper_dispatch_root(
320 fn_name: &str,
321 fn_defs: &HashMap<String, &FnDef>,
322 fn_sigs: &FnSigMap,
323) -> Option<String> {
324 let fd = fn_defs.get(fn_name)?;
325 let tail = fd.body.tail_expr()?;
326 match tail {
327 Expr::Match { subject, .. } => direct_pure_user_call_name(subject, fn_defs, fn_sigs),
328 Expr::FnCall(_, _) => direct_pure_user_call_name(tail, fn_defs, fn_sigs),
329 _ => None,
330 }
331}
332
333fn direct_pure_fn_callees_matching_return(
334 fn_name: &str,
335 fn_defs: &HashMap<String, &FnDef>,
336 fn_sigs: &FnSigMap,
337) -> BTreeSet<String> {
338 let Some((_, return_type, _)) = fn_sigs.get(fn_name) else {
339 return BTreeSet::new();
340 };
341 let Some(fd) = fn_defs.get(fn_name) else {
342 return BTreeSet::new();
343 };
344
345 let mut direct = BTreeSet::new();
346 for stmt in fd.body.stmts() {
347 match stmt {
348 Stmt::Expr(expr) | Stmt::Binding(_, _, expr) => {
349 collect_direct_pure_user_calls(expr, fn_defs, fn_sigs, &mut direct);
350 }
351 }
352 }
353 direct
354 .into_iter()
355 .filter(|callee| {
356 callee != fn_name
357 && fn_sigs.get(callee).is_some_and(|(_, callee_ret, effects)| {
358 effects.is_empty() && callee_ret == return_type
359 })
360 })
361 .collect()
362}
363
364fn collect_direct_pure_user_calls(
365 expr: &Expr,
366 fn_defs: &HashMap<String, &FnDef>,
367 fn_sigs: &FnSigMap,
368 out: &mut BTreeSet<String>,
369) {
370 match expr {
371 Expr::FnCall(callee, args) => {
372 if let Some(name) = direct_pure_user_call_name(expr, fn_defs, fn_sigs) {
373 out.insert(name);
374 }
375 collect_direct_pure_user_calls(callee, fn_defs, fn_sigs, out);
376 for arg in args {
377 collect_direct_pure_user_calls(arg, fn_defs, fn_sigs, out);
378 }
379 }
380 Expr::Attr(obj, _) => collect_direct_pure_user_calls(obj, fn_defs, fn_sigs, out),
381 Expr::BinOp(_, left, right) => {
382 collect_direct_pure_user_calls(left, fn_defs, fn_sigs, out);
383 collect_direct_pure_user_calls(right, fn_defs, fn_sigs, out);
384 }
385 Expr::Match { subject, arms, .. } => {
386 collect_direct_pure_user_calls(subject, fn_defs, fn_sigs, out);
387 for arm in arms {
388 collect_direct_pure_user_calls(&arm.body, fn_defs, fn_sigs, out);
389 }
390 }
391 Expr::Constructor(_, Some(inner)) | Expr::ErrorProp(inner) => {
392 collect_direct_pure_user_calls(inner, fn_defs, fn_sigs, out);
393 }
394 Expr::InterpolatedStr(parts) => {
395 for part in parts {
396 if let StrPart::Parsed(inner) = part {
397 collect_direct_pure_user_calls(inner, fn_defs, fn_sigs, out);
398 }
399 }
400 }
401 Expr::List(items) | Expr::Tuple(items) => {
402 for item in items {
403 collect_direct_pure_user_calls(item, fn_defs, fn_sigs, out);
404 }
405 }
406 Expr::MapLiteral(entries) => {
407 for (key, value) in entries {
408 collect_direct_pure_user_calls(key, fn_defs, fn_sigs, out);
409 collect_direct_pure_user_calls(value, fn_defs, fn_sigs, out);
410 }
411 }
412 Expr::RecordCreate { fields, .. } => {
413 for (_, value) in fields {
414 collect_direct_pure_user_calls(value, fn_defs, fn_sigs, out);
415 }
416 }
417 Expr::RecordUpdate { base, updates, .. } => {
418 collect_direct_pure_user_calls(base, fn_defs, fn_sigs, out);
419 for (_, value) in updates {
420 collect_direct_pure_user_calls(value, fn_defs, fn_sigs, out);
421 }
422 }
423 Expr::TailCall(boxed) => {
424 let (target, args) = boxed.as_ref();
425 if fn_defs.contains_key(target)
426 && fn_sigs
427 .get(target)
428 .is_some_and(|(_, _, effects)| effects.is_empty())
429 {
430 out.insert(target.clone());
431 }
432 for arg in args {
433 collect_direct_pure_user_calls(arg, fn_defs, fn_sigs, out);
434 }
435 }
436 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
437 }
438}
439
440fn direct_pure_user_call_name(
441 expr: &Expr,
442 fn_defs: &HashMap<String, &FnDef>,
443 fn_sigs: &FnSigMap,
444) -> Option<String> {
445 let Expr::FnCall(callee, _) = expr else {
446 return None;
447 };
448 let name = dotted_name(callee)?;
449 if !fn_defs.contains_key(&name) {
450 return None;
451 }
452 fn_sigs
453 .get(&name)
454 .is_some_and(|(_, _, effects)| effects.is_empty())
455 .then_some(name)
456}
457
458fn dotted_name(expr: &Expr) -> Option<String> {
459 match expr {
460 Expr::Ident(name) => Some(name.clone()),
461 Expr::Attr(base, field) => {
462 let mut prefix = dotted_name(base)?;
463 prefix.push('.');
464 prefix.push_str(field);
465 Some(prefix)
466 }
467 _ => None,
468 }
469}
470
471fn detect_roundtrip_layers(
472 law: &VerifyLaw,
473 given_name: &str,
474 fn_defs: &HashMap<String, &FnDef>,
475 fn_sigs: &FnSigMap,
476) -> Option<(String, String)> {
477 if law.givens.len() != 1 {
478 return None;
479 }
480
481 fn detect_roundtrip_side(
482 expr: &Expr,
483 given_name: &str,
484 fn_defs: &HashMap<String, &FnDef>,
485 fn_sigs: &FnSigMap,
486 ) -> Option<(String, String)> {
487 let Expr::FnCall(parser_callee, parser_args) = expr else {
488 return None;
489 };
490 if parser_args.is_empty() {
491 return None;
492 }
493 let (serializer_callee, serializer_args) =
494 extract_roundtrip_serializer_call(&parser_args[0], given_name)?;
495 if !serializer_args
496 .iter()
497 .any(|arg| matches_ident(arg, given_name))
498 {
499 return None;
500 }
501 if serializer_args
502 .iter()
503 .filter(|arg| expr_mentions_ident(arg, given_name))
504 .any(|arg| !matches_ident(arg, given_name))
505 {
506 return None;
507 }
508 if parser_args[1..]
509 .iter()
510 .any(|arg| expr_mentions_ident(arg, given_name))
511 {
512 return None;
513 }
514
515 let parser_name = dotted_name(parser_callee)?;
516 let serializer_name = dotted_name(serializer_callee)?;
517 if !fn_defs.contains_key(&parser_name) || !fn_defs.contains_key(&serializer_name) {
518 return None;
519 }
520 if !fn_sigs
521 .get(&parser_name)
522 .is_some_and(|(_, _, effects)| effects.is_empty())
523 {
524 return None;
525 }
526 if !fn_sigs
527 .get(&serializer_name)
528 .is_some_and(|(_, _, effects)| effects.is_empty())
529 {
530 return None;
531 }
532 Some((parser_name, serializer_name))
533 }
534
535 detect_roundtrip_side(&law.lhs, given_name, fn_defs, fn_sigs)
536 .or_else(|| detect_roundtrip_side(&law.rhs, given_name, fn_defs, fn_sigs))
537}
538
539fn extract_roundtrip_serializer_call<'a>(
540 expr: &'a Expr,
541 given_name: &str,
542) -> Option<(&'a Expr, &'a [Expr])> {
543 let mut candidates = Vec::new();
544 collect_roundtrip_serializer_calls(expr, given_name, &mut candidates);
545 if candidates.len() != 1 {
546 return None;
547 }
548 let (callee, args) = candidates.pop()?;
549 if expr_mentions_ident(expr, given_name)
550 && args
551 .iter()
552 .filter(|arg| expr_mentions_ident(arg, given_name))
553 .all(|arg| matches_ident(arg, given_name))
554 {
555 Some((callee, args))
556 } else {
557 None
558 }
559}
560
561fn collect_roundtrip_serializer_calls<'a>(
562 expr: &'a Expr,
563 given_name: &str,
564 out: &mut Vec<(&'a Expr, &'a [Expr])>,
565) {
566 match expr {
567 Expr::FnCall(callee, args) => {
568 if args.iter().any(|arg| matches_ident(arg, given_name))
569 && args
570 .iter()
571 .filter(|arg| expr_mentions_ident(arg, given_name))
572 .all(|arg| matches_ident(arg, given_name))
573 {
574 out.push((callee.as_ref(), args.as_slice()));
575 }
576 collect_roundtrip_serializer_calls(callee, given_name, out);
577 for arg in args {
578 collect_roundtrip_serializer_calls(arg, given_name, out);
579 }
580 }
581 Expr::Attr(base, _) => collect_roundtrip_serializer_calls(base, given_name, out),
582 Expr::BinOp(_, left, right) => {
583 collect_roundtrip_serializer_calls(left, given_name, out);
584 collect_roundtrip_serializer_calls(right, given_name, out);
585 }
586 Expr::Match { subject, arms, .. } => {
587 collect_roundtrip_serializer_calls(subject, given_name, out);
588 for arm in arms {
589 collect_roundtrip_serializer_calls(&arm.body, given_name, out);
590 }
591 }
592 Expr::Constructor(_, inner) => {
593 if let Some(inner) = inner {
594 collect_roundtrip_serializer_calls(inner, given_name, out);
595 }
596 }
597 Expr::ErrorProp(inner) => collect_roundtrip_serializer_calls(inner, given_name, out),
598 Expr::InterpolatedStr(parts) => {
599 for part in parts {
600 if let StrPart::Parsed(inner) = part {
601 collect_roundtrip_serializer_calls(inner, given_name, out);
602 }
603 }
604 }
605 Expr::List(items) | Expr::Tuple(items) => {
606 for item in items {
607 collect_roundtrip_serializer_calls(item, given_name, out);
608 }
609 }
610 Expr::MapLiteral(entries) => {
611 for (key, value) in entries {
612 collect_roundtrip_serializer_calls(key, given_name, out);
613 collect_roundtrip_serializer_calls(value, given_name, out);
614 }
615 }
616 Expr::RecordCreate { fields, .. } => {
617 for (_, value) in fields {
618 collect_roundtrip_serializer_calls(value, given_name, out);
619 }
620 }
621 Expr::RecordUpdate { base, updates, .. } => {
622 collect_roundtrip_serializer_calls(base, given_name, out);
623 for (_, value) in updates {
624 collect_roundtrip_serializer_calls(value, given_name, out);
625 }
626 }
627 Expr::TailCall(call) => {
628 for arg in &call.1 {
629 collect_roundtrip_serializer_calls(arg, given_name, out);
630 }
631 }
632 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => {}
633 }
634}
635
636fn matches_ident(expr: &Expr, name: &str) -> bool {
637 matches!(expr, Expr::Ident(current) if current == name)
638}
639
640fn expr_mentions_ident(expr: &Expr, name: &str) -> bool {
641 match expr {
642 Expr::Ident(current) => current == name,
643 Expr::Attr(base, _) => expr_mentions_ident(base, name),
644 Expr::FnCall(callee, args) => {
645 expr_mentions_ident(callee, name)
646 || args.iter().any(|arg| expr_mentions_ident(arg, name))
647 }
648 Expr::BinOp(_, left, right) => {
649 expr_mentions_ident(left, name) || expr_mentions_ident(right, name)
650 }
651 Expr::Match { subject, arms, .. } => {
652 expr_mentions_ident(subject, name)
653 || arms.iter().any(|arm| expr_mentions_ident(&arm.body, name))
654 }
655 Expr::Constructor(_, inner) => inner
656 .as_deref()
657 .is_some_and(|inner| expr_mentions_ident(inner, name)),
658 Expr::ErrorProp(inner) => expr_mentions_ident(inner, name),
659 Expr::InterpolatedStr(parts) => parts.iter().any(|part| match part {
660 StrPart::Literal(_) => false,
661 StrPart::Parsed(inner) => expr_mentions_ident(inner, name),
662 }),
663 Expr::List(items) | Expr::Tuple(items) => {
664 items.iter().any(|item| expr_mentions_ident(item, name))
665 }
666 Expr::MapLiteral(entries) => entries
667 .iter()
668 .any(|(key, value)| expr_mentions_ident(key, name) || expr_mentions_ident(value, name)),
669 Expr::RecordCreate { fields, .. } => fields
670 .iter()
671 .any(|(_, value)| expr_mentions_ident(value, name)),
672 Expr::RecordUpdate { base, updates, .. } => {
673 expr_mentions_ident(base, name)
674 || updates
675 .iter()
676 .any(|(_, value)| expr_mentions_ident(value, name))
677 }
678 Expr::TailCall(call) => call.1.iter().any(|arg| expr_mentions_ident(arg, name)),
679 Expr::Literal(_) | Expr::Resolved(_) => false,
680 }
681}
682
683fn expr_calls_function(expr: &Expr, fn_name: &str) -> bool {
684 match expr {
685 Expr::FnCall(callee, args) => {
686 expr_is_function_name(callee, fn_name)
687 || expr_calls_function(callee, fn_name)
688 || args.iter().any(|arg| expr_calls_function(arg, fn_name))
689 }
690 Expr::Attr(obj, _) => expr_calls_function(obj, fn_name),
691 Expr::BinOp(_, left, right) => {
692 expr_calls_function(left, fn_name) || expr_calls_function(right, fn_name)
693 }
694 Expr::Match { subject, arms, .. } => {
695 expr_calls_function(subject, fn_name)
696 || arms
697 .iter()
698 .any(|arm| match_arm_calls_function(arm, fn_name))
699 }
700 Expr::Constructor(_, Some(inner)) => expr_calls_function(inner, fn_name),
701 Expr::ErrorProp(inner) => expr_calls_function(inner, fn_name),
702 Expr::InterpolatedStr(parts) => parts.iter().any(|part| match part {
703 StrPart::Literal(_) => false,
704 StrPart::Parsed(expr) => expr_calls_function(expr, fn_name),
705 }),
706 Expr::List(items) | Expr::Tuple(items) => {
707 items.iter().any(|item| expr_calls_function(item, fn_name))
708 }
709 Expr::MapLiteral(entries) => entries.iter().any(|(key, value)| {
710 expr_calls_function(key, fn_name) || expr_calls_function(value, fn_name)
711 }),
712 Expr::RecordCreate { fields, .. } => fields
713 .iter()
714 .any(|(_, expr)| expr_calls_function(expr, fn_name)),
715 Expr::RecordUpdate { base, updates, .. } => {
716 expr_calls_function(base, fn_name)
717 || updates
718 .iter()
719 .any(|(_, expr)| expr_calls_function(expr, fn_name))
720 }
721 Expr::TailCall(boxed) => {
722 boxed.0 == fn_name || boxed.1.iter().any(|arg| expr_calls_function(arg, fn_name))
723 }
724 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => false,
725 }
726}
727
728fn match_arm_calls_function(arm: &MatchArm, fn_name: &str) -> bool {
729 expr_calls_function(&arm.body, fn_name)
730}
731
732fn expr_is_function_name(expr: &Expr, fn_name: &str) -> bool {
733 matches!(expr, Expr::Ident(name) if name == fn_name)
734}
735
736fn direct_call(expr: &Expr) -> Option<(&str, &[Expr])> {
737 let Expr::FnCall(callee, args) = expr else {
738 return None;
739 };
740 let Expr::Ident(name) = callee.as_ref() else {
741 return None;
742 };
743 Some((name.as_str(), args.as_slice()))
744}
745
746#[cfg(test)]
747mod tests {
748 use super::*;
749 use crate::ast::{Literal, VerifyGiven, VerifyGivenDomain};
750
751 fn int_sig() -> (Vec<Type>, Type, Vec<String>) {
752 (vec![Type::Int], Type::Int, vec![])
753 }
754
755 fn law(lhs: Expr, rhs: Expr, name: &str) -> VerifyLaw {
756 VerifyLaw {
757 name: name.to_string(),
758 givens: vec![VerifyGiven {
759 name: "x".to_string(),
760 type_name: "Int".to_string(),
761 domain: VerifyGivenDomain::Explicit(vec![Expr::Literal(Literal::Int(1))]),
762 }],
763 when: None,
764 lhs,
765 rhs,
766 sample_guards: vec![],
767 }
768 }
769
770 #[test]
771 fn pure_named_law_function_becomes_declared_spec_ref() {
772 let mut fn_sigs = FnSigMap::new();
773 fn_sigs.insert("fooSpec".to_string(), int_sig());
774
775 let verify_law = law(
776 Expr::FnCall(
777 Box::new(Expr::Ident("foo".to_string())),
778 vec![Expr::Ident("x".to_string())],
779 ),
780 Expr::FnCall(
781 Box::new(Expr::Ident("fooSpec".to_string())),
782 vec![Expr::Ident("x".to_string())],
783 ),
784 "fooSpec",
785 );
786
787 assert_eq!(
788 declared_spec_ref(&verify_law, &fn_sigs),
789 Some(VerifyLawSpecRef {
790 spec_fn_name: "fooSpec".to_string()
791 })
792 );
793 assert_eq!(
794 law_spec_ref(&verify_law, &fn_sigs),
795 declared_spec_ref(&verify_law, &fn_sigs)
796 );
797 assert_eq!(
798 canonical_spec_ref("foo", &verify_law, &fn_sigs),
799 declared_spec_ref(&verify_law, &fn_sigs)
800 );
801 }
802
803 #[test]
804 fn effectful_named_law_function_is_not_a_spec_ref() {
805 let mut fn_sigs = FnSigMap::new();
806 fn_sigs.insert(
807 "fooSpec".to_string(),
808 (
809 vec![Type::Int],
810 Type::Int,
811 vec!["Console.print".to_string()],
812 ),
813 );
814
815 let verify_law = law(
816 Expr::Ident("x".to_string()),
817 Expr::Ident("x".to_string()),
818 "fooSpec",
819 );
820
821 assert!(declared_spec_ref(&verify_law, &fn_sigs).is_none());
822 assert_eq!(
823 named_law_function(&verify_law, &fn_sigs),
824 Some(NamedLawFunction {
825 name: "fooSpec".to_string(),
826 is_pure: false
827 })
828 );
829 }
830
831 #[test]
832 fn canonical_spec_ref_requires_call_to_named_function() {
833 let mut fn_sigs = FnSigMap::new();
834 fn_sigs.insert("fooSpec".to_string(), int_sig());
835
836 let verify_law = law(
837 Expr::Ident("x".to_string()),
838 Expr::Ident("x".to_string()),
839 "fooSpec",
840 );
841
842 assert!(declared_spec_ref(&verify_law, &fn_sigs).is_some());
843 assert!(law_spec_ref(&verify_law, &fn_sigs).is_none());
844 assert!(!law_calls_function(&verify_law, "fooSpec"));
845 }
846
847 #[test]
848 fn canonical_spec_ref_requires_same_arguments_on_both_sides() {
849 let mut fn_sigs = FnSigMap::new();
850 fn_sigs.insert("fooSpec".to_string(), int_sig());
851
852 let verify_law = law(
853 Expr::FnCall(
854 Box::new(Expr::Ident("foo".to_string())),
855 vec![Expr::Ident("x".to_string())],
856 ),
857 Expr::FnCall(
858 Box::new(Expr::Ident("fooSpec".to_string())),
859 vec![Expr::Literal(Literal::Int(5)), Expr::Ident("x".to_string())],
860 ),
861 "fooSpec",
862 );
863
864 assert!(law_spec_ref(&verify_law, &fn_sigs).is_some());
865 assert!(canonical_spec_ref("foo", &verify_law, &fn_sigs).is_none());
866 assert!(!canonical_spec_shape("foo", &verify_law, "fooSpec"));
867 }
868}