1use std::collections::HashMap;
2
3use crate::ast::{Expr, MatchArm, StrPart, VerifyLaw};
4use crate::types::Type;
5
6pub type FnSigMap = HashMap<String, (Vec<Type>, Type, Vec<String>)>;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct NamedLawFunction {
10 pub name: String,
11 pub is_pure: bool,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct VerifyLawSpecRef {
16 pub spec_fn_name: String,
17}
18
19pub fn named_law_function(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<NamedLawFunction> {
20 let (_, _, effects) = fn_sigs.get(&law.name)?;
21 Some(NamedLawFunction {
22 name: law.name.clone(),
23 is_pure: effects.is_empty(),
24 })
25}
26
27pub fn declared_spec_ref(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<VerifyLawSpecRef> {
28 let named = named_law_function(law, fn_sigs)?;
29 named.is_pure.then_some(VerifyLawSpecRef {
30 spec_fn_name: named.name,
31 })
32}
33
34pub fn law_spec_ref(law: &VerifyLaw, fn_sigs: &FnSigMap) -> Option<VerifyLawSpecRef> {
35 let spec = declared_spec_ref(law, fn_sigs)?;
36 law_calls_function(law, &spec.spec_fn_name).then_some(spec)
37}
38
39pub fn canonical_spec_ref(
40 fn_name: &str,
41 law: &VerifyLaw,
42 fn_sigs: &FnSigMap,
43) -> Option<VerifyLawSpecRef> {
44 let spec = law_spec_ref(law, fn_sigs)?;
45 canonical_spec_shape(fn_name, law, &spec.spec_fn_name).then_some(spec)
46}
47
48pub fn law_calls_function(law: &VerifyLaw, fn_name: &str) -> bool {
49 expr_calls_function(&law.lhs, fn_name) || expr_calls_function(&law.rhs, fn_name)
50}
51
52pub fn canonical_spec_shape(fn_name: &str, law: &VerifyLaw, spec_fn_name: &str) -> bool {
53 let try_side = |impl_side: &Expr, spec_side: &Expr| -> bool {
54 let Some((impl_callee, impl_args)) = direct_call(impl_side) else {
55 return false;
56 };
57 let Some((spec_callee, spec_args)) = direct_call(spec_side) else {
58 return false;
59 };
60 impl_callee == fn_name && spec_callee == spec_fn_name && impl_args == spec_args
61 };
62
63 try_side(&law.lhs, &law.rhs) || try_side(&law.rhs, &law.lhs)
64}
65
66fn expr_calls_function(expr: &Expr, fn_name: &str) -> bool {
67 match expr {
68 Expr::FnCall(callee, args) => {
69 expr_is_function_name(callee, fn_name)
70 || expr_calls_function(callee, fn_name)
71 || args.iter().any(|arg| expr_calls_function(arg, fn_name))
72 }
73 Expr::Attr(obj, _) => expr_calls_function(obj, fn_name),
74 Expr::BinOp(_, left, right) => {
75 expr_calls_function(left, fn_name) || expr_calls_function(right, fn_name)
76 }
77 Expr::Match { subject, arms, .. } => {
78 expr_calls_function(subject, fn_name)
79 || arms
80 .iter()
81 .any(|arm| match_arm_calls_function(arm, fn_name))
82 }
83 Expr::Constructor(_, Some(inner)) => expr_calls_function(inner, fn_name),
84 Expr::ErrorProp(inner) => expr_calls_function(inner, fn_name),
85 Expr::InterpolatedStr(parts) => parts.iter().any(|part| match part {
86 StrPart::Literal(_) => false,
87 StrPart::Parsed(expr) => expr_calls_function(expr, fn_name),
88 }),
89 Expr::List(items) | Expr::Tuple(items) => {
90 items.iter().any(|item| expr_calls_function(item, fn_name))
91 }
92 Expr::MapLiteral(entries) => entries.iter().any(|(key, value)| {
93 expr_calls_function(key, fn_name) || expr_calls_function(value, fn_name)
94 }),
95 Expr::RecordCreate { fields, .. } => fields
96 .iter()
97 .any(|(_, expr)| expr_calls_function(expr, fn_name)),
98 Expr::RecordUpdate { base, updates, .. } => {
99 expr_calls_function(base, fn_name)
100 || updates
101 .iter()
102 .any(|(_, expr)| expr_calls_function(expr, fn_name))
103 }
104 Expr::TailCall(boxed) => {
105 boxed.0 == fn_name || boxed.1.iter().any(|arg| expr_calls_function(arg, fn_name))
106 }
107 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => false,
108 }
109}
110
111fn match_arm_calls_function(arm: &MatchArm, fn_name: &str) -> bool {
112 expr_calls_function(&arm.body, fn_name)
113}
114
115fn expr_is_function_name(expr: &Expr, fn_name: &str) -> bool {
116 matches!(expr, Expr::Ident(name) if name == fn_name)
117}
118
119fn direct_call(expr: &Expr) -> Option<(&str, &[Expr])> {
120 let Expr::FnCall(callee, args) = expr else {
121 return None;
122 };
123 let Expr::Ident(name) = callee.as_ref() else {
124 return None;
125 };
126 Some((name.as_str(), args.as_slice()))
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::ast::{Literal, VerifyGiven, VerifyGivenDomain};
133
134 fn int_sig() -> (Vec<Type>, Type, Vec<String>) {
135 (vec![Type::Int], Type::Int, vec![])
136 }
137
138 fn law(lhs: Expr, rhs: Expr, name: &str) -> VerifyLaw {
139 VerifyLaw {
140 name: name.to_string(),
141 givens: vec![VerifyGiven {
142 name: "x".to_string(),
143 type_name: "Int".to_string(),
144 domain: VerifyGivenDomain::Explicit(vec![Expr::Literal(Literal::Int(1))]),
145 }],
146 lhs,
147 rhs,
148 }
149 }
150
151 #[test]
152 fn pure_named_law_function_becomes_declared_spec_ref() {
153 let mut fn_sigs = FnSigMap::new();
154 fn_sigs.insert("fooSpec".to_string(), int_sig());
155
156 let verify_law = law(
157 Expr::FnCall(
158 Box::new(Expr::Ident("foo".to_string())),
159 vec![Expr::Ident("x".to_string())],
160 ),
161 Expr::FnCall(
162 Box::new(Expr::Ident("fooSpec".to_string())),
163 vec![Expr::Ident("x".to_string())],
164 ),
165 "fooSpec",
166 );
167
168 assert_eq!(
169 declared_spec_ref(&verify_law, &fn_sigs),
170 Some(VerifyLawSpecRef {
171 spec_fn_name: "fooSpec".to_string()
172 })
173 );
174 assert_eq!(
175 law_spec_ref(&verify_law, &fn_sigs),
176 declared_spec_ref(&verify_law, &fn_sigs)
177 );
178 assert_eq!(
179 canonical_spec_ref("foo", &verify_law, &fn_sigs),
180 declared_spec_ref(&verify_law, &fn_sigs)
181 );
182 }
183
184 #[test]
185 fn effectful_named_law_function_is_not_a_spec_ref() {
186 let mut fn_sigs = FnSigMap::new();
187 fn_sigs.insert(
188 "fooSpec".to_string(),
189 (
190 vec![Type::Int],
191 Type::Int,
192 vec!["Console.print".to_string()],
193 ),
194 );
195
196 let verify_law = law(
197 Expr::Ident("x".to_string()),
198 Expr::Ident("x".to_string()),
199 "fooSpec",
200 );
201
202 assert!(declared_spec_ref(&verify_law, &fn_sigs).is_none());
203 assert_eq!(
204 named_law_function(&verify_law, &fn_sigs),
205 Some(NamedLawFunction {
206 name: "fooSpec".to_string(),
207 is_pure: false
208 })
209 );
210 }
211
212 #[test]
213 fn canonical_spec_ref_requires_call_to_named_function() {
214 let mut fn_sigs = FnSigMap::new();
215 fn_sigs.insert("fooSpec".to_string(), int_sig());
216
217 let verify_law = law(
218 Expr::Ident("x".to_string()),
219 Expr::Ident("x".to_string()),
220 "fooSpec",
221 );
222
223 assert!(declared_spec_ref(&verify_law, &fn_sigs).is_some());
224 assert!(law_spec_ref(&verify_law, &fn_sigs).is_none());
225 assert!(!law_calls_function(&verify_law, "fooSpec"));
226 }
227
228 #[test]
229 fn canonical_spec_ref_requires_same_arguments_on_both_sides() {
230 let mut fn_sigs = FnSigMap::new();
231 fn_sigs.insert("fooSpec".to_string(), int_sig());
232
233 let verify_law = law(
234 Expr::FnCall(
235 Box::new(Expr::Ident("foo".to_string())),
236 vec![Expr::Ident("x".to_string())],
237 ),
238 Expr::FnCall(
239 Box::new(Expr::Ident("fooSpec".to_string())),
240 vec![Expr::Literal(Literal::Int(5)), Expr::Ident("x".to_string())],
241 ),
242 "fooSpec",
243 );
244
245 assert!(law_spec_ref(&verify_law, &fn_sigs).is_some());
246 assert!(canonical_spec_ref("foo", &verify_law, &fn_sigs).is_none());
247 assert!(!canonical_spec_shape("foo", &verify_law, "fooSpec"));
248 }
249}