reddb_server/storage/index/
heavy_hitters.rs1use std::cmp::Reverse;
28use std::collections::BinaryHeap;
29
30use super::{IndexBase, IndexKind, IndexStats};
31use crate::storage::primitives::count_min_sketch::CountMinSketch;
32
33const DEFAULT_K: usize = 16;
35
36pub struct HeavyHitters {
38 name: String,
39 k: usize,
40 cms: CountMinSketch,
41 top: BinaryHeap<Reverse<(u64, Vec<u8>)>>,
44 total_observed: u64,
46}
47
48impl HeavyHitters {
49 pub fn new(name: impl Into<String>) -> Self {
52 Self::with_params(name, DEFAULT_K, 1000, 5)
53 }
54
55 pub fn with_params(
58 name: impl Into<String>,
59 k: usize,
60 cms_width: usize,
61 cms_depth: usize,
62 ) -> Self {
63 Self {
64 name: name.into(),
65 k: k.max(1),
66 cms: CountMinSketch::new(cms_width, cms_depth),
67 top: BinaryHeap::new(),
68 total_observed: 0,
69 }
70 }
71
72 pub fn observe(&mut self, key: &[u8]) {
75 self.observe_n(key, 1);
76 }
77
78 pub fn observe_n(&mut self, key: &[u8], count: u64) {
80 if count == 0 {
81 return;
82 }
83 self.cms.add(key, count);
84 self.total_observed = self.total_observed.saturating_add(count);
85
86 let estimate = self.cms.estimate(key);
87
88 let mut kept: Vec<(u64, Vec<u8>)> = self
92 .top
93 .drain()
94 .map(|Reverse(pair)| pair)
95 .filter(|(_, k)| k != key)
96 .collect();
97 kept.push((estimate, key.to_vec()));
98 kept.sort_by_key(|b| std::cmp::Reverse(b.0));
99 kept.truncate(self.k);
100 self.top = kept.into_iter().map(Reverse).collect();
101 }
102
103 pub fn top_k(&self) -> Vec<(Vec<u8>, u64)> {
105 let mut out: Vec<(u64, Vec<u8>)> = self
106 .top
107 .iter()
108 .map(|Reverse((c, k))| (*c, k.clone()))
109 .collect();
110 out.sort_by_key(|b| std::cmp::Reverse(b.0));
111 out.into_iter().map(|(c, k)| (k, c)).collect()
112 }
113
114 pub fn estimate(&self, key: &[u8]) -> u64 {
116 self.cms.estimate(key)
117 }
118
119 pub fn total_observed(&self) -> u64 {
122 self.total_observed
123 }
124
125 pub fn relative_frequency(&self, key: &[u8]) -> f64 {
128 if self.total_observed == 0 {
129 return 0.0;
130 }
131 self.estimate(key) as f64 / self.total_observed as f64
132 }
133
134 pub fn k(&self) -> usize {
136 self.k
137 }
138
139 pub fn clear(&mut self) {
141 self.cms.clear();
142 self.top.clear();
143 self.total_observed = 0;
144 }
145}
146
147impl IndexBase for HeavyHitters {
148 fn name(&self) -> &str {
149 &self.name
150 }
151
152 fn kind(&self) -> IndexKind {
153 IndexKind::HeavyHitters
154 }
155
156 fn stats(&self) -> IndexStats {
157 IndexStats {
158 entries: self.total_observed as usize,
159 distinct_keys: self.top.len(),
162 approx_bytes: 0,
163 kind: IndexKind::HeavyHitters,
164 has_bloom: false,
165 index_correlation: 0.0,
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn observes_and_tracks_top_k() {
176 let mut hh = HeavyHitters::with_params("test", 3, 256, 4);
177 for _ in 0..100 {
178 hh.observe(b"alpha");
179 }
180 for _ in 0..50 {
181 hh.observe(b"beta");
182 }
183 for _ in 0..10 {
184 hh.observe(b"charlie");
185 }
186 for _ in 0..1 {
187 hh.observe(b"delta");
188 }
189
190 let top = hh.top_k();
191 assert_eq!(top.len(), 3);
192 assert_eq!(top[0].0, b"alpha".to_vec());
193 assert!(top[0].1 >= 100);
194 assert_eq!(top[1].0, b"beta".to_vec());
195 assert!(top[1].1 >= 50);
196 assert_eq!(top[2].0, b"charlie".to_vec());
197 }
198
199 #[test]
200 fn estimate_never_underestimates() {
201 let mut hh = HeavyHitters::with_params("test", 8, 1024, 4);
202 for i in 0..500u32 {
203 hh.observe(&i.to_be_bytes());
204 }
205 for i in 0..500u32 {
206 assert!(hh.estimate(&i.to_be_bytes()) >= 1);
207 }
208 }
209
210 #[test]
211 fn relative_frequency_scales_with_total() {
212 let mut hh = HeavyHitters::new("t");
213 for _ in 0..400 {
214 hh.observe(b"hot");
215 }
216 for _ in 0..100 {
217 hh.observe(b"cold");
218 }
219 let hot = hh.relative_frequency(b"hot");
220 let cold = hh.relative_frequency(b"cold");
221 assert!(hot > cold);
223 assert!(hot >= 0.75);
224 }
225
226 #[test]
227 fn skewed_distribution_surfaces_hot_keys() {
228 let mut hh = HeavyHitters::with_params("t", 5, 4096, 5);
229 for _ in 0..1000 {
231 hh.observe(b"hotA");
232 }
233 for _ in 0..800 {
234 hh.observe(b"hotB");
235 }
236 for _ in 0..600 {
237 hh.observe(b"hotC");
238 }
239 for i in 0..1000u32 {
240 hh.observe(&i.to_be_bytes());
241 }
242 let top = hh.top_k();
243 let top_keys: Vec<&[u8]> = top.iter().map(|(k, _)| k.as_slice()).collect();
244 assert!(top_keys.contains(&b"hotA".as_ref()));
245 assert!(top_keys.contains(&b"hotB".as_ref()));
246 assert!(top_keys.contains(&b"hotC".as_ref()));
247 }
248
249 #[test]
250 fn observe_n_is_equivalent_to_looped_observe() {
251 let mut a = HeavyHitters::with_params("a", 4, 512, 4);
252 let mut b = HeavyHitters::with_params("b", 4, 512, 4);
253 a.observe_n(b"bulk", 1000);
254 for _ in 0..1000 {
255 b.observe(b"bulk");
256 }
257 assert_eq!(a.estimate(b"bulk"), b.estimate(b"bulk"));
258 assert_eq!(a.total_observed(), b.total_observed());
259 }
260
261 #[test]
262 fn clear_resets_state() {
263 let mut hh = HeavyHitters::new("t");
264 hh.observe(b"x");
265 hh.clear();
266 assert_eq!(hh.total_observed(), 0);
267 assert!(hh.top_k().is_empty());
268 assert_eq!(hh.estimate(b"x"), 0);
269 }
270
271 #[test]
272 fn stats_surface_totals_and_kind() {
273 let mut hh = HeavyHitters::with_params("t", 4, 256, 3);
274 for i in 0..50u32 {
275 hh.observe(&i.to_be_bytes());
276 }
277 let s = hh.stats();
278 assert_eq!(s.entries, 50);
279 assert_eq!(s.kind, IndexKind::HeavyHitters);
280 assert!(s.distinct_keys <= 4);
282 }
283
284 #[test]
285 fn zero_count_observation_is_noop() {
286 let mut hh = HeavyHitters::new("t");
287 hh.observe_n(b"ghost", 0);
288 assert_eq!(hh.total_observed(), 0);
289 assert!(hh.top_k().is_empty());
290 }
291}