sieve_cache/
sharded.rs

1use crate::SieveCache;
2use std::borrow::Borrow;
3use std::collections::hash_map::DefaultHasher;
4use std::hash::{Hash, Hasher};
5use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
6
7/// Default number of shards to use if not specified explicitly.
8/// This value was chosen as a good default that balances memory overhead
9/// with concurrency in most practical scenarios.
10const DEFAULT_SHARDS: usize = 16;
11
12/// A thread-safe implementation of `SieveCache` that uses multiple shards to reduce contention.
13///
14/// This provides better concurrency than `SyncSieveCache` by splitting the cache into multiple
15/// independent shards, each protected by its own mutex. Operations on different shards can
16/// proceed in parallel, which can significantly improve throughput in multi-threaded environments.
17///
18/// # How Sharding Works
19///
20/// The cache is partitioned into multiple independent segments (shards) based on the hash of the key.
21/// Each shard has its own mutex, allowing operations on different shards to proceed concurrently.
22/// This reduces lock contention compared to a single-mutex approach, especially under high
23/// concurrency with access patterns distributed across different keys.
24///
25/// # Performance Considerations
26///
27/// - For workloads with high concurrency across different keys, `ShardedSieveCache` typically offers
28///   better performance than `SyncSieveCache`
29/// - The benefits increase with the number of concurrent threads and the distribution of keys
30/// - More shards reduce contention but increase memory overhead
31/// - If most operations target the same few keys (which map to the same shards), the benefits of
32///   sharding may be limited
33///
34/// # Examples
35///
36/// ```
37/// # use sieve_cache::ShardedSieveCache;
38/// // Create a cache with default number of shards (16)
39/// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::new(1000).unwrap();
40///
41/// // Or specify a custom number of shards
42/// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::with_shards(1000, 32).unwrap();
43///
44/// cache.insert("key1".to_string(), "value1".to_string());
45/// assert_eq!(cache.get(&"key1".to_string()), Some("value1".to_string()));
46/// ```
47#[derive(Clone)]
48pub struct ShardedSieveCache<K, V>
49where
50    K: Eq + Hash + Clone + Send + Sync,
51    V: Send + Sync,
52{
53    /// Array of shard mutexes, each containing a separate SieveCache instance
54    shards: Vec<Arc<Mutex<SieveCache<K, V>>>>,
55    /// Number of shards in the cache - kept as a separate field for quick access
56    num_shards: usize,
57}
58
59impl<K, V> ShardedSieveCache<K, V>
60where
61    K: Eq + Hash + Clone + Send + Sync,
62    V: Send + Sync,
63{
64    /// Creates a new sharded cache with the specified capacity, using the default number of shards.
65    ///
66    /// The capacity will be divided evenly among the shards. The default shard count (16)
67    /// provides a good balance between concurrency and memory overhead for most workloads.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if the capacity is zero.
72    ///
73    /// # Examples
74    ///
75    /// ```
76    /// # use sieve_cache::ShardedSieveCache;
77    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::new(1000).unwrap();
78    /// assert_eq!(cache.num_shards(), 16); // Default shard count
79    /// ```
80    pub fn new(capacity: usize) -> Result<Self, &'static str> {
81        Self::with_shards(capacity, DEFAULT_SHARDS)
82    }
83
84    /// Creates a new sharded cache with the specified capacity and number of shards.
85    ///
86    /// The capacity will be divided among the shards, distributing any remainder to ensure
87    /// the total capacity is at least the requested amount.
88    ///
89    /// # Arguments
90    ///
91    /// * `capacity` - The total capacity of the cache
92    /// * `num_shards` - The number of shards to divide the cache into
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if either the capacity or number of shards is zero.
97    ///
98    /// # Performance Impact
99    ///
100    /// - More shards can reduce contention in highly concurrent environments
101    /// - However, each shard has memory overhead, so very high shard counts may
102    ///   increase memory usage without providing additional performance benefits
103    /// - For most workloads, a value between 8 and 32 shards is optimal
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// # use sieve_cache::ShardedSieveCache;
109    /// // Create a cache with 8 shards
110    /// let cache: ShardedSieveCache<String, u32> = ShardedSieveCache::with_shards(1000, 8).unwrap();
111    /// assert_eq!(cache.num_shards(), 8);
112    /// assert!(cache.capacity() >= 1000);
113    /// ```
114    pub fn with_shards(capacity: usize, num_shards: usize) -> Result<Self, &'static str> {
115        if capacity == 0 {
116            return Err("capacity must be greater than 0");
117        }
118        if num_shards == 0 {
119            return Err("number of shards must be greater than 0");
120        }
121
122        // Calculate per-shard capacity
123        let base_capacity_per_shard = capacity / num_shards;
124        let remaining = capacity % num_shards;
125
126        let mut shards = Vec::with_capacity(num_shards);
127        for i in 0..num_shards {
128            // Distribute the remaining capacity to the first 'remaining' shards
129            let shard_capacity = if i < remaining {
130                base_capacity_per_shard + 1
131            } else {
132                base_capacity_per_shard
133            };
134
135            // Ensure at least capacity 1 per shard
136            let shard_capacity = std::cmp::max(1, shard_capacity);
137            shards.push(Arc::new(Mutex::new(SieveCache::new(shard_capacity)?)));
138        }
139
140        Ok(Self { shards, num_shards })
141    }
142
143    /// Returns the shard index for a given key.
144    ///
145    /// This function computes a hash of the key and uses it to determine which shard
146    /// the key belongs to.
147    #[inline]
148    fn get_shard_index<Q>(&self, key: &Q) -> usize
149    where
150        Q: Hash + ?Sized,
151    {
152        let mut hasher = DefaultHasher::new();
153        key.hash(&mut hasher);
154        let hash = hasher.finish() as usize;
155        hash % self.num_shards
156    }
157
158    /// Returns a reference to the shard for a given key.
159    ///
160    /// This is an internal helper method that maps a key to its corresponding shard.
161    #[inline]
162    fn get_shard<Q>(&self, key: &Q) -> &Arc<Mutex<SieveCache<K, V>>>
163    where
164        Q: Hash + ?Sized,
165    {
166        let index = self.get_shard_index(key);
167        &self.shards[index]
168    }
169
170    /// Returns a locked reference to the shard for a given key.
171    ///
172    /// This is an internal helper method to abstract away the lock handling and error recovery.
173    /// If the mutex is poisoned due to a panic in another thread, the poison error is
174    /// recovered from by calling `into_inner()` to access the underlying data.
175    #[inline]
176    fn locked_shard<Q>(&self, key: &Q) -> MutexGuard<'_, SieveCache<K, V>>
177    where
178        Q: Hash + ?Sized,
179    {
180        self.get_shard(key)
181            .lock()
182            .unwrap_or_else(PoisonError::into_inner)
183    }
184
185    /// Returns the total capacity of the cache (sum of all shard capacities).
186    ///
187    /// # Examples
188    ///
189    /// ```
190    /// # use sieve_cache::ShardedSieveCache;
191    /// let cache: ShardedSieveCache<String, u32> = ShardedSieveCache::new(1000).unwrap();
192    /// assert!(cache.capacity() >= 1000);
193    /// ```
194    pub fn capacity(&self) -> usize {
195        self.shards
196            .iter()
197            .map(|shard| {
198                shard
199                    .lock()
200                    .unwrap_or_else(PoisonError::into_inner)
201                    .capacity()
202            })
203            .sum()
204    }
205
206    /// Returns the total number of entries in the cache (sum of all shard lengths).
207    ///
208    /// Note that this operation requires acquiring a lock on each shard, so it may
209    /// cause temporary contention if called frequently in a high-concurrency environment.
210    ///
211    /// # Examples
212    ///
213    /// ```
214    /// # use sieve_cache::ShardedSieveCache;
215    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::new(100).unwrap();
216    ///
217    /// cache.insert("key1".to_string(), "value1".to_string());
218    /// cache.insert("key2".to_string(), "value2".to_string());
219    ///
220    /// assert_eq!(cache.len(), 2);
221    /// ```
222    pub fn len(&self) -> usize {
223        self.shards
224            .iter()
225            .map(|shard| shard.lock().unwrap_or_else(PoisonError::into_inner).len())
226            .sum()
227    }
228
229    /// Returns `true` when no values are currently cached in any shard.
230    ///
231    /// Note that this operation requires acquiring a lock on each shard, so it may
232    /// cause temporary contention if called frequently in a high-concurrency environment.
233    ///
234    /// # Examples
235    ///
236    /// ```
237    /// # use sieve_cache::ShardedSieveCache;
238    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::new(100).unwrap();
239    /// assert!(cache.is_empty());
240    ///
241    /// cache.insert("key".to_string(), "value".to_string());
242    /// assert!(!cache.is_empty());
243    /// ```
244    pub fn is_empty(&self) -> bool {
245        self.shards.iter().all(|shard| {
246            shard
247                .lock()
248                .unwrap_or_else(PoisonError::into_inner)
249                .is_empty()
250        })
251    }
252
253    /// Returns `true` if there is a value in the cache mapped to by `key`.
254    ///
255    /// This operation only locks the specific shard containing the key.
256    ///
257    /// # Examples
258    ///
259    /// ```
260    /// # use sieve_cache::ShardedSieveCache;
261    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::new(100).unwrap();
262    /// cache.insert("key".to_string(), "value".to_string());
263    ///
264    /// assert!(cache.contains_key(&"key".to_string()));
265    /// assert!(!cache.contains_key(&"missing".to_string()));
266    /// ```
267    pub fn contains_key<Q>(&self, key: &Q) -> bool
268    where
269        Q: Hash + Eq + ?Sized,
270        K: Borrow<Q>,
271    {
272        let mut guard = self.locked_shard(key);
273        guard.contains_key(key)
274    }
275
276    /// Gets a clone of the value in the cache mapped to by `key`.
277    ///
278    /// If no value exists for `key`, this returns `None`. This operation only locks
279    /// the specific shard containing the key.
280    ///
281    /// # Note
282    ///
283    /// This method returns a clone of the value rather than a reference, since the
284    /// mutex guard would be dropped after this method returns. This means that
285    /// `V` must implement `Clone`.
286    ///
287    /// # Examples
288    ///
289    /// ```
290    /// # use sieve_cache::ShardedSieveCache;
291    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::new(100).unwrap();
292    /// cache.insert("key".to_string(), "value".to_string());
293    ///
294    /// assert_eq!(cache.get(&"key".to_string()), Some("value".to_string()));
295    /// assert_eq!(cache.get(&"missing".to_string()), None);
296    /// ```
297    pub fn get<Q>(&self, key: &Q) -> Option<V>
298    where
299        Q: Hash + Eq + ?Sized,
300        K: Borrow<Q>,
301        V: Clone,
302    {
303        let mut guard = self.locked_shard(key);
304        guard.get(key).cloned()
305    }
306
307    /// Maps `key` to `value` in the cache, possibly evicting old entries from the appropriate shard.
308    ///
309    /// This method returns `true` when this is a new entry, and `false` if an existing entry was
310    /// updated. This operation only locks the specific shard containing the key.
311    ///
312    /// # Examples
313    ///
314    /// ```
315    /// # use sieve_cache::ShardedSieveCache;
316    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::new(100).unwrap();
317    ///
318    /// // Insert a new key
319    /// assert!(cache.insert("key1".to_string(), "value1".to_string()));
320    ///
321    /// // Update an existing key
322    /// assert!(!cache.insert("key1".to_string(), "updated_value".to_string()));
323    /// ```
324    pub fn insert(&self, key: K, value: V) -> bool {
325        let mut guard = self.locked_shard(&key);
326        guard.insert(key, value)
327    }
328
329    /// Removes the cache entry mapped to by `key`.
330    ///
331    /// This method returns the value removed from the cache. If `key` did not map to any value,
332    /// then this returns `None`. This operation only locks the specific shard containing the key.
333    ///
334    /// # Examples
335    ///
336    /// ```
337    /// # use sieve_cache::ShardedSieveCache;
338    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::new(100).unwrap();
339    /// cache.insert("key".to_string(), "value".to_string());
340    ///
341    /// // Remove an existing key
342    /// assert_eq!(cache.remove(&"key".to_string()), Some("value".to_string()));
343    ///
344    /// // Attempt to remove a missing key
345    /// assert_eq!(cache.remove(&"key".to_string()), None);
346    /// ```
347    pub fn remove<Q>(&self, key: &Q) -> Option<V>
348    where
349        K: Borrow<Q>,
350        Q: Eq + Hash + ?Sized,
351    {
352        let mut guard = self.locked_shard(key);
353        guard.remove(key)
354    }
355
356    /// Removes and returns a value from the cache that was not recently accessed.
357    ///
358    /// This method tries to evict from each shard in turn until it finds a value to evict.
359    /// If no suitable value exists in any shard, this returns `None`.
360    ///
361    /// # Note
362    ///
363    /// This implementation differs from the non-sharded version in that it checks each shard
364    /// in sequence until it finds a suitable entry to evict. This may not provide the globally
365    /// optimal eviction decision across all shards, but it avoids the need to lock all shards
366    /// simultaneously.
367    ///
368    /// # Examples
369    ///
370    /// ```
371    /// # use sieve_cache::ShardedSieveCache;
372    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::with_shards(10, 2).unwrap();
373    ///
374    /// // Fill the cache
375    /// for i in 0..15 {
376    ///     cache.insert(format!("key{}", i), format!("value{}", i));
377    /// }
378    ///
379    /// // Should be able to evict something
380    /// assert!(cache.evict().is_some());
381    /// ```
382    pub fn evict(&self) -> Option<V> {
383        // Try each shard in turn
384        for shard in &self.shards {
385            let result = shard.lock().unwrap_or_else(PoisonError::into_inner).evict();
386            if result.is_some() {
387                return result;
388            }
389        }
390        None
391    }
392
393    /// Gets exclusive access to a specific shard based on the key.
394    ///
395    /// This can be useful for performing multiple operations atomically on entries
396    /// that share the same shard. Note that only keys that hash to the same shard
397    /// can be manipulated within a single transaction.
398    ///
399    /// # Examples
400    ///
401    /// ```
402    /// # use sieve_cache::ShardedSieveCache;
403    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::new(100).unwrap();
404    ///
405    /// // Perform multiple operations atomically
406    /// cache.with_key_lock(&"foo", |shard| {
407    ///     // All operations within this closure have exclusive access to the shard
408    ///     shard.insert("key1".to_string(), "value1".to_string());
409    ///     shard.insert("key2".to_string(), "value2".to_string());
410    ///     
411    ///     // We can check state mid-transaction
412    ///     assert!(shard.contains_key(&"key1".to_string()));
413    /// });
414    /// ```
415    pub fn with_key_lock<Q, F, T>(&self, key: &Q, f: F) -> T
416    where
417        Q: Hash + ?Sized,
418        F: FnOnce(&mut SieveCache<K, V>) -> T,
419    {
420        let mut guard = self.locked_shard(key);
421        f(&mut guard)
422    }
423
424    /// Returns the number of shards in this cache.
425    ///
426    /// # Examples
427    ///
428    /// ```
429    /// # use sieve_cache::ShardedSieveCache;
430    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::with_shards(1000, 32).unwrap();
431    /// assert_eq!(cache.num_shards(), 32);
432    /// ```
433    pub fn num_shards(&self) -> usize {
434        self.num_shards
435    }
436
437    /// Gets a specific shard by index.
438    ///
439    /// This is mainly useful for advanced use cases and maintenance operations.
440    /// Returns `None` if the index is out of bounds.
441    ///
442    /// # Examples
443    ///
444    /// ```
445    /// # use sieve_cache::ShardedSieveCache;
446    /// let cache: ShardedSieveCache<String, String> = ShardedSieveCache::with_shards(1000, 8).unwrap();
447    ///
448    /// // Access a valid shard
449    /// assert!(cache.get_shard_by_index(0).is_some());
450    ///
451    /// // Out of bounds index
452    /// assert!(cache.get_shard_by_index(100).is_none());
453    /// ```
454    pub fn get_shard_by_index(&self, index: usize) -> Option<&Arc<Mutex<SieveCache<K, V>>>> {
455        self.shards.get(index)
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462    use std::sync::Arc;
463    use std::thread;
464    use std::time::Duration;
465
466    #[test]
467    fn test_sharded_cache_basics() {
468        let cache = ShardedSieveCache::new(100).unwrap();
469
470        // Insert a value
471        assert!(cache.insert("key1".to_string(), "value1".to_string()));
472
473        // Read back the value
474        assert_eq!(cache.get(&"key1".to_string()), Some("value1".to_string()));
475
476        // Check contains_key
477        assert!(cache.contains_key(&"key1".to_string()));
478
479        // Check capacity and length
480        assert!(cache.capacity() >= 100); // May be slightly higher due to rounding up per shard
481        assert_eq!(cache.len(), 1);
482
483        // Remove a value
484        assert_eq!(
485            cache.remove(&"key1".to_string()),
486            Some("value1".to_string())
487        );
488        assert_eq!(cache.len(), 0);
489        assert!(cache.is_empty());
490    }
491
492    #[test]
493    fn test_custom_shard_count() {
494        let cache = ShardedSieveCache::with_shards(100, 4).unwrap();
495        assert_eq!(cache.num_shards(), 4);
496
497        for i in 0..10 {
498            let key = format!("key{}", i);
499            let value = format!("value{}", i);
500            cache.insert(key, value);
501        }
502
503        assert_eq!(cache.len(), 10);
504    }
505
506    #[test]
507    fn test_parallel_access() {
508        let cache = Arc::new(ShardedSieveCache::with_shards(1000, 16).unwrap());
509        let mut handles = vec![];
510
511        // Spawn 8 threads that each insert 100 items
512        for t in 0..8 {
513            let cache_clone = Arc::clone(&cache);
514            let handle = thread::spawn(move || {
515                for i in 0..100 {
516                    let key = format!("thread{}key{}", t, i);
517                    let value = format!("value{}_{}", t, i);
518                    cache_clone.insert(key, value);
519                }
520            });
521            handles.push(handle);
522        }
523
524        // Wait for all threads to complete
525        for handle in handles {
526            handle.join().unwrap();
527        }
528
529        // Verify total item count
530        assert_eq!(cache.len(), 800);
531
532        // Check a few random keys
533        assert_eq!(
534            cache.get(&"thread0key50".to_string()),
535            Some("value0_50".to_string())
536        );
537        assert_eq!(
538            cache.get(&"thread7key99".to_string()),
539            Some("value7_99".to_string())
540        );
541    }
542
543    #[test]
544    fn test_with_key_lock() {
545        let cache = ShardedSieveCache::new(100).unwrap();
546
547        // Perform multiple operations atomically on keys that map to the same shard
548        cache.with_key_lock(&"test_key", |shard| {
549            shard.insert("key1".to_string(), "value1".to_string());
550            shard.insert("key2".to_string(), "value2".to_string());
551            shard.insert("key3".to_string(), "value3".to_string());
552        });
553
554        assert_eq!(cache.len(), 3);
555    }
556
557    #[test]
558    fn test_eviction() {
559        let cache = ShardedSieveCache::with_shards(10, 2).unwrap();
560
561        // Fill the cache
562        for i in 0..15 {
563            let key = format!("key{}", i);
564            let value = format!("value{}", i);
565            cache.insert(key, value);
566        }
567
568        // The cache should not exceed its capacity
569        assert!(cache.len() <= 10);
570
571        // We should be able to evict items
572        let evicted = cache.evict();
573        assert!(evicted.is_some());
574    }
575
576    #[test]
577    fn test_contention() {
578        let cache = Arc::new(ShardedSieveCache::with_shards(1000, 16).unwrap());
579        let mut handles = vec![];
580
581        // Create keys that we know will hash to different shards
582        let keys: Vec<String> = (0..16).map(|i| format!("shard_key_{}", i)).collect();
583
584        // Spawn 16 threads, each hammering a different key
585        for i in 0..16 {
586            let cache_clone = Arc::clone(&cache);
587            let key = keys[i].clone();
588
589            let handle = thread::spawn(move || {
590                for j in 0..1000 {
591                    cache_clone.insert(key.clone(), format!("value_{}", j));
592                    let _ = cache_clone.get(&key);
593
594                    // Small sleep to make contention more likely
595                    if j % 100 == 0 {
596                        thread::sleep(Duration::from_micros(1));
597                    }
598                }
599            });
600
601            handles.push(handle);
602        }
603
604        // Wait for all threads to complete
605        for handle in handles {
606            handle.join().unwrap();
607        }
608
609        // All keys should still be present
610        for key in keys {
611            assert!(cache.contains_key(&key));
612        }
613    }
614}