tla-checker 0.3.9

A TLA+ model checker written in Rust
Documentation
use super::Definitions;
use super::ast_utils::{
    collect_conjuncts, collect_disjuncts_with_labels, contains_prime_ref, format_expr_brief,
    infer_action_name,
};
use super::candidates::infer_all_candidates;
use super::error::Result;
#[cfg(feature = "profiling")]
use super::global_state::PROFILING_STATS;
use super::helpers::{eval_bool, eval_set};
use super::state::env_to_next_state;
use crate::ast::{Env, Expr, GuardEval, Transition, Value};
use std::sync::Arc;
#[cfg(feature = "profiling")]
use std::time::Instant;

pub(crate) fn extract_guards_for_action(
    next: &Expr,
    env: &mut Env,
    defs: &Definitions,
    action: Option<&Arc<str>>,
) -> Result<Vec<GuardEval>> {
    let disjuncts = collect_disjuncts_with_labels(next, defs);

    for (disjunct, label) in &disjuncts {
        let matches = match (action, label) {
            (Some(a), Some(l)) => a == l,
            (None, None) => true,
            _ => false,
        };

        if matches {
            return extract_guards_from_expr(disjunct, env, defs);
        }
    }

    extract_guards_from_expr(next, env, defs)
}

fn extract_guards_from_expr(
    expr: &Expr,
    env: &mut Env,
    defs: &Definitions,
) -> Result<Vec<GuardEval>> {
    let mut guards = Vec::new();
    let conjuncts = collect_conjuncts(expr);

    for conjunct in conjuncts {
        if !contains_prime_ref(conjunct, defs) {
            let expr_str = format_expr_brief(conjunct);
            let result = eval_bool(conjunct, env, defs).unwrap_or(false);

            guards.push(GuardEval {
                expression: expr_str,
                result,
                bindings: Vec::new(),
            });
        }
    }

    Ok(guards)
}

pub(crate) fn next_states_impl(
    next: &Expr,
    base_env: &mut Env,
    vars: &[Arc<str>],
    primed_vars: &[Arc<str>],
    defs: &Definitions,
) -> Result<Vec<Transition>> {
    let ctx = EnumCtx {
        vars,
        primed_vars,
        defs,
    };
    let effective = resolve_next(next, defs);
    if let Expr::Exists(_, _, _) = effective {
        let action = infer_action_name(effective, defs);
        let mut all_results = indexmap::IndexSet::new();
        let mut results = Vec::new();
        expand_and_enumerate(effective, base_env, &ctx, action, &mut results)?;
        for transition in results {
            all_results.insert(transition);
        }
        return Ok(all_results.into_iter().collect());
    }
    if let Expr::Or(_, _) = effective {
        let disjuncts = collect_disjuncts_with_labels(effective, defs);
        let mut all_results = indexmap::IndexSet::new();
        for (disjunct, action) in &disjuncts {
            if let Expr::Exists(_, _, _) = disjunct {
                let mut results = Vec::new();
                expand_and_enumerate(disjunct, base_env, &ctx, action.clone(), &mut results)?;
                for transition in results {
                    all_results.insert(transition);
                }
            } else {
                let mut results = Vec::new();
                enumerate_next(disjunct, base_env, &ctx, action.clone(), &mut results)?;
                for transition in results {
                    all_results.insert(transition);
                }
            }
        }
        return Ok(all_results.into_iter().collect());
    }

    let action = infer_action_name(effective, defs);
    let mut results = Vec::new();
    enumerate_next(effective, base_env, &ctx, action, &mut results)?;
    Ok(results)
}

fn resolve_next<'a>(expr: &'a Expr, defs: &'a Definitions) -> &'a Expr {
    if let Expr::FnCall(name, args) = expr
        && args.is_empty()
        && let Some((params, body)) = defs.get(name)
        && params.is_empty()
    {
        return resolve_next(body, defs);
    }
    expr
}

