use crate::merge::Merge;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt;
type Tag = (String, u64);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstraintORSet {
elements: HashMap<String, HashSet<Tag>>,
tombstones: HashSet<Tag>,
seq: HashMap<String, u64>,
}
impl ConstraintORSet {
pub fn new() -> Self {
Self {
elements: HashMap::new(),
tombstones: HashSet::new(),
seq: HashMap::new(),
}
}
pub fn add(&mut self, constraint_id: &str, node: &str) {
let seq = self.seq.entry(node.to_string()).or_insert(0);
*seq += 1;
let tag = (node.to_string(), *seq);
self.elements
.entry(constraint_id.to_string())
.or_insert_with(HashSet::new)
.insert(tag);
}
pub fn remove(&mut self, constraint_id: &str) {
if let Some(tags) = self.elements.get(constraint_id) {
for tag in tags {
self.tombstones.insert(tag.clone());
}
self.elements.remove(constraint_id);
}
}
pub fn contains(&self, constraint_id: &str) -> bool {
self.elements.contains_key(constraint_id)
}
pub fn active_constraints(&self) -> Vec<String> {
self.elements.keys().cloned().collect()
}
pub fn len(&self) -> usize {
self.elements.len()
}
pub fn is_empty(&self) -> bool {
self.elements.is_empty()
}
pub fn tombstone_count(&self) -> usize {
self.tombstones.len()
}
pub fn gc_tombstones(&mut self, node: &str, before_seq: u64) {
self.tombstones.retain(|(n, seq)| {
!(n == node && *seq < before_seq)
});
}
}
impl Merge for ConstraintORSet {
fn merge(&mut self, other: &Self) {
for (constraint, tags) in &other.elements {
let entry = self.elements
.entry(constraint.clone())
.or_insert_with(HashSet::new);
entry.extend(tags.iter().cloned());
}
self.tombstones.extend(other.tombstones.iter().cloned());
for (node, seq) in &other.seq {
let entry = self.seq.entry(node.clone()).or_insert(0);
*entry = (*entry).max(*seq);
}
self.elements.retain(|_, tags| {
tags.iter().any(|tag| !self.tombstones.contains(tag))
});
}
}
impl PartialEq for ConstraintORSet {
fn eq(&self, other: &Self) -> bool {
let self_keys: HashSet<_> = self.elements.keys().collect();
let other_keys: HashSet<_> = other.elements.keys().collect();
self_keys == other_keys
}
}
impl fmt::Display for ConstraintORSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ConstraintORSet({} active, {} tombstones)",
self.len(), self.tombstone_count())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::merge::laws;
#[test]
fn test_add_contains() {
let mut s = ConstraintORSet::new();
s.add("bounds_check", "node-a");
assert!(s.contains("bounds_check"));
assert_eq!(s.len(), 1);
}
#[test]
fn test_remove() {
let mut s = ConstraintORSet::new();
s.add("bounds_check", "node-a");
s.remove("bounds_check");
assert!(!s.contains("bounds_check"));
}
#[test]
fn test_add_wins_on_merge() {
let mut a = ConstraintORSet::new();
a.add("bounds_check", "node-a");
let mut b = a.clone();
b.remove("bounds_check");
b.add("bounds_check", "node-b");
let merged = a.merged(&b);
assert!(merged.contains("bounds_check"),
"add-wins: re-added constraint should survive");
}
#[test]
fn test_merge_combines_constraints() {
let mut a = ConstraintORSet::new();
a.add("bounds_check", "node-a");
let mut b = ConstraintORSet::new();
b.add("norm_check", "node-b");
let merged = a.merged(&b);
assert!(merged.contains("bounds_check"));
assert!(merged.contains("norm_check"));
assert_eq!(merged.len(), 2);
}
#[test]
fn test_merge_commutative() {
let mut a = ConstraintORSet::new();
a.add("c1", "node-a");
let mut b = ConstraintORSet::new();
b.add("c2", "node-b");
assert!(laws::check_commutative(&a, &b));
}
#[test]
fn test_merge_associative() {
let mut a = ConstraintORSet::new();
a.add("c1", "a");
let mut b = ConstraintORSet::new();
b.add("c2", "b");
let mut c = ConstraintORSet::new();
c.add("c3", "c");
assert!(laws::check_associative(&a, &b, &c));
}
#[test]
fn test_merge_idempotent() {
let mut a = ConstraintORSet::new();
a.add("c1", "node-a");
a.add("c2", "node-a");
assert!(laws::check_idempotent(&a));
}
#[test]
fn test_gc_tombstones() {
let mut s = ConstraintORSet::new();
s.add("c1", "node-a");
s.remove("c1");
assert_eq!(s.tombstone_count(), 1);
s.gc_tombstones("node-a", 10);
assert_eq!(s.tombstone_count(), 0);
}
}