1use parking_lot::{Mutex, RwLock};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::hash::{DefaultHasher, Hash, Hasher};
10
11const DEFAULT_SHARD_COUNT: usize = 16;
13
14#[derive(Debug)]
16pub struct ShardedRwLock<K, V>
17where
18 K: Hash + Eq,
19{
20 shards: Vec<RwLock<HashMap<K, V>>>,
21 shard_count: usize,
22}
23
24impl<K, V> ShardedRwLock<K, V>
25where
26 K: Hash + Eq,
27{
28 pub fn new() -> Self {
30 Self::with_shard_count(DEFAULT_SHARD_COUNT)
31 }
32
33 pub fn with_shard_count(shard_count: usize) -> Self {
35 let mut shards = Vec::with_capacity(shard_count);
36 for _ in 0..shard_count {
37 shards.push(RwLock::new(HashMap::new()));
38 }
39
40 Self {
41 shards,
42 shard_count,
43 }
44 }
45
46 fn get_shard_index<Q>(&self, key: &Q) -> usize
48 where
49 Q: Hash + ?Sized,
50 {
51 let mut hasher = DefaultHasher::new();
52 key.hash(&mut hasher);
53 (hasher.finish() as usize) % self.shard_count
54 }
55
56 pub fn insert(&self, key: K, value: V) -> Option<V> {
58 let shard_index = self.get_shard_index(&key);
59 let mut shard = self.shards[shard_index].write();
60 shard.insert(key, value)
61 }
62
63 pub fn get<Q>(&self, key: &Q) -> Option<V>
65 where
66 K: std::borrow::Borrow<Q>,
67 Q: Hash + Eq + ?Sized,
68 V: Clone,
69 {
70 let shard_index = self.get_shard_index(key);
71 let shard = self.shards[shard_index].read();
72 shard.get(key).cloned()
73 }
74
75 pub fn remove<Q>(&self, key: &Q) -> Option<V>
77 where
78 K: std::borrow::Borrow<Q>,
79 Q: Hash + Eq + ?Sized,
80 {
81 let shard_index = self.get_shard_index(key);
82 let mut shard = self.shards[shard_index].write();
83 shard.remove(key)
84 }
85
86 pub fn contains_key<Q>(&self, key: &Q) -> bool
88 where
89 K: std::borrow::Borrow<Q>,
90 Q: Hash + Eq + ?Sized,
91 {
92 let shard_index = self.get_shard_index(key);
93 let shard = self.shards[shard_index].read();
94 shard.contains_key(key)
95 }
96
97 pub fn len(&self) -> usize {
99 self.shards.iter().map(|shard| shard.read().len()).sum()
100 }
101
102 pub fn is_empty(&self) -> bool {
104 self.shards.iter().all(|shard| shard.read().is_empty())
105 }
106
107 pub fn clear(&self) {
109 for shard in &self.shards {
110 shard.write().clear();
111 }
112 }
113
114 pub fn with_shard_read<Q, F, R>(&self, key: &Q, f: F) -> R
116 where
117 K: std::borrow::Borrow<Q>,
118 Q: Hash + Eq + ?Sized,
119 F: FnOnce(&HashMap<K, V>) -> R,
120 {
121 let shard_index = self.get_shard_index(key);
122 let shard = self.shards[shard_index].read();
123 f(&*shard)
124 }
125
126 pub fn with_shard_write<Q, F, R>(&self, key: &Q, f: F) -> R
128 where
129 K: std::borrow::Borrow<Q>,
130 Q: Hash + Eq + ?Sized,
131 F: FnOnce(&mut HashMap<K, V>) -> R,
132 {
133 let shard_index = self.get_shard_index(key);
134 let mut shard = self.shards[shard_index].write();
135 f(&mut *shard)
136 }
137
138 pub fn shard_stats(&self) -> ShardStats {
140 let shard_sizes: Vec<usize> = self.shards.iter().map(|shard| shard.read().len()).collect();
141
142 let total_entries: usize = shard_sizes.iter().sum();
143 let max_shard_size = shard_sizes.iter().max().copied().unwrap_or(0);
144 let min_shard_size = shard_sizes.iter().min().copied().unwrap_or(0);
145 let avg_shard_size = if self.shard_count > 0 {
146 total_entries as f64 / self.shard_count as f64
147 } else {
148 0.0
149 };
150
151 ShardStats {
152 shard_count: self.shard_count,
153 total_entries,
154 max_shard_size,
155 min_shard_size,
156 avg_shard_size,
157 shard_sizes,
158 }
159 }
160}
161
162impl<K, V> Default for ShardedRwLock<K, V>
163where
164 K: Hash + Eq,
165{
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct ShardStats {
174 pub shard_count: usize,
175 pub total_entries: usize,
176 pub max_shard_size: usize,
177 pub min_shard_size: usize,
178 pub avg_shard_size: f64,
179 pub shard_sizes: Vec<usize>,
180}
181
182impl ShardStats {
183 pub fn load_balance_ratio(&self) -> f64 {
185 if self.total_entries == 0 || self.avg_shard_size == 0.0 {
186 return 0.0;
187 }
188
189 let variance: f64 = self
190 .shard_sizes
191 .iter()
192 .map(|&size| {
193 let diff = size as f64 - self.avg_shard_size;
194 diff * diff
195 })
196 .sum::<f64>()
197 / self.shard_count as f64;
198
199 let std_dev = variance.sqrt();
200 std_dev / self.avg_shard_size
201 }
202}
203
204#[derive(Debug)]
206pub struct ShardedMutex<K, V>
207where
208 K: Hash + Eq,
209{
210 shards: Vec<Mutex<HashMap<K, V>>>,
211 shard_count: usize,
212}
213
214impl<K, V> ShardedMutex<K, V>
215where
216 K: Hash + Eq,
217{
218 pub fn new() -> Self {
220 Self::with_shard_count(DEFAULT_SHARD_COUNT)
221 }
222
223 pub fn with_shard_count(shard_count: usize) -> Self {
225 let mut shards = Vec::with_capacity(shard_count);
226 for _ in 0..shard_count {
227 shards.push(Mutex::new(HashMap::new()));
228 }
229
230 Self {
231 shards,
232 shard_count,
233 }
234 }
235
236 fn get_shard_index<Q>(&self, key: &Q) -> usize
238 where
239 Q: Hash + ?Sized,
240 {
241 let mut hasher = DefaultHasher::new();
242 key.hash(&mut hasher);
243 (hasher.finish() as usize) % self.shard_count
244 }
245
246 pub fn insert(&self, key: K, value: V) -> Option<V> {
248 let shard_index = self.get_shard_index(&key);
249 let mut shard = self.shards[shard_index].lock();
250 shard.insert(key, value)
251 }
252
253 pub fn get<Q>(&self, key: &Q) -> Option<V>
255 where
256 K: std::borrow::Borrow<Q>,
257 Q: Hash + Eq + ?Sized,
258 V: Clone,
259 {
260 let shard_index = self.get_shard_index(key);
261 let shard = self.shards[shard_index].lock();
262 shard.get(key).cloned()
263 }
264
265 pub fn remove<Q>(&self, key: &Q) -> Option<V>
267 where
268 K: std::borrow::Borrow<Q>,
269 Q: Hash + Eq + ?Sized,
270 {
271 let shard_index = self.get_shard_index(key);
272 let mut shard = self.shards[shard_index].lock();
273 shard.remove(key)
274 }
275
276 pub fn with_shard<Q, F, R>(&self, key: &Q, f: F) -> R
278 where
279 K: std::borrow::Borrow<Q>,
280 Q: Hash + Eq + ?Sized,
281 F: FnOnce(&mut HashMap<K, V>) -> R,
282 {
283 let shard_index = self.get_shard_index(key);
284 let mut shard = self.shards[shard_index].lock();
285 f(&mut *shard)
286 }
287
288 pub fn len(&self) -> usize {
290 self.shards.iter().map(|shard| shard.lock().len()).sum()
291 }
292
293 pub fn is_empty(&self) -> bool {
295 self.shards.iter().all(|shard| shard.lock().is_empty())
296 }
297}
298
299impl<K, V> Default for ShardedMutex<K, V>
300where
301 K: Hash + Eq,
302{
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_sharded_rwlock_basic_operations() {
314 let sharded = ShardedRwLock::new();
315
316 assert_eq!(sharded.insert("key1", "value1"), None);
318 assert_eq!(sharded.get("key1"), Some("value1"));
319
320 assert_eq!(sharded.insert("key1", "value2"), Some("value1"));
322 assert_eq!(sharded.get("key1"), Some("value2"));
323
324 assert_eq!(sharded.remove("key1"), Some("value2"));
326 assert_eq!(sharded.get("key1"), None);
327 }
328
329 #[test]
330 fn test_shard_stats() {
331 let sharded = ShardedRwLock::with_shard_count(4);
332
333 for i in 0..100 {
335 sharded.insert(i, format!("value_{i}"));
336 }
337
338 let stats = sharded.shard_stats();
339 assert_eq!(stats.shard_count, 4);
340 assert_eq!(stats.total_entries, 100);
341 assert!(stats.avg_shard_size > 0.0);
342 assert!(stats.load_balance_ratio() >= 0.0);
343 }
344}