concurrent_initializer/
lib.rs1#[cfg(test)]
2#[macro_use]
3extern crate assert_matches;
4
5mod cht;
6
7use std::{
8 any::Any,
9 collections::hash_map::RandomState,
10 hash::{BuildHasher, Hash},
11 sync::Arc,
12};
13
14use parking_lot::RwLock;
15
16const WAITER_MAP_NUM_SEGMENTS: usize = 64;
17
18#[derive(Debug)]
19pub enum InitResult<V, E> {
20 Initialized(V),
21 ReadExisting(V),
22 InitErr(Arc<E>),
23}
24
25type ErrorObject = Arc<dyn Any + Send + Sync + 'static>;
26type WaiterValue<V> = Option<Result<V, ErrorObject>>;
27type Waiter<V> = Arc<RwLock<WaiterValue<V>>>;
28
29pub struct ConcurrentInitializer<K, V, S = RandomState> {
30 waiters: crate::cht::SegmentedHashMap<Arc<K>, Waiter<V>, S>,
31}
32
33impl<K, V> ConcurrentInitializer<K, V>
34where
35 K: Eq + Hash,
36 V: Clone,
37{
38 pub fn new() -> Self {
39 Self::with_hasher(RandomState::new())
40 }
41}
42
43impl<K, V, S> ConcurrentInitializer<K, V, S>
44where
45 K: Eq + Hash,
46 V: Clone,
47 S: BuildHasher,
48{
49 pub fn with_hasher(build_hasher: S) -> Self {
50 Self {
51 waiters: cht::SegmentedHashMap::with_num_segments_and_hasher(
52 WAITER_MAP_NUM_SEGMENTS,
53 build_hasher,
54 ),
55 }
56 }
57
58 pub fn try_get_or_init<E>(
61 &self,
62 key: &Arc<K>,
63 mut get: impl FnMut() -> Result<Option<V>, E>,
65 init: impl FnOnce() -> Result<V, E>,
67 ) -> InitResult<V, E>
68 where
69 E: Send + Sync + 'static,
70 {
71 use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
72
73 use InitResult::*;
74
75 const MAX_RETRIES: usize = 200;
76 let mut retries = 0;
77
78 let (cht_key, hash) = self.cht_key_hash(key);
79
80 loop {
81 let waiter = Arc::new(RwLock::new(None));
82 let mut lock = waiter.write();
83
84 match self.try_insert_waiter(cht_key.clone(), hash, &waiter) {
85 None => {
86 match get() {
89 Ok(ok) => {
90 if let Some(value) = ok {
91 *lock = Some(Ok(value.clone()));
94 self.remove_waiter(cht_key, hash);
95 return InitResult::ReadExisting(value);
96 }
97 }
98 Err(err) => {
99 let err: ErrorObject = Arc::new(err);
102 *lock = Some(Err(Arc::clone(&err)));
103 self.remove_waiter(cht_key, hash);
104 return InitErr(err.downcast().unwrap());
105 }
106 }
107
108 match catch_unwind(AssertUnwindSafe(init)) {
112 Ok(value) => {
114 let (waiter_val, init_res) = match value {
115 Ok(value) => (Some(Ok(value.clone())), InitResult::Initialized(value)),
116 Err(e) => {
117 let err: ErrorObject = Arc::new(e);
118 (Some(Err(Arc::clone(&err))), InitResult::InitErr(err.downcast().unwrap()))
119 }
120 };
121 *lock = waiter_val;
122 self.remove_waiter(cht_key, hash);
123 return init_res;
124 }
125 Err(payload) => {
127 *lock = None;
128 self.remove_waiter(cht_key, hash);
130 resume_unwind(payload);
131 }
132 } }
134 Some(res) => {
135 std::mem::drop(lock);
138 match &*res.read() {
139 Some(Ok(value)) => return ReadExisting(value.clone()),
140 Some(Err(e)) => return InitErr(Arc::clone(e).downcast().unwrap()),
141 None => {
143 retries += 1;
144 if retries < MAX_RETRIES {
145 continue;
147 } else {
148 panic!(
149 "Too many retries. Tried to read the return value from the `init` \
150 closure but failed {} times. Maybe the `init` kept panicking?",
151 retries
152 );
153 }
154 }
155 }
156 }
157 }
158 }
159 }
160
161 #[inline]
162 fn remove_waiter(&self, cht_key: Arc<K>, hash: u64) {
163 self.waiters.remove(hash, |k| k == &cht_key);
164 }
165
166 #[inline]
167 fn try_insert_waiter(&self, cht_key: Arc<K>, hash: u64, waiter: &Waiter<V>) -> Option<Waiter<V>> {
168 let waiter = Arc::clone(waiter);
169 self.waiters.insert_if_not_present(cht_key, hash, waiter)
170 }
171
172 #[inline]
173 fn cht_key_hash(&self, key: &Arc<K>) -> (Arc<K>, u64) {
174 let cht_key = Arc::clone(key);
175 let hash = self.waiters.hash(&cht_key);
176 (cht_key, hash)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use std::{
183 sync::atomic::{AtomicUsize, Ordering},
184 thread,
185 };
186
187 use super::*;
188
189 #[test]
190 fn test_concurrent() {
191 let initializer: Arc<ConcurrentInitializer<String, u64>> =
192 Arc::new(ConcurrentInitializer::new());
193 let store = Arc::new(AtomicUsize::new(0));
194
195 let threads: Vec<_> = (0..16_u8)
197 .map(|thread_id| {
198 let my_initializer = initializer.clone();
199 let my_store = store.clone();
200
201 thread::spawn(move || {
202 println!("Thread {} started.", thread_id);
203
204 let value: InitResult<u64, std::io::Error> = my_initializer.try_get_or_init(
208 &Arc::new("key1".to_owned()),
209 || {
210 let size = my_store.load(Ordering::SeqCst) as u64;
211 if size > 0 {
212 return Ok(Some(size));
213 } else {
214 return Ok(None);
215 }
216 },
217 || {
218 println!("The init closure called by thread {}.", thread_id);
219 let size = std::fs::metadata("./Cargo.toml")?.len();
220 my_store.store(size as usize, Ordering::SeqCst);
221 Ok(size)
222 },
223 );
224
225 assert_matches!(value, InitResult::Initialized(_) | InitResult::ReadExisting(_));
227
228 println!("Thread {} got the value. (len: {:?})", thread_id, value);
229 })
230 })
231 .collect();
232
233 threads.into_iter().for_each(|t| t.join().expect("Thread failed"));
235 }
236}