use std::{
collections::{HashMap, HashSet},
hash::Hash,
};
use crate::CustomSet;
#[derive(Clone)]
pub struct MultiSet<T: Eq + Hash + Clone> {
elements: HashMap<T, usize>,
}
impl<T: Eq + Hash + Clone> MultiSet<T> {
pub fn empty() -> Self {
Self {
elements: HashMap::new(),
}
}
pub fn new<I: IntoIterator<Item = T>>(iter: I) -> Self {
let mut multiset = MultiSet::empty();
for element in iter {
multiset.add(element, 1);
}
multiset
}
pub fn multiplicity(&self, element: &T) -> usize {
*self.elements.get(element).unwrap_or(&0)
}
pub fn cardinality(&self) -> usize {
self.elements.values().sum()
}
pub fn unique_count(&self) -> usize {
self.elements.len()
}
pub fn is_empty(&self) -> bool {
self.elements.is_empty()
}
pub fn add(&mut self, element: T, count: usize) {
if count == 0 {
return;
}
*self.elements.entry(element).or_insert(0) += count;
}
pub fn remove(&mut self, element: &T, count: usize) {
if let Some(current) = self.elements.get_mut(element) {
if *current <= count {
self.elements.remove(element);
} else {
*current -= count;
}
}
}
pub fn union(&self, other: &Self) -> Self {
let mut result = MultiSet::empty();
let all_keys = self
.elements
.keys()
.chain(other.elements.keys())
.collect::<Vec<_>>()
.into_iter()
.collect::<HashSet<_>>();
for key in all_keys {
let count =
(*self.elements.get(key).unwrap_or(&0)).max(*other.elements.get(key).unwrap_or(&0));
if count > 0 {
result.add(key.clone(), count);
}
}
result
}
pub fn intersection(&self, other: &Self) -> Self {
let mut result = MultiSet::empty();
for (key, &count) in &self.elements {
if let Some(&other_count) = other.elements.get(key) {
let min_count = count.min(other_count);
if min_count > 0 {
result.add(key.clone(), min_count);
}
}
}
result
}
pub fn difference(&self, other: &Self) -> Self {
let mut result = MultiSet::empty();
for (key, &count) in &self.elements {
let other_count = other.elements.get(key).unwrap_or(&0);
let diff = count.saturating_sub(*other_count);
if diff > 0 {
result.add(key.clone(), diff);
}
}
result
}
pub fn sum(&self, other: &Self) -> Self {
let mut result = MultiSet::empty();
let all_keys: HashSet<_> = self.elements.keys().chain(other.elements.keys()).collect();
for key in all_keys {
let count =
(*self.elements.get(key).unwrap_or(&0)) + (*other.elements.get(key).unwrap_or(&0));
if count > 0 {
result.add(key.clone(), count);
}
}
result
}
pub fn to_set(&self) -> CustomSet<T> {
CustomSet::new(self.elements.keys().cloned())
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.elements.keys()
}
}
impl<T: Eq + Hash + Clone> From<Vec<T>> for MultiSet<T> {
fn from(vec: Vec<T>) -> Self {
MultiSet::new(vec)
}
}
impl<T: Eq + Hash + Clone> FromIterator<T> for MultiSet<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
MultiSet::new(iter)
}
}
impl<T: Eq + Hash + Clone + std::fmt::Display> std::fmt::Display for MultiSet<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.is_empty() {
write!(f, "∅")
} else {
let entries: Vec<_> = self
.elements
.iter()
.map(|(k, v)| format!("{}:{}", k, v))
.collect();
write!(f, "{{{}}}", entries.join(", "))
}
}
}
impl<T: Eq + Hash + Clone + std::fmt::Debug> std::fmt::Debug for MultiSet<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MultiSet({:?})", self.elements)
}
}
impl<T: Eq + Hash + Clone> PartialEq for MultiSet<T> {
fn eq(&self, other: &Self) -> bool {
self.elements == other.elements
}
}
impl<T: Eq + Hash + Clone> Eq for MultiSet<T> {}