#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
use clap::ValueEnum;
use std::fmt::{Debug, Display};
use std::hash::Hash;
pub type Variable = u32;
pub trait Literal:
Copy + Debug + Eq + Hash + Default + Ord + PartialOrd + PartialEq + Send + Sync
{
fn new(var: Variable, polarity: bool) -> Self;
fn variable(self) -> Variable;
fn polarity(self) -> bool;
#[must_use]
fn negated(self) -> Self;
fn is_negated(self) -> bool {
!self.polarity()
}
fn is_positive(self) -> bool {
self.polarity()
}
#[must_use]
fn from_i32(value: i32) -> Self {
let polarity = value.is_positive();
let var = value.unsigned_abs();
Self::new(var, polarity)
}
fn to_i32(&self) -> i32 {
#[allow(clippy::cast_possible_wrap)]
let var_signed = self.variable() as i32;
if self.polarity() {
var_signed
} else {
-var_signed
}
}
fn index(self) -> usize {
let polarity_bit = usize::from(self.polarity());
let var_usize = self.variable() as usize;
var_usize.wrapping_mul(2).wrapping_add(polarity_bit)
}
#[must_use]
fn from_index(index: usize) -> Self {
let polarity = (index % 2) != 0;
let var_usize = index / 2;
#[allow(clippy::cast_possible_truncation)]
Self::new(var_usize as Variable, polarity)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
pub struct PackedLiteral(u32);
const VAR_MASK: u32 = 0x7FFF_FFFF;
const LSHIFT: u32 = 31;
impl Literal for PackedLiteral {
#[inline]
fn new(var: Variable, polarity: bool) -> Self {
Self((var & VAR_MASK) | (u32::from(polarity) << LSHIFT))
}
#[inline]
fn variable(self) -> Variable {
self.0 & VAR_MASK
}
#[inline]
fn polarity(self) -> bool {
(self.0 >> LSHIFT) != 0
}
#[inline]
fn negated(self) -> Self {
Self(self.0 ^ (1 << LSHIFT))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
pub struct StructLiteral {
value: u32,
polarity: bool,
}
impl Literal for StructLiteral {
#[inline]
fn new(var: Variable, polarity: bool) -> Self {
Self {
value: var,
polarity,
}
}
#[inline]
fn variable(self) -> Variable {
self.value
}
#[inline]
fn polarity(self) -> bool {
self.polarity
}
#[inline]
fn negated(self) -> Self {
Self {
value: self.value,
polarity: !self.polarity,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
pub struct DoubleLiteral(u32);
impl Literal for DoubleLiteral {
#[inline]
fn new(var: Variable, polarity: bool) -> Self {
Self(var.wrapping_mul(2).wrapping_add(u32::from(polarity)))
}
#[inline]
fn variable(self) -> Variable {
self.0 / 2
}
#[inline]
fn polarity(self) -> bool {
(self.0 % 2) != 0
}
#[inline]
fn negated(self) -> Self {
Self(self.0 ^ 1)
}
#[inline]
fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
pub struct NegativeLiteral(i32);
impl Literal for NegativeLiteral {
#[inline]
fn new(var: Variable, polarity: bool) -> Self {
#[allow(clippy::cast_possible_wrap)]
let var = var as i32;
let p = i32::from(!polarity); let var = var * (1 - 2 * p);
Self(var)
}
#[inline]
fn variable(self) -> Variable {
self.0.unsigned_abs() }
#[inline]
fn polarity(self) -> bool {
self.0.is_positive()
}
#[inline]
fn negated(self) -> Self {
Self(-self.0)
}
}
pub fn convert<T: Literal, U: Literal>(lit: &T) -> U {
let var = lit.variable();
let polarity = lit.polarity();
U::new(var, polarity)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum LiteralImpls {
Packed(PackedLiteral),
Struct(StructLiteral),
Double(DoubleLiteral),
Negative(NegativeLiteral),
}
impl Default for LiteralImpls {
fn default() -> Self {
Self::Double(DoubleLiteral::default())
}
}
impl Literal for LiteralImpls {
fn new(var: Variable, polarity: bool) -> Self {
Self::Packed(PackedLiteral::new(var, polarity))
}
fn variable(self) -> Variable {
match self {
Self::Packed(lit) => lit.variable(),
Self::Struct(lit) => lit.variable(),
Self::Double(lit) => lit.variable(),
Self::Negative(lit) => lit.variable(),
}
}
fn polarity(self) -> bool {
match self {
Self::Packed(lit) => lit.polarity(),
Self::Struct(lit) => lit.polarity(),
Self::Double(lit) => lit.polarity(),
Self::Negative(lit) => lit.polarity(),
}
}
fn negated(self) -> Self {
match self {
Self::Packed(lit) => Self::Packed(lit.negated()),
Self::Struct(lit) => Self::Struct(lit.negated()),
Self::Double(lit) => Self::Double(lit.negated()),
Self::Negative(lit) => Self::Negative(lit.negated()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default, ValueEnum)]
pub enum LiteralType {
Packed,
Struct,
#[default]
Double,
Negative,
}
impl Display for LiteralType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Packed => write!(f, "packed"),
Self::Struct => write!(f, "struct"),
Self::Double => write!(f, "double"),
Self::Negative => write!(f, "negative"),
}
}
}
impl LiteralType {
#[allow(dead_code)]
#[must_use]
pub fn to_impl(self, var: Variable, polarity: bool) -> LiteralImpls {
match self {
Self::Packed => LiteralImpls::Packed(PackedLiteral::new(var, polarity)),
Self::Struct => LiteralImpls::Struct(StructLiteral::new(var, polarity)),
Self::Double => LiteralImpls::Double(DoubleLiteral::new(var, polarity)),
Self::Negative => LiteralImpls::Negative(NegativeLiteral::new(var, polarity)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_literal_implementation<L: Literal + 'static>(var_id: Variable, initial_polarity: bool) {
let lit = L::new(var_id, initial_polarity);
assert_eq!(lit.variable(), var_id, "Variable ID mismatch");
assert_eq!(
lit.polarity(),
initial_polarity,
"Initial polarity mismatch"
);
if initial_polarity {
assert!(lit.is_positive(), "Should be positive");
assert!(!lit.is_negated(), "Should not be negated");
} else {
assert!(!lit.is_positive(), "Should not be positive");
assert!(lit.is_negated(), "Should be negated");
}
let neg_lit = lit.negated();
assert_eq!(
neg_lit.variable(),
var_id,
"Variable ID mismatch after negation"
);
assert_eq!(
neg_lit.polarity(),
!initial_polarity,
"Polarity should flip after negation"
);
let double_neg_lit = neg_lit.negated();
assert_eq!(
double_neg_lit.variable(),
var_id,
"Variable ID mismatch after double negation"
);
assert_eq!(
double_neg_lit.polarity(),
initial_polarity,
"Polarity should revert after double negation"
);
assert_eq!(
double_neg_lit, lit,
"Double negation should return original literal"
);
let i32_val = lit.to_i32();
let lit_from_i32 = L::from_i32(i32_val);
assert_eq!(
lit_from_i32, lit,
"from_i32(to_i32(lit)) should be lit. Got: L={lit:?}, i32={i32_val}, L'={lit_from_i32:?}"
);
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_possible_wrap
)]
let expected_i32_val = if initial_polarity {
var_id as i32
} else {
-(var_id as i32)
};
if var_id == 0 && !initial_polarity {
if var_id != 0 {
assert_eq!(
i32_val, expected_i32_val,
"to_i32() value incorrect. Expected {expected_i32_val}, Got {i32_val}"
);
}
} else {
assert_eq!(
i32_val, expected_i32_val,
"to_i32() value incorrect. Expected {expected_i32_val}, Got {i32_val}"
);
}
if std::any::TypeId::of::<L>() != std::any::TypeId::of::<DoubleLiteral>() {
let idx = lit.index();
let lit_from_idx = L::from_index(idx);
assert_eq!(lit_from_idx, lit, "from_index(index(lit)) should be lit");
}
}
#[test]
fn test_all_literal_implementations() {
let test_cases = [
(1u32, true),
(1u32, false),
(VAR_MASK, true),
(VAR_MASK, false),
(12345u32, true),
(67890u32, false),
];
for &(var_id, polarity) in &test_cases {
let packed_var_id = var_id & VAR_MASK;
test_literal_implementation::<PackedLiteral>(packed_var_id, polarity);
test_literal_implementation::<StructLiteral>(var_id, polarity);
test_literal_implementation::<DoubleLiteral>(var_id, polarity);
if var_id != 0 {
test_literal_implementation::<NegativeLiteral>(var_id, polarity);
}
}
}
#[test]
fn test_double_literal_index() {
let lit_pos = DoubleLiteral::new(5, true); assert_eq!(lit_pos.index(), 11);
assert_eq!(DoubleLiteral::from_index(11), lit_pos);
let lit_neg = DoubleLiteral::new(5, false); assert_eq!(lit_neg.index(), 10);
assert_eq!(DoubleLiteral::from_index(10), lit_neg);
}
#[test]
fn test_literal_negation_consistency() {
assert_eq!(
PackedLiteral::new(1, false).negated().negated(),
PackedLiteral::new(1, false)
);
assert_eq!(
PackedLiteral::new(1, true).negated().negated(),
PackedLiteral::new(1, true)
);
assert_eq!(
StructLiteral::new(1, false).negated().negated(),
StructLiteral::new(1, false)
);
assert_eq!(
StructLiteral::new(1, true).negated().negated(),
StructLiteral::new(1, true)
);
assert_eq!(
DoubleLiteral::new(1, false).negated().negated(),
DoubleLiteral::new(1, false)
);
assert_eq!(
DoubleLiteral::new(1, true).negated().negated(),
DoubleLiteral::new(1, true)
);
assert_eq!(
NegativeLiteral::new(1, false).negated().negated(),
NegativeLiteral::new(1, false)
);
assert_eq!(
NegativeLiteral::new(1, true).negated().negated(),
NegativeLiteral::new(1, true)
);
}
#[test]
fn test_conversion_function() {
let p_lit = PackedLiteral::new(10, true);
let s_lit: StructLiteral = convert(&p_lit);
assert_eq!(s_lit.variable(), 10);
assert!(s_lit.polarity());
let d_lit: DoubleLiteral = convert(&s_lit);
assert_eq!(d_lit.variable(), 10);
assert!(d_lit.polarity());
let n_lit: NegativeLiteral = convert(&d_lit);
assert_eq!(n_lit.variable(), 10);
assert!(n_lit.polarity());
let p_lit_again: PackedLiteral = convert(&n_lit);
assert_eq!(p_lit_again, p_lit);
}
}