use crate::chc::{PredId, RuleId};
use oxiz_core::TermId;
use rustc_hash::FxHashSet;
use smallvec::SmallVec;
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ReachFactId(pub u32);
impl ReachFactId {
#[inline]
#[must_use]
pub const fn new(id: u32) -> Self {
Self(id)
}
#[inline]
#[must_use]
pub const fn raw(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone)]
pub struct ReachFact {
pub id: ReachFactId,
pub pred: PredId,
pub fact: TermId,
aux_vars: SmallVec<[TermId; 4]>,
rule: RuleId,
justification: SmallVec<[ReachFactId; 2]>,
tag: Option<TermId>,
is_init: bool,
}
impl ReachFact {
pub fn new(id: ReachFactId, pred: PredId, fact: TermId, rule: RuleId, is_init: bool) -> Self {
Self {
id,
pred,
fact,
aux_vars: SmallVec::new(),
rule,
justification: SmallVec::new(),
tag: None,
is_init,
}
}
#[inline]
#[must_use]
pub fn is_init(&self) -> bool {
self.is_init
}
#[inline]
#[must_use]
pub fn rule(&self) -> RuleId {
self.rule
}
pub fn add_justification(&mut self, fact: ReachFactId) {
self.justification.push(fact);
}
pub fn justifications(&self) -> &[ReachFactId] {
&self.justification
}
pub fn aux_vars(&self) -> &[TermId] {
&self.aux_vars
}
pub fn set_aux_vars(&mut self, vars: impl IntoIterator<Item = TermId>) {
self.aux_vars = vars.into_iter().collect();
}
#[must_use]
pub fn tag(&self) -> Option<TermId> {
self.tag
}
pub fn set_tag(&mut self, tag: TermId) {
self.tag = Some(tag);
}
}
#[derive(Debug)]
pub struct ReachFactStore {
facts: Vec<ReachFact>,
by_pred: rustc_hash::FxHashMap<PredId, SmallVec<[ReachFactId; 8]>>,
init_facts: SmallVec<[ReachFactId; 4]>,
next_id: AtomicU32,
}
impl Default for ReachFactStore {
fn default() -> Self {
Self::new()
}
}
impl ReachFactStore {
pub fn new() -> Self {
Self {
facts: Vec::new(),
by_pred: rustc_hash::FxHashMap::default(),
init_facts: SmallVec::new(),
next_id: AtomicU32::new(0),
}
}
pub fn add(&mut self, pred: PredId, fact: TermId, rule: RuleId, is_init: bool) -> ReachFactId {
let id = ReachFactId(self.next_id.fetch_add(1, Ordering::Relaxed));
let reach_fact = ReachFact::new(id, pred, fact, rule, is_init);
self.facts.push(reach_fact);
self.by_pred.entry(pred).or_default().push(id);
if is_init {
self.init_facts.push(id);
}
id
}
#[must_use]
pub fn get(&self, id: ReachFactId) -> Option<&ReachFact> {
self.facts.get(id.0 as usize)
}
pub fn get_mut(&mut self, id: ReachFactId) -> Option<&mut ReachFact> {
self.facts.get_mut(id.0 as usize)
}
pub fn for_pred(&self, pred: PredId) -> impl Iterator<Item = &ReachFact> {
self.by_pred
.get(&pred)
.into_iter()
.flat_map(|ids| ids.iter())
.filter_map(|&id| self.get(id))
}
pub fn init_facts(&self) -> impl Iterator<Item = &ReachFact> {
self.init_facts.iter().filter_map(|&id| self.get(id))
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.facts.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.facts.is_empty()
}
pub fn clear(&mut self) {
self.facts.clear();
self.by_pred.clear();
self.init_facts.clear();
self.next_id = AtomicU32::new(0);
}
}
#[derive(Debug, Clone)]
pub struct Counterexample {
states: Vec<CexState>,
spurious: bool,
}
#[derive(Debug, Clone)]
pub struct CexState {
pub pred: PredId,
pub state: TermId,
pub rule: Option<RuleId>,
pub assignments: SmallVec<[(TermId, TermId); 4]>,
}
impl Counterexample {
pub fn new() -> Self {
Self {
states: Vec::new(),
spurious: false,
}
}
pub fn push(&mut self, state: CexState) {
self.states.push(state);
}
pub fn states(&self) -> &[CexState] {
&self.states
}
#[must_use]
pub fn len(&self) -> usize {
self.states.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
pub fn mark_spurious(&mut self) {
self.spurious = true;
}
#[must_use]
pub fn is_spurious(&self) -> bool {
self.spurious
}
pub fn reverse(&mut self) {
self.states.reverse();
}
}
impl Default for Counterexample {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct Projection {
pub formula: TermId,
pub projected_vars: SmallVec<[TermId; 4]>,
pub aux_vars: SmallVec<[TermId; 4]>,
}
impl Projection {
pub fn new(formula: TermId) -> Self {
Self {
formula,
projected_vars: SmallVec::new(),
aux_vars: SmallVec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct Generalization {
pub cube: SmallVec<[TermId; 8]>,
pub inductive: bool,
pub dropped: SmallVec<[TermId; 4]>,
}
impl Generalization {
pub fn new(cube: impl IntoIterator<Item = TermId>) -> Self {
Self {
cube: cube.into_iter().collect(),
inductive: false,
dropped: SmallVec::new(),
}
}
pub fn mark_inductive(&mut self) {
self.inductive = true;
}
pub fn drop_literal(&mut self, lit: TermId) {
self.dropped.push(lit);
}
}
pub trait ReachabilityChecker {
fn is_reachable(&self, pred: PredId, state: TermId) -> bool;
fn can_transition(&self, from_pred: PredId, from_state: TermId, rule: RuleId) -> bool;
fn reach_facts(&self, pred: PredId) -> Vec<TermId>;
}
#[derive(Debug)]
pub struct UnderApproximation {
states: rustc_hash::FxHashMap<PredId, FxHashSet<TermId>>,
}
impl Default for UnderApproximation {
fn default() -> Self {
Self::new()
}
}
impl UnderApproximation {
pub fn new() -> Self {
Self {
states: rustc_hash::FxHashMap::default(),
}
}
pub fn add(&mut self, pred: PredId, state: TermId) {
self.states.entry(pred).or_default().insert(state);
}
#[must_use]
pub fn contains(&self, pred: PredId, state: TermId) -> bool {
self.states.get(&pred).is_some_and(|s| s.contains(&state))
}
pub fn states(&self, pred: PredId) -> impl Iterator<Item = TermId> + '_ {
self.states
.get(&pred)
.into_iter()
.flat_map(|s| s.iter().copied())
}
pub fn clear(&mut self) {
self.states.clear();
}
}
#[derive(Debug)]
pub struct OverApproximation {
blocked: rustc_hash::FxHashMap<PredId, rustc_hash::FxHashMap<u32, Vec<TermId>>>,
}
impl Default for OverApproximation {
fn default() -> Self {
Self::new()
}
}
impl OverApproximation {
pub fn new() -> Self {
Self {
blocked: rustc_hash::FxHashMap::default(),
}
}
pub fn add_blocked(&mut self, pred: PredId, level: u32, lemma: TermId) {
self.blocked
.entry(pred)
.or_default()
.entry(level)
.or_default()
.push(lemma);
}
pub fn blocked_at(&self, pred: PredId, level: u32) -> impl Iterator<Item = TermId> + '_ {
self.blocked.get(&pred).into_iter().flat_map(move |levels| {
levels
.iter()
.filter(move |&(l, _)| *l >= level)
.flat_map(|(_, lemmas)| lemmas.iter().copied())
})
}
pub fn clear(&mut self) {
self.blocked.clear();
}
}
pub struct ConcreteWitnessExtractor<'a> {
terms: &'a mut oxiz_core::TermManager,
system: &'a crate::chc::ChcSystem,
}
impl<'a> ConcreteWitnessExtractor<'a> {
pub fn new(terms: &'a mut oxiz_core::TermManager, system: &'a crate::chc::ChcSystem) -> Self {
Self { terms, system }
}
pub fn extract_witness(
&mut self,
cex: &Counterexample,
models: &[crate::smt::Model],
) -> Result<ConcreteWitness, WitnessError> {
use tracing::debug;
debug!("Extracting concrete witness from {} states", cex.len());
if cex.is_empty() {
return Err(WitnessError::EmptyTrace);
}
if models.len() != cex.len() {
return Err(WitnessError::ModelMismatch {
expected: cex.len(),
got: models.len(),
});
}
let mut witness = ConcreteWitness::new();
for (i, (state, model)) in cex.states().iter().zip(models.iter()).enumerate() {
let assignments = self.extract_assignments(state.pred, model)?;
let concrete_state = ConcreteState {
step: i,
pred: state.pred,
state: state.state,
rule: state.rule,
assignments: assignments.clone(),
model_values: self.model_to_values(model),
};
witness.add_state(concrete_state);
debug!(
"Extracted witness state {}: {} assignments",
i,
assignments.len()
);
}
Ok(witness)
}
fn extract_assignments(
&mut self,
pred: PredId,
model: &crate::smt::Model,
) -> Result<Vec<(TermId, TermId)>, WitnessError> {
let mut assignments = Vec::new();
let pred_info = self
.system
.get_predicate(pred)
.ok_or(WitnessError::PredicateNotFound(pred))?;
for (idx, &sort) in pred_info.params.iter().enumerate() {
if let Some(value) = model.get(idx) {
let var_name = format!("{}_{}", pred_info.name, idx);
let var = self.terms.mk_var(&var_name, sort);
assignments.push((var, value));
}
}
Ok(assignments)
}
fn model_to_values(&self, model: &crate::smt::Model) -> Vec<ConcreteValue> {
use oxiz_core::TermKind;
model
.assignments()
.iter()
.filter_map(|&term_id| {
let term = self.terms.get(term_id)?;
match &term.kind {
TermKind::IntConst(n) => Some(ConcreteValue::Int(n.clone())),
TermKind::RealConst(r) => {
use num_bigint::BigInt;
Some(ConcreteValue::Real(
BigInt::from(*r.numer()),
BigInt::from(*r.denom()),
))
}
TermKind::True => Some(ConcreteValue::Bool(true)),
TermKind::False => Some(ConcreteValue::Bool(false)),
TermKind::BitVecConst { value, width } => {
Some(ConcreteValue::BitVec(value.clone(), *width))
}
_ => None,
}
})
.collect()
}
pub fn validate_witness(&self, witness: &ConcreteWitness) -> Result<bool, WitnessError> {
use tracing::trace;
trace!("Validating witness with {} states", witness.states.len());
if let Some(first) = witness.states.first()
&& first.step != 0
{
return Ok(false);
}
for i in 1..witness.states.len() {
let prev = &witness.states[i - 1];
let curr = &witness.states[i];
if let Some(rule_id) = curr.rule
&& self.system.get_rule(rule_id).is_none()
{
return Ok(false);
}
trace!(
"Validated transition {} -> {} via rule {:?}",
prev.step, curr.step, curr.rule
);
}
Ok(true)
}
}
#[derive(Debug, Clone)]
pub struct ConcreteWitness {
pub states: Vec<ConcreteState>,
validated: bool,
}
impl ConcreteWitness {
pub fn new() -> Self {
Self {
states: Vec::new(),
validated: false,
}
}
pub fn add_state(&mut self, state: ConcreteState) {
self.states.push(state);
}
pub fn mark_validated(&mut self) {
self.validated = true;
}
pub fn is_validated(&self) -> bool {
self.validated
}
pub fn len(&self) -> usize {
self.states.len()
}
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
}
impl Default for ConcreteWitness {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ConcreteState {
pub step: usize,
pub pred: PredId,
pub state: TermId,
pub rule: Option<RuleId>,
pub assignments: Vec<(TermId, TermId)>,
pub model_values: Vec<ConcreteValue>,
}
#[derive(Debug, Clone)]
pub enum ConcreteValue {
Bool(bool),
Int(num_bigint::BigInt),
Real(num_bigint::BigInt, num_bigint::BigInt),
BitVec(num_bigint::BigInt, u32),
}
#[derive(Debug, thiserror::Error)]
pub enum WitnessError {
#[error("empty counterexample trace")]
EmptyTrace,
#[error("model count mismatch: expected {expected}, got {got}")]
ModelMismatch { expected: usize, got: usize },
#[error("predicate not found: {0:?}")]
PredicateNotFound(PredId),
#[error("invalid state: {0}")]
InvalidState(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reach_fact_creation() {
let mut store = ReachFactStore::new();
let pred = PredId::new(0);
let fact = oxiz_core::TermId::new(42);
let rule = RuleId::new(0);
let id = store.add(pred, fact, rule, true);
let reach_fact = store.get(id).expect("key should exist in map");
assert!(reach_fact.is_init());
assert_eq!(reach_fact.rule(), rule);
assert_eq!(reach_fact.fact, fact);
}
#[test]
fn test_reach_fact_justification() {
let mut store = ReachFactStore::new();
let pred = PredId::new(0);
let rule = RuleId::new(0);
let fact1 = oxiz_core::TermId::new(1);
let fact2 = oxiz_core::TermId::new(2);
let id1 = store.add(pred, fact1, rule, true);
let id2 = store.add(pred, fact2, rule, false);
store
.get_mut(id2)
.expect("test operation should succeed")
.add_justification(id1);
let reach_fact = store.get(id2).expect("key should exist in map");
assert_eq!(reach_fact.justifications(), &[id1]);
}
#[test]
fn test_counterexample() {
let mut cex = Counterexample::new();
let pred = PredId::new(0);
let state = oxiz_core::TermId::new(42);
cex.push(CexState {
pred,
state,
rule: None,
assignments: SmallVec::new(),
});
assert_eq!(cex.len(), 1);
assert!(!cex.is_spurious());
cex.mark_spurious();
assert!(cex.is_spurious());
}
#[test]
fn test_under_approximation() {
let mut under = UnderApproximation::new();
let pred = PredId::new(0);
let state1 = oxiz_core::TermId::new(1);
let state2 = oxiz_core::TermId::new(2);
under.add(pred, state1);
under.add(pred, state2);
assert!(under.contains(pred, state1));
assert!(under.contains(pred, state2));
assert!(!under.contains(pred, oxiz_core::TermId::new(3)));
}
#[test]
fn test_over_approximation() {
let mut over = OverApproximation::new();
let pred = PredId::new(0);
let lemma1 = oxiz_core::TermId::new(1);
let lemma2 = oxiz_core::TermId::new(2);
over.add_blocked(pred, 1, lemma1);
over.add_blocked(pred, 2, lemma2);
let blocked: Vec<_> = over.blocked_at(pred, 1).collect();
assert_eq!(blocked.len(), 2);
let blocked: Vec<_> = over.blocked_at(pred, 2).collect();
assert_eq!(blocked.len(), 1);
}
#[test]
fn test_generalization() {
let cube = [
oxiz_core::TermId::new(1),
oxiz_core::TermId::new(2),
oxiz_core::TermId::new(3),
];
let mut generalization = Generalization::new(cube);
assert_eq!(generalization.cube.len(), 3);
assert!(!generalization.inductive);
generalization.drop_literal(oxiz_core::TermId::new(2));
generalization.mark_inductive();
assert!(generalization.inductive);
assert_eq!(generalization.dropped.len(), 1);
}
}