use crate::lattice::Lattice;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PNCounter<K: Ord + Clone> {
increments: BTreeMap<K, u64>,
decrements: BTreeMap<K, u64>,
}
impl<K: Ord + Clone> PNCounter<K> {
pub fn new() -> Self {
Self {
increments: BTreeMap::new(),
decrements: BTreeMap::new(),
}
}
pub fn increment(&mut self, replica_id: K, amount: u64) {
let entry = self.increments.entry(replica_id).or_insert(0);
*entry = entry.saturating_add(amount);
}
pub fn decrement(&mut self, replica_id: K, amount: u64) {
let entry = self.decrements.entry(replica_id).or_insert(0);
*entry = entry.saturating_add(amount);
}
pub fn value(&self) -> i64 {
let inc_sum: u64 = self.increments.values().sum();
let dec_sum: u64 = self.decrements.values().sum();
(inc_sum as i64).saturating_sub(dec_sum as i64)
}
pub fn get_increment(&self, replica_id: &K) -> u64 {
self.increments.get(replica_id).copied().unwrap_or(0)
}
pub fn get_decrement(&self, replica_id: &K) -> u64 {
self.decrements.get(replica_id).copied().unwrap_or(0)
}
pub fn increments(&self) -> &BTreeMap<K, u64> {
&self.increments
}
pub fn decrements(&self) -> &BTreeMap<K, u64> {
&self.decrements
}
}
impl<K: Ord + Clone> Default for PNCounter<K> {
fn default() -> Self {
Self::new()
}
}
impl<K: Ord + Clone> Lattice for PNCounter<K> {
fn bottom() -> Self {
Self::new()
}
fn join(&self, other: &Self) -> Self {
let mut increments = self.increments.clone();
let mut decrements = self.decrements.clone();
for (k, v) in &other.increments {
increments
.entry(k.clone())
.and_modify(|e| *e = (*e).max(*v))
.or_insert(*v);
}
for (k, v) in &other.decrements {
decrements
.entry(k.clone())
.and_modify(|e| *e = (*e).max(*v))
.or_insert(*v);
}
Self {
increments,
decrements,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pncounter_basic_operations() {
let mut counter = PNCounter::new();
counter.increment("A", 5);
assert_eq!(counter.value(), 5);
counter.decrement("B", 2);
assert_eq!(counter.value(), 3);
counter.increment("A", 3);
assert_eq!(counter.value(), 6);
}
#[test]
fn test_pncounter_join_idempotent() {
let mut c1 = PNCounter::new();
c1.increment("A", 5);
c1.decrement("B", 2);
let joined = c1.join(&c1);
assert_eq!(joined.value(), c1.value());
assert_eq!(joined.value(), 3);
}
#[test]
fn test_pncounter_join_commutative() {
let mut c1 = PNCounter::new();
c1.increment("A", 5);
let mut c2 = PNCounter::new();
c2.increment("B", 3);
c2.decrement("A", 1);
let joined1 = c1.join(&c2);
let joined2 = c2.join(&c1);
assert_eq!(joined1.value(), joined2.value());
assert_eq!(joined1.get_increment(&"A"), 5);
assert_eq!(joined1.get_increment(&"B"), 3);
assert_eq!(joined1.get_decrement(&"A"), 1);
}
#[test]
fn test_pncounter_join_associative() {
let mut c1 = PNCounter::new();
c1.increment("A", 1);
let mut c2 = PNCounter::new();
c2.increment("B", 2);
let mut c3 = PNCounter::new();
c3.decrement("C", 1);
let left = c1.join(&c2).join(&c3);
let right = c1.join(&c2.join(&c3));
assert_eq!(left.value(), right.value());
}
#[test]
fn test_pncounter_bottom_is_identity() {
let mut counter = PNCounter::new();
counter.increment("A", 5);
counter.decrement("B", 2);
let bottom = PNCounter::bottom();
let joined = counter.join(&bottom);
assert_eq!(joined.value(), counter.value());
}
#[test]
fn test_pncounter_convergence_different_order() {
let mut c1 = PNCounter::new();
c1.increment("X", 10);
c1.decrement("Y", 3);
let mut c2 = PNCounter::new();
c2.increment("Z", 5);
c2.decrement("X", 2);
let mut state1 = PNCounter::bottom();
state1.join_assign(&c1);
state1.join_assign(&c2);
let mut state2 = PNCounter::bottom();
state2.join_assign(&c2);
state2.join_assign(&c1);
assert_eq!(state1.value(), state2.value());
}
#[test]
fn test_pncounter_serialization() {
let mut counter = PNCounter::new();
counter.increment("replica1", 100);
counter.decrement("replica2", 25);
let serialized = serde_json::to_string(&counter).unwrap();
let deserialized: PNCounter<String> = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.value(), counter.value());
assert_eq!(deserialized.get_increment(&"replica1".to_string()), 100);
assert_eq!(deserialized.get_decrement(&"replica2".to_string()), 25);
}
}