fn expand_and_enumerate(
    expr: &Expr,
    env: &mut Env,
    ctx: &EnumCtx<'_>,
    action: Option<Arc<str>>,
    results: &mut Vec<Transition>,
) -> Result<()> {
    match expr {
        Expr::Exists(var, domain, body) => {
            let dom = eval_set(domain, env, ctx.defs)?;
            let var = var.clone();
            for val in dom {
                env.insert(var.clone(), val);
                expand_and_enumerate(body, env, ctx, action.clone(), results)?;
            }
            env.remove(&var);
            Ok(())
        }
        Expr::Or(_, _) => {
            let disjuncts = collect_disjuncts_with_labels(expr, ctx.defs);
            for (disjunct, sub_action) in &disjuncts {
                let effective_action = sub_action.clone().or(action.clone());
                if let Expr::Exists(_, _, _) = disjunct {
                    expand_and_enumerate(disjunct, env, ctx, effective_action, results)?;
                } else {
                    enumerate_next(disjunct, env, ctx, effective_action, results)?;
                }
            }
            Ok(())
        }
        _ => enumerate_next(expr, env, ctx, action, results),
    }
}

fn evaluate_guards(expr: &Expr, env: &mut Env, defs: &Definitions) -> Result<bool> {
    for conjunct in collect_conjuncts(expr) {
        if !contains_prime_ref(conjunct, defs) && !eval_bool(conjunct, env, defs)? {
            return Ok(false);
        }
    }
    Ok(true)
}

struct EnumCtx<'a> {
    vars: &'a [Arc<str>],
    primed_vars: &'a [Arc<str>],
    defs: &'a Definitions,
}

fn enumerate_next(
    next: &Expr,
    env: &mut Env,
    ctx: &EnumCtx<'_>,
    action: Option<Arc<str>>,
    results: &mut Vec<Transition>,
) -> Result<()> {
    #[cfg(feature = "profiling")]
    let _start = Instant::now();

    let result = enumerate_next_with_refinement(next, env, ctx, action, results);

    #[cfg(feature = "profiling")]
    PROFILING_STATS.with(|s| {
        let mut stats = s.borrow_mut();
        stats.enumerate_next_time_ns += _start.elapsed().as_nanos();
        stats.enumerate_next_calls += 1;
    });

    result
}

fn enumerate_next_with_refinement(
    next: &Expr,
    env: &mut Env,
    ctx: &EnumCtx<'_>,
    action: Option<Arc<str>>,
    results: &mut Vec<Transition>,
) -> Result<()> {
    if !evaluate_guards(next, env, ctx.defs)? {
        return Ok(());
    }

    let mut all_candidates = infer_all_candidates(next, env, ctx.vars, ctx.defs)?;

    for (i, primed) in ctx.primed_vars.iter().enumerate() {
        if let Some(first) = all_candidates[i].first() {
            env.insert(primed.clone(), first.clone());
        }
    }

    let mut changed = true;
    let mut iterations = 0;
    while changed && iterations < 3 {
        changed = false;
        iterations += 1;

        let new_all = infer_all_candidates(next, env, ctx.vars, ctx.defs)?;
        for (i, new_candidates) in new_all.into_iter().enumerate() {
            if new_candidates != all_candidates[i] {
                all_candidates[i] = new_candidates;
                changed = true;
                if let Some(first) = all_candidates[i].first() {
                    env.insert(ctx.primed_vars[i].clone(), first.clone());
                }
            }
        }
    }

    for primed in ctx.primed_vars {
        env.remove(primed);
    }

    enumerate_combinations(next, env, ctx, 0, &all_candidates, &action, results)
}

fn enumerate_combinations(
    next: &Expr,
    env: &mut Env,
    ctx: &EnumCtx<'_>,
    idx: usize,
    all_candidates: &[Vec<Value>],
    action: &Option<Arc<str>>,
    results: &mut Vec<Transition>,
) -> Result<()> {
    if idx >= ctx.vars.len() {
        if eval_bool(next, env, ctx.defs)? {
            results.push(Transition {
                state: env_to_next_state(env, ctx.vars, ctx.primed_vars),
                action: action.clone(),
            });
        }
        return Ok(());
    }

    let primed = &ctx.primed_vars[idx];
    for candidate in &all_candidates[idx] {
        env.insert(primed.clone(), candidate.clone());
        enumerate_combinations(next, env, ctx, idx + 1, all_candidates, action, results)?;
    }
    env.remove(primed);

    Ok(())
}