use crate::merge::Merge;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PNCounter {
p: HashMap<String, u64>,
n: HashMap<String, u64>,
}
impl PNCounter {
pub fn new() -> Self {
Self {
p: HashMap::new(),
n: HashMap::new(),
}
}
pub fn increment(&mut self, node: &str, amount: u64) {
*self.p.entry(node.to_string()).or_insert(0) += amount;
}
pub fn decrement(&mut self, node: &str, amount: u64) {
*self.n.entry(node.to_string()).or_insert(0) += amount;
}
pub fn value(&self) -> i64 {
let pos: u64 = self.p.values().sum();
let neg: u64 = self.n.values().sum();
pos as i64 - neg as i64
}
pub fn total_positive(&self) -> u64 {
self.p.values().sum()
}
pub fn total_negative(&self) -> u64 {
self.n.values().sum()
}
pub fn node_value(&self, node: &str) -> i64 {
let pos = self.p.get(node).copied().unwrap_or(0);
let neg = self.n.get(node).copied().unwrap_or(0);
pos as i64 - neg as i64
}
pub fn node_count(&self) -> usize {
let mut nodes: std::collections::HashSet<&str> = std::collections::HashSet::new();
for k in self.p.keys() { nodes.insert(k); }
for k in self.n.keys() { nodes.insert(k); }
nodes.len()
}
}
impl Merge for PNCounter {
fn merge(&mut self, other: &Self) {
for (node, count) in &other.p {
let entry = self.p.entry(node.clone()).or_insert(0);
*entry = (*entry).max(*count);
}
for (node, count) in &other.n {
let entry = self.n.entry(node.clone()).or_insert(0);
*entry = (*entry).max(*count);
}
}
}
impl PartialEq for PNCounter {
fn eq(&self, other: &Self) -> bool {
self.value() == other.value()
}
}
impl fmt::Display for PNCounter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PNCounter(+{} -{} = {}, {} nodes)",
self.total_positive(), self.total_negative(),
self.value(), self.node_count())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::merge::laws;
#[test]
fn test_increment_decrement() {
let mut c = PNCounter::new();
c.increment("a", 100);
c.decrement("a", 30);
assert_eq!(c.value(), 70);
}
#[test]
fn test_multi_node() {
let mut c = PNCounter::new();
c.increment("a", 100);
c.increment("b", 200);
c.decrement("a", 50);
assert_eq!(c.value(), 250); }
#[test]
fn test_merge_combines_nodes() {
let mut a = PNCounter::new();
a.increment("a", 100);
let mut b = PNCounter::new();
b.increment("b", 200);
b.decrement("b", 50);
let merged = a.merged(&b);
assert_eq!(merged.value(), 250);
assert_eq!(merged.node_count(), 2);
}
#[test]
fn test_merge_commutative() {
let mut a = PNCounter::new();
a.increment("a", 100);
let mut b = PNCounter::new();
b.increment("b", 200);
assert!(laws::check_commutative(&a, &b));
}
#[test]
fn test_merge_associative() {
let mut a = PNCounter::new();
a.increment("a", 1);
let mut b = PNCounter::new();
b.increment("b", 2);
let mut c = PNCounter::new();
c.increment("c", 3);
assert!(laws::check_associative(&a, &b, &c));
}
#[test]
fn test_merge_idempotent() {
let mut a = PNCounter::new();
a.increment("a", 100);
a.decrement("a", 30);
assert!(laws::check_idempotent(&a));
}
#[test]
fn test_negative_value() {
let mut c = PNCounter::new();
c.decrement("a", 100);
c.increment("a", 30);
assert_eq!(c.value(), -70);
}
#[test]
fn test_display() {
let mut c = PNCounter::new();
c.increment("a", 100);
c.decrement("a", 30);
let s = format!("{}", c);
assert!(s.contains("+100"));
assert!(s.contains("-30"));
assert!(s.contains("70"));
}
}