1use alloc::vec::Vec;
43use core::hash::Hash;
44use std::collections::HashMap;
45
46use crate::error::{RcfError, RcfResult};
47
48pub const DEFAULT_CAPACITY: usize = 128;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55pub struct HeavyHitterEntry {
56 pub estimate: u64,
59 pub error: u64,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
69#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70pub struct HeavyHitter<K> {
71 pub rank: u32,
73 pub key: K,
75 pub estimate: u64,
77 pub error: u64,
80}
81
82#[derive(Debug, Clone)]
102#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
103pub struct SpaceSaving<K>
104where
105 K: Hash + Eq + Clone,
106{
107 counts: HashMap<K, HeavyHitterEntry>,
109 capacity: usize,
111 total: u64,
114}
115
116impl<K> SpaceSaving<K>
117where
118 K: Hash + Eq + Clone,
119{
120 pub fn new(capacity: usize) -> RcfResult<Self> {
126 if capacity == 0 {
127 return Err(RcfError::InvalidConfig(
128 alloc::string::ToString::to_string("SpaceSaving: capacity must be > 0").into(),
129 ));
130 }
131 Ok(Self {
132 counts: HashMap::with_capacity(capacity),
133 capacity,
134 total: 0,
135 })
136 }
137
138 pub fn with_default_capacity() -> RcfResult<Self> {
145 Self::new(DEFAULT_CAPACITY)
146 }
147
148 #[must_use]
150 pub fn capacity(&self) -> usize {
151 self.capacity
152 }
153
154 #[must_use]
156 pub fn len(&self) -> usize {
157 self.counts.len()
158 }
159
160 #[must_use]
162 pub fn is_empty(&self) -> bool {
163 self.counts.is_empty()
164 }
165
166 #[must_use]
168 pub fn total(&self) -> u64 {
169 self.total
170 }
171
172 #[must_use]
175 pub fn error_bound(&self) -> u64 {
176 if self.capacity == 0 {
177 return 0;
178 }
179 self.total / (self.capacity as u64)
180 }
181
182 #[inline]
184 pub fn observe(&mut self, key: K) {
185 self.observe_weighted(key, 1);
186 }
187
188 #[inline]
192 pub fn observe_weighted(&mut self, key: K, weight: u64) {
193 if weight == 0 {
194 return;
195 }
196 self.total = self.total.saturating_add(weight);
197
198 if let Some(entry) = self.counts.get_mut(&key) {
199 entry.estimate = entry.estimate.saturating_add(weight);
200 return;
201 }
202
203 if self.counts.len() < self.capacity {
204 self.counts.insert(
205 key,
206 HeavyHitterEntry {
207 estimate: weight,
208 error: 0,
209 },
210 );
211 return;
212 }
213
214 if let Some((min_key, min_entry)) = self.find_min() {
217 self.counts.remove(&min_key);
218 let boosted = HeavyHitterEntry {
219 estimate: min_entry.estimate.saturating_add(weight),
220 error: min_entry.estimate,
221 };
222 self.counts.insert(key, boosted);
223 }
224 }
225
226 #[must_use]
230 pub fn estimate(&self, key: &K) -> Option<HeavyHitterEntry> {
231 self.counts.get(key).copied()
232 }
233
234 #[must_use]
237 pub fn top_k(&self, n: usize) -> Vec<HeavyHitter<K>> {
238 let mut entries: Vec<(K, HeavyHitterEntry)> =
239 self.counts.iter().map(|(k, e)| (k.clone(), *e)).collect();
240 entries.sort_by_key(|(_, e)| core::cmp::Reverse(e.estimate));
241 entries.truncate(n);
242 entries
243 .into_iter()
244 .enumerate()
245 .map(|(idx, (k, e))| HeavyHitter {
246 rank: u32::try_from(idx).unwrap_or(u32::MAX),
247 key: k,
248 estimate: e.estimate,
249 error: e.error,
250 })
251 .collect()
252 }
253
254 pub fn iter(&self) -> impl Iterator<Item = (&K, &HeavyHitterEntry)> {
258 self.counts.iter()
259 }
260
261 pub fn reset(&mut self) {
263 self.counts.clear();
264 self.total = 0;
265 }
266
267 fn find_min(&self) -> Option<(K, HeavyHitterEntry)> {
272 self.counts
273 .iter()
274 .min_by_key(|(_, e)| e.estimate)
275 .map(|(k, e)| (k.clone(), *e))
276 }
277}
278
279#[cfg(test)]
280#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn new_rejects_zero_capacity() {
286 assert!(SpaceSaving::<u32>::new(0).is_err());
287 }
288
289 #[test]
290 fn exact_counts_within_capacity() {
291 let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
292 for i in 0..5_u32 {
293 for _ in 0..=u64::from(i) {
294 ss.observe(i);
295 }
296 }
297 let top = ss.top_k(5);
300 assert_eq!(top.len(), 5);
301 for hh in &top {
302 assert_eq!(hh.error, 0);
303 }
304 assert_eq!(top[0].key, 4);
305 assert_eq!(top[0].estimate, 5);
306 }
307
308 #[test]
309 fn heavy_hitter_always_retained() {
310 let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
312 for _ in 0..1_000 {
315 ss.observe(0_u32);
316 }
317 for i in 1..2_001_u32 {
319 ss.observe(i);
320 }
321 let h = ss
322 .top_k(8)
323 .into_iter()
324 .find(|hh| hh.key == 0)
325 .expect("heavy hitter retained");
326 assert!(h.estimate >= 1_000);
328 }
329
330 #[test]
331 fn error_bound_sandwiches_true_count() {
332 let mut ss: SpaceSaving<u32> = SpaceSaving::new(16).unwrap();
333 for i in 0..100_u32 {
334 for _ in 0..10 {
335 ss.observe(i);
336 }
337 }
338 for hh in ss.top_k(16) {
339 assert!(hh.estimate >= hh.error);
341 let lower = hh.estimate - hh.error;
342 assert!(lower <= 10, "lower={lower}");
343 assert!(hh.estimate >= 10 || hh.error > 0);
344 }
345 }
346
347 #[test]
348 fn estimate_returns_none_for_untracked() {
349 let mut ss: SpaceSaving<u32> = SpaceSaving::new(2).unwrap();
350 ss.observe(1);
351 ss.observe(2);
352 for _ in 0..5 {
353 ss.observe(3);
354 }
355 assert!(ss.estimate(&3).is_some());
357 assert!(ss.estimate(&100).is_none());
359 }
360
361 #[test]
362 fn weighted_observe_accumulates() {
363 let mut ss: SpaceSaving<u32> = SpaceSaving::new(4).unwrap();
364 ss.observe_weighted(7, 1_000);
365 ss.observe_weighted(7, 500);
366 let h = ss.estimate(&7).expect("tracked");
367 assert_eq!(h.estimate, 1_500);
368 assert_eq!(ss.total(), 1_500);
369 }
370
371 #[test]
372 fn zero_weight_is_noop() {
373 let mut ss: SpaceSaving<u32> = SpaceSaving::new(4).unwrap();
374 ss.observe_weighted(1, 0);
375 assert!(ss.is_empty());
376 assert_eq!(ss.total(), 0);
377 }
378
379 #[test]
380 fn error_bound_grows_linearly() {
381 let mut ss: SpaceSaving<u32> = SpaceSaving::new(10).unwrap();
382 for i in 0..1_000_u32 {
383 ss.observe(i);
384 }
385 assert_eq!(ss.error_bound(), 100);
387 }
388
389 #[test]
390 fn top_k_ranks_descending() {
391 let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
392 for (key, count) in [(1_u32, 100_u64), (2, 50), (3, 25), (4, 10)] {
393 for _ in 0..count {
394 ss.observe(key);
395 }
396 }
397 let top = ss.top_k(4);
398 assert_eq!(top[0].key, 1);
399 assert_eq!(top[1].key, 2);
400 assert_eq!(top[2].key, 3);
401 assert_eq!(top[3].key, 4);
402 assert_eq!(top[0].rank, 0);
403 assert_eq!(top[3].rank, 3);
404 }
405
406 #[test]
407 fn top_k_clamps_to_len() {
408 let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
409 ss.observe(1);
410 assert_eq!(ss.top_k(10).len(), 1);
411 assert_eq!(ss.top_k(0).len(), 0);
412 }
413
414 #[test]
415 fn reset_clears_everything() {
416 let mut ss: SpaceSaving<u32> = SpaceSaving::new(4).unwrap();
417 for i in 0..100_u32 {
418 ss.observe(i);
419 }
420 ss.reset();
421 assert!(ss.is_empty());
422 assert_eq!(ss.total(), 0);
423 assert_eq!(ss.top_k(4).len(), 0);
424 }
425
426 #[test]
427 fn byte_key_roundtrip() {
428 let mut ss: SpaceSaving<[u8; 16]> = SpaceSaving::new(4).unwrap();
429 let k = [0x01_u8; 16];
430 for _ in 0..10 {
431 ss.observe(k);
432 }
433 assert_eq!(ss.estimate(&k).unwrap().estimate, 10);
434 }
435
436 #[cfg(all(feature = "serde", feature = "postcard"))]
437 #[test]
438 fn postcard_roundtrip_preserves_top_k() {
439 let mut ss: SpaceSaving<u32> = SpaceSaving::new(8).unwrap();
440 for i in 0..20_u32 {
441 for _ in 0..=u64::from(i) {
442 ss.observe(i);
443 }
444 }
445 let bytes = postcard::to_allocvec(&ss).expect("serde ok");
446 let back: SpaceSaving<u32> = postcard::from_bytes(&bytes).expect("serde ok");
447 let a = ss.top_k(8);
448 let b = back.top_k(8);
449 for (x, y) in a.iter().zip(b.iter()) {
450 assert_eq!(x.key, y.key);
451 assert_eq!(x.estimate, y.estimate);
452 assert_eq!(x.error, y.error);
453 }
454 }
455}