pingora_limits/
estimator.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The estimator module contains a Count-Min Sketch type to help estimate the frequency of an item.
16
17use crate::hash;
18use crate::RandomState;
19use std::hash::Hash;
20use std::sync::atomic::{AtomicIsize, Ordering};
21
22/// An implementation of a lock-free count–min sketch estimator. See the [wikipedia] page for more
23/// information.
24///
25/// [wikipedia]: https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch
26pub struct Estimator {
27    estimator: Box<[(Box<[AtomicIsize]>, RandomState)]>,
28}
29
30impl Estimator {
31    /// Create a new `Estimator` with the given amount of hashes and columns (slots).
32    pub fn new(hashes: usize, slots: usize) -> Self {
33        Self {
34            estimator: (0..hashes)
35                .map(|_| (0..slots).map(|_| AtomicIsize::new(0)).collect::<Vec<_>>())
36                .map(|slot| (slot.into_boxed_slice(), RandomState::new()))
37                .collect::<Vec<_>>()
38                .into_boxed_slice(),
39        }
40    }
41
42    /// Increment `key` by the value given. Return the new estimated value as a result.
43    /// Note: overflow can happen. When some of the internal counters overflow, a negative number
44    /// will be returned. It is up to the caller to catch and handle this case.
45    pub fn incr<T: Hash>(&self, key: T, value: isize) -> isize {
46        self.estimator
47            .iter()
48            .fold(isize::MAX, |min, (slot, hasher)| {
49                let hash = hash(&key, hasher) as usize;
50                let counter = &slot[hash % slot.len()];
51                // Overflow is allowed for simplicity
52                let current = counter.fetch_add(value, Ordering::Relaxed);
53                std::cmp::min(min, current + value)
54            })
55    }
56
57    /// Decrement `key` by the value given.
58    pub fn decr<T: Hash>(&self, key: T, value: isize) {
59        for (slot, hasher) in self.estimator.iter() {
60            let hash = hash(&key, hasher) as usize;
61            let counter = &slot[hash % slot.len()];
62            counter.fetch_sub(value, Ordering::Relaxed);
63        }
64    }
65
66    /// Get the estimated frequency of `key`.
67    pub fn get<T: Hash>(&self, key: T) -> isize {
68        self.estimator
69            .iter()
70            .fold(isize::MAX, |min, (slot, hasher)| {
71                let hash = hash(&key, hasher) as usize;
72                let counter = &slot[hash % slot.len()];
73                let current = counter.load(Ordering::Relaxed);
74                std::cmp::min(min, current)
75            })
76    }
77
78    /// Reset all values inside this `Estimator`.
79    pub fn reset(&self) {
80        self.estimator.iter().for_each(|(slot, _)| {
81            slot.iter()
82                .for_each(|counter| counter.store(0, Ordering::Relaxed))
83        });
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn incr() {
93        let est = Estimator::new(8, 8);
94        let v = est.incr("a", 1);
95        assert_eq!(v, 1);
96        let v = est.incr("b", 1);
97        assert_eq!(v, 1);
98        let v = est.incr("a", 2);
99        assert_eq!(v, 3);
100        let v = est.incr("b", 2);
101        assert_eq!(v, 3);
102    }
103
104    #[test]
105    fn desc() {
106        let est = Estimator::new(8, 8);
107        est.incr("a", 3);
108        est.incr("b", 3);
109        est.decr("a", 1);
110        est.decr("b", 1);
111        assert_eq!(est.get("a"), 2);
112        assert_eq!(est.get("b"), 2);
113    }
114
115    #[test]
116    fn get() {
117        let est = Estimator::new(8, 8);
118        est.incr("a", 1);
119        est.incr("a", 2);
120        est.incr("b", 1);
121        est.incr("b", 2);
122        assert_eq!(est.get("a"), 3);
123        assert_eq!(est.get("b"), 3);
124    }
125
126    #[test]
127    fn reset() {
128        let est = Estimator::new(8, 8);
129        est.incr("a", 1);
130        est.incr("a", 2);
131        est.incr("b", 1);
132        est.incr("b", 2);
133        est.decr("b", 1);
134        est.reset();
135        assert_eq!(est.get("a"), 0);
136        assert_eq!(est.get("b"), 0);
137    }
138}