Skip to main content

nodedb_types/approx/
spacesaving.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! SpaceSaving — approximate top-K heavy hitters (bounded memory).
4
5use std::collections::HashMap;
6
7/// Space-saving algorithm for approximate top-K heavy hitters.
8///
9/// Tracks the K most frequent items with bounded memory. Items not in
10/// the top K are approximated — their counts may be over-estimated by
11/// at most the minimum count in the structure.
12#[derive(Debug, serde::Serialize, serde::Deserialize)]
13pub struct SpaceSaving {
14    items: HashMap<u64, (u64, u64)>,
15    max_items: usize,
16}
17
18impl SpaceSaving {
19    pub fn new(k: usize) -> Self {
20        Self {
21            items: HashMap::with_capacity(k + 1),
22            max_items: k.max(1),
23        }
24    }
25
26    pub fn add(&mut self, item: u64) {
27        if let Some(entry) = self.items.get_mut(&item) {
28            entry.0 += 1;
29            return;
30        }
31
32        if self.items.len() < self.max_items {
33            self.items.insert(item, (1, 0));
34        } else {
35            let Some((&min_key, &(min_count, _))) =
36                self.items.iter().min_by_key(|(_, (count, _))| *count)
37            else {
38                return;
39            };
40            self.items.remove(&min_key);
41            self.items.insert(item, (min_count + 1, min_count));
42        }
43    }
44
45    pub fn add_batch(&mut self, items: &[u64]) {
46        for &item in items {
47            self.add(item);
48        }
49    }
50
51    /// Get the top-K items sorted by count (descending).
52    ///
53    /// Returns `(item, count, error_bound)` tuples.
54    pub fn top_k(&self) -> Vec<(u64, u64, u64)> {
55        let mut result: Vec<(u64, u64, u64)> = self
56            .items
57            .iter()
58            .map(|(&item, &(count, error))| (item, count, error))
59            .collect();
60        result.sort_by_key(|item| std::cmp::Reverse(item.1));
61        result
62    }
63
64    pub fn merge(&mut self, other: &SpaceSaving) {
65        for (&item, &(count, error)) in &other.items {
66            let entry = self.items.entry(item).or_insert((0, 0));
67            entry.0 += count;
68            entry.1 += error;
69        }
70
71        while self.items.len() > self.max_items {
72            let Some((&min_key, _)) = self.items.iter().min_by_key(|(_, (count, _))| *count) else {
73                break;
74            };
75            self.items.remove(&min_key);
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn topk_basic() {
86        let mut ss = SpaceSaving::new(3);
87        for _ in 0..100 {
88            ss.add(1);
89        }
90        for _ in 0..50 {
91            ss.add(2);
92        }
93        for _ in 0..30 {
94            ss.add(3);
95        }
96        for _ in 0..10 {
97            ss.add(4);
98        }
99        let top = ss.top_k();
100        assert_eq!(top[0].0, 1);
101        assert_eq!(top[0].1, 100);
102    }
103
104    #[test]
105    fn topk_merge() {
106        let mut a = SpaceSaving::new(5);
107        let mut b = SpaceSaving::new(5);
108        for _ in 0..100 {
109            a.add(1);
110        }
111        for _ in 0..80 {
112            b.add(1);
113        }
114        for _ in 0..50 {
115            b.add(2);
116        }
117        a.merge(&b);
118        let top = a.top_k();
119        assert_eq!(top[0].0, 1);
120        assert_eq!(top[0].1, 180);
121    }
122
123    #[test]
124    fn topk_eviction() {
125        let mut ss = SpaceSaving::new(3);
126        for i in 0..10u64 {
127            for _ in 0..(10 - i) {
128                ss.add(i);
129            }
130        }
131        let top = ss.top_k();
132        assert_eq!(top.len(), 3);
133        assert!(top[0].1 >= top[1].1);
134        assert!(top[1].1 >= top[2].1);
135    }
136
137    #[test]
138    fn spacesaving_serde_roundtrip_merge_semantics() {
139        let mut a = SpaceSaving::new(5);
140        let mut b = SpaceSaving::new(5);
141        for _ in 0..100u64 {
142            a.add(1);
143        }
144        for _ in 0..80u64 {
145            b.add(1);
146        }
147        for _ in 0..50u64 {
148            b.add(2);
149        }
150
151        let bytes = serde_json::to_vec(&a).expect("serialize SpaceSaving");
152        let mut a_prime: SpaceSaving =
153            serde_json::from_slice(&bytes).expect("deserialize SpaceSaving");
154
155        a_prime.merge(&b);
156        a.merge(&b);
157
158        // Item 1 should have count 180 in both.
159        let top_orig = a.top_k();
160        let top_rt = a_prime.top_k();
161        assert_eq!(
162            top_orig[0].0, top_rt[0].0,
163            "top item mismatch after roundtrip"
164        );
165        assert_eq!(
166            top_orig[0].1, top_rt[0].1,
167            "top count mismatch after roundtrip"
168        );
169    }
170}