use crate::eval::value::Value;
use crate::containers::{Container, ContainerError, ContainerResult};
use crate::containers::comparator::HashComparator;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct Bag {
elements: HashMap<Value, usize>,
comparator: HashComparator,
debug_name: Option<String>,
}
impl Bag {
pub fn new() -> Self {
Self {
elements: HashMap::new(),
comparator: HashComparator::with_default(),
debug_name: None,
}
}
pub fn with_comparator(comparator: HashComparator) -> Self {
Self {
elements: HashMap::new(),
comparator,
debug_name: None,
}
}
pub fn from_iter_with_comparator<I>(iter: I, comparator: HashComparator) -> Self
where
I: IntoIterator<Item = Value>,
{
let mut bag = Self::with_comparator(comparator);
for value in iter {
bag.adjoin(value);
}
bag
}
pub fn adjoin(&mut self, value: Value) -> usize {
let count = self.elements.entry(value).or_insert(0);
*count += 1;
*count
}
pub fn adjoin_count(&mut self, value: Value, count: usize) -> usize {
if count == 0 {
return self.element_count(&value);
}
let current_count = self.elements.entry(value).or_insert(0);
*current_count += count;
*current_count
}
pub fn delete(&mut self, value: &Value) -> bool {
if let Some(count) = self.elements.get_mut(value) {
if *count > 1 {
*count -= 1;
true
} else {
self.elements.remove(value);
true
}
} else {
false
}
}
pub fn delete_all(&mut self, value: &Value) -> usize {
self.elements.remove(value).unwrap_or(0)
}
pub fn contains(&self, value: &Value) -> bool {
self.elements.get(value).is_some_and(|&count| count > 0)
}
pub fn element_count(&self, value: &Value) -> usize {
self.elements.get(value).copied().unwrap_or(0)
}
pub fn unique_size(&self) -> usize {
self.elements.len()
}
pub fn total_size(&self) -> usize {
self.elements.values().sum()
}
pub fn size(&self) -> usize {
self.total_size()
}
pub fn is_empty(&self) -> bool {
self.elements.is_empty()
}
pub fn unique_iter(&self) -> impl Iterator<Item = &Value> {
self.elements.keys()
}
pub fn iter_with_counts(&self) -> impl Iterator<Item = (&Value, &usize)> {
self.elements.iter()
}
pub fn to_vec(&self) -> Vec<Value> {
let mut result = Vec::with_capacity(self.total_size());
for (value, &count) in &self.elements {
for _ in 0..count {
result.push(value.clone());
}
}
result
}
pub fn unique_to_vec(&self) -> Vec<Value> {
self.elements.keys().cloned().collect()
}
pub fn to_set(&self) -> crate::containers::set::Set {
crate::containers::set::Set::from_iter_with_comparator(
self.elements.keys().cloned(),
self.comparator.clone()
)
}
pub fn comparator(&self) -> &HashComparator {
&self.comparator
}
pub fn set_debug_name(&mut self, name: impl Into<String>) {
self.debug_name = Some(name.into());
}
pub fn debug_name(&self) -> Option<&str> {
self.debug_name.as_deref()
}
pub fn clear_debug_name(&mut self) {
self.debug_name = None;
}
pub fn increment(&mut self, value: &Value) -> usize {
let count = self.elements.entry(value.clone()).or_insert(0);
*count += 1;
*count
}
pub fn decrement(&mut self, value: &Value) -> usize {
if let Some(count) = self.elements.get_mut(value) {
if *count > 1 {
*count -= 1;
*count
} else {
self.elements.remove(value);
0
}
} else {
0
}
}
pub fn union(&self, other: &Bag) -> Bag {
let mut result = self.clone();
for (value, &count) in &other.elements {
result.adjoin_count(value.clone(), count);
}
result
}
pub fn intersection(&self, other: &Bag) -> Bag {
let mut result = Bag::with_comparator(self.comparator.clone());
for (value, &self_count) in &self.elements {
if let Some(&other_count) = other.elements.get(value) {
let min_count = self_count.min(other_count);
if min_count > 0 {
result.adjoin_count(value.clone(), min_count);
}
}
}
result
}
pub fn difference(&self, other: &Bag) -> Bag {
let mut result = Bag::with_comparator(self.comparator.clone());
for (value, &self_count) in &self.elements {
let other_count = other.element_count(value);
if self_count > other_count {
result.adjoin_count(value.clone(), self_count - other_count);
}
}
result
}
pub fn sum(&self, other: &Bag) -> Bag {
self.union(other)
}
pub fn product(&self, other: &Bag) -> Bag {
let mut result = Bag::with_comparator(self.comparator.clone());
for (value, &self_count) in &self.elements {
if let Some(&other_count) = other.elements.get(value) {
let product_count = self_count * other_count;
if product_count > 0 {
result.adjoin_count(value.clone(), product_count);
}
}
}
result
}
pub fn is_subbag(&self, other: &Bag) -> bool {
self.elements.iter().all(|(value, &count)| {
other.element_count(value) >= count
})
}
pub fn is_disjoint(&self, other: &Bag) -> bool {
!self.elements.keys().any(|value| other.contains(value))
}
}
impl Default for Bag {
fn default() -> Self {
Self::new()
}
}
impl Container for Bag {
fn len(&self) -> usize {
self.total_size()
}
fn clear(&mut self) {
self.elements.clear();
}
}
impl PartialEq for Bag {
fn eq(&self, other: &Self) -> bool {
self.elements == other.elements
}
}
impl Eq for Bag {}
#[derive(Debug, Clone)]
pub struct ThreadSafeBag {
inner: Arc<RwLock<Bag>>,
}
impl ThreadSafeBag {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(Bag::new())),
}
}
pub fn with_comparator(comparator: HashComparator) -> Self {
Self {
inner: Arc::new(RwLock::new(Bag::with_comparator(comparator))),
}
}
pub fn adjoin(&self, value: Value) -> ContainerResult<usize> {
Ok(self
.inner
.write()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire write lock".to_string(),
})?
.adjoin(value))
}
pub fn adjoin_count(&self, value: Value, count: usize) -> ContainerResult<usize> {
Ok(self
.inner
.write()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire write lock".to_string(),
})?
.adjoin_count(value, count))
}
pub fn delete(&self, value: &Value) -> ContainerResult<bool> {
Ok(self
.inner
.write()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire write lock".to_string(),
})?
.delete(value))
}
pub fn contains(&self, value: &Value) -> ContainerResult<bool> {
Ok(self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?
.contains(value))
}
pub fn element_count(&self, value: &Value) -> ContainerResult<usize> {
Ok(self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?
.element_count(value))
}
pub fn unique_size(&self) -> ContainerResult<usize> {
Ok(self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?
.unique_size())
}
pub fn total_size(&self) -> ContainerResult<usize> {
Ok(self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?
.total_size())
}
pub fn size(&self) -> ContainerResult<usize> {
self.total_size()
}
pub fn is_empty(&self) -> ContainerResult<bool> {
Ok(self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?
.is_empty())
}
pub fn to_vec(&self) -> ContainerResult<Vec<Value>> {
Ok(self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?
.to_vec())
}
pub fn increment(&self, value: &Value) -> ContainerResult<usize> {
Ok(self
.inner
.write()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire write lock".to_string(),
})?
.increment(value))
}
pub fn decrement(&self, value: &Value) -> ContainerResult<usize> {
Ok(self
.inner
.write()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire write lock".to_string(),
})?
.decrement(value))
}
pub fn clear(&self) -> ContainerResult<()> {
self.inner
.write()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire write lock".to_string(),
})?
.clear();
Ok(())
}
pub fn set_debug_name(&self, name: impl Into<String>) -> ContainerResult<()> {
self.inner
.write()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire write lock".to_string(),
})?
.set_debug_name(name);
Ok(())
}
pub fn debug_name(&self) -> ContainerResult<Option<String>> {
Ok(self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?
.debug_name()
.map(|s| s.to_string()))
}
pub fn union(&self, other: &ThreadSafeBag) -> ContainerResult<ThreadSafeBag> {
let self_bag = self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
let other_bag = other
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
let result = self_bag.union(&other_bag);
Ok(ThreadSafeBag {
inner: Arc::new(RwLock::new(result)),
})
}
pub fn intersection(&self, other: &ThreadSafeBag) -> ContainerResult<ThreadSafeBag> {
let self_bag = self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
let other_bag = other
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
let result = self_bag.intersection(&other_bag);
Ok(ThreadSafeBag {
inner: Arc::new(RwLock::new(result)),
})
}
pub fn difference(&self, other: &ThreadSafeBag) -> ContainerResult<ThreadSafeBag> {
let self_bag = self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
let other_bag = other
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
let result = self_bag.difference(&other_bag);
Ok(ThreadSafeBag {
inner: Arc::new(RwLock::new(result)),
})
}
pub fn product(&self, other: &ThreadSafeBag) -> ContainerResult<ThreadSafeBag> {
let self_bag = self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
let other_bag = other
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
let result = self_bag.product(&other_bag);
Ok(ThreadSafeBag {
inner: Arc::new(RwLock::new(result)),
})
}
pub fn is_subbag(&self, other: &ThreadSafeBag) -> ContainerResult<bool> {
let self_bag = self
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
let other_bag = other
.inner
.read()
.map_err(|_| ContainerError::InvalidComparator {
message: "Failed to acquire read lock".to_string(),
})?;
Ok(self_bag.is_subbag(&other_bag))
}
}
impl Default for ThreadSafeBag {
fn default() -> Self {
Self::new()
}
}
impl std::iter::FromIterator<Value> for Bag {
fn from_iter<I: IntoIterator<Item = Value>>(iter: I) -> Self {
let mut bag = Self::new();
for value in iter {
bag.adjoin(value);
}
bag
}
}
impl std::iter::FromIterator<Value> for ThreadSafeBag {
fn from_iter<I: IntoIterator<Item = Value>>(iter: I) -> Self {
Self {
inner: Arc::new(RwLock::new(iter.into_iter().collect())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bag_new() {
let bag = Bag::new();
assert!(bag.is_empty());
assert_eq!(bag.total_size(), 0);
assert_eq!(bag.unique_size(), 0);
}
#[test]
fn test_bag_adjoin() {
let mut bag = Bag::new();
let v1 = Value::number(42.0);
let v2 = Value::string("hello");
assert_eq!(bag.adjoin(v1.clone()), 1);
assert_eq!(bag.total_size(), 1);
assert_eq!(bag.unique_size(), 1);
assert!(bag.contains(&v1));
assert_eq!(bag.element_count(&v1), 1);
assert_eq!(bag.adjoin(v1.clone()), 2);
assert_eq!(bag.total_size(), 2);
assert_eq!(bag.unique_size(), 1);
assert_eq!(bag.element_count(&v1), 2);
assert_eq!(bag.adjoin(v2.clone()), 1);
assert_eq!(bag.total_size(), 3);
assert_eq!(bag.unique_size(), 2);
assert!(bag.contains(&v2));
}
#[test]
fn test_bag_delete() {
let mut bag = Bag::new();
let v1 = Value::number(42.0);
bag.adjoin(v1.clone());
bag.adjoin(v1.clone());
bag.adjoin(v1.clone());
assert_eq!(bag.element_count(&v1), 3);
assert!(bag.delete(&v1));
assert_eq!(bag.element_count(&v1), 2);
assert_eq!(bag.total_size(), 2);
assert!(bag.delete(&v1));
assert_eq!(bag.element_count(&v1), 1);
assert!(bag.delete(&v1));
assert_eq!(bag.element_count(&v1), 0);
assert!(!bag.contains(&v1));
assert_eq!(bag.total_size(), 0);
assert!(!bag.delete(&v1));
}
#[test]
fn test_bag_increment_decrement() {
let mut bag = Bag::new();
let v1 = Value::number(42.0);
assert_eq!(bag.increment(&v1), 1);
assert_eq!(bag.increment(&v1), 2);
assert_eq!(bag.element_count(&v1), 2);
assert_eq!(bag.decrement(&v1), 1);
assert_eq!(bag.element_count(&v1), 1);
assert_eq!(bag.decrement(&v1), 0);
assert!(!bag.contains(&v1));
assert_eq!(bag.decrement(&v1), 0);
}
#[test]
fn test_bag_operations() {
let mut bag1 = Bag::new();
let mut bag2 = Bag::new();
let v1 = Value::number(1.0);
let v2 = Value::number(2.0);
let v3 = Value::number(3.0);
bag1.adjoin_count(v1.clone(), 2);
bag1.adjoin_count(v2.clone(), 3);
bag2.adjoin_count(v2.clone(), 1);
bag2.adjoin_count(v3.clone(), 2);
let union = bag1.union(&bag2);
assert_eq!(union.element_count(&v1), 2);
assert_eq!(union.element_count(&v2), 4); assert_eq!(union.element_count(&v3), 2);
assert_eq!(union.total_size(), 8);
let intersection = bag1.intersection(&bag2);
assert_eq!(intersection.element_count(&v1), 0); assert_eq!(intersection.element_count(&v2), 1); assert_eq!(intersection.element_count(&v3), 0); assert_eq!(intersection.total_size(), 1);
let diff = bag1.difference(&bag2);
assert_eq!(diff.element_count(&v1), 2); assert_eq!(diff.element_count(&v2), 2); assert_eq!(diff.element_count(&v3), 0); assert_eq!(diff.total_size(), 4);
}
#[test]
fn test_bag_product() {
let mut bag1 = Bag::new();
let mut bag2 = Bag::new();
let v1 = Value::number(1.0);
let v2 = Value::number(2.0);
bag1.adjoin_count(v1.clone(), 3);
bag1.adjoin_count(v2.clone(), 2);
bag2.adjoin_count(v1.clone(), 2);
bag2.adjoin_count(v2.clone(), 4);
let product = bag1.product(&bag2);
assert_eq!(product.element_count(&v1), 6); assert_eq!(product.element_count(&v2), 8); assert_eq!(product.total_size(), 14);
}
#[test]
fn test_bag_subbag() {
let mut bag1 = Bag::new();
let mut bag2 = Bag::new();
let v1 = Value::number(1.0);
let v2 = Value::number(2.0);
bag1.adjoin_count(v1.clone(), 1);
bag1.adjoin_count(v2.clone(), 2);
bag2.adjoin_count(v1.clone(), 2);
bag2.adjoin_count(v2.clone(), 3);
assert!(bag1.is_subbag(&bag2));
assert!(!bag2.is_subbag(&bag1));
}
#[test]
fn test_bag_to_set() {
let mut bag = Bag::new();
let v1 = Value::number(1.0);
let v2 = Value::number(2.0);
bag.adjoin_count(v1.clone(), 3);
bag.adjoin_count(v2.clone(), 2);
let set = bag.to_set();
assert_eq!(set.size(), 2);
assert!(set.contains(&v1));
assert!(set.contains(&v2));
}
#[test]
fn test_thread_safe_bag() {
let bag = ThreadSafeBag::new();
let v1 = Value::number(42.0);
assert_eq!(bag.adjoin(v1.clone()).unwrap(), 1);
assert_eq!(bag.total_size().unwrap(), 1);
assert!(bag.contains(&v1).unwrap());
assert_eq!(bag.increment(&v1).unwrap(), 2);
assert_eq!(bag.element_count(&v1).unwrap(), 2);
assert!(bag.delete(&v1).unwrap());
assert_eq!(bag.element_count(&v1).unwrap(), 1);
}
}