use std::collections::BTreeSet;
use std::fmt;
use std::hash::Hash;
use smallvec::SmallVec;
const INLINE_CAPACITY: usize = 8;
#[derive(Clone)]
pub struct SetWeight<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> {
elements: SmallVec<[T; INLINE_CAPACITY]>,
is_universe: bool,
}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> SetWeight<T> {
#[inline]
pub fn empty() -> Self {
SetWeight {
elements: SmallVec::new(),
is_universe: false,
}
}
#[inline]
pub fn universe() -> Self {
SetWeight {
elements: SmallVec::new(),
is_universe: true,
}
}
#[inline]
pub fn singleton(element: T) -> Self {
let mut elements = SmallVec::new();
elements.push(element);
SetWeight {
elements,
is_universe: false,
}
}
pub fn from_iter(iter: impl IntoIterator<Item = T>) -> Self {
let mut elements: SmallVec<[T; INLINE_CAPACITY]> = iter.into_iter().collect();
elements.sort();
elements.dedup();
SetWeight {
elements,
is_universe: false,
}
}
pub fn from_set(set: BTreeSet<T>) -> Self {
SetWeight {
elements: set.into_iter().collect(),
is_universe: false,
}
}
#[inline]
pub fn is_empty(&self) -> bool {
!self.is_universe && self.elements.is_empty()
}
#[inline]
pub fn is_universal(&self) -> bool {
self.is_universe
}
#[inline]
pub fn len(&self) -> usize {
if self.is_universe {
0
} else {
self.elements.len()
}
}
#[inline]
pub fn contains(&self, element: &T) -> bool {
if self.is_universe {
true } else {
self.elements.binary_search(element).is_ok()
}
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.elements.iter()
}
pub fn to_set(&self) -> BTreeSet<T> {
if self.is_universe {
panic!("Cannot convert universe to finite set");
}
self.elements.iter().cloned().collect()
}
fn set_union(&self, other: &Self) -> Self {
if self.is_universe || other.is_universe {
return SetWeight::universe();
}
let mut result = SmallVec::with_capacity(self.elements.len() + other.elements.len());
let mut i = 0;
let mut j = 0;
while i < self.elements.len() && j < other.elements.len() {
match self.elements[i].cmp(&other.elements[j]) {
std::cmp::Ordering::Less => {
result.push(self.elements[i].clone());
i += 1;
}
std::cmp::Ordering::Greater => {
result.push(other.elements[j].clone());
j += 1;
}
std::cmp::Ordering::Equal => {
result.push(self.elements[i].clone());
i += 1;
j += 1;
}
}
}
while i < self.elements.len() {
result.push(self.elements[i].clone());
i += 1;
}
while j < other.elements.len() {
result.push(other.elements[j].clone());
j += 1;
}
SetWeight {
elements: result,
is_universe: false,
}
}
fn set_intersection(&self, other: &Self) -> Self {
if self.is_empty() || other.is_empty() {
return SetWeight::empty();
}
if self.is_universe {
return other.clone();
}
if other.is_universe {
return self.clone();
}
let mut result = SmallVec::new();
let mut i = 0;
let mut j = 0;
while i < self.elements.len() && j < other.elements.len() {
match self.elements[i].cmp(&other.elements[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result.push(self.elements[i].clone());
i += 1;
j += 1;
}
}
}
SetWeight {
elements: result,
is_universe: false,
}
}
}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> PartialEq for SetWeight<T> {
fn eq(&self, other: &Self) -> bool {
if self.is_universe != other.is_universe {
return false;
}
if self.is_universe {
return true; }
self.elements == other.elements
}
}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> Eq for SetWeight<T> {}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> Hash for SetWeight<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.is_universe.hash(state);
if !self.is_universe {
for elem in &self.elements {
elem.hash(state);
}
}
}
}
impl<T: Clone + Eq + Ord + Hash + fmt::Debug + Send + Sync + 'static> fmt::Debug for SetWeight<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_universe {
write!(f, "SetWeight(Universe)")
} else {
write!(f, "SetWeight({:?})", self.elements.as_slice())
}
}
}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> Default for SetWeight<T> {
fn default() -> Self {
Self::one()
}
}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> SetWeight<T> {
#[inline]
pub fn zero() -> Self {
SetWeight::empty()
}
#[inline]
pub fn one() -> Self {
SetWeight::universe()
}
pub fn plus(&self, other: &Self) -> Self {
self.set_union(other)
}
pub fn times(&self, other: &Self) -> Self {
self.set_intersection(other)
}
#[inline]
pub fn is_zero(&self) -> bool {
self.is_empty()
}
#[inline]
pub fn is_one(&self) -> bool {
self.is_universal()
}
pub fn approx_eq(&self, other: &Self, _epsilon: f64) -> bool {
self == other
}
pub fn natural_less(&self, other: &Self) -> Option<bool> {
match (self.is_universe, other.is_universe) {
(true, true) => Some(false),
(true, false) => Some(false), (false, true) => Some(true), (false, false) => Some(self.elements.len() < other.elements.len()),
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.push(if self.is_universe { 1 } else { 0 });
if !self.is_universe {
bytes.extend((self.elements.len() as u64).to_le_bytes());
}
bytes
}
}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> std::ops::Add for SetWeight<T> {
type Output = Self;
#[inline]
fn add(self, other: Self) -> Self {
self.plus(&other)
}
}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> std::ops::Mul for SetWeight<T> {
type Output = Self;
#[inline]
fn mul(self, other: Self) -> Self {
self.times(&other)
}
}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> std::ops::AddAssign for SetWeight<T> {
#[inline]
fn add_assign(&mut self, other: Self) {
*self = self.plus(&other);
}
}
impl<T: Clone + Eq + Ord + Hash + Send + Sync + 'static> std::ops::MulAssign for SetWeight<T> {
#[inline]
fn mul_assign(&mut self, other: Self) {
*self = self.times(&other);
}
}
pub type StringSetWeight = SetWeight<String>;
pub type StrSetWeight = SetWeight<&'static str>;
pub type FeatureSetWeight = SetWeight<u32>;
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use std::collections::HashSet;
#[test]
fn test_basic_operations() {
let a = SetWeight::from_iter(vec![1, 2, 3]);
let b = SetWeight::from_iter(vec![2, 3, 4]);
let union = a.plus(&b);
assert_eq!(union.len(), 4);
assert!(union.contains(&1));
assert!(union.contains(&2));
assert!(union.contains(&3));
assert!(union.contains(&4));
let intersection = a.times(&b);
assert_eq!(intersection.len(), 2);
assert!(!intersection.contains(&1));
assert!(intersection.contains(&2));
assert!(intersection.contains(&3));
assert!(!intersection.contains(&4));
}
#[test]
fn test_identities() {
let a = SetWeight::from_iter(vec![1, 2, 3]);
let empty: SetWeight<i32> = SetWeight::empty();
let universe: SetWeight<i32> = SetWeight::universe();
assert_eq!(a.plus(&empty), a);
assert_eq!(empty.plus(&a), a);
assert_eq!(a.times(&universe), a);
assert_eq!(universe.times(&a), a);
}
#[test]
fn test_annihilation() {
let a = SetWeight::from_iter(vec![1, 2, 3]);
let empty: SetWeight<i32> = SetWeight::empty();
assert!(a.times(&empty).is_zero());
assert!(empty.times(&a).is_zero());
}
#[test]
fn test_universe_absorption() {
let a = SetWeight::from_iter(vec![1, 2, 3]);
let universe: SetWeight<i32> = SetWeight::universe();
assert!(a.plus(&universe).is_one());
assert!(universe.plus(&a).is_one());
}
#[test]
fn test_idempotence() {
let a = SetWeight::from_iter(vec![1, 2, 3]);
assert_eq!(a.plus(&a), a);
assert_eq!(a.times(&a), a);
}
#[test]
fn test_commutativity() {
let a = SetWeight::from_iter(vec![1, 2]);
let b = SetWeight::from_iter(vec![2, 3]);
assert_eq!(a.plus(&b), b.plus(&a));
assert_eq!(a.times(&b), b.times(&a));
}
#[test]
fn test_distributivity() {
let a = SetWeight::from_iter(vec![1, 2]);
let b = SetWeight::from_iter(vec![2, 3]);
let c = SetWeight::from_iter(vec![3, 4]);
let left = a.times(&b.plus(&c));
let right = a.times(&b).plus(&a.times(&c));
assert_eq!(left, right);
}
#[test]
fn test_singleton() {
let a = SetWeight::singleton(42i32);
assert_eq!(a.len(), 1);
assert!(a.contains(&42));
assert!(!a.contains(&0));
}
#[test]
fn test_from_set() {
let mut set = BTreeSet::new();
set.insert(1);
set.insert(2);
set.insert(3);
let a = SetWeight::from_set(set.clone());
assert_eq!(a.to_set(), set);
}
#[test]
fn test_natural_ordering() {
let small = SetWeight::from_iter(vec![1]);
let medium = SetWeight::from_iter(vec![1, 2, 3]);
let universe: SetWeight<i32> = SetWeight::universe();
assert_eq!(small.natural_less(&medium), Some(true));
assert_eq!(medium.natural_less(&small), Some(false));
assert_eq!(small.natural_less(&universe), Some(true));
assert_eq!(universe.natural_less(&small), Some(false));
}
#[test]
fn test_string_sets() {
let a = SetWeight::from_iter(vec!["noun".to_string(), "verb".to_string()]);
let b = SetWeight::from_iter(vec!["verb".to_string(), "adj".to_string()]);
let union = a.plus(&b);
assert_eq!(union.len(), 3);
let intersection = a.times(&b);
assert_eq!(intersection.len(), 1);
assert!(intersection.contains(&"verb".to_string()));
}
proptest! {
#[test]
fn proptest_semiring_axioms(
a_elems in prop::collection::vec(0u32..50, 0..5),
b_elems in prop::collection::vec(0u32..50, 0..5),
c_elems in prop::collection::vec(0u32..50, 0..5)
) {
let a = SetWeight::from_iter(a_elems);
let b = SetWeight::from_iter(b_elems);
let c = SetWeight::from_iter(c_elems);
let zero: SetWeight<u32> = SetWeight::zero();
let one: SetWeight<u32> = SetWeight::one();
prop_assert_eq!(a.plus(&b).plus(&c), a.plus(&b.plus(&c)));
prop_assert_eq!(a.plus(&b), b.plus(&a));
prop_assert_eq!(a.plus(&zero), a.clone());
prop_assert_eq!(a.times(&b).times(&c), a.times(&b.times(&c)));
prop_assert_eq!(a.times(&one), a.clone());
prop_assert!(a.times(&zero).is_zero());
prop_assert_eq!(a.times(&b.plus(&c)), a.times(&b).plus(&a.times(&c)));
prop_assert_eq!(a.plus(&b).times(&c), a.times(&c).plus(&b.times(&c)));
}
#[test]
fn proptest_idempotent(a_elems in prop::collection::vec(0u32..50, 0..5)) {
let a = SetWeight::from_iter(a_elems);
prop_assert_eq!(a.plus(&a), a.clone());
prop_assert_eq!(a.times(&a), a);
}
#[test]
fn proptest_commutative_times(
a_elems in prop::collection::vec(0u32..50, 0..5),
b_elems in prop::collection::vec(0u32..50, 0..5)
) {
let a = SetWeight::from_iter(a_elems);
let b = SetWeight::from_iter(b_elems);
prop_assert_eq!(a.times(&b), b.times(&a));
}
#[test]
fn proptest_zero_sum_free(
a_elems in prop::collection::vec(0u32..50, 0..5),
b_elems in prop::collection::vec(0u32..50, 0..5)
) {
let a = SetWeight::from_iter(a_elems);
let b = SetWeight::from_iter(b_elems);
let sum = a.plus(&b);
if sum.is_zero() {
prop_assert!(a.is_zero(), "a should be zero when a ⊕ b = 0");
prop_assert!(b.is_zero(), "b should be zero when a ⊕ b = 0");
}
}
#[test]
fn proptest_union_correct(
a_elems in prop::collection::vec(0u32..100, 0..10),
b_elems in prop::collection::vec(0u32..100, 0..10)
) {
let a = SetWeight::from_iter(a_elems.clone());
let b = SetWeight::from_iter(b_elems.clone());
let union = a.plus(&b);
let a_hs: HashSet<_> = a_elems.iter().collect();
let b_hs: HashSet<_> = b_elems.iter().collect();
for elem in a_hs.union(&b_hs) {
prop_assert!(union.contains(elem));
}
}
#[test]
fn proptest_intersection_correct(
a_elems in prop::collection::vec(0u32..100, 0..10),
b_elems in prop::collection::vec(0u32..100, 0..10)
) {
let a = SetWeight::from_iter(a_elems.clone());
let b = SetWeight::from_iter(b_elems.clone());
let intersection = a.times(&b);
let a_hs: HashSet<_> = a_elems.iter().collect();
let b_hs: HashSet<_> = b_elems.iter().collect();
for elem in a_hs.intersection(&b_hs) {
prop_assert!(intersection.contains(elem));
}
for elem in intersection.iter() {
prop_assert!(a.contains(elem) && b.contains(elem));
}
}
#[test]
fn proptest_universe_identity(a_elems in prop::collection::vec(0u32..100, 0..10)) {
let a = SetWeight::from_iter(a_elems);
let universe: SetWeight<u32> = SetWeight::universe();
prop_assert_eq!(a.times(&universe), a.clone());
prop_assert_eq!(universe.times(&a), a);
}
}
}