use core::fmt;
use crate::spec::law::{AlgebraicLaw, MonotonicDirection};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FormalLaw {
ForAll {
vars: Vec<String>,
body: Predicate,
},
Exists {
vars: Vec<String>,
body: Predicate,
},
}
impl FormalLaw {
#[inline]
pub fn is_non_trivial(&self) -> bool {
match self {
Self::ForAll { body, .. } | Self::Exists { body, .. } => body.is_non_trivial(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Expr {
Var(String),
Const(u32),
Call {
function: String,
args: Vec<Expr>,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Predicate {
Equality {
lhs: Expr,
rhs: Expr,
},
NotEqual {
lhs: Expr,
rhs: Expr,
},
LessEqual {
lhs: Expr,
rhs: Expr,
},
GreaterEqual {
lhs: Expr,
rhs: Expr,
},
Implies {
premise: Box<Predicate>,
conclusion: Box<Predicate>,
},
And(Vec<Predicate>),
Or(Vec<Predicate>),
Not(Box<Predicate>),
ExactlyOne(Vec<Predicate>),
Custom {
name: String,
args: Vec<Expr>,
},
}
impl Predicate {
fn is_non_trivial(&self) -> bool {
match self {
Self::Equality { lhs, rhs }
| Self::NotEqual { lhs, rhs }
| Self::LessEqual { lhs, rhs }
| Self::GreaterEqual { lhs, rhs } => lhs != rhs,
Self::Implies {
premise,
conclusion,
} => premise.is_non_trivial() && conclusion.is_non_trivial(),
Self::And(items) | Self::Or(items) | Self::ExactlyOne(items) => {
!items.is_empty() && items.iter().all(Self::is_non_trivial)
}
Self::Not(inner) => inner.is_non_trivial(),
Self::Custom { name, args } => !name.is_empty() && !args.is_empty(),
}
}
}
pub trait AlgebraicLawFormalSpec {
fn formal_spec(&self) -> FormalLaw;
}
impl AlgebraicLawFormalSpec for AlgebraicLaw {
fn formal_spec(&self) -> FormalLaw {
match self {
AlgebraicLaw::Commutative => forall(&["a", "b"], eq(f(&["a", "b"]), f(&["b", "a"]))),
AlgebraicLaw::Associative => forall(
&["a", "b", "c"],
eq(
call("f", vec![f(&["a", "b"]), v("c")]),
call("f", vec![v("a"), f(&["b", "c"])]),
),
),
AlgebraicLaw::Identity { element } => forall(
&["a"],
and(vec![
eq(call("f", vec![v("a"), c(*element)]), v("a")),
eq(call("f", vec![c(*element), v("a")]), v("a")),
]),
),
AlgebraicLaw::LeftIdentity { element } => {
forall(&["a"], eq(call("f", vec![c(*element), v("a")]), v("a")))
}
AlgebraicLaw::RightIdentity { element } => {
forall(&["a"], eq(call("f", vec![v("a"), c(*element)]), v("a")))
}
AlgebraicLaw::SelfInverse { result } => {
forall(&["a"], eq(call("f", vec![v("a"), v("a")]), c(*result)))
}
AlgebraicLaw::Idempotent => forall(&["a"], eq(call("f", vec![v("a"), v("a")]), v("a"))),
AlgebraicLaw::Absorbing { element } => forall(
&["a"],
and(vec![
eq(call("f", vec![v("a"), c(*element)]), c(*element)),
eq(call("f", vec![c(*element), v("a")]), c(*element)),
]),
),
AlgebraicLaw::LeftAbsorbing { element } => forall(
&["a"],
eq(call("f", vec![c(*element), v("a")]), c(*element)),
),
AlgebraicLaw::RightAbsorbing { element } => forall(
&["a"],
eq(call("f", vec![v("a"), c(*element)]), c(*element)),
),
AlgebraicLaw::Involution => {
forall(&["a"], eq(call("f", vec![call("f", vec![v("a")])]), v("a")))
}
AlgebraicLaw::DeMorgan { inner_op, dual_op } => forall(
&["a", "b"],
eq(
call("f", vec![op(inner_op, &["a", "b"])]),
call(
*dual_op,
vec![call("f", vec![v("a")]), call("f", vec![v("b")])],
),
),
),
AlgebraicLaw::Monotone => monotonic(MonotonicDirection::NonDecreasing),
AlgebraicLaw::Monotonic { direction } => monotonic(*direction),
AlgebraicLaw::Bounded { lo, hi } => forall(
&["a", "b"],
and(vec![
Predicate::LessEqual {
lhs: c(*lo),
rhs: f(&["a", "b"]),
},
Predicate::LessEqual {
lhs: f(&["a", "b"]),
rhs: c(*hi),
},
]),
),
AlgebraicLaw::Complement {
complement_op,
universe,
} => forall(
&["a"],
eq(
call("f", vec![v("a"), call(*complement_op, vec![v("a")])]),
c(*universe),
),
),
AlgebraicLaw::DistributiveOver { over_op } => forall(
&["a", "b", "c"],
eq(
call("f", vec![v("a"), op(over_op, &["b", "c"])]),
call(*over_op, vec![f(&["a", "b"]), f(&["a", "c"])]),
),
),
AlgebraicLaw::LatticeAbsorption { dual_op } => forall(
&["a", "b"],
eq(call("f", vec![v("a"), op(dual_op, &["a", "b"])]), v("a")),
),
AlgebraicLaw::InverseOf { op: inverse_op } => forall(
&["a", "b"],
eq(
call("f", vec![call(*inverse_op, vec![v("a"), v("b")]), v("b")]),
v("a"),
),
),
AlgebraicLaw::Trichotomy {
less_op,
equal_op,
greater_op,
} => forall(
&["a", "b"],
Predicate::ExactlyOne(vec![
eq(call(*less_op, vec![v("a"), v("b")]), c(1)),
eq(call(*equal_op, vec![v("a"), v("b")]), c(1)),
eq(call(*greater_op, vec![v("a"), v("b")]), c(1)),
]),
),
AlgebraicLaw::ZeroProduct { holds: true } => forall(
&["a", "b"],
Predicate::Implies {
premise: Box::new(eq(f(&["a", "b"]), c(0))),
conclusion: Box::new(or(vec![eq(v("a"), c(0)), eq(v("b"), c(0))])),
},
),
AlgebraicLaw::ZeroProduct { holds: false } => FormalLaw::Exists {
vars: vec!["a".to_string(), "b".to_string()],
body: and(vec![
eq(f(&["a", "b"]), c(0)),
Predicate::NotEqual {
lhs: v("a"),
rhs: c(0),
},
Predicate::NotEqual {
lhs: v("b"),
rhs: c(0),
},
]),
},
AlgebraicLaw::Custom { name, arity, .. } => FormalLaw::ForAll {
vars: custom_vars(*arity),
body: Predicate::Custom {
name: (*name).to_string(),
args: (0..*arity).map(custom_var).collect(),
},
},
other => FormalLaw::ForAll {
vars: vec!["a".to_string()],
body: Predicate::Custom {
name: format!("unhandled:{}", other.name()),
args: vec![v("a")],
},
},
}
}
}
fn monotonic(direction: MonotonicDirection) -> FormalLaw {
let conclusion = match direction {
MonotonicDirection::NonDecreasing => Predicate::LessEqual {
lhs: call("f", vec![v("a")]),
rhs: call("f", vec![v("b")]),
},
MonotonicDirection::NonIncreasing => Predicate::GreaterEqual {
lhs: call("f", vec![v("a")]),
rhs: call("f", vec![v("b")]),
},
_ => {
return FormalLaw::ForAll {
vars: vec!["a".to_string(), "b".to_string()],
body: Predicate::Custom {
name: "unknown-monotonic-direction".to_string(),
args: vec![v("a"), v("b")],
},
};
}
};
forall(
&["a", "b"],
Predicate::Implies {
premise: Box::new(Predicate::LessEqual {
lhs: v("a"),
rhs: v("b"),
}),
conclusion: Box::new(conclusion),
},
)
}
fn forall(vars: &[&'static str], body: Predicate) -> FormalLaw {
FormalLaw::ForAll {
vars: vars.iter().map(ToString::to_string).collect(),
body,
}
}
fn v(name: &'static str) -> Expr {
Expr::Var(name.to_string())
}
fn c(value: u32) -> Expr {
Expr::Const(value)
}
fn f(vars: &[&'static str]) -> Expr {
op("f", vars)
}
fn op(function: &str, vars: &[&'static str]) -> Expr {
call(function, vars.iter().copied().map(v).collect())
}
fn call(function: impl Into<String>, args: Vec<Expr>) -> Expr {
Expr::Call {
function: function.into(),
args,
}
}
fn eq(lhs: Expr, rhs: Expr) -> Predicate {
Predicate::Equality { lhs, rhs }
}
fn and(items: Vec<Predicate>) -> Predicate {
Predicate::And(items)
}
fn or(items: Vec<Predicate>) -> Predicate {
Predicate::Or(items)
}
fn custom_vars(arity: usize) -> Vec<String> {
(0..arity).map(|index| format!("x{index}")).collect()
}
fn custom_var(index: usize) -> Expr {
Expr::Var(format!("x{index}"))
}
impl fmt::Display for FormalLaw {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ForAll { vars, body } => write!(f, "forall {} . {}", vars.join(" "), body),
Self::Exists { vars, body } => write!(f, "exists {} . {}", vars.join(" "), body),
}
}
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Var(name) => f.write_str(name),
Self::Const(value) => write!(f, "{value}"),
Self::Call { function, args } => {
let args = args
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ");
write!(f, "{function}({args})")
}
}
}
}
impl fmt::Display for Predicate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Equality { lhs, rhs } => write!(f, "{lhs} = {rhs}"),
Self::NotEqual { lhs, rhs } => write!(f, "{lhs} != {rhs}"),
Self::LessEqual { lhs, rhs } => write!(f, "{lhs} <= {rhs}"),
Self::GreaterEqual { lhs, rhs } => write!(f, "{lhs} >= {rhs}"),
Self::Implies {
premise,
conclusion,
} => write!(f, "({premise}) -> ({conclusion})"),
Self::And(items) => join_predicates(f, "/\\", items),
Self::Or(items) => join_predicates(f, "\\/", items),
Self::Not(inner) => write!(f, "!({inner})"),
Self::ExactlyOne(items) => join_predicates(f, "exactly_one", items),
Self::Custom { name, args } => {
let args = args
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ");
write!(f, "{name}({args})")
}
}
}
}
fn join_predicates(f: &mut fmt::Formatter<'_>, sep: &str, items: &[Predicate]) -> fmt::Result {
let text = items
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(&format!(" {sep} "));
write!(f, "({text})")
}