1use std::collections::{HashMap, HashSet};
17
18use crate::types::*;
19
20#[derive(Debug, Clone)]
22pub struct VerifyError {
23 pub kind: VerifyErrorKind,
24 pub trace: SourceTrace,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum VerifyErrorKind {
29 UnboundVariable { name: String },
31 OldOutsidePoscondition,
33 PostconditionWithoutParams { function: String },
35 UnknownQuantifierType { ty: String },
37 DisconnectedPostcondition { function: String },
39}
40
41impl std::fmt::Display for VerifyError {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match &self.kind {
44 VerifyErrorKind::UnboundVariable { name } => {
45 write!(f, "unbound variable `{name}`")
46 }
47 VerifyErrorKind::OldOutsidePoscondition => {
48 write!(f, "`old()` used outside of postcondition")
49 }
50 VerifyErrorKind::PostconditionWithoutParams { function } => {
51 write!(
52 f,
53 "function `{function}` has postconditions but no parameters"
54 )
55 }
56 VerifyErrorKind::UnknownQuantifierType { ty } => {
57 write!(f, "quantifier references unknown type `{ty}`")
58 }
59 VerifyErrorKind::DisconnectedPostcondition { function } => {
60 write!(
61 f,
62 "postcondition in `{function}` doesn't reference any parameter"
63 )
64 }
65 }
66 }
67}
68
69pub fn verify_module(module: &Module) -> Vec<VerifyError> {
71 let mut errors = Vec::new();
72
73 let known_types: HashSet<&str> = module
75 .structs
76 .iter()
77 .map(|s| s.name.as_str())
78 .chain(module.functions.iter().map(|f| f.name.as_str()))
79 .collect();
80
81 let mut call_names = HashSet::new();
86 collect_module_call_names(module, &mut call_names);
87
88 for func in &module.functions {
89 verify_function(func, &known_types, &call_names, &mut errors);
90 }
91
92 for inv in &module.invariants {
93 verify_invariant(inv, &known_types, &call_names, &mut errors);
94 }
95
96 for guard in &module.edge_guards {
97 verify_edge_guard(guard, &known_types, &mut errors);
98 }
99
100 errors
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
108pub struct Obligation {
109 pub action: String,
111 pub invariant: String,
113 pub entity: String,
115 pub fields: Vec<String>,
117 pub kind: ObligationKind,
119}
120
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub enum ObligationKind {
123 InvariantPreservation,
126 TemporalProperty,
129}
130
131impl std::fmt::Display for Obligation {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 match &self.kind {
134 ObligationKind::InvariantPreservation => {
135 write!(
136 f,
137 "{} modifies {}.{{{}}} (constrained by {})",
138 self.action,
139 self.entity,
140 self.fields.join(", "),
141 self.invariant,
142 )
143 }
144 ObligationKind::TemporalProperty => {
145 write!(
146 f,
147 "{} must satisfy temporal property {}",
148 self.action, self.invariant,
149 )
150 }
151 }
152 }
153}
154
155pub fn analyze_obligations(module: &Module) -> Vec<Obligation> {
161 let mut obligations = Vec::new();
162
163 let struct_fields: HashMap<&str, Vec<&str>> = module
165 .structs
166 .iter()
167 .map(|s| {
168 (
169 s.name.as_str(),
170 s.fields.iter().map(|f| f.name.as_str()).collect(),
171 )
172 })
173 .collect();
174
175 let func_entity_params: HashMap<&str, Vec<(&str, &str)>> = module
178 .functions
179 .iter()
180 .map(|func| {
181 let entity_params: Vec<(&str, &str)> = func
182 .params
183 .iter()
184 .filter_map(|p| match &p.ty {
185 IrType::Named(t) | IrType::Struct(t)
186 if struct_fields.contains_key(t.as_str()) =>
187 {
188 Some((p.name.as_str(), t.as_str()))
189 }
190 _ => None,
191 })
192 .collect();
193 (func.name.as_str(), entity_params)
194 })
195 .collect();
196
197 let mut modified_fields: HashMap<&str, HashSet<(&str, &str)>> = HashMap::new();
200 for func in &module.functions {
201 let entity_params = &func_entity_params[func.name.as_str()];
202 let param_to_entity: HashMap<&str, &str> = entity_params.iter().copied().collect();
203 let mut fields = HashSet::new();
204 for post in &func.postconditions {
205 let exprs: Vec<&IrExpr> = match post {
206 Postcondition::Always { expr, .. } => vec![expr],
207 Postcondition::When { guard, expr, .. } => vec![guard, expr],
208 };
209 for expr in exprs {
210 collect_old_field_accesses(expr, ¶m_to_entity, &mut fields);
211 }
212 }
213 modified_fields.insert(func.name.as_str(), fields);
214 }
215
216 for inv in &module.invariants {
218 if let IrExpr::Forall { binding, ty, body } = &inv.expr {
219 let is_action = module.functions.iter().any(|f| f.name == *ty);
221 if is_action {
222 obligations.push(Obligation {
224 action: ty.clone(),
225 invariant: inv.name.clone(),
226 entity: ty.clone(),
227 fields: vec![],
228 kind: ObligationKind::TemporalProperty,
229 });
230 continue;
231 }
232
233 let constrained = collect_field_accesses_on(body, binding);
236
237 for func in &module.functions {
239 if let Some(mods) = modified_fields.get(func.name.as_str()) {
240 let overlapping: Vec<String> = constrained
241 .iter()
242 .filter(|f| mods.contains(&(ty.as_str(), f.as_str())))
243 .cloned()
244 .collect();
245 if !overlapping.is_empty() {
246 obligations.push(Obligation {
247 action: func.name.clone(),
248 invariant: inv.name.clone(),
249 entity: ty.clone(),
250 fields: overlapping,
251 kind: ObligationKind::InvariantPreservation,
252 });
253 }
254 }
255 }
256 }
257 }
258
259 obligations
260}
261
262fn collect_old_field_accesses<'a>(
268 expr: &'a IrExpr,
269 param_to_entity: &HashMap<&str, &'a str>,
270 result: &mut HashSet<(&'a str, &'a str)>,
271) {
272 match expr {
273 IrExpr::Old(inner) => {
274 collect_inner_field_accesses(inner, param_to_entity, result);
275 }
276 _ => {
277 match expr {
279 IrExpr::Compare { left, right, .. }
280 | IrExpr::Arithmetic { left, right, .. }
281 | IrExpr::And(left, right)
282 | IrExpr::Or(left, right)
283 | IrExpr::Implies(left, right) => {
284 collect_old_field_accesses(left, param_to_entity, result);
285 collect_old_field_accesses(right, param_to_entity, result);
286 }
287 IrExpr::Not(inner) => {
288 collect_old_field_accesses(inner, param_to_entity, result);
289 }
290 IrExpr::FieldAccess { root, .. } => {
291 collect_old_field_accesses(root, param_to_entity, result);
292 }
293 IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
294 collect_old_field_accesses(body, param_to_entity, result);
295 }
296 IrExpr::Call { args, .. } => {
297 for arg in args {
298 collect_old_field_accesses(arg, param_to_entity, result);
299 }
300 }
301 IrExpr::Var(_) | IrExpr::Literal(_) | IrExpr::Old(_) => {}
302 }
303 }
304 }
305}
306
307fn collect_inner_field_accesses<'a>(
309 expr: &'a IrExpr,
310 param_to_entity: &HashMap<&str, &'a str>,
311 result: &mut HashSet<(&'a str, &'a str)>,
312) {
313 match expr {
314 IrExpr::FieldAccess { root, field } => {
315 if let IrExpr::Var(var) = root.as_ref()
317 && let Some(&entity) = param_to_entity.get(var.as_str())
318 {
319 result.insert((entity, field.as_str()));
320 }
321 collect_inner_field_accesses(root, param_to_entity, result);
323 }
324 _ => match expr {
325 IrExpr::Compare { left, right, .. }
326 | IrExpr::Arithmetic { left, right, .. }
327 | IrExpr::And(left, right)
328 | IrExpr::Or(left, right)
329 | IrExpr::Implies(left, right) => {
330 collect_inner_field_accesses(left, param_to_entity, result);
331 collect_inner_field_accesses(right, param_to_entity, result);
332 }
333 IrExpr::Not(inner) | IrExpr::Old(inner) => {
334 collect_inner_field_accesses(inner, param_to_entity, result);
335 }
336 IrExpr::FieldAccess { .. } => unreachable!(),
337 IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
338 collect_inner_field_accesses(body, param_to_entity, result);
339 }
340 IrExpr::Call { args, .. } => {
341 for arg in args {
342 collect_inner_field_accesses(arg, param_to_entity, result);
343 }
344 }
345 IrExpr::Var(_) | IrExpr::Literal(_) => {}
346 },
347 }
348}
349
350fn collect_field_accesses_on(expr: &IrExpr, binding: &str) -> Vec<String> {
355 let mut fields = Vec::new();
356 collect_fields_on_inner(expr, binding, &mut fields);
357 fields.sort();
358 fields.dedup();
359 fields
360}
361
362fn collect_fields_on_inner(expr: &IrExpr, binding: &str, fields: &mut Vec<String>) {
363 match expr {
364 IrExpr::FieldAccess { root, field } => {
365 if let IrExpr::Var(var) = root.as_ref()
366 && var == binding
367 {
368 fields.push(field.clone());
369 }
370 collect_fields_on_inner(root, binding, fields);
371 }
372 _ => for_each_child(expr, |child| {
373 collect_fields_on_inner(child, binding, fields)
374 }),
375 }
376}
377
378fn collect_module_call_names<'a>(module: &'a Module, names: &mut HashSet<&'a str>) {
382 for func in &module.functions {
383 for pre in &func.preconditions {
384 collect_call_names(&pre.expr, names);
385 }
386 for post in &func.postconditions {
387 match post {
388 Postcondition::Always { expr, .. } => collect_call_names(expr, names),
389 Postcondition::When { guard, expr, .. } => {
390 collect_call_names(guard, names);
391 collect_call_names(expr, names);
392 }
393 }
394 }
395 }
396 for inv in &module.invariants {
397 collect_call_names(&inv.expr, names);
398 }
399 for guard in &module.edge_guards {
400 collect_call_names(&guard.condition, names);
401 for (_, arg) in &guard.args {
402 collect_call_names(arg, names);
403 }
404 }
405}
406
407fn collect_call_names<'a>(expr: &'a IrExpr, names: &mut HashSet<&'a str>) {
408 if let IrExpr::Call { name, args } = expr {
409 names.insert(name.as_str());
410 for arg in args {
411 collect_call_names(arg, names);
412 }
413 return;
414 }
415 match expr {
416 IrExpr::Compare { left, right, .. }
417 | IrExpr::Arithmetic { left, right, .. }
418 | IrExpr::And(left, right)
419 | IrExpr::Or(left, right)
420 | IrExpr::Implies(left, right) => {
421 collect_call_names(left, names);
422 collect_call_names(right, names);
423 }
424 IrExpr::Not(inner) | IrExpr::Old(inner) => collect_call_names(inner, names),
425 IrExpr::FieldAccess { root, .. } => collect_call_names(root, names),
426 IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
427 collect_call_names(body, names);
428 }
429 IrExpr::Var(_) | IrExpr::Literal(_) | IrExpr::Call { .. } => {}
430 }
431}
432
433fn verify_function(
434 func: &Function,
435 known_types: &HashSet<&str>,
436 call_names: &HashSet<&str>,
437 errors: &mut Vec<VerifyError>,
438) {
439 let param_names: HashSet<&str> = func.params.iter().map(|p| p.name.as_str()).collect();
440
441 if !func.postconditions.is_empty() && func.params.is_empty() {
443 errors.push(VerifyError {
444 kind: VerifyErrorKind::PostconditionWithoutParams {
445 function: func.name.clone(),
446 },
447 trace: func.trace.clone(),
448 });
449 }
450
451 for pre in &func.preconditions {
453 check_no_old(&pre.expr, &pre.trace, errors);
454 check_bound_vars(
455 &pre.expr,
456 ¶m_names,
457 &HashSet::new(),
458 call_names,
459 &pre.trace,
460 errors,
461 );
462 }
463
464 for post in &func.postconditions {
466 let (expr, trace) = match post {
467 Postcondition::Always { expr, trace } => (expr, trace),
468 Postcondition::When { guard, expr, trace } => {
469 check_bound_vars(
470 guard,
471 ¶m_names,
472 &HashSet::new(),
473 call_names,
474 trace,
475 errors,
476 );
477 (expr, trace)
478 }
479 };
480 check_bound_vars(
481 expr,
482 ¶m_names,
483 &HashSet::new(),
484 call_names,
485 trace,
486 errors,
487 );
488
489 let vars = collect_vars(expr);
491 if !vars.iter().any(|v| param_names.contains(v.as_str())) {
492 errors.push(VerifyError {
493 kind: VerifyErrorKind::DisconnectedPostcondition {
494 function: func.name.clone(),
495 },
496 trace: trace.clone(),
497 });
498 }
499 }
500
501 for pre in &func.preconditions {
503 check_quantifier_types(&pre.expr, known_types, &pre.trace, errors);
504 }
505 for post in &func.postconditions {
506 match post {
507 Postcondition::Always { expr, trace } => {
508 check_quantifier_types(expr, known_types, trace, errors);
509 }
510 Postcondition::When {
511 guard, expr, trace, ..
512 } => {
513 check_quantifier_types(guard, known_types, trace, errors);
514 check_quantifier_types(expr, known_types, trace, errors);
515 }
516 }
517 }
518}
519
520fn verify_invariant(
521 inv: &Invariant,
522 known_types: &HashSet<&str>,
523 call_names: &HashSet<&str>,
524 errors: &mut Vec<VerifyError>,
525) {
526 check_quantifier_types(&inv.expr, known_types, &inv.trace, errors);
528 check_bound_vars(
530 &inv.expr,
531 &HashSet::new(),
532 &HashSet::new(),
533 call_names,
534 &inv.trace,
535 errors,
536 );
537}
538
539fn verify_edge_guard(
540 guard: &EdgeGuard,
541 known_types: &HashSet<&str>,
542 errors: &mut Vec<VerifyError>,
543) {
544 check_no_old(&guard.condition, &guard.trace, errors);
545 check_quantifier_types(&guard.condition, known_types, &guard.trace, errors);
546 for (_, arg_expr) in &guard.args {
547 check_no_old(arg_expr, &guard.trace, errors);
548 }
549}
550
551fn check_no_old(expr: &IrExpr, trace: &SourceTrace, errors: &mut Vec<VerifyError>) {
555 match expr {
556 IrExpr::Old(_) => {
557 errors.push(VerifyError {
558 kind: VerifyErrorKind::OldOutsidePoscondition,
559 trace: trace.clone(),
560 });
561 }
562 _ => {
563 for_each_child(expr, |child| check_no_old(child, trace, errors));
564 }
565 }
566}
567
568fn check_bound_vars(
570 expr: &IrExpr,
571 params: &HashSet<&str>,
572 quantifier_bindings: &HashSet<&str>,
573 call_names: &HashSet<&str>,
574 trace: &SourceTrace,
575 errors: &mut Vec<VerifyError>,
576) {
577 match expr {
578 IrExpr::Var(name) => {
579 let is_variant = name.starts_with(|c: char| c.is_ascii_uppercase());
581 let is_call = call_names.contains(name.as_str());
583 if !is_variant
584 && !is_call
585 && !params.contains(name.as_str())
586 && !quantifier_bindings.contains(name.as_str())
587 {
588 errors.push(VerifyError {
589 kind: VerifyErrorKind::UnboundVariable { name: name.clone() },
590 trace: trace.clone(),
591 });
592 }
593 }
594 IrExpr::Forall { binding, body, .. } | IrExpr::Exists { binding, body, .. } => {
595 let mut extended = quantifier_bindings.clone();
596 extended.insert(binding.as_str());
597 check_bound_vars(body, params, &extended, call_names, trace, errors);
598 }
599 _ => {
600 for_each_child(expr, |child| {
601 check_bound_vars(
602 child,
603 params,
604 quantifier_bindings,
605 call_names,
606 trace,
607 errors,
608 );
609 });
610 }
611 }
612}
613
614fn check_quantifier_types(
616 expr: &IrExpr,
617 known_types: &HashSet<&str>,
618 trace: &SourceTrace,
619 errors: &mut Vec<VerifyError>,
620) {
621 match expr {
622 IrExpr::Forall { ty, body, .. } | IrExpr::Exists { ty, body, .. } => {
623 if !known_types.contains(ty.as_str()) {
624 errors.push(VerifyError {
625 kind: VerifyErrorKind::UnknownQuantifierType { ty: ty.clone() },
626 trace: trace.clone(),
627 });
628 }
629 check_quantifier_types(body, known_types, trace, errors);
630 }
631 _ => {
632 for_each_child(expr, |child| {
633 check_quantifier_types(child, known_types, trace, errors);
634 });
635 }
636 }
637}
638
639fn collect_vars(expr: &IrExpr) -> Vec<String> {
641 let mut vars = Vec::new();
642 collect_vars_inner(expr, &mut vars);
643 vars
644}
645
646fn collect_vars_inner(expr: &IrExpr, vars: &mut Vec<String>) {
647 match expr {
648 IrExpr::Var(name) => vars.push(name.clone()),
649 _ => for_each_child(expr, |child| collect_vars_inner(child, vars)),
650 }
651}
652
653fn for_each_child(expr: &IrExpr, mut f: impl FnMut(&IrExpr)) {
655 match expr {
656 IrExpr::Compare { left, right, .. }
657 | IrExpr::Arithmetic { left, right, .. }
658 | IrExpr::And(left, right)
659 | IrExpr::Or(left, right)
660 | IrExpr::Implies(left, right) => {
661 f(left);
662 f(right);
663 }
664 IrExpr::Not(inner) | IrExpr::Old(inner) => f(inner),
665 IrExpr::FieldAccess { root, .. } => f(root),
666 IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => f(body),
667 IrExpr::Call { args, .. } => {
668 for arg in args {
669 f(arg);
670 }
671 }
672 IrExpr::Var(_) | IrExpr::Literal(_) => {}
673 }
674}