smt-scope 0.1.7

A library for parsing and analysing SMT traces.
Documentation
use std::collections::hash_map::Entry;

#[cfg(feature = "mem_dbg")]
use mem_dbg::{MemDbg, MemSize};
use typed_index_collections::TiSlice;

use crate::{
    error::Either,
    items::{
        InstProofLink, InstantiationKind, Meaning, ProofIdx, ProofStep, ProofStepKind, QuantIdx,
        Quantifier, Term, TermId, TermIdToIdxMap, TermIdx, TermKind,
    },
    Error, FxHashMap, Result, StringTable, TiVec,
};

use super::bugs::TermsBug;

pub trait HasTermId {
    fn term_id(&self) -> TermId;
}

impl HasTermId for Term {
    fn term_id(&self) -> TermId {
        self.id
    }
}

impl HasTermId for ProofStep {
    fn term_id(&self) -> TermId {
        self.id
    }
}

#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
#[derive(Debug)]
pub struct TermStorage<K: From<usize> + Copy, V: HasTermId> {
    term_id_map: TermIdToIdxMap<K>,
    terms: TiVec<K, V>,
}

impl<K: From<usize> + Copy, V: HasTermId> Default for TermStorage<K, V> {
    fn default() -> Self {
        Self {
            term_id_map: TermIdToIdxMap::default(),
            terms: TiVec::default(),
        }
    }
}

impl<K: From<usize> + Copy, V: HasTermId> TermStorage<K, V> {
    pub(super) fn new_term(&mut self, term: V) -> Result<K> {
        self.terms.raw.try_reserve(1)?;
        let id = term.term_id();
        let idx = self.terms.push_and_get_key(term);
        self.term_id_map.register_term(id, idx)?;
        Ok(idx)
    }

    fn get_term(&self, term_id: TermId) -> Either<K, TermId> {
        self.term_id_map
            .get_term(&term_id)
            .map_or(Either::Right(term_id), Either::Left)
    }
    pub(super) fn parse_id(
        &self,
        strings: &mut StringTable,
        id: &str,
    ) -> Result<Either<K, TermId>> {
        let term_id = TermId::parse(strings, id)?;
        Ok(self.get_term(term_id))
    }
    pub(super) fn parse_existing_id(&self, strings: &mut StringTable, id: &str) -> Result<K> {
        self.parse_id(strings, id)?
            .into_result()
            .map_err(Error::UnknownId)
    }

    pub(super) fn terms(&self) -> &TiSlice<K, V> {
        &self.terms
    }

    /// Perform a top-down dfs walk of the AST rooted at `idx` calling `f` on
    /// each node. If `f` returns `None` then the walk is terminated early.
    /// Otherwise the walk is restricted to the children returned by `f`.
    pub fn ast_walk<T>(
        &self,
        idx: K,
        mut f: impl FnMut(K, &V) -> core::result::Result<&[K], T>,
    ) -> Option<T>
    where
        usize: From<K>,
    {
        let mut todo = vec![idx];
        while let Some(idx) = todo.pop() {
            let next = &self.terms[idx];
            match f(idx, next) {
                Ok(to_walk) => todo.extend_from_slice(to_walk),
                Err(t) => return Some(t),
            }
        }
        None
    }

    /// Perform a bottom-up dfs walk of the AST rooted at `idx` calling `d` on
    /// each node while walking down and `u` on each node while walking up.
    pub fn ast_walk_cached<'a, D, I>(
        &self,
        idx: K,
        state: I,
        cache: &'a mut FxHashMap<K, D>,
        mut d: impl for<'c> FnMut(K, &'c V, &I) -> (&'c [K], I),
        mut u: impl FnMut(K, &V, &FxHashMap<K, D>, &I) -> D,
    ) -> &'a D
    where
        usize: From<K>,
        K: Eq + core::hash::Hash,
    {
        fn trim<T: Eq + core::hash::Hash, V>(
            node: &mut core::slice::Iter<T>,
            cache: &FxHashMap<T, V>,
        ) {
            let slice = node.as_slice();
            let nk = slice.iter().position(|idx| !cache.contains_key(idx));
            *node = slice[nk.unwrap_or(slice.len())..].iter()
        }

        let mut todo = vec![(state, core::slice::from_ref(&idx).iter())];
        while let Some((mut state, mut node)) = todo.pop() {
            trim(&mut node, cache);
            while let Some(idx) = node.as_slice().first().copied() {
                let (new_node, new_state) = d(idx, &self.terms[idx], &state);
                todo.push((state, node));
                state = new_state;
                node = new_node.iter();
                trim(&mut node, cache);
            }

            let Some((state, node)) = todo.last_mut() else {
                break;
            };
            let idx = node.next().copied().unwrap();
            let next = &self.terms[idx];
            let data = u(idx, next, &*cache, &*state);
            let old = cache.insert(idx, data);
            assert!(old.is_none());
        }
        &cache[&idx]
    }
}

