Skip to main content

intent_ir/
verify.rs

1//! IR verification pass.
2//!
3//! Validates structural and logical properties of the IR:
4//! - All variable references in expressions are bound (params or quantifier bindings)
5//!   (uppercase identifiers are treated as union variant labels and skipped)
6//! - `old()` only appears in postconditions or temporal invariants
7//! - Postconditions reference at least one parameter (otherwise they're trivially unverifiable)
8//! - Quantifiers reference known types (structs or functions) in this module
9//! - Functions with postconditions have at least one parameter (nothing to ensure about)
10//!
11//! Also performs coherence analysis:
12//! - Extracts verification obligations (invariant-action relationships)
13//! - Tracks which entity fields each action modifies (via `old()` in postconditions)
14//! - Matches modified fields against invariant constraints
15
16use std::collections::{HashMap, HashSet};
17
18use crate::types::*;
19
20/// A verification diagnostic.
21#[derive(Debug, Clone)]
22pub struct VerifyError {
23    pub kind: VerifyErrorKind,
24    pub trace: SourceTrace,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum VerifyErrorKind {
29    /// A variable is referenced but not bound as a parameter or quantifier binding.
30    UnboundVariable { name: String },
31    /// `old()` appears outside of a postcondition context.
32    OldOutsidePoscondition,
33    /// A function has postconditions but no parameters.
34    PostconditionWithoutParams { function: String },
35    /// A quantifier references a type not defined in this module (struct or function).
36    UnknownQuantifierType { ty: String },
37    /// A postcondition doesn't reference any function parameter.
38    DisconnectedPostcondition { function: String },
39}
40
41impl std::fmt::Display for VerifyError {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match &self.kind {
44            VerifyErrorKind::UnboundVariable { name } => {
45                write!(f, "unbound variable `{name}`")
46            }
47            VerifyErrorKind::OldOutsidePoscondition => {
48                write!(f, "`old()` used outside of postcondition")
49            }
50            VerifyErrorKind::PostconditionWithoutParams { function } => {
51                write!(
52                    f,
53                    "function `{function}` has postconditions but no parameters"
54                )
55            }
56            VerifyErrorKind::UnknownQuantifierType { ty } => {
57                write!(f, "quantifier references unknown type `{ty}`")
58            }
59            VerifyErrorKind::DisconnectedPostcondition { function } => {
60                write!(
61                    f,
62                    "postcondition in `{function}` doesn't reference any parameter"
63                )
64            }
65        }
66    }
67}
68
69/// Run verification checks on an IR module.
70pub fn verify_module(module: &Module) -> Vec<VerifyError> {
71    let mut errors = Vec::new();
72
73    // Known types for quantifiers: both structs and functions (actions).
74    let known_types: HashSet<&str> = module
75        .structs
76        .iter()
77        .map(|s| s.name.as_str())
78        .chain(module.functions.iter().map(|f| f.name.as_str()))
79        .collect();
80
81    // Collect all names that appear in Call positions — these are domain-level
82    // functions (now, lookup, etc.) and should be treated as implicitly bound.
83    // Also includes names the parser may lower as Var instead of Call (e.g., `now()`
84    // becomes Var("now") due to grammar limitations).
85    let mut call_names = HashSet::new();
86    collect_module_call_names(module, &mut call_names);
87
88    for func in &module.functions {
89        verify_function(func, &known_types, &call_names, &mut errors);
90    }
91
92    for inv in &module.invariants {
93        verify_invariant(inv, &known_types, &call_names, &mut errors);
94    }
95
96    for guard in &module.edge_guards {
97        verify_edge_guard(guard, &known_types, &mut errors);
98    }
99
100    errors
101}
102
103// ── Coherence analysis ─────────────────────────────────────
104
105/// A verification obligation — something that needs to be proven
106/// for the module to be correct.
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub struct Obligation {
109    /// The action (function) that triggers this obligation.
110    pub action: String,
111    /// The invariant that must be preserved.
112    pub invariant: String,
113    /// The entity type involved.
114    pub entity: String,
115    /// The specific fields that the action modifies and the invariant constrains.
116    pub fields: Vec<String>,
117    /// The kind of obligation.
118    pub kind: ObligationKind,
119}
120
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub enum ObligationKind {
123    /// Action modifies fields that an entity invariant constrains.
124    /// The invariant quantifies over the entity type (e.g., `forall a: Account => ...`).
125    InvariantPreservation,
126    /// A temporal invariant directly references this action via quantifier
127    /// (e.g., `forall t: Transfer => old(...) == ...`).
128    TemporalProperty,
129}
130
131impl std::fmt::Display for Obligation {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        match &self.kind {
134            ObligationKind::InvariantPreservation => {
135                write!(
136                    f,
137                    "{} modifies {}.{{{}}} (constrained by {})",
138                    self.action,
139                    self.entity,
140                    self.fields.join(", "),
141                    self.invariant,
142                )
143            }
144            ObligationKind::TemporalProperty => {
145                write!(
146                    f,
147                    "{} must satisfy temporal property {}",
148                    self.action, self.invariant,
149                )
150            }
151        }
152    }
153}
154
155/// Analyze a verified IR module for verification obligations.
156///
157/// Returns a list of obligations that describe what logical properties
158/// need to hold for the module to be correct. These are informational —
159/// not errors — representing proof goals a formal verifier would check.
160pub fn analyze_obligations(module: &Module) -> Vec<Obligation> {
161    let mut obligations = Vec::new();
162
163    // Build a map of struct name → field names for lookup.
164    let struct_fields: HashMap<&str, Vec<&str>> = module
165        .structs
166        .iter()
167        .map(|s| {
168            (
169                s.name.as_str(),
170                s.fields.iter().map(|f| f.name.as_str()).collect(),
171            )
172        })
173        .collect();
174
175    // Build a map of param name → entity type for each function.
176    // Only includes params whose type is an entity (struct).
177    let func_entity_params: HashMap<&str, Vec<(&str, &str)>> = module
178        .functions
179        .iter()
180        .map(|func| {
181            let entity_params: Vec<(&str, &str)> = func
182                .params
183                .iter()
184                .filter_map(|p| match &p.ty {
185                    IrType::Named(t) | IrType::Struct(t)
186                        if struct_fields.contains_key(t.as_str()) =>
187                    {
188                        Some((p.name.as_str(), t.as_str()))
189                    }
190                    _ => None,
191                })
192                .collect();
193            (func.name.as_str(), entity_params)
194        })
195        .collect();
196
197    // For each function, collect fields modified in postconditions (via old()).
198    // Result: function name → set of (entity_type, field_name).
199    let mut modified_fields: HashMap<&str, HashSet<(&str, &str)>> = HashMap::new();
200    for func in &module.functions {
201        let entity_params = &func_entity_params[func.name.as_str()];
202        let param_to_entity: HashMap<&str, &str> = entity_params.iter().copied().collect();
203        let mut fields = HashSet::new();
204        for post in &func.postconditions {
205            let exprs: Vec<&IrExpr> = match post {
206                Postcondition::Always { expr, .. } => vec![expr],
207                Postcondition::When { guard, expr, .. } => vec![guard, expr],
208            };
209            for expr in exprs {
210                collect_old_field_accesses(expr, &param_to_entity, &mut fields);
211            }
212        }
213        modified_fields.insert(func.name.as_str(), fields);
214    }
215
216    // For each invariant, determine what it constrains.
217    for inv in &module.invariants {
218        if let IrExpr::Forall { binding, ty, body } = &inv.expr {
219            // Check if this is a temporal invariant (quantifies over an action).
220            let is_action = module.functions.iter().any(|f| f.name == *ty);
221            if is_action {
222                // Temporal property: directly references an action.
223                obligations.push(Obligation {
224                    action: ty.clone(),
225                    invariant: inv.name.clone(),
226                    entity: ty.clone(),
227                    fields: vec![],
228                    kind: ObligationKind::TemporalProperty,
229                });
230                continue;
231            }
232
233            // Entity invariant: quantifies over an entity type.
234            // Collect fields the invariant constrains.
235            let constrained = collect_field_accesses_on(body, binding);
236
237            // Find all actions that modify any of these fields on this entity type.
238            for func in &module.functions {
239                if let Some(mods) = modified_fields.get(func.name.as_str()) {
240                    let overlapping: Vec<String> = constrained
241                        .iter()
242                        .filter(|f| mods.contains(&(ty.as_str(), f.as_str())))
243                        .cloned()
244                        .collect();
245                    if !overlapping.is_empty() {
246                        obligations.push(Obligation {
247                            action: func.name.clone(),
248                            invariant: inv.name.clone(),
249                            entity: ty.clone(),
250                            fields: overlapping,
251                            kind: ObligationKind::InvariantPreservation,
252                        });
253                    }
254                }
255            }
256        }
257    }
258
259    obligations
260}
261
262/// Collect field accesses inside `old()` expressions, mapping them to entity types
263/// via the param→entity mapping.
264///
265/// For an expression like `old(from.balance)`, if `from` maps to entity `Account`,
266/// this records `("Account", "balance")`.
267fn collect_old_field_accesses<'a>(
268    expr: &'a IrExpr,
269    param_to_entity: &HashMap<&str, &'a str>,
270    result: &mut HashSet<(&'a str, &'a str)>,
271) {
272    match expr {
273        IrExpr::Old(inner) => {
274            collect_inner_field_accesses(inner, param_to_entity, result);
275        }
276        _ => {
277            // Manual match for recursion (avoid lifetime issues with for_each_child closure)
278            match expr {
279                IrExpr::Compare { left, right, .. }
280                | IrExpr::Arithmetic { left, right, .. }
281                | IrExpr::And(left, right)
282                | IrExpr::Or(left, right)
283                | IrExpr::Implies(left, right) => {
284                    collect_old_field_accesses(left, param_to_entity, result);
285                    collect_old_field_accesses(right, param_to_entity, result);
286                }
287                IrExpr::Not(inner) => {
288                    collect_old_field_accesses(inner, param_to_entity, result);
289                }
290                IrExpr::FieldAccess { root, .. } => {
291                    collect_old_field_accesses(root, param_to_entity, result);
292                }
293                IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
294                    collect_old_field_accesses(body, param_to_entity, result);
295                }
296                IrExpr::Call { args, .. } => {
297                    for arg in args {
298                        collect_old_field_accesses(arg, param_to_entity, result);
299                    }
300                }
301                IrExpr::Var(_) | IrExpr::Literal(_) | IrExpr::Old(_) => {}
302            }
303        }
304    }
305}
306
307/// Collect field accesses within an old() body, resolving param names to entity types.
308fn collect_inner_field_accesses<'a>(
309    expr: &'a IrExpr,
310    param_to_entity: &HashMap<&str, &'a str>,
311    result: &mut HashSet<(&'a str, &'a str)>,
312) {
313    match expr {
314        IrExpr::FieldAccess { root, field } => {
315            // Check if root is a direct param reference: old(param.field)
316            if let IrExpr::Var(var) = root.as_ref()
317                && let Some(&entity) = param_to_entity.get(var.as_str())
318            {
319                result.insert((entity, field.as_str()));
320            }
321            // Also check for chained access: old(param.sub.field)
322            collect_inner_field_accesses(root, param_to_entity, result);
323        }
324        _ => match expr {
325            IrExpr::Compare { left, right, .. }
326            | IrExpr::Arithmetic { left, right, .. }
327            | IrExpr::And(left, right)
328            | IrExpr::Or(left, right)
329            | IrExpr::Implies(left, right) => {
330                collect_inner_field_accesses(left, param_to_entity, result);
331                collect_inner_field_accesses(right, param_to_entity, result);
332            }
333            IrExpr::Not(inner) | IrExpr::Old(inner) => {
334                collect_inner_field_accesses(inner, param_to_entity, result);
335            }
336            IrExpr::FieldAccess { .. } => unreachable!(),
337            IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
338                collect_inner_field_accesses(body, param_to_entity, result);
339            }
340            IrExpr::Call { args, .. } => {
341                for arg in args {
342                    collect_inner_field_accesses(arg, param_to_entity, result);
343                }
344            }
345            IrExpr::Var(_) | IrExpr::Literal(_) => {}
346        },
347    }
348}
349
350/// Collect field names accessed on a specific binding variable in an expression.
351///
352/// For `forall a: Account => a.balance >= 0`, calling this with binding="a"
353/// returns `["balance"]`.
354fn collect_field_accesses_on(expr: &IrExpr, binding: &str) -> Vec<String> {
355    let mut fields = Vec::new();
356    collect_fields_on_inner(expr, binding, &mut fields);
357    fields.sort();
358    fields.dedup();
359    fields
360}
361
362fn collect_fields_on_inner(expr: &IrExpr, binding: &str, fields: &mut Vec<String>) {
363    match expr {
364        IrExpr::FieldAccess { root, field } => {
365            if let IrExpr::Var(var) = root.as_ref()
366                && var == binding
367            {
368                fields.push(field.clone());
369            }
370            collect_fields_on_inner(root, binding, fields);
371        }
372        _ => for_each_child(expr, |child| {
373            collect_fields_on_inner(child, binding, fields)
374        }),
375    }
376}
377
378// ── Structural verification helpers ────────────────────────
379
380/// Collect all function names used in Call expressions across the module.
381fn collect_module_call_names<'a>(module: &'a Module, names: &mut HashSet<&'a str>) {
382    for func in &module.functions {
383        for pre in &func.preconditions {
384            collect_call_names(&pre.expr, names);
385        }
386        for post in &func.postconditions {
387            match post {
388                Postcondition::Always { expr, .. } => collect_call_names(expr, names),
389                Postcondition::When { guard, expr, .. } => {
390                    collect_call_names(guard, names);
391                    collect_call_names(expr, names);
392                }
393            }
394        }
395    }
396    for inv in &module.invariants {
397        collect_call_names(&inv.expr, names);
398    }
399    for guard in &module.edge_guards {
400        collect_call_names(&guard.condition, names);
401        for (_, arg) in &guard.args {
402            collect_call_names(arg, names);
403        }
404    }
405}
406
407fn collect_call_names<'a>(expr: &'a IrExpr, names: &mut HashSet<&'a str>) {
408    if let IrExpr::Call { name, args } = expr {
409        names.insert(name.as_str());
410        for arg in args {
411            collect_call_names(arg, names);
412        }
413        return;
414    }
415    match expr {
416        IrExpr::Compare { left, right, .. }
417        | IrExpr::Arithmetic { left, right, .. }
418        | IrExpr::And(left, right)
419        | IrExpr::Or(left, right)
420        | IrExpr::Implies(left, right) => {
421            collect_call_names(left, names);
422            collect_call_names(right, names);
423        }
424        IrExpr::Not(inner) | IrExpr::Old(inner) => collect_call_names(inner, names),
425        IrExpr::FieldAccess { root, .. } => collect_call_names(root, names),
426        IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => {
427            collect_call_names(body, names);
428        }
429        IrExpr::Var(_) | IrExpr::Literal(_) | IrExpr::Call { .. } => {}
430    }
431}
432
433fn verify_function(
434    func: &Function,
435    known_types: &HashSet<&str>,
436    call_names: &HashSet<&str>,
437    errors: &mut Vec<VerifyError>,
438) {
439    let param_names: HashSet<&str> = func.params.iter().map(|p| p.name.as_str()).collect();
440
441    // Check: postconditions without parameters
442    if !func.postconditions.is_empty() && func.params.is_empty() {
443        errors.push(VerifyError {
444            kind: VerifyErrorKind::PostconditionWithoutParams {
445                function: func.name.clone(),
446            },
447            trace: func.trace.clone(),
448        });
449    }
450
451    // Check preconditions: no old(), variables must be bound
452    for pre in &func.preconditions {
453        check_no_old(&pre.expr, &pre.trace, errors);
454        check_bound_vars(
455            &pre.expr,
456            &param_names,
457            &HashSet::new(),
458            call_names,
459            &pre.trace,
460            errors,
461        );
462    }
463
464    // Check postconditions: variables must be bound, check param references
465    for post in &func.postconditions {
466        let (expr, trace) = match post {
467            Postcondition::Always { expr, trace } => (expr, trace),
468            Postcondition::When { guard, expr, trace } => {
469                check_bound_vars(
470                    guard,
471                    &param_names,
472                    &HashSet::new(),
473                    call_names,
474                    trace,
475                    errors,
476                );
477                (expr, trace)
478            }
479        };
480        check_bound_vars(
481            expr,
482            &param_names,
483            &HashSet::new(),
484            call_names,
485            trace,
486            errors,
487        );
488
489        // Check postcondition references at least one parameter
490        let vars = collect_vars(expr);
491        if !vars.iter().any(|v| param_names.contains(v.as_str())) {
492            errors.push(VerifyError {
493                kind: VerifyErrorKind::DisconnectedPostcondition {
494                    function: func.name.clone(),
495                },
496                trace: trace.clone(),
497            });
498        }
499    }
500
501    // Check quantifier types in all expressions
502    for pre in &func.preconditions {
503        check_quantifier_types(&pre.expr, known_types, &pre.trace, errors);
504    }
505    for post in &func.postconditions {
506        match post {
507            Postcondition::Always { expr, trace } => {
508                check_quantifier_types(expr, known_types, trace, errors);
509            }
510            Postcondition::When {
511                guard, expr, trace, ..
512            } => {
513                check_quantifier_types(guard, known_types, trace, errors);
514                check_quantifier_types(expr, known_types, trace, errors);
515            }
516        }
517    }
518}
519
520fn verify_invariant(
521    inv: &Invariant,
522    known_types: &HashSet<&str>,
523    call_names: &HashSet<&str>,
524    errors: &mut Vec<VerifyError>,
525) {
526    // old() is valid in invariants for temporal properties (e.g., conservation laws)
527    check_quantifier_types(&inv.expr, known_types, &inv.trace, errors);
528    // Invariant body variables are bound by quantifiers, so we check with empty params
529    check_bound_vars(
530        &inv.expr,
531        &HashSet::new(),
532        &HashSet::new(),
533        call_names,
534        &inv.trace,
535        errors,
536    );
537}
538
539fn verify_edge_guard(
540    guard: &EdgeGuard,
541    known_types: &HashSet<&str>,
542    errors: &mut Vec<VerifyError>,
543) {
544    check_no_old(&guard.condition, &guard.trace, errors);
545    check_quantifier_types(&guard.condition, known_types, &guard.trace, errors);
546    for (_, arg_expr) in &guard.args {
547        check_no_old(arg_expr, &guard.trace, errors);
548    }
549}
550
551// ── Expression walkers ──────────────────────────────────────
552
553/// Check that `old()` does not appear in this expression.
554fn check_no_old(expr: &IrExpr, trace: &SourceTrace, errors: &mut Vec<VerifyError>) {
555    match expr {
556        IrExpr::Old(_) => {
557            errors.push(VerifyError {
558                kind: VerifyErrorKind::OldOutsidePoscondition,
559                trace: trace.clone(),
560            });
561        }
562        _ => {
563            for_each_child(expr, |child| check_no_old(child, trace, errors));
564        }
565    }
566}
567
568/// Check that all variable references are bound.
569fn check_bound_vars(
570    expr: &IrExpr,
571    params: &HashSet<&str>,
572    quantifier_bindings: &HashSet<&str>,
573    call_names: &HashSet<&str>,
574    trace: &SourceTrace,
575    errors: &mut Vec<VerifyError>,
576) {
577    match expr {
578        IrExpr::Var(name) => {
579            // Uppercase identifiers are union variant labels (Active, Frozen, etc.)
580            let is_variant = name.starts_with(|c: char| c.is_ascii_uppercase());
581            // Names that appear as function calls elsewhere are domain-level references
582            let is_call = call_names.contains(name.as_str());
583            if !is_variant
584                && !is_call
585                && !params.contains(name.as_str())
586                && !quantifier_bindings.contains(name.as_str())
587            {
588                errors.push(VerifyError {
589                    kind: VerifyErrorKind::UnboundVariable { name: name.clone() },
590                    trace: trace.clone(),
591                });
592            }
593        }
594        IrExpr::Forall { binding, body, .. } | IrExpr::Exists { binding, body, .. } => {
595            let mut extended = quantifier_bindings.clone();
596            extended.insert(binding.as_str());
597            check_bound_vars(body, params, &extended, call_names, trace, errors);
598        }
599        _ => {
600            for_each_child(expr, |child| {
601                check_bound_vars(
602                    child,
603                    params,
604                    quantifier_bindings,
605                    call_names,
606                    trace,
607                    errors,
608                );
609            });
610        }
611    }
612}
613
614/// Check that quantifier types reference known types (structs or functions).
615fn check_quantifier_types(
616    expr: &IrExpr,
617    known_types: &HashSet<&str>,
618    trace: &SourceTrace,
619    errors: &mut Vec<VerifyError>,
620) {
621    match expr {
622        IrExpr::Forall { ty, body, .. } | IrExpr::Exists { ty, body, .. } => {
623            if !known_types.contains(ty.as_str()) {
624                errors.push(VerifyError {
625                    kind: VerifyErrorKind::UnknownQuantifierType { ty: ty.clone() },
626                    trace: trace.clone(),
627                });
628            }
629            check_quantifier_types(body, known_types, trace, errors);
630        }
631        _ => {
632            for_each_child(expr, |child| {
633                check_quantifier_types(child, known_types, trace, errors);
634            });
635        }
636    }
637}
638
639/// Collect all variable names referenced in an expression.
640fn collect_vars(expr: &IrExpr) -> Vec<String> {
641    let mut vars = Vec::new();
642    collect_vars_inner(expr, &mut vars);
643    vars
644}
645
646fn collect_vars_inner(expr: &IrExpr, vars: &mut Vec<String>) {
647    match expr {
648        IrExpr::Var(name) => vars.push(name.clone()),
649        _ => for_each_child(expr, |child| collect_vars_inner(child, vars)),
650    }
651}
652
653/// Visit each immediate child of an IR expression.
654fn for_each_child(expr: &IrExpr, mut f: impl FnMut(&IrExpr)) {
655    match expr {
656        IrExpr::Compare { left, right, .. }
657        | IrExpr::Arithmetic { left, right, .. }
658        | IrExpr::And(left, right)
659        | IrExpr::Or(left, right)
660        | IrExpr::Implies(left, right) => {
661            f(left);
662            f(right);
663        }
664        IrExpr::Not(inner) | IrExpr::Old(inner) => f(inner),
665        IrExpr::FieldAccess { root, .. } => f(root),
666        IrExpr::Forall { body, .. } | IrExpr::Exists { body, .. } => f(body),
667        IrExpr::Call { args, .. } => {
668            for arg in args {
669                f(arg);
670            }
671        }
672        IrExpr::Var(_) | IrExpr::Literal(_) => {}
673    }
674}