#[cfg(not(feature = "std"))]
use alloc::format;
#[cfg(not(feature = "std"))]
use alloc::string::String;
pub fn check_associativity<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
b: T,
c: T,
op: impl Fn(T, T) -> T,
) -> Result<(), LawViolation> {
let left = op(op(a.clone(), b.clone()), c.clone());
let right = op(a, op(b, c));
if left == right {
Ok(())
} else {
Err(LawViolation {
law: "associativity",
left: format!("{left:?}"),
right: format!("{right:?}"),
})
}
}
pub fn check_left_identity<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
e: T,
op: impl Fn(T, T) -> T,
) -> Result<(), LawViolation> {
let result = op(e, a.clone());
if result == a {
Ok(())
} else {
Err(LawViolation {
law: "left identity",
left: format!("{result:?}"),
right: format!("{a:?}"),
})
}
}
pub fn check_right_identity<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
e: T,
op: impl Fn(T, T) -> T,
) -> Result<(), LawViolation> {
let result = op(a.clone(), e);
if result == a {
Ok(())
} else {
Err(LawViolation {
law: "right identity",
left: format!("{result:?}"),
right: format!("{a:?}"),
})
}
}
pub fn check_commutativity<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
b: T,
op: impl Fn(T, T) -> T,
) -> Result<(), LawViolation> {
let left = op(a.clone(), b.clone());
let right = op(b, a);
if left == right {
Ok(())
} else {
Err(LawViolation {
law: "commutativity",
left: format!("{left:?}"),
right: format!("{right:?}"),
})
}
}
pub fn check_left_inverse<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
e: T,
op: impl Fn(T, T) -> T,
inv: impl Fn(T) -> T,
) -> Result<(), LawViolation> {
let result = op(inv(a.clone()), a);
if result == e {
Ok(())
} else {
Err(LawViolation {
law: "left inverse",
left: format!("{result:?}"),
right: format!("{e:?}"),
})
}
}
pub fn check_right_inverse<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
e: T,
op: impl Fn(T, T) -> T,
inv: impl Fn(T) -> T,
) -> Result<(), LawViolation> {
let result = op(a.clone(), inv(a));
if result == e {
Ok(())
} else {
Err(LawViolation {
law: "right inverse",
left: format!("{result:?}"),
right: format!("{e:?}"),
})
}
}
pub fn check_left_distributivity<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
b: T,
c: T,
add: impl Fn(T, T) -> T,
mul: impl Fn(T, T) -> T,
) -> Result<(), LawViolation> {
let left = mul(a.clone(), add(b.clone(), c.clone()));
let right = add(mul(a.clone(), b), mul(a, c));
if left == right {
Ok(())
} else {
Err(LawViolation {
law: "left distributivity",
left: format!("{left:?}"),
right: format!("{right:?}"),
})
}
}
pub fn check_right_distributivity<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
b: T,
c: T,
add: impl Fn(T, T) -> T,
mul: impl Fn(T, T) -> T,
) -> Result<(), LawViolation> {
let left = mul(add(a.clone(), b.clone()), c.clone());
let right = add(mul(a, c.clone()), mul(b, c));
if left == right {
Ok(())
} else {
Err(LawViolation {
law: "right distributivity",
left: format!("{left:?}"),
right: format!("{right:?}"),
})
}
}
pub fn check_idempotency<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
op: impl Fn(T, T) -> T,
) -> Result<(), LawViolation> {
let result = op(a.clone(), a.clone());
if result == a {
Ok(())
} else {
Err(LawViolation {
law: "idempotency",
left: format!("{result:?}"),
right: format!("{a:?}"),
})
}
}
pub fn check_absorption<T: Clone + PartialEq + core::fmt::Debug>(
a: T,
b: T,
meet: impl Fn(T, T) -> T,
join: impl Fn(T, T) -> T,
) -> Result<(), LawViolation> {
let result = meet(a.clone(), join(a.clone(), b));
if result == a {
Ok(())
} else {
Err(LawViolation {
law: "absorption",
left: format!("{result:?}"),
right: format!("{a:?}"),
})
}
}
#[derive(Debug, Clone)]
pub struct LawViolation {
pub law: &'static str,
pub left: String,
pub right: String,
}
impl core::fmt::Display for LawViolation {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Law violation ({}): {} != {}",
self.law, self.left, self.right
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn associativity_holds_for_addition() {
assert!(check_associativity(1i32, 2, 3, |a, b| a.wrapping_add(b)).is_ok());
}
#[test]
fn associativity_fails_for_subtraction() {
assert!(check_associativity(10i32, 5, 3, |a, b| a - b).is_err());
}
#[test]
fn identity_holds() {
assert!(check_left_identity(5i32, 0, |a, b| a + b).is_ok());
assert!(check_right_identity(5i32, 0, |a, b| a + b).is_ok());
}
#[test]
fn commutativity_holds_for_addition() {
assert!(check_commutativity(3i32, 7, |a, b| a.wrapping_add(b)).is_ok());
}
#[test]
fn commutativity_fails_for_subtraction() {
assert!(check_commutativity(3i32, 7, |a, b| a - b).is_err());
}
#[test]
fn inverse_holds() {
assert!(check_left_inverse(5i32, 0, |a, b| a + b, |a| -a).is_ok());
assert!(check_right_inverse(5i32, 0, |a, b| a + b, |a| -a).is_ok());
}
#[test]
fn distributivity_holds_for_integers() {
assert!(
check_left_distributivity(
2i32,
3,
4,
|a, b| a.wrapping_add(b),
|a, b| a.wrapping_mul(b),
)
.is_ok()
);
assert!(
check_right_distributivity(
2i32,
3,
4,
|a, b| a.wrapping_add(b),
|a, b| a.wrapping_mul(b),
)
.is_ok()
);
}
#[test]
fn idempotency_holds_for_min() {
assert!(check_idempotency(5i32, |a, b| a.min(b)).is_ok());
}
#[test]
fn idempotency_fails_for_addition() {
assert!(check_idempotency(5i32, |a, b| a + b).is_err());
}
#[test]
fn absorption_holds_for_bool() {
assert!(check_absorption(true, false, |a, b| a && b, |a, b| a || b).is_ok());
assert!(check_absorption(false, true, |a, b| a && b, |a, b| a || b).is_ok());
}
#[test]
fn law_violation_display() {
let v = LawViolation {
law: "test",
left: "1".into(),
right: "2".into(),
};
assert_eq!(format!("{v}"), "Law violation (test): 1 != 2");
}
}