Skip to main content

aver/
verify_law.rs

1use std::collections::{BTreeSet, HashMap, HashSet};
2
3use crate::ast::{
4    Expr, FnDef, MatchArm, Stmt, StrPart, TopLevel, VerifyBlock, VerifyKind, VerifyLaw,
5};
6use crate::types::Type;
7
8pub type FnSigMap = HashMap<String, (Vec<Type>, Type, Vec<String>)>;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct NamedLawFunction {
12    pub name: String,
13    pub is_pure: bool,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct VerifyLawSpecRef {
18    pub spec_fn_name: String,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct MissingHelperLawHint {
23    pub line: usize,
24    pub fn_name: String,
25    pub law_name: String,
26    pub missing_helpers: Vec<String>,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct ContextualHelperLawHint {
31    pub line: usize,
32    pub fn_name: String,
33    pub law_name: String,
34    pub missing_helpers: Vec<String>,
35}
36
37pub fn named_law_function(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<NamedLawFunction> {
38    let (_, _, effects) = fn_sigs.get(&law.name)?;
39    Some(NamedLawFunction {
40        name: law.name.clone(),
41        is_pure: effects.is_empty(),
42    })
43}
44
45pub fn declared_spec_ref(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<VerifyLawSpecRef> {
46    let named = named_law_function(law, fn_sigs)?;
47    named.is_pure.then_some(VerifyLawSpecRef {
48        spec_fn_name: named.name,
49    })
50}
51
52pub fn law_spec_ref(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<VerifyLawSpecRef> {
53    let spec = declared_spec_ref(law, fn_sigs)?;
54    law_calls_function(law, &spec.spec_fn_name).then_some(spec)
55}
56
57pub fn canonical_spec_ref(
58    fn_name: &str,
59    law: &VerifyLaw,
60    fn_sigs: &FnSigMap,
61) -> Option<VerifyLawSpecRef> {
62    let spec = law_spec_ref(law, fn_sigs)?;
63    canonical_spec_shape(fn_name, law, &spec.spec_fn_name).then_some(spec)
64}
65
66pub fn law_calls_function(law: &VerifyLaw, fn_name: &str) -> bool {
67    expr_calls_function(&law.lhs, fn_name) || expr_calls_function(&law.rhs, fn_name)
68}
69
70pub fn canonical_spec_shape(fn_name: &str, law: &VerifyLaw, spec_fn_name: &str) -> bool {
71    let try_side = |impl_side: &Expr, spec_side: &Expr| -> bool {
72        let Some((impl_callee, impl_args)) = direct_call(impl_side) else {
73            return false;
74        };
75        let Some((spec_callee, spec_args)) = direct_call(spec_side) else {
76            return false;
77        };
78        impl_callee == fn_name && spec_callee == spec_fn_name && impl_args == spec_args
79    };
80
81    try_side(&law.lhs, &law.rhs) || try_side(&law.rhs, &law.lhs)
82}
83
84pub fn collect_missing_helper_law_hints(
85    items: &[TopLevel],
86    fn_sigs: &FnSigMap,
87) -> Vec<MissingHelperLawHint> {
88    let fn_defs = items
89        .iter()
90        .filter_map(|item| {
91            if let TopLevel::FnDef(fd) = item {
92                Some((fd.name.clone(), fd))
93            } else {
94                None
95            }
96        })
97        .collect::<HashMap<_, _>>();
98    let verified_law_functions = items
99        .iter()
100        .filter_map(|item| {
101            let TopLevel::Verify(vb) = item else {
102                return None;
103            };
104            let VerifyKind::Law(law) = &vb.kind else {
105                return None;
106            };
107            let mut covered = BTreeSet::new();
108            covered.insert(vb.fn_name.clone());
109            collect_direct_pure_user_calls(&law.lhs, &fn_defs, fn_sigs, &mut covered);
110            collect_direct_pure_user_calls(&law.rhs, &fn_defs, fn_sigs, &mut covered);
111            Some(covered)
112        })
113        .flatten()
114        .collect::<HashSet<_>>();
115
116    items
117        .iter()
118        .filter_map(|item| {
119            let TopLevel::Verify(vb) = item else {
120                return None;
121            };
122            let VerifyKind::Law(law) = &vb.kind else {
123                return None;
124            };
125            missing_helper_law_hint_for_block(vb, law, &fn_defs, &verified_law_functions, fn_sigs)
126        })
127        .collect()
128}
129
130pub fn missing_helper_law_message(hint: &MissingHelperLawHint) -> String {
131    format!(
132        "verify law '{}.{}' uses helper functions without their own verify law: {}; add layered `verify ... law ...` blocks for those helpers before expecting a universal auto-proof",
133        hint.fn_name,
134        hint.law_name,
135        hint.missing_helpers.join(", ")
136    )
137}
138
139pub fn collect_contextual_helper_law_hints(
140    items: &[TopLevel],
141    fn_sigs: &FnSigMap,
142) -> Vec<ContextualHelperLawHint> {
143    let fn_defs = items
144        .iter()
145        .filter_map(|item| {
146            if let TopLevel::FnDef(fd) = item {
147                Some((fd.name.clone(), fd))
148            } else {
149                None
150            }
151        })
152        .collect::<HashMap<_, _>>();
153    let contextual_law_targets = items
154        .iter()
155        .filter_map(|item| {
156            let TopLevel::Verify(vb) = item else {
157                return None;
158            };
159            let VerifyKind::Law(law) = &vb.kind else {
160                return None;
161            };
162            top_level_direct_pure_call_in_law(law, &fn_defs, fn_sigs)
163        })
164        .collect::<HashSet<_>>();
165
166    items
167        .iter()
168        .filter_map(|item| {
169            let TopLevel::Verify(vb) = item else {
170                return None;
171            };
172            let VerifyKind::Law(law) = &vb.kind else {
173                return None;
174            };
175            contextual_helper_law_hint_for_block(
176                vb,
177                law,
178                &fn_defs,
179                &contextual_law_targets,
180                fn_sigs,
181            )
182        })
183        .collect()
184}
185
186pub fn contextual_helper_law_message(hint: &ContextualHelperLawHint) -> String {
187    format!(
188        "verify law '{}.{}' still lacks analogous `verify ... law ...` coverage for contextual helpers: {}; universal auto-proof will likely stop at those helper boundaries",
189        hint.fn_name,
190        hint.law_name,
191        hint.missing_helpers.join(", ")
192    )
193}
194
195fn missing_helper_law_hint_for_block(
196    vb: &VerifyBlock,
197    law: &VerifyLaw,
198    fn_defs: &HashMap<String, &FnDef>,
199    verified_law_functions: &HashSet<String>,
200    fn_sigs: &FnSigMap,
201) -> Option<MissingHelperLawHint> {
202    if law.when.is_none() || law.givens.len() != 1 {
203        return None;
204    }
205
206    let root_calls = direct_pure_user_calls_in_law(law, fn_defs, fn_sigs);
207    if root_calls.is_empty() {
208        return None;
209    }
210
211    let mut missing_helpers = BTreeSet::new();
212    for root in root_calls {
213        for helper in frontier_helper_calls(&root, fn_defs, fn_sigs) {
214            if helper != vb.fn_name && !verified_law_functions.contains(&helper) {
215                missing_helpers.insert(helper);
216            }
217        }
218    }
219
220    if missing_helpers.is_empty() {
221        return None;
222    }
223
224    Some(MissingHelperLawHint {
225        line: vb.line,
226        fn_name: vb.fn_name.clone(),
227        law_name: law.name.clone(),
228        missing_helpers: missing_helpers.into_iter().collect(),
229    })
230}
231
232fn contextual_helper_law_hint_for_block(
233    vb: &VerifyBlock,
234    law: &VerifyLaw,
235    fn_defs: &HashMap<String, &FnDef>,
236    contextual_law_targets: &HashSet<String>,
237    fn_sigs: &FnSigMap,
238) -> Option<ContextualHelperLawHint> {
239    let parser_name = contextual_roundtrip_parser_name(law, fn_defs, fn_sigs)?;
240    let root_parser_name = wrapper_dispatch_root(&parser_name, fn_defs, fn_sigs)
241        .unwrap_or_else(|| parser_name.clone());
242    if root_parser_name != parser_name {
243        return None;
244    }
245
246    let missing_helpers = frontier_helper_calls(&root_parser_name, fn_defs, fn_sigs)
247        .into_iter()
248        .filter(|helper| helper != &vb.fn_name && !contextual_law_targets.contains(helper))
249        .collect::<BTreeSet<_>>();
250
251    if missing_helpers.is_empty() {
252        return None;
253    }
254
255    Some(ContextualHelperLawHint {
256        line: vb.line,
257        fn_name: vb.fn_name.clone(),
258        law_name: law.name.clone(),
259        missing_helpers: missing_helpers.into_iter().collect(),
260    })
261}
262
263fn direct_pure_user_calls_in_law(
264    law: &VerifyLaw,
265    fn_defs: &HashMap<String, &FnDef>,
266    fn_sigs: &FnSigMap,
267) -> BTreeSet<String> {
268    let mut out = BTreeSet::new();
269    collect_direct_pure_user_calls(&law.lhs, fn_defs, fn_sigs, &mut out);
270    collect_direct_pure_user_calls(&law.rhs, fn_defs, fn_sigs, &mut out);
271    out
272}
273
274fn top_level_direct_pure_call_in_law(
275    law: &VerifyLaw,
276    fn_defs: &HashMap<String, &FnDef>,
277    fn_sigs: &FnSigMap,
278) -> Option<String> {
279    direct_pure_user_call_name(&law.lhs, fn_defs, fn_sigs)
280        .or_else(|| direct_pure_user_call_name(&law.rhs, fn_defs, fn_sigs))
281}
282
283fn contextual_roundtrip_parser_name(
284    law: &VerifyLaw,
285    fn_defs: &HashMap<String, &FnDef>,
286    fn_sigs: &FnSigMap,
287) -> Option<String> {
288    let given = law.givens.first()?;
289    detect_roundtrip_layers(law, &given.name, fn_defs, fn_sigs).map(|(parser_name, _)| parser_name)
290}
291
292fn frontier_helper_calls(
293    root_name: &str,
294    fn_defs: &HashMap<String, &FnDef>,
295    fn_sigs: &FnSigMap,
296) -> BTreeSet<String> {
297    let mut current =
298        wrapper_dispatch_root(root_name, fn_defs, fn_sigs).unwrap_or_else(|| root_name.to_string());
299    let mut visited = BTreeSet::new();
300
301    for _ in 0..2 {
302        if !visited.insert(current.clone()) {
303            break;
304        }
305        let direct = direct_pure_fn_callees_matching_return(&current, fn_defs, fn_sigs);
306        if direct.is_empty() {
307            return BTreeSet::new();
308        }
309        if direct.len() == 1 {
310            current = direct.iter().next().cloned().unwrap_or_default();
311            continue;
312        }
313        return direct;
314    }
315
316    direct_pure_fn_callees_matching_return(&current, fn_defs, fn_sigs)
317}
318
319fn wrapper_dispatch_root(
320    fn_name: &str,
321    fn_defs: &HashMap<String, &FnDef>,
322    fn_sigs: &FnSigMap,
323) -> Option<String> {
324    let fd = fn_defs.get(fn_name)?;
325    let tail = fd.body.tail_expr()?;
326    match tail {
327        Expr::Match { subject, .. } => direct_pure_user_call_name(subject, fn_defs, fn_sigs),
328        Expr::FnCall(_, _) => direct_pure_user_call_name(tail, fn_defs, fn_sigs),
329        _ => None,
330    }
331}
332
333fn direct_pure_fn_callees_matching_return(
334    fn_name: &str,
335    fn_defs: &HashMap<String, &FnDef>,
336    fn_sigs: &FnSigMap,
337) -> BTreeSet<String> {
338    let Some((_, return_type, _)) = fn_sigs.get(fn_name) else {
339        return BTreeSet::new();
340    };
341    let Some(fd) = fn_defs.get(fn_name) else {
342        return BTreeSet::new();
343    };
344
345    let mut direct = BTreeSet::new();
346    for stmt in fd.body.stmts() {
347        match stmt {
348            Stmt::Expr(expr) | Stmt::Binding(_, _, expr) => {
349                collect_direct_pure_user_calls(expr, fn_defs, fn_sigs, &mut direct);
350            }
351        }
352    }
353    direct
354        .into_iter()
355        .filter(|callee| {
356            callee != fn_name
357                && fn_sigs.get(callee).is_some_and(|(_, callee_ret, effects)| {
358                    effects.is_empty() && callee_ret == return_type
359                })
360        })
361        .collect()
362}
363
364fn collect_direct_pure_user_calls(
365    expr: &Expr,
366    fn_defs: &HashMap<String, &FnDef>,
367    fn_sigs: &FnSigMap,
368    out: &mut BTreeSet<String>,
369) {
370    match expr {
371        Expr::FnCall(callee, args) => {
372            if let Some(name) = direct_pure_user_call_name(expr, fn_defs, fn_sigs) {
373                out.insert(name);
374            }
375            collect_direct_pure_user_calls(callee, fn_defs, fn_sigs, out);
376            for arg in args {
377                collect_direct_pure_user_calls(arg, fn_defs, fn_sigs, out);
378            }
379        }
380        Expr::Attr(obj, _) => collect_direct_pure_user_calls(obj, fn_defs, fn_sigs, out),
381        Expr::BinOp(_, left, right) => {
382            collect_direct_pure_user_calls(left, fn_defs, fn_sigs, out);
383            collect_direct_pure_user_calls(right, fn_defs, fn_sigs, out);
384        }
385        Expr::Match { subject, arms, .. } => {
386            collect_direct_pure_user_calls(subject, fn_defs, fn_sigs, out);
387            for arm in arms {
388                collect_direct_pure_user_calls(&arm.body, fn_defs, fn_sigs, out);
389            }
390        }
391        Expr::Constructor(_, Some(inner)) | Expr::ErrorProp(inner) => {
392            collect_direct_pure_user_calls(inner, fn_defs, fn_sigs, out);
393        }
394        Expr::InterpolatedStr(parts) => {
395            for part in parts {
396                if let StrPart::Parsed(inner) = part {
397                    collect_direct_pure_user_calls(inner, fn_defs, fn_sigs, out);
398                }
399            }
400        }
401        Expr::List(items) | Expr::Tuple(items) => {
402            for item in items {
403                collect_direct_pure_user_calls(item, fn_defs, fn_sigs, out);
404            }
405        }
406        Expr::MapLiteral(entries) => {
407            for (key, value) in entries {
408                collect_direct_pure_user_calls(key, fn_defs, fn_sigs, out);
409                collect_direct_pure_user_calls(value, fn_defs, fn_sigs, out);
410            }
411        }
412        Expr::RecordCreate { fields, .. } => {
413            for (_, value) in fields {
414                collect_direct_pure_user_calls(value, fn_defs, fn_sigs, out);
415            }
416        }
417        Expr::RecordUpdate { base, updates, .. } => {
418            collect_direct_pure_user_calls(base, fn_defs, fn_sigs, out);
419            for (_, value) in updates {
420                collect_direct_pure_user_calls(value, fn_defs, fn_sigs, out);
421            }
422        }
423        Expr::TailCall(boxed) => {
424            let (target, args) = boxed.as_ref();
425            if fn_defs.contains_key(target)
426                && fn_sigs
427                    .get(target)
428                    .is_some_and(|(_, _, effects)| effects.is_empty())
429            {
430                out.insert(target.clone());
431            }
432            for arg in args {
433                collect_direct_pure_user_calls(arg, fn_defs, fn_sigs, out);
434            }
435        }
436        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
437    }
438}
439
440fn direct_pure_user_call_name(
441    expr: &Expr,
442    fn_defs: &HashMap<String, &FnDef>,
443    fn_sigs: &FnSigMap,
444) -> Option<String> {
445    let Expr::FnCall(callee, _) = expr else {
446        return None;
447    };
448    let name = dotted_name(callee)?;
449    if !fn_defs.contains_key(&name) {
450        return None;
451    }
452    fn_sigs
453        .get(&name)
454        .is_some_and(|(_, _, effects)| effects.is_empty())
455        .then_some(name)
456}
457
458fn dotted_name(expr: &Expr) -> Option<String> {
459    match expr {
460        Expr::Ident(name) => Some(name.clone()),
461        Expr::Attr(base, field) => {
462            let mut prefix = dotted_name(base)?;
463            prefix.push('.');
464            prefix.push_str(field);
465            Some(prefix)
466        }
467        _ => None,
468    }
469}
470
471fn detect_roundtrip_layers(
472    law: &VerifyLaw,
473    given_name: &str,
474    fn_defs: &HashMap<String, &FnDef>,
475    fn_sigs: &FnSigMap,
476) -> Option<(String, String)> {
477    if law.givens.len() != 1 {
478        return None;
479    }
480
481    fn detect_roundtrip_side(
482        expr: &Expr,
483        given_name: &str,
484        fn_defs: &HashMap<String, &FnDef>,
485        fn_sigs: &FnSigMap,
486    ) -> Option<(String, String)> {
487        let Expr::FnCall(parser_callee, parser_args) = expr else {
488            return None;
489        };
490        if parser_args.is_empty() {
491            return None;
492        }
493        let (serializer_callee, serializer_args) =
494            extract_roundtrip_serializer_call(&parser_args[0], given_name)?;
495        if !serializer_args
496            .iter()
497            .any(|arg| matches_ident(arg, given_name))
498        {
499            return None;
500        }
501        if serializer_args
502            .iter()
503            .filter(|arg| expr_mentions_ident(arg, given_name))
504            .any(|arg| !matches_ident(arg, given_name))
505        {
506            return None;
507        }
508        if parser_args[1..]
509            .iter()
510            .any(|arg| expr_mentions_ident(arg, given_name))
511        {
512            return None;
513        }
514
515        let parser_name = dotted_name(parser_callee)?;
516        let serializer_name = dotted_name(serializer_callee)?;
517        if !fn_defs.contains_key(&parser_name) || !fn_defs.contains_key(&serializer_name) {
518            return None;
519        }
520        if !fn_sigs
521            .get(&parser_name)
522            .is_some_and(|(_, _, effects)| effects.is_empty())
523        {
524            return None;
525        }
526        if !fn_sigs
527            .get(&serializer_name)
528            .is_some_and(|(_, _, effects)| effects.is_empty())
529        {
530            return None;
531        }
532        Some((parser_name, serializer_name))
533    }
534
535    detect_roundtrip_side(&law.lhs, given_name, fn_defs, fn_sigs)
536        .or_else(|| detect_roundtrip_side(&law.rhs, given_name, fn_defs, fn_sigs))
537}
538
539fn extract_roundtrip_serializer_call<'a>(
540    expr: &'a Expr,
541    given_name: &str,
542) -> Option<(&'a Expr, &'a [Expr])> {
543    let mut candidates = Vec::new();
544    collect_roundtrip_serializer_calls(expr, given_name, &mut candidates);
545    if candidates.len() != 1 {
546        return None;
547    }
548    let (callee, args) = candidates.pop()?;
549    if expr_mentions_ident(expr, given_name)
550        && args
551            .iter()
552            .filter(|arg| expr_mentions_ident(arg, given_name))
553            .all(|arg| matches_ident(arg, given_name))
554    {
555        Some((callee, args))
556    } else {
557        None
558    }
559}
560
561fn collect_roundtrip_serializer_calls<'a>(
562    expr: &'a Expr,
563    given_name: &str,
564    out: &mut Vec<(&'a Expr, &'a [Expr])>,
565) {
566    match expr {
567        Expr::FnCall(callee, args) => {
568            if args.iter().any(|arg| matches_ident(arg, given_name))
569                && args
570                    .iter()
571                    .filter(|arg| expr_mentions_ident(arg, given_name))
572                    .all(|arg| matches_ident(arg, given_name))
573            {
574                out.push((callee.as_ref(), args.as_slice()));
575            }
576            collect_roundtrip_serializer_calls(callee, given_name, out);
577            for arg in args {
578                collect_roundtrip_serializer_calls(arg, given_name, out);
579            }
580        }
581        Expr::Attr(base, _) => collect_roundtrip_serializer_calls(base, given_name, out),
582        Expr::BinOp(_, left, right) => {
583            collect_roundtrip_serializer_calls(left, given_name, out);
584            collect_roundtrip_serializer_calls(right, given_name, out);
585        }
586        Expr::Match { subject, arms, .. } => {
587            collect_roundtrip_serializer_calls(subject, given_name, out);
588            for arm in arms {
589                collect_roundtrip_serializer_calls(&arm.body, given_name, out);
590            }
591        }
592        Expr::Constructor(_, inner) => {
593            if let Some(inner) = inner {
594                collect_roundtrip_serializer_calls(inner, given_name, out);
595            }
596        }
597        Expr::ErrorProp(inner) => collect_roundtrip_serializer_calls(inner, given_name, out),
598        Expr::InterpolatedStr(parts) => {
599            for part in parts {
600                if let StrPart::Parsed(inner) = part {
601                    collect_roundtrip_serializer_calls(inner, given_name, out);
602                }
603            }
604        }
605        Expr::List(items) | Expr::Tuple(items) => {
606            for item in items {
607                collect_roundtrip_serializer_calls(item, given_name, out);
608            }
609        }
610        Expr::MapLiteral(entries) => {
611            for (key, value) in entries {
612                collect_roundtrip_serializer_calls(key, given_name, out);
613                collect_roundtrip_serializer_calls(value, given_name, out);
614            }
615        }
616        Expr::RecordCreate { fields, .. } => {
617            for (_, value) in fields {
618                collect_roundtrip_serializer_calls(value, given_name, out);
619            }
620        }
621        Expr::RecordUpdate { base, updates, .. } => {
622            collect_roundtrip_serializer_calls(base, given_name, out);
623            for (_, value) in updates {
624                collect_roundtrip_serializer_calls(value, given_name, out);
625            }
626        }
627        Expr::TailCall(call) => {
628            for arg in &call.1 {
629                collect_roundtrip_serializer_calls(arg, given_name, out);
630            }
631        }
632        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => {}
633    }
634}
635
636fn matches_ident(expr: &Expr, name: &str) -> bool {
637    matches!(expr, Expr::Ident(current) if current == name)
638}
639
640fn expr_mentions_ident(expr: &Expr, name: &str) -> bool {
641    match expr {
642        Expr::Ident(current) => current == name,
643        Expr::Attr(base, _) => expr_mentions_ident(base, name),
644        Expr::FnCall(callee, args) => {
645            expr_mentions_ident(callee, name)
646                || args.iter().any(|arg| expr_mentions_ident(arg, name))
647        }
648        Expr::BinOp(_, left, right) => {
649            expr_mentions_ident(left, name) || expr_mentions_ident(right, name)
650        }
651        Expr::Match { subject, arms, .. } => {
652            expr_mentions_ident(subject, name)
653                || arms.iter().any(|arm| expr_mentions_ident(&arm.body, name))
654        }
655        Expr::Constructor(_, inner) => inner
656            .as_deref()
657            .is_some_and(|inner| expr_mentions_ident(inner, name)),
658        Expr::ErrorProp(inner) => expr_mentions_ident(inner, name),
659        Expr::InterpolatedStr(parts) => parts.iter().any(|part| match part {
660            StrPart::Literal(_) => false,
661            StrPart::Parsed(inner) => expr_mentions_ident(inner, name),
662        }),
663        Expr::List(items) | Expr::Tuple(items) => {
664            items.iter().any(|item| expr_mentions_ident(item, name))
665        }
666        Expr::MapLiteral(entries) => entries
667            .iter()
668            .any(|(key, value)| expr_mentions_ident(key, name) || expr_mentions_ident(value, name)),
669        Expr::RecordCreate { fields, .. } => fields
670            .iter()
671            .any(|(_, value)| expr_mentions_ident(value, name)),
672        Expr::RecordUpdate { base, updates, .. } => {
673            expr_mentions_ident(base, name)
674                || updates
675                    .iter()
676                    .any(|(_, value)| expr_mentions_ident(value, name))
677        }
678        Expr::TailCall(call) => call.1.iter().any(|arg| expr_mentions_ident(arg, name)),
679        Expr::Literal(_) | Expr::Resolved(_) => false,
680    }
681}
682
683fn expr_calls_function(expr: &Expr, fn_name: &str) -> bool {
684    match expr {
685        Expr::FnCall(callee, args) => {
686            expr_is_function_name(callee, fn_name)
687                || expr_calls_function(callee, fn_name)
688                || args.iter().any(|arg| expr_calls_function(arg, fn_name))
689        }
690        Expr::Attr(obj, _) => expr_calls_function(obj, fn_name),
691        Expr::BinOp(_, left, right) => {
692            expr_calls_function(left, fn_name) || expr_calls_function(right, fn_name)
693        }
694        Expr::Match { subject, arms, .. } => {
695            expr_calls_function(subject, fn_name)
696                || arms
697                    .iter()
698                    .any(|arm| match_arm_calls_function(arm, fn_name))
699        }
700        Expr::Constructor(_, Some(inner)) => expr_calls_function(inner, fn_name),
701        Expr::ErrorProp(inner) => expr_calls_function(inner, fn_name),
702        Expr::InterpolatedStr(parts) => parts.iter().any(|part| match part {
703            StrPart::Literal(_) => false,
704            StrPart::Parsed(expr) => expr_calls_function(expr, fn_name),
705        }),
706        Expr::List(items) | Expr::Tuple(items) => {
707            items.iter().any(|item| expr_calls_function(item, fn_name))
708        }
709        Expr::MapLiteral(entries) => entries.iter().any(|(key, value)| {
710            expr_calls_function(key, fn_name) || expr_calls_function(value, fn_name)
711        }),
712        Expr::RecordCreate { fields, .. } => fields
713            .iter()
714            .any(|(_, expr)| expr_calls_function(expr, fn_name)),
715        Expr::RecordUpdate { base, updates, .. } => {
716            expr_calls_function(base, fn_name)
717                || updates
718                    .iter()
719                    .any(|(_, expr)| expr_calls_function(expr, fn_name))
720        }
721        Expr::TailCall(boxed) => {
722            boxed.0 == fn_name || boxed.1.iter().any(|arg| expr_calls_function(arg, fn_name))
723        }
724        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => false,
725    }
726}
727
728fn match_arm_calls_function(arm: &MatchArm, fn_name: &str) -> bool {
729    expr_calls_function(&arm.body, fn_name)
730}
731
732fn expr_is_function_name(expr: &Expr, fn_name: &str) -> bool {
733    matches!(expr, Expr::Ident(name) if name == fn_name)
734}
735
736fn direct_call(expr: &Expr) -> Option<(&str, &[Expr])> {
737    let Expr::FnCall(callee, args) = expr else {
738        return None;
739    };
740    let Expr::Ident(name) = callee.as_ref() else {
741        return None;
742    };
743    Some((name.as_str(), args.as_slice()))
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749    use crate::ast::{Literal, VerifyGiven, VerifyGivenDomain};
750
751    fn int_sig() -> (Vec<Type>, Type, Vec<String>) {
752        (vec![Type::Int], Type::Int, vec![])
753    }
754
755    fn law(lhs: Expr, rhs: Expr, name: &str) -> VerifyLaw {
756        VerifyLaw {
757            name: name.to_string(),
758            givens: vec![VerifyGiven {
759                name: "x".to_string(),
760                type_name: "Int".to_string(),
761                domain: VerifyGivenDomain::Explicit(vec![Expr::Literal(Literal::Int(1))]),
762            }],
763            when: None,
764            lhs,
765            rhs,
766            sample_guards: vec![],
767        }
768    }
769
770    #[test]
771    fn pure_named_law_function_becomes_declared_spec_ref() {
772        let mut fn_sigs = FnSigMap::new();
773        fn_sigs.insert("fooSpec".to_string(), int_sig());
774
775        let verify_law = law(
776            Expr::FnCall(
777                Box::new(Expr::Ident("foo".to_string())),
778                vec![Expr::Ident("x".to_string())],
779            ),
780            Expr::FnCall(
781                Box::new(Expr::Ident("fooSpec".to_string())),
782                vec![Expr::Ident("x".to_string())],
783            ),
784            "fooSpec",
785        );
786
787        assert_eq!(
788            declared_spec_ref(&verify_law, &fn_sigs),
789            Some(VerifyLawSpecRef {
790                spec_fn_name: "fooSpec".to_string()
791            })
792        );
793        assert_eq!(
794            law_spec_ref(&verify_law, &fn_sigs),
795            declared_spec_ref(&verify_law, &fn_sigs)
796        );
797        assert_eq!(
798            canonical_spec_ref("foo", &verify_law, &fn_sigs),
799            declared_spec_ref(&verify_law, &fn_sigs)
800        );
801    }
802
803    #[test]
804    fn effectful_named_law_function_is_not_a_spec_ref() {
805        let mut fn_sigs = FnSigMap::new();
806        fn_sigs.insert(
807            "fooSpec".to_string(),
808            (
809                vec![Type::Int],
810                Type::Int,
811                vec!["Console.print".to_string()],
812            ),
813        );
814
815        let verify_law = law(
816            Expr::Ident("x".to_string()),
817            Expr::Ident("x".to_string()),
818            "fooSpec",
819        );
820
821        assert!(declared_spec_ref(&verify_law, &fn_sigs).is_none());
822        assert_eq!(
823            named_law_function(&verify_law, &fn_sigs),
824            Some(NamedLawFunction {
825                name: "fooSpec".to_string(),
826                is_pure: false
827            })
828        );
829    }
830
831    #[test]
832    fn canonical_spec_ref_requires_call_to_named_function() {
833        let mut fn_sigs = FnSigMap::new();
834        fn_sigs.insert("fooSpec".to_string(), int_sig());
835
836        let verify_law = law(
837            Expr::Ident("x".to_string()),
838            Expr::Ident("x".to_string()),
839            "fooSpec",
840        );
841
842        assert!(declared_spec_ref(&verify_law, &fn_sigs).is_some());
843        assert!(law_spec_ref(&verify_law, &fn_sigs).is_none());
844        assert!(!law_calls_function(&verify_law, "fooSpec"));
845    }
846
847    #[test]
848    fn canonical_spec_ref_requires_same_arguments_on_both_sides() {
849        let mut fn_sigs = FnSigMap::new();
850        fn_sigs.insert("fooSpec".to_string(), int_sig());
851
852        let verify_law = law(
853            Expr::FnCall(
854                Box::new(Expr::Ident("foo".to_string())),
855                vec![Expr::Ident("x".to_string())],
856            ),
857            Expr::FnCall(
858                Box::new(Expr::Ident("fooSpec".to_string())),
859                vec![Expr::Literal(Literal::Int(5)), Expr::Ident("x".to_string())],
860            ),
861            "fooSpec",
862        );
863
864        assert!(law_spec_ref(&verify_law, &fn_sigs).is_some());
865        assert!(canonical_spec_ref("foo", &verify_law, &fn_sigs).is_none());
866        assert!(!canonical_spec_shape("foo", &verify_law, "fooSpec"));
867    }
868}