nodedb_types/approx/
spacesaving.rs1use std::collections::HashMap;
6
7#[derive(Debug, serde::Serialize, serde::Deserialize)]
13pub struct SpaceSaving {
14 items: HashMap<u64, (u64, u64)>,
15 max_items: usize,
16}
17
18impl SpaceSaving {
19 pub fn new(k: usize) -> Self {
20 Self {
21 items: HashMap::with_capacity(k + 1),
22 max_items: k.max(1),
23 }
24 }
25
26 pub fn add(&mut self, item: u64) {
27 if let Some(entry) = self.items.get_mut(&item) {
28 entry.0 += 1;
29 return;
30 }
31
32 if self.items.len() < self.max_items {
33 self.items.insert(item, (1, 0));
34 } else {
35 let Some((&min_key, &(min_count, _))) =
36 self.items.iter().min_by_key(|(_, (count, _))| *count)
37 else {
38 return;
39 };
40 self.items.remove(&min_key);
41 self.items.insert(item, (min_count + 1, min_count));
42 }
43 }
44
45 pub fn add_batch(&mut self, items: &[u64]) {
46 for &item in items {
47 self.add(item);
48 }
49 }
50
51 pub fn top_k(&self) -> Vec<(u64, u64, u64)> {
55 let mut result: Vec<(u64, u64, u64)> = self
56 .items
57 .iter()
58 .map(|(&item, &(count, error))| (item, count, error))
59 .collect();
60 result.sort_by_key(|item| std::cmp::Reverse(item.1));
61 result
62 }
63
64 pub fn merge(&mut self, other: &SpaceSaving) {
65 for (&item, &(count, error)) in &other.items {
66 let entry = self.items.entry(item).or_insert((0, 0));
67 entry.0 += count;
68 entry.1 += error;
69 }
70
71 while self.items.len() > self.max_items {
72 let Some((&min_key, _)) = self.items.iter().min_by_key(|(_, (count, _))| *count) else {
73 break;
74 };
75 self.items.remove(&min_key);
76 }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn topk_basic() {
86 let mut ss = SpaceSaving::new(3);
87 for _ in 0..100 {
88 ss.add(1);
89 }
90 for _ in 0..50 {
91 ss.add(2);
92 }
93 for _ in 0..30 {
94 ss.add(3);
95 }
96 for _ in 0..10 {
97 ss.add(4);
98 }
99 let top = ss.top_k();
100 assert_eq!(top[0].0, 1);
101 assert_eq!(top[0].1, 100);
102 }
103
104 #[test]
105 fn topk_merge() {
106 let mut a = SpaceSaving::new(5);
107 let mut b = SpaceSaving::new(5);
108 for _ in 0..100 {
109 a.add(1);
110 }
111 for _ in 0..80 {
112 b.add(1);
113 }
114 for _ in 0..50 {
115 b.add(2);
116 }
117 a.merge(&b);
118 let top = a.top_k();
119 assert_eq!(top[0].0, 1);
120 assert_eq!(top[0].1, 180);
121 }
122
123 #[test]
124 fn topk_eviction() {
125 let mut ss = SpaceSaving::new(3);
126 for i in 0..10u64 {
127 for _ in 0..(10 - i) {
128 ss.add(i);
129 }
130 }
131 let top = ss.top_k();
132 assert_eq!(top.len(), 3);
133 assert!(top[0].1 >= top[1].1);
134 assert!(top[1].1 >= top[2].1);
135 }
136
137 #[test]
138 fn spacesaving_serde_roundtrip_merge_semantics() {
139 let mut a = SpaceSaving::new(5);
140 let mut b = SpaceSaving::new(5);
141 for _ in 0..100u64 {
142 a.add(1);
143 }
144 for _ in 0..80u64 {
145 b.add(1);
146 }
147 for _ in 0..50u64 {
148 b.add(2);
149 }
150
151 let bytes = serde_json::to_vec(&a).expect("serialize SpaceSaving");
152 let mut a_prime: SpaceSaving =
153 serde_json::from_slice(&bytes).expect("deserialize SpaceSaving");
154
155 a_prime.merge(&b);
156 a.merge(&b);
157
158 let top_orig = a.top_k();
160 let top_rt = a_prime.top_k();
161 assert_eq!(
162 top_orig[0].0, top_rt[0].0,
163 "top item mismatch after roundtrip"
164 );
165 assert_eq!(
166 top_orig[0].1, top_rt[0].1,
167 "top count mismatch after roundtrip"
168 );
169 }
170}