#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
#[derive(Debug, Default)]
pub struct Terms {
    pub(super) app_terms: TermStorage<TermIdx, Term>,
    pub(super) proof_terms: TermStorage<ProofIdx, ProofStep>,
    pub(super) named_asserts: NamedAsserts,

    meanings: FxHashMap<TermIdx, Meaning>,
    pub(super) bug: TermsBug,
}

impl Terms {
    pub(crate) fn meaning(&self, tidx: TermIdx) -> Option<&Meaning> {
        self.meanings.get(&tidx)
    }

    pub(super) fn get_instantiation_body(&self, inst: InstantiationKind) -> Option<TermIdx> {
        let proved_term = match inst {
            InstantiationKind::Axiom { body } => return Some(body),
            InstantiationKind::NonAxiom { proof, .. } => match proof {
                InstProofLink::HasProof(proof_idx) => {
                    let proof = &self[proof_idx];
                    if matches!(proof.kind, ProofStepKind::PR_QUANT_INST) {
                        proof.result
                    } else {
                        return Some(proof.result);
                    }
                }
                InstProofLink::ProofsDisabled(term_idx) => term_idx?,
            },
        };
        // The proved term is of the form `quant-inst(¬(quant) ∨ (inst))`.
        let proved_term = &self[proved_term];
        assert_eq!(proved_term.child_ids.len(), 2);
        Some(proved_term.child_ids[1])
    }

    pub(super) fn quant(&self, quant: TermIdx) -> Result<QuantIdx> {
        self[quant]
            .quant_idx()
            .ok_or(Error::UnknownQuantifierIdx(quant))
    }

    pub(super) fn new_meaning(&mut self, mut tidx: TermIdx, meaning: Meaning) -> Result<TermIdx> {
        self.meanings.try_reserve(1)?;
        match self.meanings.entry(tidx) {
            Entry::Occupied(old) => {
                if old.get() != &meaning {
                    let term = self.app_terms.terms[tidx].clone();
                    tidx = self.app_terms.new_term(term)?;
                    self.meanings.insert(tidx, meaning);
                }
            }
            Entry::Vacant(empty) => {
                empty.insert(meaning);
            }
        };
        Ok(tidx)
    }

    /// Perform a top-down walk of the AST rooted at `tidx` calling `f` on each
    /// node of kind `App` and walking the children that are returned. `Quant`
    /// and `Var` nodes are skipped.
    pub(super) fn app_walk<T>(
        &self,
        tidx: TermIdx,
        mut f: impl FnMut(TermIdx, &Term) -> core::result::Result<&[TermIdx], T>,
    ) -> core::result::Result<(), T> {
        self.app_terms
            .ast_walk(tidx, |tidx, term| match term.kind() {
                TermKind::Var(_) => Ok(&[]),
                TermKind::App(_) => f(tidx, term),
                TermKind::Quant(_) => Ok(&[]),
            })
            .map_or(Ok(()), Err)
    }

    /// Heuristic to get body of instantiated quantifier. See documentation of
    /// [`InstProofLink::ProofsDisabled`].
    pub(super) fn last_term_from_instance(&self, strings: &StringTable) -> Option<TermIdx> {
        let last_non_eq = self
            .app_terms
            .terms
            .iter_enumerated()
            .rev()
            .find(|(_, term)| term.app_name().is_none_or(|name| &strings[*name] != "="));
        let last_term = last_non_eq.filter(|(_, term)| {
            term.app_name().is_some_and(|name| &strings[*name] == "or")
                && term.child_ids.len() == 2
                && {
                    let neg_quant = &self[term.child_ids[0]];
                    neg_quant
                        .app_name()
                        .is_some_and(|name| &strings[*name] == "not")
                        && neg_quant.child_ids.len() == 1
                        && self[neg_quant.child_ids[0]].quant_idx().is_some()
                }
        });
        debug_assert!(
            last_term.is_some(),
            "{:?}",
            last_non_eq.map(|(_, t)| t.app_name().map(|n| &strings[*n]))
        );
        last_term.map(|(idx, _)| idx)
    }

