1use std::collections::{HashMap, HashSet};
17
18use serde::{Deserialize, Serialize};
19
20use crate::types::*;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct VerifyError {
25 pub kind: VerifyErrorKind,
26 pub trace: SourceTrace,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
30pub enum VerifyErrorKind {
31 UnboundVariable { name: String },
33 OldOutsidePoscondition,
35 PostconditionWithoutParams { function: String },
37 UnknownQuantifierType { ty: String },
39 DisconnectedPostcondition { function: String },
41}
42
43impl std::fmt::Display for VerifyError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match &self.kind {
46 VerifyErrorKind::UnboundVariable { name } => {
47 write!(f, "unbound variable `{name}`")
48 }
49 VerifyErrorKind::OldOutsidePoscondition => {
50 write!(f, "`old()` used outside of postcondition")
51 }
52 VerifyErrorKind::PostconditionWithoutParams { function } => {
53 write!(
54 f,
55 "function `{function}` has postconditions but no parameters"
56 )
57 }
58 VerifyErrorKind::UnknownQuantifierType { ty } => {
59 write!(f, "quantifier references unknown type `{ty}`")
60 }
61 VerifyErrorKind::DisconnectedPostcondition { function } => {
62 write!(
63 f,
64 "postcondition in `{function}` doesn't reference any parameter"
65 )
66 }
67 }
68 }
69}
70
71pub fn verify_module(module: &Module) -> Vec<VerifyError> {
73 let mut errors = Vec::new();
74
75 let known_types: HashSet<&str> = module
77 .structs
78 .iter()
79 .map(|s| s.name.as_str())
80 .chain(module.functions.iter().map(|f| f.name.as_str()))
81 .collect();
82
83 let mut call_names = HashSet::new();
88 collect_module_call_names(module, &mut call_names);
89
90 for func in &module.functions {
91 verify_function(func, &known_types, &call_names, &mut errors);
92 }
93
94 for inv in &module.invariants {
95 verify_invariant(inv, &known_types, &call_names, &mut errors);
96 }
97
98 for guard in &module.edge_guards {
99 verify_edge_guard(guard, &known_types, &mut errors);
100 }
101
102 errors
103}
104
105#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
110pub struct Obligation {
111 pub action: String,
113 pub invariant: String,
115 pub entity: String,
117 pub fields: Vec<String>,
119 pub kind: ObligationKind,
121}
122
123#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
124pub enum ObligationKind {
125 InvariantPreservation,
128 TemporalProperty,
131}
132
133impl std::fmt::Display for Obligation {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match &self.kind {
136 ObligationKind::InvariantPreservation => {
137 write!(
138 f,
139 "{} modifies {}.{{{}}} (constrained by {})",
140 self.action,
141 self.entity,
142 self.fields.join(", "),
143 self.invariant,
144 )
145 }
146 ObligationKind::TemporalProperty => {
147 write!(
148 f,
149 "{} must satisfy temporal property {}",
150 self.action, self.invariant,
151 )
152 }
153 }
154 }
155}
156
157pub fn analyze_obligations(module: &Module) -> Vec<Obligation> {
163 let mut obligations = Vec::new();
164
165 let struct_fields: HashMap<&str, Vec<&str>> = module
167 .structs
168 .iter()
169 .map(|s| {
170 (
171 s.name.as_str(),
172 s.fields.iter().map(|f| f.name.as_str()).collect(),
173 )
174 })
175 .collect();
176
177 let func_entity_params: HashMap<&str, Vec<(&str, &str)>> = module
180 .functions
181 .iter()
182 .map(|func| {
183 let entity_params: Vec<(&str, &str)> = func
184 .params
185 .iter()
186 .filter_map(|p| match &p.ty {
187 IrType::Named(t) | IrType::Struct(t)
188 if struct_fields.contains_key(t.as_str()) =>
189 {
190 Some((p.name.as_str(), t.as_str()))
191 }
192 _ => None,
193 })
194 .collect();
195 (func.name.as_str(), entity_params)
196 })
197 .collect();
198
199 let mut modified_fields: HashMap<&str, HashSet<(&str, &str)>> = HashMap::new();
202 for func in &module.functions {
203 let entity_params = &func_entity_params[func.name.as_str()];
204 let param_to_entity: HashMap<&str, &str> = entity_params.iter().copied().collect();
205 let mut fields = HashSet::new();
206 for post in &func.postconditions {
207 let exprs: Vec<&IrExpr> = match post {
208 Postcondition::Always { expr, .. } => vec![expr],
209 Postcondition::When { guard, expr, .. } => vec![guard, expr],
210 };
211 for expr in exprs {
212 collect_old_field_accesses(expr, ¶m_to_entity, &mut fields);
213 }
214 }
215 modified_fields.insert(func.name.as_str(), fields);
216 }
217
218 for inv in &module.invariants {
220 if let IrExpr::Forall { binding, ty, body } = &inv.expr {
221 let is_action = module.functions.iter().any(|f| f.name == *ty);
223 if is_action {
224 obligations.push(Obligation {
226 action: ty.clone(),
227 invariant: inv.name.clone(),
228 entity: ty.clone(),
229 fields: vec![],
230 kind: ObligationKind::TemporalProperty,
231 });
232 continue;
233 }
234
235 let constrained = collect_field_accesses_on(body, binding);
238
239 for func in &module.functions {
241 if let Some(mods) = modified_fields.get(func.name.as_str()) {
242 let overlapping: Vec<String> = constrained
243 .iter()
244 .filter(|f| mods.contains(&(ty.as_str(), f.as_str())))
245 .cloned()
246 .collect();
247 if !overlapping.is_empty() {
248 obligations.push(Obligation {
249 action: func.name.clone(),
250 invariant: inv.name.clone(),
251 entity: ty.clone(),
252 fields: overlapping,
253 kind: ObligationKind::InvariantPreservation,
254 });
255 }
256 }
257 }
258 }
259 }
260
261 obligations
262}
263
264fn collect_old_field_accesses<'a>(
270 expr: &'a IrExpr,
271 param_to_entity: &HashMap<&str, &'a str>,
272 result: &mut HashSet<(&'a str, &'a str)>,
273) {
274 match expr {
275 IrExpr::Old(inner) => {
276 collect_inner_field_accesses(inner, param_to_entity, result);
277 }
278 _ => {
279 match expr {
281 IrExpr::Compare { left, right, .. }
282 | IrExpr::Arithmetic { left, right, .. }
283 | IrExpr::And(left, right)
284 | IrExpr::Or(left, right)
285 | IrExpr::Implies(left, right) => {
286 collect_old_field_accesses(left, param_to_entity, result);
287 collect_old_field_accesses(right, param_to_entity, result);
288 }
289 IrExpr::Not(inner) => {
290 collect_old_field_accesses(inner, param_to_entity, result);
291 }
292 IrExpr::FieldAccess { root, .. } => {
293 collect_old_field_accesses(root, param_to_entity, result);
294 }
295 IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
296 collect_old_field_accesses(body, param_to_entity, result);
297 }
298 IrExpr::Call { args, .. } => {
299 for arg in args {
300 collect_old_field_accesses(arg, param_to_entity, result);
301 }
302 }
303 IrExpr::Var(_) | IrExpr::Literal(_) | IrExpr::Old(_) => {}
304 }
305 }
306 }
307}
308
309fn collect_inner_field_accesses<'a>(
311 expr: &'a IrExpr,
312 param_to_entity: &HashMap<&str, &'a str>,
313 result: &mut HashSet<(&'a str, &'a str)>,
314) {
315 match expr {
316 IrExpr::FieldAccess { root, field } => {
317 if let IrExpr::Var(var) = root.as_ref()
319 && let Some(&entity) = param_to_entity.get(var.as_str())
320 {
321 result.insert((entity, field.as_str()));
322 }
323 collect_inner_field_accesses(root, param_to_entity, result);
325 }
326 _ => match expr {
327 IrExpr::Compare { left, right, .. }
328 | IrExpr::Arithmetic { left, right, .. }
329 | IrExpr::And(left, right)
330 | IrExpr::Or(left, right)
331 | IrExpr::Implies(left, right) => {
332 collect_inner_field_accesses(left, param_to_entity, result);
333 collect_inner_field_accesses(right, param_to_entity, result);
334 }
335 IrExpr::Not(inner) | IrExpr::Old(inner) => {
336 collect_inner_field_accesses(inner, param_to_entity, result);
337 }
338 IrExpr::FieldAccess { .. } => unreachable!(),
339 IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
340 collect_inner_field_accesses(body, param_to_entity, result);
341 }
342 IrExpr::Call { args, .. } => {
343 for arg in args {
344 collect_inner_field_accesses(arg, param_to_entity, result);
345 }
346 }
347 IrExpr::Var(_) | IrExpr::Literal(_) => {}
348 },
349 }
350}
351
352fn collect_field_accesses_on(expr: &IrExpr, binding: &str) -> Vec<String> {
357 let mut fields = Vec::new();
358 collect_fields_on_inner(expr, binding, &mut fields);
359 fields.sort();
360 fields.dedup();
361 fields
362}
363
364fn collect_fields_on_inner(expr: &IrExpr, binding: &str, fields: &mut Vec<String>) {
365 match expr {
366 IrExpr::FieldAccess { root, field } => {
367 if let IrExpr::Var(var) = root.as_ref()
368 && var == binding
369 {
370 fields.push(field.clone());
371 }
372 collect_fields_on_inner(root, binding, fields);
373 }
374 _ => for_each_child(expr, |child| {
375 collect_fields_on_inner(child, binding, fields)
376 }),
377 }
378}
379
380pub(crate) fn collect_module_call_names<'a>(module: &'a Module, names: &mut HashSet<&'a str>) {
384 for func in &module.functions {
385 for pre in &func.preconditions {
386 collect_call_names(&pre.expr, names);
387 }
388 for post in &func.postconditions {
389 match post {
390 Postcondition::Always { expr, .. } => collect_call_names(expr, names),
391 Postcondition::When { guard, expr, .. } => {
392 collect_call_names(guard, names);
393 collect_call_names(expr, names);
394 }
395 }
396 }
397 }
398 for inv in &module.invariants {
399 collect_call_names(&inv.expr, names);
400 }
401 for guard in &module.edge_guards {
402 collect_call_names(&guard.condition, names);
403 for (_, arg) in &guard.args {
404 collect_call_names(arg, names);
405 }
406 }
407}
408
409fn collect_call_names<'a>(expr: &'a IrExpr, names: &mut HashSet<&'a str>) {
410 if let IrExpr::Call { name, args } = expr {
411 names.insert(name.as_str());
412 for arg in args {
413 collect_call_names(arg, names);
414 }
415 return;
416 }
417 match expr {
418 IrExpr::Compare { left, right, .. }
419 | IrExpr::Arithmetic { left, right, .. }
420 | IrExpr::And(left, right)
421 | IrExpr::Or(left, right)
422 | IrExpr::Implies(left, right) => {
423 collect_call_names(left, names);
424 collect_call_names(right, names);
425 }
426 IrExpr::Not(inner) | IrExpr::Old(inner) => collect_call_names(inner, names),
427 IrExpr::FieldAccess { root, .. } => collect_call_names(root, names),
428 IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
429 collect_call_names(body, names);
430 }
431 IrExpr::Var(_) | IrExpr::Literal(_) | IrExpr::Call { .. } => {}
432 }
433}
434
435pub(crate) fn verify_function(
436 func: &Function,
437 known_types: &HashSet<&str>,
438 call_names: &HashSet<&str>,
439 errors: &mut Vec<VerifyError>,
440) {
441 let param_names: HashSet<&str> = func.params.iter().map(|p| p.name.as_str()).collect();
442
443 if !func.postconditions.is_empty() && func.params.is_empty() {
445 errors.push(VerifyError {
446 kind: VerifyErrorKind::PostconditionWithoutParams {
447 function: func.name.clone(),
448 },
449 trace: func.trace.clone(),
450 });
451 }
452
453 for pre in &func.preconditions {
455 check_no_old(&pre.expr, &pre.trace, errors);
456 check_bound_vars(
457 &pre.expr,
458 ¶m_names,
459 &HashSet::new(),
460 call_names,
461 &pre.trace,
462 errors,
463 );
464 }
465
466 for post in &func.postconditions {
468 let (expr, trace) = match post {
469 Postcondition::Always { expr, trace } => (expr, trace),
470 Postcondition::When { guard, expr, trace } => {
471 check_bound_vars(
472 guard,
473 ¶m_names,
474 &HashSet::new(),
475 call_names,
476 trace,
477 errors,
478 );
479 (expr, trace)
480 }
481 };
482 check_bound_vars(
483 expr,
484 ¶m_names,
485 &HashSet::new(),
486 call_names,
487 trace,
488 errors,
489 );
490
491 let vars = collect_vars(expr);
493 if !vars.iter().any(|v| param_names.contains(v.as_str())) {
494 errors.push(VerifyError {
495 kind: VerifyErrorKind::DisconnectedPostcondition {
496 function: func.name.clone(),
497 },
498 trace: trace.clone(),
499 });
500 }
501 }
502
503 for pre in &func.preconditions {
505 check_quantifier_types(&pre.expr, known_types, &pre.trace, errors);
506 }
507 for post in &func.postconditions {
508 match post {
509 Postcondition::Always { expr, trace } => {
510 check_quantifier_types(expr, known_types, trace, errors);
511 }
512 Postcondition::When {
513 guard, expr, trace, ..
514 } => {
515 check_quantifier_types(guard, known_types, trace, errors);
516 check_quantifier_types(expr, known_types, trace, errors);
517 }
518 }
519 }
520}
521
522pub(crate) fn verify_invariant(
523 inv: &Invariant,
524 known_types: &HashSet<&str>,
525 call_names: &HashSet<&str>,
526 errors: &mut Vec<VerifyError>,
527) {
528 check_quantifier_types(&inv.expr, known_types, &inv.trace, errors);
530 check_bound_vars(
532 &inv.expr,
533 &HashSet::new(),
534 &HashSet::new(),
535 call_names,
536 &inv.trace,
537 errors,
538 );
539}
540
541pub(crate) fn verify_edge_guard(
542 guard: &EdgeGuard,
543 known_types: &HashSet<&str>,
544 errors: &mut Vec<VerifyError>,
545) {
546 check_no_old(&guard.condition, &guard.trace, errors);
547 check_quantifier_types(&guard.condition, known_types, &guard.trace, errors);
548 for (_, arg_expr) in &guard.args {
549 check_no_old(arg_expr, &guard.trace, errors);
550 }
551}
552
553fn check_no_old(expr: &IrExpr, trace: &SourceTrace, errors: &mut Vec<VerifyError>) {
557 match expr {
558 IrExpr::Old(_) => {
559 errors.push(VerifyError {
560 kind: VerifyErrorKind::OldOutsidePoscondition,
561 trace: trace.clone(),
562 });
563 }
564 _ => {
565 for_each_child(expr, |child| check_no_old(child, trace, errors));
566 }
567 }
568}
569
570fn check_bound_vars(
572 expr: &IrExpr,
573 params: &HashSet<&str>,
574 quantifier_bindings: &HashSet<&str>,
575 call_names: &HashSet<&str>,
576 trace: &SourceTrace,
577 errors: &mut Vec<VerifyError>,
578) {
579 match expr {
580 IrExpr::Var(name) => {
581 let is_variant = name.starts_with(|c: char| c.is_ascii_uppercase());
583 let is_call = call_names.contains(name.as_str());
585 if !is_variant
586 && !is_call
587 && !params.contains(name.as_str())
588 && !quantifier_bindings.contains(name.as_str())
589 {
590 errors.push(VerifyError {
591 kind: VerifyErrorKind::UnboundVariable { name: name.clone() },
592 trace: trace.clone(),
593 });
594 }
595 }
596 IrExpr::Forall { binding, body, .. } | IrExpr::Exists { binding, body, .. } => {
597 let mut extended = quantifier_bindings.clone();
598 extended.insert(binding.as_str());
599 check_bound_vars(body, params, &extended, call_names, trace, errors);
600 }
601 _ => {
602 for_each_child(expr, |child| {
603 check_bound_vars(
604 child,
605 params,
606 quantifier_bindings,
607 call_names,
608 trace,
609 errors,
610 );
611 });
612 }
613 }
614}
615
616fn check_quantifier_types(
618 expr: &IrExpr,
619 known_types: &HashSet<&str>,
620 trace: &SourceTrace,
621 errors: &mut Vec<VerifyError>,
622) {
623 match expr {
624 IrExpr::Forall { ty, body, .. } | IrExpr::Exists { ty, body, .. } => {
625 if !known_types.contains(ty.as_str()) {
626 errors.push(VerifyError {
627 kind: VerifyErrorKind::UnknownQuantifierType { ty: ty.clone() },
628 trace: trace.clone(),
629 });
630 }
631 check_quantifier_types(body, known_types, trace, errors);
632 }
633 _ => {
634 for_each_child(expr, |child| {
635 check_quantifier_types(child, known_types, trace, errors);
636 });
637 }
638 }
639}
640
641fn collect_vars(expr: &IrExpr) -> Vec<String> {
643 let mut vars = Vec::new();
644 collect_vars_inner(expr, &mut vars);
645 vars
646}
647
648fn collect_vars_inner(expr: &IrExpr, vars: &mut Vec<String>) {
649 match expr {
650 IrExpr::Var(name) => vars.push(name.clone()),
651 _ => for_each_child(expr, |child| collect_vars_inner(child, vars)),
652 }
653}
654
655fn for_each_child(expr: &IrExpr, mut f: impl FnMut(&IrExpr)) {
657 match expr {
658 IrExpr::Compare { left, right, .. }
659 | IrExpr::Arithmetic { left, right, .. }
660 | IrExpr::And(left, right)
661 | IrExpr::Or(left, right)
662 | IrExpr::Implies(left, right) => {
663 f(left);
664 f(right);
665 }
666 IrExpr::Not(inner) | IrExpr::Old(inner) => f(inner),
667 IrExpr::FieldAccess { root, .. } => f(root),
668 IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => f(body),
669 IrExpr::Call { args, .. } => {
670 for arg in args {
671 f(arg);
672 }
673 }
674 IrExpr::Var(_) | IrExpr::Literal(_) => {}
675 }
676}