menmos_std/sync/
sharded_mutex.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3
4use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
5
6// See https://preshing.com/20110504/hash-collision-probabilities/
7fn optimal_bucket_count(concurrent_calls: usize, mut collision_probability: f64) -> usize {
8    if collision_probability == 0.0 {
9        collision_probability = f64::EPSILON; // use the smallest possible float that is greater than zero - this will generate a huge amount of filters.
10    }
11
12    let bucket_count_float =
13        ((concurrent_calls as f64 - 1.0) * concurrent_calls as f64) / (2.0 * collision_probability);
14
15    (bucket_count_float.ceil() as usize).max(1)
16}
17
18pub struct ShardedMutex {
19    buf: Vec<RwLock<()>>,
20}
21
22impl ShardedMutex {
23    pub fn new(concurrent_calls: usize, collision_probability: f64) -> Self {
24        let bucket_count = optimal_bucket_count(concurrent_calls, collision_probability);
25
26        let mut buf = Vec::with_capacity(bucket_count);
27        for _ in 0..bucket_count {
28            buf.push(RwLock::new(()));
29        }
30
31        Self { buf }
32    }
33
34    fn get_lock_id<H: Hash>(&self, key: &H) -> usize {
35        // TODO: Faster hashing algo? Look into what hashmap does.
36        let mut hasher = DefaultHasher::new();
37        key.hash(&mut hasher);
38        let hash_value = hasher.finish();
39
40        let mod_val = hash_value % (self.buf.len() as u64);
41
42        debug_assert!(
43            mod_val <= (self.buf.len() as u64),
44            "mod of length should give a value withing length bounds"
45        );
46
47        mod_val as usize
48    }
49
50    pub async fn read<'a, H: Hash>(&'a self, key: &H) -> RwLockReadGuard<'a, ()> {
51        let lock_id = self.get_lock_id(key);
52        self.buf[lock_id].read().await
53    }
54
55    pub async fn write<'a, H: Hash>(&'a self, key: &H) -> RwLockWriteGuard<'a, ()> {
56        let lock_id = self.get_lock_id(key);
57        self.buf[lock_id].write().await
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::optimal_bucket_count;
64
65    #[test]
66    fn optimal_bucket_count_basic() {
67        let actual = optimal_bucket_count(2, 0.5);
68        assert_eq!(actual, 2);
69    }
70}