tla-checker 0.3.9

A TLA+ model checker written in Rust
Documentation
use std::cell::RefCell;
use std::collections::{BTreeMap, BTreeSet};
use std::sync::Arc;

use crate::ast::{Env, Expr, Value};
use crate::checker::format_value;

use super::Definitions;
use super::core::eval;
use super::error::{EvalError, Result, value_type_name};

pub(crate) fn eval_fn_def_recursive(
    fn_name: &Arc<str>,
    param: &Arc<str>,
    domain: &[Value],
    body: &Expr,
    env: &mut Env,
    defs: &Definitions,
) -> Result<BTreeMap<Value, Value>> {
    let memo: RefCell<BTreeMap<Value, Value>> = RefCell::new(BTreeMap::new());

    let prev = env.remove(param);
    for val in domain.iter() {
        if memo.borrow().contains_key(val) {
            continue;
        }
        env.insert(param.clone(), val.clone());
        let result = eval_with_memo(body, env, defs, fn_name, param, domain, &memo)?;
        memo.borrow_mut().insert(val.clone(), result);
    }
    match prev {
        Some(old) => {
            env.insert(param.clone(), old);
        }
        None => {
            env.remove(param);
        }
    }

    Ok(memo.into_inner())
}

#[allow(clippy::only_used_in_recursion)]
pub(crate) fn eval_with_memo(
    expr: &Expr,
    env: &mut Env,
    defs: &Definitions,
    fn_name: &Arc<str>,
    fn_param: &Arc<str>,
    fn_domain: &[Value],
    memo: &RefCell<BTreeMap<Value, Value>>,
) -> Result<Value> {
    match expr {
        Expr::FnApp(f, arg) => {
            if let Expr::Var(name) = f.as_ref()
                && name == fn_name
            {
                let av = eval_with_memo(arg, env, defs, fn_name, fn_param, fn_domain, memo)?;
                if let Some(v) = memo.borrow().get(&av) {
                    return Ok(v.clone());
                }
                if let Some((params, fn_body)) = defs.get(name)
                    && params.is_empty()
                    && let Expr::FnDef(p, _, body) = fn_body
                {
                    let prev_p = env.insert(p.clone(), av.clone());
                    let result =
                        eval_with_memo(body, env, defs, fn_name, fn_param, fn_domain, memo)?;
                    match prev_p {
                        Some(old) => {
                            env.insert(p.clone(), old);
                        }
                        None => {
                            env.remove(p);
                        }
                    }
                    memo.borrow_mut().insert(av, result.clone());
                    return Ok(result);
                }
            }
            let fval = eval_with_memo(f, env, defs, fn_name, fn_param, fn_domain, memo)?;
            let av = eval_with_memo(arg, env, defs, fn_name, fn_param, fn_domain, memo)?;
            match fval {
                Value::Fn(fv) => fv.get(&av).cloned().ok_or_else(|| {
                    EvalError::domain_error(format!(
                        "key {} not in function domain",
                        format_value(&av)
                    ))
                }),
                Value::Tuple(tv) => {
                    if let Value::Int(idx) = av {
                        let i = idx as usize;
                        if i >= 1 && i <= tv.len() {
                            Ok(tv[i - 1].clone())
                        } else {
                            Err(EvalError::domain_error(format!(
                                "sequence index {} out of bounds (sequence has {} elements)",
                                idx,
                                tv.len()
                            )))
                        }
                    } else {
                        Err(EvalError::TypeMismatch {
                            expected: "Int",
                            got: av,
                            context: Some("sequence index"),
                            span: None,
                        })
                    }
                }
                other => Err(EvalError::type_mismatch_ctx(
                    "Fn or Sequence",
                    other,
                    "function application",
                )),
            }
        }

        Expr::Let(var, binding, body) => {
            let mut local_defs = defs.clone();
            local_defs.insert(var.clone(), (vec![], (**binding).clone()));
            eval_with_memo(body, env, &local_defs, fn_name, fn_param, fn_domain, memo)
        }

        Expr::If(cond, then_br, else_br) => {
            let cv = eval_with_memo(cond, env, defs, fn_name, fn_param, fn_domain, memo)?;
            match cv {
                Value::Bool(true) => {
                    eval_with_memo(then_br, env, defs, fn_name, fn_param, fn_domain, memo)
                }
                Value::Bool(false) => {
                    eval_with_memo(else_br, env, defs, fn_name, fn_param, fn_domain, memo)
                }
                _ => Err(EvalError::type_mismatch_ctx("Bool", cv, "IF condition")),
            }
        }

        Expr::Add(l, r) => {
            let lv = eval_with_memo(l, env, defs, fn_name, fn_param, fn_domain, memo)?;
            let rv = eval_with_memo(r, env, defs, fn_name, fn_param, fn_domain, memo)?;
            match (lv, rv) {
                (Value::Int(a), Value::Int(b)) => Ok(Value::Int(a + b)),
                (a, b) => Err(EvalError::domain_error(format!(
                    "cannot add {} and {} (expected Int + Int)",
                    value_type_name(&a),
                    value_type_name(&b)
                ))),
            }
        }

        Expr::Eq(l, r) => {
            let lv = eval_with_memo(l, env, defs, fn_name, fn_param, fn_domain, memo)?;
            let rv = eval_with_memo(r, env, defs, fn_name, fn_param, fn_domain, memo)?;
            Ok(Value::Bool(lv == rv))
        }

        Expr::SetMinus(l, r) => {
            let lv = eval_with_memo(l, env, defs, fn_name, fn_param, fn_domain, memo)?;
            let rv = eval_with_memo(r, env, defs, fn_name, fn_param, fn_domain, memo)?;
            match (lv, rv) {
                (Value::Set(a), Value::Set(b)) => {
                    Ok(Value::Set(a.difference(&b).cloned().collect()))
                }
                (a, b) => Err(EvalError::domain_error(format!(
                    "set minus requires Set \\ Set, got {} \\ {}",
                    value_type_name(&a),
                    value_type_name(&b)
                ))),
            }
        }

        Expr::SetEnum(elems) => {
            let mut result = BTreeSet::new();
            for e in elems {
                result.insert(eval_with_memo(
                    e, env, defs, fn_name, fn_param, fn_domain, memo,
                )?);
            }
            Ok(Value::Set(result))
        }

        Expr::Var(name) => {
            if let Some(val) = env.get(name) {
                return Ok(val.clone());
            }
            if let Some((params, body)) = defs.get(name)
                && params.is_empty()
            {
                return eval_with_memo(body, env, defs, fn_name, fn_param, fn_domain, memo);
            }
            Err(EvalError::undefined_var_with_env(name.clone(), env, defs))
        }

        _ => eval(expr, env, defs),
    }
}