1use std::{
8 hash::{BuildHasher, BuildHasherDefault, DefaultHasher, Hash},
9 marker::PhantomData,
10 num::NonZeroUsize,
11 sync::{Mutex, MutexGuard},
12};
13
14use crate::batch::{KeyBatch, MAX_BATCH_KEYS};
15
16pub struct StripedPoisonError;
18
19pub struct StripedLockGuard<'l> {
20 _guard: MutexGuard<'l, ()>,
21}
22
23pub struct StripedBatchLockGuard<'l> {
24 _guards: [Option<MutexGuard<'l, ()>>; MAX_BATCH_KEYS],
25}
26
27pub struct StripedLock<K, H = BuildHasherDefault<DefaultHasher>>
28where
29 K: Hash,
30 H: BuildHasher,
31{
32 hasher_builder: H,
33 locks: Box<[Mutex<()>]>,
34 phantom: PhantomData<K>,
35}
36
37impl<K> StripedLock<K, BuildHasherDefault<DefaultHasher>>
38where
39 K: Hash,
40{
41 pub fn new(locks: NonZeroUsize) -> Self {
47 Self::with_hasher(BuildHasherDefault::default(), locks)
48 }
49}
50
51impl<K, H> StripedLock<K, H>
52where
53 K: Hash,
54 H: BuildHasher,
55{
56 pub fn with_hasher(hasher_builder: H, locks: NonZeroUsize) -> Self {
63 let locks = (0..locks.get())
64 .map(|_| Mutex::new(()))
65 .collect::<Vec<_>>()
66 .into_boxed_slice();
67
68 Self {
69 hasher_builder,
70 locks,
71 phantom: PhantomData::default(),
72 }
73 }
74
75 pub fn lock(&self, key: K) -> Result<StripedLockGuard, StripedPoisonError> {
82 fn inner(locks: &[Mutex<()>], key: u64) -> Result<StripedLockGuard, StripedPoisonError> {
83 let idx = (key % locks.len() as u64) as usize;
84 let lock = &locks[idx];
85
86 match lock.lock() {
87 Ok(guard) => Ok(StripedLockGuard { _guard: guard }),
88 Err(_) => Err(StripedPoisonError),
89 }
90 }
91
92 let hash = self.hasher_builder.hash_one(key);
93 inner(&self.locks, hash)
94 }
95
96 pub fn lock_batch<B>(&self, batch: B) -> Result<StripedBatchLockGuard, StripedPoisonError>
113 where
114 B: KeyBatch<K, H>,
115 {
116 fn inner<'l>(
117 locks: &'l [Mutex<()>],
118 batch: &mut [u64],
119 ) -> Result<StripedBatchLockGuard<'l>, StripedPoisonError> {
120 const ARRAY_REPEAT_VALUE: Option<MutexGuard<()>> = None;
121
122 assert!(batch.len() > 0);
123 assert!(batch.len() <= MAX_BATCH_KEYS);
124
125 for key in batch.iter_mut() {
127 *key %= locks.len() as u64;
128 }
129
130 batch.sort_unstable();
132
133 let mut guards = [ARRAY_REPEAT_VALUE; MAX_BATCH_KEYS];
134
135 guards[0] = Some(
136 locks[batch[0] as usize]
137 .lock()
138 .map_err(|_| StripedPoisonError)?,
139 );
140
141 for i in 1..batch.len() {
142 if batch[i] != batch[i - 1] {
144 guards[i] = Some(
145 locks[batch[i] as usize]
146 .lock()
147 .map_err(|_| StripedPoisonError)?,
148 );
149 }
150 }
151
152 Ok(StripedBatchLockGuard { _guards: guards })
153 }
154
155 let (mut arr, filled) = batch.into_hash_array(&self.hasher_builder);
156 let batch = &mut arr[..filled];
157 inner(&self.locks, batch)
158 }
159
160 pub fn is_poisoned(&self, key: K) -> bool {
162 fn inner(locks: &[Mutex<()>], key: u64) -> bool {
163 let idx = (key % locks.len() as u64) as usize;
164 let lock = &locks[idx];
165 lock.is_poisoned()
166 }
167
168 let key = self.hasher_builder.hash_one(key);
169 inner(&self.locks, key)
170 }
171
172 pub fn clear_poison(&self, key: K) {
174 fn inner(locks: &[Mutex<()>], key: u64) {
175 let idx = (key % locks.len() as u64) as usize;
176 let lock = &locks[idx];
177 lock.clear_poison();
178 }
179
180 let key = self.hasher_builder.hash_one(key);
181 inner(&self.locks, key);
182 }
183}