memscope_rs/core/
sharded_locks.rs

1//! Sharded lock system for reducing lock contention
2//!
3//! This module provides a sharded locking mechanism that distributes
4//! lock contention across multiple shards, improving concurrent performance.
5
6use parking_lot::{Mutex, RwLock};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::hash::{DefaultHasher, Hash, Hasher};
10
11/// Number of shards to use by default
12const DEFAULT_SHARD_COUNT: usize = 16;
13
14/// Sharded read-write lock for concurrent access
15#[derive(Debug)]
16pub struct ShardedRwLock<K, V>
17where
18    K: Hash + Eq,
19{
20    shards: Vec<RwLock<HashMap<K, V>>>,
21    shard_count: usize,
22}
23
24impl<K, V> ShardedRwLock<K, V>
25where
26    K: Hash + Eq,
27{
28    /// Create a new sharded RwLock with default shard count
29    pub fn new() -> Self {
30        Self::with_shard_count(DEFAULT_SHARD_COUNT)
31    }
32
33    /// Create a new sharded RwLock with specified shard count
34    pub fn with_shard_count(shard_count: usize) -> Self {
35        let mut shards = Vec::with_capacity(shard_count);
36        for _ in 0..shard_count {
37            shards.push(RwLock::new(HashMap::new()));
38        }
39
40        Self {
41            shards,
42            shard_count,
43        }
44    }
45
46    /// Get the shard index for a given key
47    fn get_shard_index<Q>(&self, key: &Q) -> usize
48    where
49        Q: Hash + ?Sized,
50    {
51        let mut hasher = DefaultHasher::new();
52        key.hash(&mut hasher);
53        (hasher.finish() as usize) % self.shard_count
54    }
55
56    /// Insert a key-value pair
57    pub fn insert(&self, key: K, value: V) -> Option<V> {
58        let shard_index = self.get_shard_index(&key);
59        let mut shard = self.shards[shard_index].write();
60        shard.insert(key, value)
61    }
62
63    /// Get a value by key
64    pub fn get<Q>(&self, key: &Q) -> Option<V>
65    where
66        K: std::borrow::Borrow<Q>,
67        Q: Hash + Eq + ?Sized,
68        V: Clone,
69    {
70        let shard_index = self.get_shard_index(key);
71        let shard = self.shards[shard_index].read();
72        shard.get(key).cloned()
73    }
74
75    /// Remove a key-value pair
76    pub fn remove<Q>(&self, key: &Q) -> Option<V>
77    where
78        K: std::borrow::Borrow<Q>,
79        Q: Hash + Eq + ?Sized,
80    {
81        let shard_index = self.get_shard_index(key);
82        let mut shard = self.shards[shard_index].write();
83        shard.remove(key)
84    }
85
86    /// Check if a key exists
87    pub fn contains_key<Q>(&self, key: &Q) -> bool
88    where
89        K: std::borrow::Borrow<Q>,
90        Q: Hash + Eq + ?Sized,
91    {
92        let shard_index = self.get_shard_index(key);
93        let shard = self.shards[shard_index].read();
94        shard.contains_key(key)
95    }
96
97    /// Get the total number of entries across all shards
98    pub fn len(&self) -> usize {
99        self.shards.iter().map(|shard| shard.read().len()).sum()
100    }
101
102    /// Check if the sharded map is empty
103    pub fn is_empty(&self) -> bool {
104        self.shards.iter().all(|shard| shard.read().is_empty())
105    }
106
107    /// Clear all entries from all shards
108    pub fn clear(&self) {
109        for shard in &self.shards {
110            shard.write().clear();
111        }
112    }
113
114    /// Execute a function with read access to a specific shard
115    pub fn with_shard_read<Q, F, R>(&self, key: &Q, f: F) -> R
116    where
117        K: std::borrow::Borrow<Q>,
118        Q: Hash + Eq + ?Sized,
119        F: FnOnce(&HashMap<K, V>) -> R,
120    {
121        let shard_index = self.get_shard_index(key);
122        let shard = self.shards[shard_index].read();
123        f(&*shard)
124    }
125
126    /// Execute a function with write access to a specific shard
127    pub fn with_shard_write<Q, F, R>(&self, key: &Q, f: F) -> R
128    where
129        K: std::borrow::Borrow<Q>,
130        Q: Hash + Eq + ?Sized,
131        F: FnOnce(&mut HashMap<K, V>) -> R,
132    {
133        let shard_index = self.get_shard_index(key);
134        let mut shard = self.shards[shard_index].write();
135        f(&mut *shard)
136    }
137
138    /// Get statistics about shard distribution
139    pub fn shard_stats(&self) -> ShardStats {
140        let shard_sizes: Vec<usize> = self.shards.iter().map(|shard| shard.read().len()).collect();
141
142        let total_entries: usize = shard_sizes.iter().sum();
143        let max_shard_size = shard_sizes.iter().max().copied().unwrap_or(0);
144        let min_shard_size = shard_sizes.iter().min().copied().unwrap_or(0);
145        let avg_shard_size = if self.shard_count > 0 {
146            total_entries as f64 / self.shard_count as f64
147        } else {
148            0.0
149        };
150
151        ShardStats {
152            shard_count: self.shard_count,
153            total_entries,
154            max_shard_size,
155            min_shard_size,
156            avg_shard_size,
157            shard_sizes,
158        }
159    }
160}
161
162impl<K, V> Default for ShardedRwLock<K, V>
163where
164    K: Hash + Eq,
165{
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171/// Statistics about shard distribution
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct ShardStats {
174    pub shard_count: usize,
175    pub total_entries: usize,
176    pub max_shard_size: usize,
177    pub min_shard_size: usize,
178    pub avg_shard_size: f64,
179    pub shard_sizes: Vec<usize>,
180}
181
182impl ShardStats {
183    /// Calculate load balance ratio (0.0 = perfectly balanced, 1.0 = completely unbalanced)
184    pub fn load_balance_ratio(&self) -> f64 {
185        if self.total_entries == 0 || self.avg_shard_size == 0.0 {
186            return 0.0;
187        }
188
189        let variance: f64 = self
190            .shard_sizes
191            .iter()
192            .map(|&size| {
193                let diff = size as f64 - self.avg_shard_size;
194                diff * diff
195            })
196            .sum::<f64>()
197            / self.shard_count as f64;
198
199        let std_dev = variance.sqrt();
200        std_dev / self.avg_shard_size
201    }
202}
203
204/// Sharded mutex for exclusive access
205#[derive(Debug)]
206pub struct ShardedMutex<K, V>
207where
208    K: Hash + Eq,
209{
210    shards: Vec<Mutex<HashMap<K, V>>>,
211    shard_count: usize,
212}
213
214impl<K, V> ShardedMutex<K, V>
215where
216    K: Hash + Eq,
217{
218    /// Create a new sharded Mutex with default shard count
219    pub fn new() -> Self {
220        Self::with_shard_count(DEFAULT_SHARD_COUNT)
221    }
222
223    /// Create a new sharded Mutex with specified shard count
224    pub fn with_shard_count(shard_count: usize) -> Self {
225        let mut shards = Vec::with_capacity(shard_count);
226        for _ in 0..shard_count {
227            shards.push(Mutex::new(HashMap::new()));
228        }
229
230        Self {
231            shards,
232            shard_count,
233        }
234    }
235
236    /// Get the shard index for a given key
237    fn get_shard_index<Q>(&self, key: &Q) -> usize
238    where
239        Q: Hash + ?Sized,
240    {
241        let mut hasher = DefaultHasher::new();
242        key.hash(&mut hasher);
243        (hasher.finish() as usize) % self.shard_count
244    }
245
246    /// Insert a key-value pair
247    pub fn insert(&self, key: K, value: V) -> Option<V> {
248        let shard_index = self.get_shard_index(&key);
249        let mut shard = self.shards[shard_index].lock();
250        shard.insert(key, value)
251    }
252
253    /// Get a value by key
254    pub fn get<Q>(&self, key: &Q) -> Option<V>
255    where
256        K: std::borrow::Borrow<Q>,
257        Q: Hash + Eq + ?Sized,
258        V: Clone,
259    {
260        let shard_index = self.get_shard_index(key);
261        let shard = self.shards[shard_index].lock();
262        shard.get(key).cloned()
263    }
264
265    /// Remove a key-value pair
266    pub fn remove<Q>(&self, key: &Q) -> Option<V>
267    where
268        K: std::borrow::Borrow<Q>,
269        Q: Hash + Eq + ?Sized,
270    {
271        let shard_index = self.get_shard_index(key);
272        let mut shard = self.shards[shard_index].lock();
273        shard.remove(key)
274    }
275
276    /// Execute a function with exclusive access to a specific shard
277    pub fn with_shard<Q, F, R>(&self, key: &Q, f: F) -> R
278    where
279        K: std::borrow::Borrow<Q>,
280        Q: Hash + Eq + ?Sized,
281        F: FnOnce(&mut HashMap<K, V>) -> R,
282    {
283        let shard_index = self.get_shard_index(key);
284        let mut shard = self.shards[shard_index].lock();
285        f(&mut *shard)
286    }
287
288    /// Get the total number of entries across all shards
289    pub fn len(&self) -> usize {
290        self.shards.iter().map(|shard| shard.lock().len()).sum()
291    }
292
293    /// Check if the sharded map is empty
294    pub fn is_empty(&self) -> bool {
295        self.shards.iter().all(|shard| shard.lock().is_empty())
296    }
297}
298
299impl<K, V> Default for ShardedMutex<K, V>
300where
301    K: Hash + Eq,
302{
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_sharded_rwlock_basic_operations() {
314        let sharded = ShardedRwLock::new();
315
316        // Test insert and get
317        assert_eq!(sharded.insert("key1", "value1"), None);
318        assert_eq!(sharded.get("key1"), Some("value1"));
319
320        // Test update
321        assert_eq!(sharded.insert("key1", "value2"), Some("value1"));
322        assert_eq!(sharded.get("key1"), Some("value2"));
323
324        // Test remove
325        assert_eq!(sharded.remove("key1"), Some("value2"));
326        assert_eq!(sharded.get("key1"), None);
327    }
328
329    #[test]
330    fn test_shard_stats() {
331        let sharded = ShardedRwLock::with_shard_count(4);
332
333        // Insert some data
334        for i in 0..100 {
335            sharded.insert(i, format!("value_{i}"));
336        }
337
338        let stats = sharded.shard_stats();
339        assert_eq!(stats.shard_count, 4);
340        assert_eq!(stats.total_entries, 100);
341        assert!(stats.avg_shard_size > 0.0);
342        assert!(stats.load_balance_ratio() >= 0.0);
343    }
344}