use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::eq::Term;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EqWitness {
pub lhs: Term,
pub rhs: Term,
pub justification: WitnessJustification,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum WitnessJustification {
Reflexivity,
Axiom(Arc<str>),
Symmetry(Box<EqWitness>),
Transitivity(Box<EqWitness>, Box<EqWitness>),
Congruence {
op: Arc<str>,
arg_witnesses: Vec<EqWitness>,
},
RuntimeChecked {
description: String,
},
}
impl EqWitness {
#[must_use]
pub fn reflexivity(term: Term) -> Self {
Self {
lhs: term.clone(),
rhs: term,
justification: WitnessJustification::Reflexivity,
}
}
#[must_use]
pub fn axiom(name: impl Into<Arc<str>>, lhs: Term, rhs: Term) -> Self {
Self {
lhs,
rhs,
justification: WitnessJustification::Axiom(name.into()),
}
}
#[must_use]
pub fn transitivity(ab: Self, bc: Self) -> Self {
Self {
lhs: ab.lhs.clone(),
rhs: bc.rhs.clone(),
justification: WitnessJustification::Transitivity(Box::new(ab), Box::new(bc)),
}
}
#[must_use]
pub fn symmetry(witness: Self) -> Self {
Self {
lhs: witness.rhs.clone(),
rhs: witness.lhs.clone(),
justification: WitnessJustification::Symmetry(Box::new(witness)),
}
}
#[must_use]
pub fn congruence(op: impl Into<Arc<str>>, arg_witnesses: Vec<Self>) -> Self {
let op = op.into();
let lhs_args: Vec<Term> = arg_witnesses.iter().map(|w| w.lhs.clone()).collect();
let rhs_args: Vec<Term> = arg_witnesses.iter().map(|w| w.rhs.clone()).collect();
Self {
lhs: Term::app(Arc::clone(&op), lhs_args),
rhs: Term::app(Arc::clone(&op), rhs_args),
justification: WitnessJustification::Congruence { op, arg_witnesses },
}
}
#[must_use]
pub fn depth(&self) -> usize {
match &self.justification {
WitnessJustification::Reflexivity
| WitnessJustification::Axiom(_)
| WitnessJustification::RuntimeChecked { .. } => 1,
WitnessJustification::Symmetry(w) => 1 + w.depth(),
WitnessJustification::Transitivity(a, b) => 1 + a.depth().max(b.depth()),
WitnessJustification::Congruence { arg_witnesses, .. } => {
1 + arg_witnesses.iter().map(Self::depth).max().unwrap_or(0)
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn reflexivity_witness() {
let t = Term::var("x");
let w = EqWitness::reflexivity(t.clone());
assert_eq!(w.lhs, t);
assert_eq!(w.rhs, t);
assert_eq!(w.depth(), 1);
}
#[test]
fn axiom_witness() {
let lhs = Term::app("add", vec![Term::var("x"), Term::constant("zero")]);
let rhs = Term::var("x");
let w = EqWitness::axiom("right_identity", lhs.clone(), rhs.clone());
assert_eq!(w.lhs, lhs);
assert_eq!(w.rhs, rhs);
assert_eq!(w.depth(), 1);
}
#[test]
fn transitivity_chain() {
let a = Term::var("a");
let b = Term::var("b");
let c = Term::var("c");
let ab = EqWitness::axiom("ax1", a.clone(), b.clone());
let bc = EqWitness::axiom("ax2", b, c.clone());
let ac = EqWitness::transitivity(ab, bc);
assert_eq!(ac.lhs, a);
assert_eq!(ac.rhs, c);
assert_eq!(ac.depth(), 2);
}
#[test]
fn symmetry_witness() {
let a = Term::var("a");
let b = Term::var("b");
let ab = EqWitness::axiom("ax", a.clone(), b.clone());
let ba = EqWitness::symmetry(ab);
assert_eq!(ba.lhs, b);
assert_eq!(ba.rhs, a);
assert_eq!(ba.depth(), 2);
}
#[test]
fn congruence_witness() {
let x = Term::var("x");
let _y = Term::var("y");
let w = EqWitness::reflexivity(x.clone());
let cong = EqWitness::congruence("f", vec![w]);
assert_eq!(cong.lhs, Term::app("f", vec![x.clone()]));
assert_eq!(cong.rhs, Term::app("f", vec![x]));
assert_eq!(cong.depth(), 2);
}
#[test]
fn serialization_round_trip() {
let w = EqWitness::axiom("ax", Term::var("a"), Term::var("b"));
let json = serde_json::to_string(&w).expect("serialize");
let deserialized: EqWitness = serde_json::from_str(&json).expect("deserialize");
assert_eq!(w, deserialized);
}
}