Skip to main content

intent_ir/
verify.rs

1//! IR verification pass.
2//!
3//! Validates structural and logical properties of the IR:
4//! - All variable references in expressions are bound (params or quantifier bindings)
5//!   (uppercase identifiers are treated as union variant labels and skipped)
6//! - `old()` only appears in postconditions or temporal invariants
7//! - Postconditions reference at least one parameter (otherwise they're trivially unverifiable)
8//! - Quantifiers reference known types (structs or functions) in this module
9//! - Functions with postconditions have at least one parameter (nothing to ensure about)
10//!
11//! Also performs coherence analysis:
12//! - Extracts verification obligations (invariant-action relationships)
13//! - Tracks which entity fields each action modifies (via `old()` in postconditions)
14//! - Matches modified fields against invariant constraints
15
16use std::collections::{HashMap, HashSet};
17
18use serde::{Deserialize, Serialize};
19
20use crate::types::*;
21
22/// A verification diagnostic.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct VerifyError {
25    pub kind: VerifyErrorKind,
26    pub trace: SourceTrace,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
30pub enum VerifyErrorKind {
31    /// A variable is referenced but not bound as a parameter or quantifier binding.
32    UnboundVariable { name: String },
33    /// `old()` appears outside of a postcondition context.
34    OldOutsidePoscondition,
35    /// A function has postconditions but no parameters.
36    PostconditionWithoutParams { function: String },
37    /// A quantifier references a type not defined in this module (struct or function).
38    UnknownQuantifierType { ty: String },
39    /// A postcondition doesn't reference any function parameter.
40    DisconnectedPostcondition { function: String },
41}
42
43impl std::fmt::Display for VerifyError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match &self.kind {
46            VerifyErrorKind::UnboundVariable { name } => {
47                write!(f, "unbound variable `{name}`")
48            }
49            VerifyErrorKind::OldOutsidePoscondition => {
50                write!(f, "`old()` used outside of postcondition")
51            }
52            VerifyErrorKind::PostconditionWithoutParams { function } => {
53                write!(
54                    f,
55                    "function `{function}` has postconditions but no parameters"
56                )
57            }
58            VerifyErrorKind::UnknownQuantifierType { ty } => {
59                write!(f, "quantifier references unknown type `{ty}`")
60            }
61            VerifyErrorKind::DisconnectedPostcondition { function } => {
62                write!(
63                    f,
64                    "postcondition in `{function}` doesn't reference any parameter"
65                )
66            }
67        }
68    }
69}
70
71/// Run verification checks on an IR module.
72pub fn verify_module(module: &Module) -> Vec<VerifyError> {
73    let mut errors = Vec::new();
74
75    // Known types for quantifiers: both structs and functions (actions).
76    let known_types: HashSet<&str> = module
77        .structs
78        .iter()
79        .map(|s| s.name.as_str())
80        .chain(module.functions.iter().map(|f| f.name.as_str()))
81        .collect();
82
83    // Collect all names that appear in Call positions — these are domain-level
84    // functions (now, lookup, etc.) and should be treated as implicitly bound.
85    // Also includes names the parser may lower as Var instead of Call (e.g., `now()`
86    // becomes Var("now") due to grammar limitations).
87    let mut call_names = HashSet::new();
88    collect_module_call_names(module, &mut call_names);
89
90    for func in &module.functions {
91        verify_function(func, &known_types, &call_names, &mut errors);
92    }
93
94    for inv in &module.invariants {
95        verify_invariant(inv, &known_types, &call_names, &mut errors);
96    }
97
98    for guard in &module.edge_guards {
99        verify_edge_guard(guard, &known_types, &mut errors);
100    }
101
102    errors
103}
104
105// ── Coherence analysis ─────────────────────────────────────
106
107/// A verification obligation — something that needs to be proven
108/// for the module to be correct.
109#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
110pub struct Obligation {
111    /// The action (function) that triggers this obligation.
112    pub action: String,
113    /// The invariant that must be preserved.
114    pub invariant: String,
115    /// The entity type involved.
116    pub entity: String,
117    /// The specific fields that the action modifies and the invariant constrains.
118    pub fields: Vec<String>,
119    /// The kind of obligation.
120    pub kind: ObligationKind,
121}
122
123#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
124pub enum ObligationKind {
125    /// Action modifies fields that an entity invariant constrains.
126    /// The invariant quantifies over the entity type (e.g., `forall a: Account => ...`).
127    InvariantPreservation,
128    /// A temporal invariant directly references this action via quantifier
129    /// (e.g., `forall t: Transfer => old(...) == ...`).
130    TemporalProperty,
131}
132
133impl std::fmt::Display for Obligation {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        match &self.kind {
136            ObligationKind::InvariantPreservation => {
137                write!(
138                    f,
139                    "{} modifies {}.{{{}}} (constrained by {})",
140                    self.action,
141                    self.entity,
142                    self.fields.join(", "),
143                    self.invariant,
144                )
145            }
146            ObligationKind::TemporalProperty => {
147                write!(
148                    f,
149                    "{} must satisfy temporal property {}",
150                    self.action, self.invariant,
151                )
152            }
153        }
154    }
155}
156
157/// Analyze a verified IR module for verification obligations.
158///
159/// Returns a list of obligations that describe what logical properties
160/// need to hold for the module to be correct. These are informational —
161/// not errors — representing proof goals a formal verifier would check.
162pub fn analyze_obligations(module: &Module) -> Vec<Obligation> {
163    let mut obligations = Vec::new();
164
165    // Build a map of struct name → field names for lookup.
166    let struct_fields: HashMap<&str, Vec<&str>> = module
167        .structs
168        .iter()
169        .map(|s| {
170            (
171                s.name.as_str(),
172                s.fields.iter().map(|f| f.name.as_str()).collect(),
173            )
174        })
175        .collect();
176
177    // Build a map of param name → entity type for each function.
178    // Only includes params whose type is an entity (struct).
179    let func_entity_params: HashMap<&str, Vec<(&str, &str)>> = module
180        .functions
181        .iter()
182        .map(|func| {
183            let entity_params: Vec<(&str, &str)> = func
184                .params
185                .iter()
186                .filter_map(|p| match &p.ty {
187                    IrType::Named(t) | IrType::Struct(t)
188                        if struct_fields.contains_key(t.as_str()) =>
189                    {
190                        Some((p.name.as_str(), t.as_str()))
191                    }
192                    _ => None,
193                })
194                .collect();
195            (func.name.as_str(), entity_params)
196        })
197        .collect();
198
199    // For each function, collect fields modified in postconditions (via old()).
200    // Result: function name → set of (entity_type, field_name).
201    let mut modified_fields: HashMap<&str, HashSet<(&str, &str)>> = HashMap::new();
202    for func in &module.functions {
203        let entity_params = &func_entity_params[func.name.as_str()];
204        let param_to_entity: HashMap<&str, &str> = entity_params.iter().copied().collect();
205        let mut fields = HashSet::new();
206        for post in &func.postconditions {
207            let exprs: Vec<&IrExpr> = match post {
208                Postcondition::Always { expr, .. } => vec![expr],
209                Postcondition::When { guard, expr, .. } => vec![guard, expr],
210            };
211            for expr in exprs {
212                collect_old_field_accesses(expr, &param_to_entity, &mut fields);
213            }
214        }
215        modified_fields.insert(func.name.as_str(), fields);
216    }
217
218    // For each invariant, determine what it constrains.
219    for inv in &module.invariants {
220        if let IrExpr::Forall { binding, ty, body } = &inv.expr {
221            // Check if this is a temporal invariant (quantifies over an action).
222            let is_action = module.functions.iter().any(|f| f.name == *ty);
223            if is_action {
224                // Temporal property: directly references an action.
225                obligations.push(Obligation {
226                    action: ty.clone(),
227                    invariant: inv.name.clone(),
228                    entity: ty.clone(),
229                    fields: vec![],
230                    kind: ObligationKind::TemporalProperty,
231                });
232                continue;
233            }
234
235            // Entity invariant: quantifies over an entity type.
236            // Collect fields the invariant constrains.
237            let constrained = collect_field_accesses_on(body, binding);
238
239            // Find all actions that modify any of these fields on this entity type.
240            for func in &module.functions {
241                if let Some(mods) = modified_fields.get(func.name.as_str()) {
242                    let overlapping: Vec<String> = constrained
243                        .iter()
244                        .filter(|f| mods.contains(&(ty.as_str(), f.as_str())))
245                        .cloned()
246                        .collect();
247                    if !overlapping.is_empty() {
248                        obligations.push(Obligation {
249                            action: func.name.clone(),
250                            invariant: inv.name.clone(),
251                            entity: ty.clone(),
252                            fields: overlapping,
253                            kind: ObligationKind::InvariantPreservation,
254                        });
255                    }
256                }
257            }
258        }
259    }
260
261    obligations
262}
263
264/// Collect field accesses inside `old()` expressions, mapping them to entity types
265/// via the param→entity mapping.
266///
267/// For an expression like `old(from.balance)`, if `from` maps to entity `Account`,
268/// this records `("Account", "balance")`.
269fn collect_old_field_accesses<'a>(
270    expr: &'a IrExpr,
271    param_to_entity: &HashMap<&str, &'a str>,
272    result: &mut HashSet<(&'a str, &'a str)>,
273) {
274    match expr {
275        IrExpr::Old(inner) => {
276            collect_inner_field_accesses(inner, param_to_entity, result);
277        }
278        _ => {
279            // Manual match for recursion (avoid lifetime issues with for_each_child closure)
280            match expr {
281                IrExpr::Compare { left, right, .. }
282                | IrExpr::Arithmetic { left, right, .. }
283                | IrExpr::And(left, right)
284                | IrExpr::Or(left, right)
285                | IrExpr::Implies(left, right) => {
286                    collect_old_field_accesses(left, param_to_entity, result);
287                    collect_old_field_accesses(right, param_to_entity, result);
288                }
289                IrExpr::Not(inner) => {
290                    collect_old_field_accesses(inner, param_to_entity, result);
291                }
292                IrExpr::FieldAccess { root, .. } => {
293                    collect_old_field_accesses(root, param_to_entity, result);
294                }
295                IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
296                    collect_old_field_accesses(body, param_to_entity, result);
297                }
298                IrExpr::Call { args, .. } => {
299                    for arg in args {
300                        collect_old_field_accesses(arg, param_to_entity, result);
301                    }
302                }
303                IrExpr::Var(_) | IrExpr::Literal(_) | IrExpr::Old(_) => {}
304            }
305        }
306    }
307}
308
309/// Collect field accesses within an old() body, resolving param names to entity types.
310fn collect_inner_field_accesses<'a>(
311    expr: &'a IrExpr,
312    param_to_entity: &HashMap<&str, &'a str>,
313    result: &mut HashSet<(&'a str, &'a str)>,
314) {
315    match expr {
316        IrExpr::FieldAccess { root, field } => {
317            // Check if root is a direct param reference: old(param.field)
318            if let IrExpr::Var(var) = root.as_ref()
319                && let Some(&entity) = param_to_entity.get(var.as_str())
320            {
321                result.insert((entity, field.as_str()));
322            }
323            // Also check for chained access: old(param.sub.field)
324            collect_inner_field_accesses(root, param_to_entity, result);
325        }
326        _ => match expr {
327            IrExpr::Compare { left, right, .. }
328            | IrExpr::Arithmetic { left, right, .. }
329            | IrExpr::And(left, right)
330            | IrExpr::Or(left, right)
331            | IrExpr::Implies(left, right) => {
332                collect_inner_field_accesses(left, param_to_entity, result);
333                collect_inner_field_accesses(right, param_to_entity, result);
334            }
335            IrExpr::Not(inner) | IrExpr::Old(inner) => {
336                collect_inner_field_accesses(inner, param_to_entity, result);
337            }
338            IrExpr::FieldAccess { .. } => unreachable!(),
339            IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
340                collect_inner_field_accesses(body, param_to_entity, result);
341            }
342            IrExpr::Call { args, .. } => {
343                for arg in args {
344                    collect_inner_field_accesses(arg, param_to_entity, result);
345                }
346            }
347            IrExpr::Var(_) | IrExpr::Literal(_) => {}
348        },
349    }
350}
351
352/// Collect field names accessed on a specific binding variable in an expression.
353///
354/// For `forall a: Account => a.balance >= 0`, calling this with binding="a"
355/// returns `["balance"]`.
356fn collect_field_accesses_on(expr: &IrExpr, binding: &str) -> Vec<String> {
357    let mut fields = Vec::new();
358    collect_fields_on_inner(expr, binding, &mut fields);
359    fields.sort();
360    fields.dedup();
361    fields
362}
363
364fn collect_fields_on_inner(expr: &IrExpr, binding: &str, fields: &mut Vec<String>) {
365    match expr {
366        IrExpr::FieldAccess { root, field } => {
367            if let IrExpr::Var(var) = root.as_ref()
368                && var == binding
369            {
370                fields.push(field.clone());
371            }
372            collect_fields_on_inner(root, binding, fields);
373        }
374        _ => for_each_child(expr, |child| {
375            collect_fields_on_inner(child, binding, fields)
376        }),
377    }
378}
379
380// ── Structural verification helpers ────────────────────────
381
382/// Collect all function names used in Call expressions across the module.
383pub(crate) fn collect_module_call_names<'a>(module: &'a Module, names: &mut HashSet<&'a str>) {
384    for func in &module.functions {
385        for pre in &func.preconditions {
386            collect_call_names(&pre.expr, names);
387        }
388        for post in &func.postconditions {
389            match post {
390                Postcondition::Always { expr, .. } => collect_call_names(expr, names),
391                Postcondition::When { guard, expr, .. } => {
392                    collect_call_names(guard, names);
393                    collect_call_names(expr, names);
394                }
395            }
396        }
397    }
398    for inv in &module.invariants {
399        collect_call_names(&inv.expr, names);
400    }
401    for guard in &module.edge_guards {
402        collect_call_names(&guard.condition, names);
403        for (_, arg) in &guard.args {
404            collect_call_names(arg, names);
405        }
406    }
407}
408
409fn collect_call_names<'a>(expr: &'a IrExpr, names: &mut HashSet<&'a str>) {
410    if let IrExpr::Call { name, args } = expr {
411        names.insert(name.as_str());
412        for arg in args {
413            collect_call_names(arg, names);
414        }
415        return;
416    }
417    match expr {
418        IrExpr::Compare { left, right, .. }
419        | IrExpr::Arithmetic { left, right, .. }
420        | IrExpr::And(left, right)
421        | IrExpr::Or(left, right)
422        | IrExpr::Implies(left, right) => {
423            collect_call_names(left, names);
424            collect_call_names(right, names);
425        }
426        IrExpr::Not(inner) | IrExpr::Old(inner) => collect_call_names(inner, names),
427        IrExpr::FieldAccess { root, .. } => collect_call_names(root, names),
428        IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
429            collect_call_names(body, names);
430        }
431        IrExpr::Var(_) | IrExpr::Literal(_) | IrExpr::Call { .. } => {}
432    }
433}
434
435pub(crate) fn verify_function(
436    func: &Function,
437    known_types: &HashSet<&str>,
438    call_names: &HashSet<&str>,
439    errors: &mut Vec<VerifyError>,
440) {
441    let param_names: HashSet<&str> = func.params.iter().map(|p| p.name.as_str()).collect();
442
443    // Check: postconditions without parameters
444    if !func.postconditions.is_empty() && func.params.is_empty() {
445        errors.push(VerifyError {
446            kind: VerifyErrorKind::PostconditionWithoutParams {
447                function: func.name.clone(),
448            },
449            trace: func.trace.clone(),
450        });
451    }
452
453    // Check preconditions: no old(), variables must be bound
454    for pre in &func.preconditions {
455        check_no_old(&pre.expr, &pre.trace, errors);
456        check_bound_vars(
457            &pre.expr,
458            &param_names,
459            &HashSet::new(),
460            call_names,
461            &pre.trace,
462            errors,
463        );
464    }
465
466    // Check postconditions: variables must be bound, check param references
467    for post in &func.postconditions {
468        let (expr, trace) = match post {
469            Postcondition::Always { expr, trace } => (expr, trace),
470            Postcondition::When { guard, expr, trace } => {
471                check_bound_vars(
472                    guard,
473                    &param_names,
474                    &HashSet::new(),
475                    call_names,
476                    trace,
477                    errors,
478                );
479                (expr, trace)
480            }
481        };
482        check_bound_vars(
483            expr,
484            &param_names,
485            &HashSet::new(),
486            call_names,
487            trace,
488            errors,
489        );
490
491        // Check postcondition references at least one parameter
492        let vars = collect_vars(expr);
493        if !vars.iter().any(|v| param_names.contains(v.as_str())) {
494            errors.push(VerifyError {
495                kind: VerifyErrorKind::DisconnectedPostcondition {
496                    function: func.name.clone(),
497                },
498                trace: trace.clone(),
499            });
500        }
501    }
502
503    // Check quantifier types in all expressions
504    for pre in &func.preconditions {
505        check_quantifier_types(&pre.expr, known_types, &pre.trace, errors);
506    }
507    for post in &func.postconditions {
508        match post {
509            Postcondition::Always { expr, trace } => {
510                check_quantifier_types(expr, known_types, trace, errors);
511            }
512            Postcondition::When {
513                guard, expr, trace, ..
514            } => {
515                check_quantifier_types(guard, known_types, trace, errors);
516                check_quantifier_types(expr, known_types, trace, errors);
517            }
518        }
519    }
520}
521
522pub(crate) fn verify_invariant(
523    inv: &Invariant,
524    known_types: &HashSet<&str>,
525    call_names: &HashSet<&str>,
526    errors: &mut Vec<VerifyError>,
527) {
528    // old() is valid in invariants for temporal properties (e.g., conservation laws)
529    check_quantifier_types(&inv.expr, known_types, &inv.trace, errors);
530    // Invariant body variables are bound by quantifiers, so we check with empty params
531    check_bound_vars(
532        &inv.expr,
533        &HashSet::new(),
534        &HashSet::new(),
535        call_names,
536        &inv.trace,
537        errors,
538    );
539}
540
541pub(crate) fn verify_edge_guard(
542    guard: &EdgeGuard,
543    known_types: &HashSet<&str>,
544    errors: &mut Vec<VerifyError>,
545) {
546    check_no_old(&guard.condition, &guard.trace, errors);
547    check_quantifier_types(&guard.condition, known_types, &guard.trace, errors);
548    for (_, arg_expr) in &guard.args {
549        check_no_old(arg_expr, &guard.trace, errors);
550    }
551}
552
553// ── Expression walkers ──────────────────────────────────────
554
555/// Check that `old()` does not appear in this expression.
556fn check_no_old(expr: &IrExpr, trace: &SourceTrace, errors: &mut Vec<VerifyError>) {
557    match expr {
558        IrExpr::Old(_) => {
559            errors.push(VerifyError {
560                kind: VerifyErrorKind::OldOutsidePoscondition,
561                trace: trace.clone(),
562            });
563        }
564        _ => {
565            for_each_child(expr, |child| check_no_old(child, trace, errors));
566        }
567    }
568}
569
570/// Check that all variable references are bound.
571fn check_bound_vars(
572    expr: &IrExpr,
573    params: &HashSet<&str>,
574    quantifier_bindings: &HashSet<&str>,
575    call_names: &HashSet<&str>,
576    trace: &SourceTrace,
577    errors: &mut Vec<VerifyError>,
578) {
579    match expr {
580        IrExpr::Var(name) => {
581            // Uppercase identifiers are union variant labels (Active, Frozen, etc.)
582            let is_variant = name.starts_with(|c: char| c.is_ascii_uppercase());
583            // Names that appear as function calls elsewhere are domain-level references
584            let is_call = call_names.contains(name.as_str());
585            if !is_variant
586                && !is_call
587                && !params.contains(name.as_str())
588                && !quantifier_bindings.contains(name.as_str())
589            {
590                errors.push(VerifyError {
591                    kind: VerifyErrorKind::UnboundVariable { name: name.clone() },
592                    trace: trace.clone(),
593                });
594            }
595        }
596        IrExpr::Forall { binding, body, .. } | IrExpr::Exists { binding, body, .. } => {
597            let mut extended = quantifier_bindings.clone();
598            extended.insert(binding.as_str());
599            check_bound_vars(body, params, &extended, call_names, trace, errors);
600        }
601        _ => {
602            for_each_child(expr, |child| {
603                check_bound_vars(
604                    child,
605                    params,
606                    quantifier_bindings,
607                    call_names,
608                    trace,
609                    errors,
610                );
611            });
612        }
613    }
614}
615
616/// Check that quantifier types reference known types (structs or functions).
617fn check_quantifier_types(
618    expr: &IrExpr,
619    known_types: &HashSet<&str>,
620    trace: &SourceTrace,
621    errors: &mut Vec<VerifyError>,
622) {
623    match expr {
624        IrExpr::Forall { ty, body, .. } | IrExpr::Exists { ty, body, .. } => {
625            if !known_types.contains(ty.as_str()) {
626                errors.push(VerifyError {
627                    kind: VerifyErrorKind::UnknownQuantifierType { ty: ty.clone() },
628                    trace: trace.clone(),
629                });
630            }
631            check_quantifier_types(body, known_types, trace, errors);
632        }
633        _ => {
634            for_each_child(expr, |child| {
635                check_quantifier_types(child, known_types, trace, errors);
636            });
637        }
638    }
639}
640
641/// Collect all variable names referenced in an expression.
642fn collect_vars(expr: &IrExpr) -> Vec<String> {
643    let mut vars = Vec::new();
644    collect_vars_inner(expr, &mut vars);
645    vars
646}
647
648fn collect_vars_inner(expr: &IrExpr, vars: &mut Vec<String>) {
649    match expr {
650        IrExpr::Var(name) => vars.push(name.clone()),
651        _ => for_each_child(expr, |child| collect_vars_inner(child, vars)),
652    }
653}
654
655/// Visit each immediate child of an IR expression.
656fn for_each_child(expr: &IrExpr, mut f: impl FnMut(&IrExpr)) {
657    match expr {
658        IrExpr::Compare { left, right, .. }
659        | IrExpr::Arithmetic { left, right, .. }
660        | IrExpr::And(left, right)
661        | IrExpr::Or(left, right)
662        | IrExpr::Implies(left, right) => {
663            f(left);
664            f(right);
665        }
666        IrExpr::Not(inner) | IrExpr::Old(inner) => f(inner),
667        IrExpr::FieldAccess { root, .. } => f(root),
668        IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => f(body),
669        IrExpr::Call { args, .. } => {
670            for arg in args {
671                f(arg);
672            }
673        }
674        IrExpr::Var(_) | IrExpr::Literal(_) => {}
675    }
676}