1use std::collections::HashMap;
2
3use crate::ast::{Expr, MatchArm, StrPart, VerifyLaw};
4use crate::types::Type;
5
6pub type FnSigMap = HashMap<String, (Vec<Type>, Type, Vec<String>)>;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct NamedLawFunction {
10 pub name: String,
11 pub is_pure: bool,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct VerifyLawSpecRef {
16 pub spec_fn_name: String,
17}
18
19pub fn named_law_function(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<NamedLawFunction> {
20 let (_, _, effects) = fn_sigs.get(&law.name)?;
21 Some(NamedLawFunction {
22 name: law.name.clone(),
23 is_pure: effects.is_empty(),
24 })
25}
26
27pub fn declared_spec_ref(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<VerifyLawSpecRef> {
28 let named = named_law_function(law, fn_sigs)?;
29 named.is_pure.then_some(VerifyLawSpecRef {
30 spec_fn_name: named.name,
31 })
32}
33
34pub fn law_spec_ref(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<VerifyLawSpecRef> {
35 let spec = declared_spec_ref(law, fn_sigs)?;
36 law_calls_function(law, &spec.spec_fn_name).then_some(spec)
37}
38
39pub fn canonical_spec_ref(
40 fn_name: &str,
41 law: &VerifyLaw,
42 fn_sigs: &FnSigMap,
43) -> Option<VerifyLawSpecRef> {
44 let spec = law_spec_ref(law, fn_sigs)?;
45 canonical_spec_shape(fn_name, law, &spec.spec_fn_name).then_some(spec)
46}
47
48pub fn law_calls_function(law: &VerifyLaw, fn_name: &str) -> bool {
49 expr_calls_function(&law.lhs, fn_name) || expr_calls_function(&law.rhs, fn_name)
50}
51
52pub fn canonical_spec_shape(fn_name: &str, law: &VerifyLaw, spec_fn_name: &str) -> bool {
53 let try_side = |impl_side: &Expr, spec_side: &Expr| -> bool {
54 let Some((impl_callee, impl_args)) = direct_call(impl_side) else {
55 return false;
56 };
57 let Some((spec_callee, spec_args)) = direct_call(spec_side) else {
58 return false;
59 };
60 impl_callee == fn_name && spec_callee == spec_fn_name && impl_args == spec_args
61 };
62
63 try_side(&law.lhs, &law.rhs) || try_side(&law.rhs, &law.lhs)
64}
65
66fn expr_calls_function(expr: &Expr, fn_name: &str) -> bool {
67 match expr {
68 Expr::FnCall(callee, args) => {
69 expr_is_function_name(callee, fn_name)
70 || expr_calls_function(callee, fn_name)
71 || args.iter().any(|arg| expr_calls_function(arg, fn_name))
72 }
73 Expr::Attr(obj, _) => expr_calls_function(obj, fn_name),
74 Expr::BinOp(_, left, right) => {
75 expr_calls_function(left, fn_name) || expr_calls_function(right, fn_name)
76 }
77 Expr::Match { subject, arms, .. } => {
78 expr_calls_function(subject, fn_name)
79 || arms
80 .iter()
81 .any(|arm| match_arm_calls_function(arm, fn_name))
82 }
83 Expr::Constructor(_, Some(inner)) => expr_calls_function(inner, fn_name),
84 Expr::ErrorProp(inner) => expr_calls_function(inner, fn_name),
85 Expr::InterpolatedStr(parts) => parts.iter().any(|part| match part {
86 StrPart::Literal(_) => false,
87 StrPart::Parsed(expr) => expr_calls_function(expr, fn_name),
88 }),
89 Expr::List(items) | Expr::Tuple(items) => {
90 items.iter().any(|item| expr_calls_function(item, fn_name))
91 }
92 Expr::MapLiteral(entries) => entries.iter().any(|(key, value)| {
93 expr_calls_function(key, fn_name) || expr_calls_function(value, fn_name)
94 }),
95 Expr::RecordCreate { fields, .. } => fields
96 .iter()
97 .any(|(_, expr)| expr_calls_function(expr, fn_name)),
98 Expr::RecordUpdate { base, updates, .. } => {
99 expr_calls_function(base, fn_name)
100 || updates
101 .iter()
102 .any(|(_, expr)| expr_calls_function(expr, fn_name))
103 }
104 Expr::TailCall(boxed) => {
105 boxed.0 == fn_name || boxed.1.iter().any(|arg| expr_calls_function(arg, fn_name))
106 }
107 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => false,
108 }
109}
110
111fn match_arm_calls_function(arm: &MatchArm, fn_name: &str) -> bool {
112 expr_calls_function(&arm.body, fn_name)
113}
114
115fn expr_is_function_name(expr: &Expr, fn_name: &str) -> bool {
116 matches!(expr, Expr::Ident(name) if name == fn_name)
117}
118
119fn direct_call(expr: &Expr) -> Option<(&str, &[Expr])> {
120 let Expr::FnCall(callee, args) = expr else {
121 return None;
122 };
123 let Expr::Ident(name) = callee.as_ref() else {
124 return None;
125 };
126 Some((name.as_str(), args.as_slice()))
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::ast::{Literal, VerifyGiven, VerifyGivenDomain};
133
134 fn int_sig() -> (Vec<Type>, Type, Vec<String>) {
135 (vec![Type::Int], Type::Int, vec![])
136 }
137
138 fn law(lhs: Expr, rhs: Expr, name: &str) -> VerifyLaw {
139 VerifyLaw {
140 name: name.to_string(),
141 givens: vec![VerifyGiven {
142 name: "x".to_string(),
143 type_name: "Int".to_string(),
144 domain: VerifyGivenDomain::Explicit(vec![Expr::Literal(Literal::Int(1))]),
145 }],
146 when: None,
147 lhs,
148 rhs,
149 sample_guards: vec![],
150 }
151 }
152
153 #[test]
154 fn pure_named_law_function_becomes_declared_spec_ref() {
155 let mut fn_sigs = FnSigMap::new();
156 fn_sigs.insert("fooSpec".to_string(), int_sig());
157
158 let verify_law = law(
159 Expr::FnCall(
160 Box::new(Expr::Ident("foo".to_string())),
161 vec![Expr::Ident("x".to_string())],
162 ),
163 Expr::FnCall(
164 Box::new(Expr::Ident("fooSpec".to_string())),
165 vec![Expr::Ident("x".to_string())],
166 ),
167 "fooSpec",
168 );
169
170 assert_eq!(
171 declared_spec_ref(&verify_law, &fn_sigs),
172 Some(VerifyLawSpecRef {
173 spec_fn_name: "fooSpec".to_string()
174 })
175 );
176 assert_eq!(
177 law_spec_ref(&verify_law, &fn_sigs),
178 declared_spec_ref(&verify_law, &fn_sigs)
179 );
180 assert_eq!(
181 canonical_spec_ref("foo", &verify_law, &fn_sigs),
182 declared_spec_ref(&verify_law, &fn_sigs)
183 );
184 }
185
186 #[test]
187 fn effectful_named_law_function_is_not_a_spec_ref() {
188 let mut fn_sigs = FnSigMap::new();
189 fn_sigs.insert(
190 "fooSpec".to_string(),
191 (
192 vec![Type::Int],
193 Type::Int,
194 vec!["Console.print".to_string()],
195 ),
196 );
197
198 let verify_law = law(
199 Expr::Ident("x".to_string()),
200 Expr::Ident("x".to_string()),
201 "fooSpec",
202 );
203
204 assert!(declared_spec_ref(&verify_law, &fn_sigs).is_none());
205 assert_eq!(
206 named_law_function(&verify_law, &fn_sigs),
207 Some(NamedLawFunction {
208 name: "fooSpec".to_string(),
209 is_pure: false
210 })
211 );
212 }
213
214 #[test]
215 fn canonical_spec_ref_requires_call_to_named_function() {
216 let mut fn_sigs = FnSigMap::new();
217 fn_sigs.insert("fooSpec".to_string(), int_sig());
218
219 let verify_law = law(
220 Expr::Ident("x".to_string()),
221 Expr::Ident("x".to_string()),
222 "fooSpec",
223 );
224
225 assert!(declared_spec_ref(&verify_law, &fn_sigs).is_some());
226 assert!(law_spec_ref(&verify_law, &fn_sigs).is_none());
227 assert!(!law_calls_function(&verify_law, "fooSpec"));
228 }
229
230 #[test]
231 fn canonical_spec_ref_requires_same_arguments_on_both_sides() {
232 let mut fn_sigs = FnSigMap::new();
233 fn_sigs.insert("fooSpec".to_string(), int_sig());
234
235 let verify_law = law(
236 Expr::FnCall(
237 Box::new(Expr::Ident("foo".to_string())),
238 vec![Expr::Ident("x".to_string())],
239 ),
240 Expr::FnCall(
241 Box::new(Expr::Ident("fooSpec".to_string())),
242 vec![Expr::Literal(Literal::Int(5)), Expr::Ident("x".to_string())],
243 ),
244 "fooSpec",
245 );
246
247 assert!(law_spec_ref(&verify_law, &fn_sigs).is_some());
248 assert!(canonical_spec_ref("foo", &verify_law, &fn_sigs).is_none());
249 assert!(!canonical_spec_shape("foo", &verify_law, "fooSpec"));
250 }
251}