    pub fn is_true_const(&self, tidx: TermIdx) -> bool {
        let id = self[tidx].id;
        id.namespace.is_none() && id.id.is_some_and(|id| id.get() == 1)
    }
    pub fn is_false_const(&self, tidx: TermIdx) -> bool {
        let id = self[tidx].id;
        id.namespace.is_none() && id.id.is_some_and(|id| id.get() == 2)
    }
    pub fn is_bool_const(&self, tidx: TermIdx) -> bool {
        self.is_true_const(tidx) || self.is_false_const(tidx)
    }

    /// Used only to give access to the `app_terms` field in `bugs.rs`.
    pub(super) fn get_app_term_bug(&self, term_id: TermId) -> Result<TermIdx> {
        self.app_terms
            .get_term(term_id)
            .into_result()
            .map_err(Error::UnknownId)
    }

    pub(super) fn new_proof(
        &mut self,
        quants: &mut TiVec<QuantIdx, Quantifier>,
        pidx: ProofIdx,
        strings: &StringTable,
    ) -> Result<()> {
        let proof = &self[pidx];
        if !proof.kind.is_asserted() {
            return Ok(());
        }
        let ridx = proof.result;

        // Mark any quantifiers in the asserted term as blaming this proof step.
        self.app_terms.ast_walk::<super::Never>(ridx, |_, term| {
            if term.has_var().is_some() {
                return Ok(&[]);
            }
            if let TermKind::Quant(qidx) = term.kind() {
                // TODO: how to handle multiple blames here? Currently we use
                // the latest one only.
                // debug_assert!(quants[qidx].blame.is_none());
                quants[qidx].blame = Some(pidx);
            }
            Ok(&term.child_ids)
        });

        let result = &self[ridx];
        match result.child_ids.len() {
            0 => {
                if let Some(implication) = self.named_asserts.seen.remove(&ridx) {
                    self.named_asserts.named.try_reserve(2)?;
                    let old = self.named_asserts.named.insert(implication, Some(pidx));
                    debug_assert!(old.is_none());
                    let old = self.named_asserts.named.insert(pidx, None);
                    debug_assert!(old.is_none());
                }
            }
            2 if result
                .app_name()
                .is_some_and(|name| &strings[*name] == "=>") =>
            {
                let lidx = result.child_ids[0];
                let lhs = &self[lidx];
                if lhs.app_name().is_some() && lhs.child_ids.is_empty() {
                    self.named_asserts.seen.try_reserve(1)?;
                    self.named_asserts.seen.insert(lidx, pidx);
                }
            }
            _ => (),
        };
        Ok(())
    }
}

/// Named assertions are required to get an unsat core. They are given to z3 as
/// follows: `(assert (! my_assertion :named my_name))`. This is turned into:
/// ```smt2
/// (declare-const my_name Bool)
/// (assert (=> my_name my_assertion))
/// (assert my_name)
/// ```
/// Unfortunately from the log we cannot differentiate between the user writing
/// the above directly or using `:named`, so we treat both the same (unlike z3
/// which differentiates the two).
///
/// When we see the first line:\
/// `[mk-proof] #_ asserted (=> my_name #_)`\
/// we store it under `seen`. Then when we see the second line:\
/// `[mk-proof] #_ asserted my_name`\
/// we look it up in `seen` and move it to `named`.
#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
#[derive(Debug, Default)]
pub struct NamedAsserts {
    pub seen: FxHashMap<TermIdx, ProofIdx>,
    /// Mapping from `assertion -> Some(name)` or `name -> None`. That is,
    /// if a `ProofIdx` is in this map it's either a named assertion or the
    /// boolean name variable itself. To differentiate the two, the named
    /// assertion variable has an entry in the hashmap with `Some(name)`, while
    /// the name variable has an entry with `None`.
    pub named: FxHashMap<ProofIdx, Option<ProofIdx>>,
}

impl std::ops::Index<TermIdx> for Terms {
    type Output = Term;
    fn index(&self, idx: TermIdx) -> &Self::Output {
        &self.app_terms.terms[idx]
    }
}

impl std::ops::IndexMut<TermIdx> for Terms {
    fn index_mut(&mut self, idx: TermIdx) -> &mut Self::Output {
        &mut self.app_terms.terms[idx]
    }
}

impl std::ops::Index<ProofIdx> for Terms {
    type Output = ProofStep;
    fn index(&self, idx: ProofIdx) -> &Self::Output {
        &self.proof_terms.terms[idx]
    }
}

impl std::ops::IndexMut<ProofIdx> for Terms {
    fn index_mut(&mut self, idx: ProofIdx) -> &mut Self::Output {
        &mut self.proof_terms.terms[idx]
    }
}