menmos_std/sync/
sharded_mutex.rs1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3
4use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
5
6fn optimal_bucket_count(concurrent_calls: usize, mut collision_probability: f64) -> usize {
8 if collision_probability == 0.0 {
9 collision_probability = f64::EPSILON; }
11
12 let bucket_count_float =
13 ((concurrent_calls as f64 - 1.0) * concurrent_calls as f64) / (2.0 * collision_probability);
14
15 (bucket_count_float.ceil() as usize).max(1)
16}
17
18pub struct ShardedMutex {
19 buf: Vec<RwLock<()>>,
20}
21
22impl ShardedMutex {
23 pub fn new(concurrent_calls: usize, collision_probability: f64) -> Self {
24 let bucket_count = optimal_bucket_count(concurrent_calls, collision_probability);
25
26 let mut buf = Vec::with_capacity(bucket_count);
27 for _ in 0..bucket_count {
28 buf.push(RwLock::new(()));
29 }
30
31 Self { buf }
32 }
33
34 fn get_lock_id<H: Hash>(&self, key: &H) -> usize {
35 let mut hasher = DefaultHasher::new();
37 key.hash(&mut hasher);
38 let hash_value = hasher.finish();
39
40 let mod_val = hash_value % (self.buf.len() as u64);
41
42 debug_assert!(
43 mod_val <= (self.buf.len() as u64),
44 "mod of length should give a value withing length bounds"
45 );
46
47 mod_val as usize
48 }
49
50 pub async fn read<'a, H: Hash>(&'a self, key: &H) -> RwLockReadGuard<'a, ()> {
51 let lock_id = self.get_lock_id(key);
52 self.buf[lock_id].read().await
53 }
54
55 pub async fn write<'a, H: Hash>(&'a self, key: &H) -> RwLockWriteGuard<'a, ()> {
56 let lock_id = self.get_lock_id(key);
57 self.buf[lock_id].write().await
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::optimal_bucket_count;
64
65 #[test]
66 fn optimal_bucket_count_basic() {
67 let actual = optimal_bucket_count(2, 0.5);
68 assert_eq!(actual, 2);
69 }
70}