concurrent_initializer/
lib.rs

1#[cfg(test)]
2#[macro_use]
3extern crate assert_matches;
4
5mod cht;
6
7use std::{
8  any::Any,
9  collections::hash_map::RandomState,
10  hash::{BuildHasher, Hash},
11  sync::Arc,
12};
13
14use parking_lot::RwLock;
15
16const WAITER_MAP_NUM_SEGMENTS: usize = 64;
17
18#[derive(Debug)]
19pub enum InitResult<V, E> {
20  Initialized(V),
21  ReadExisting(V),
22  InitErr(Arc<E>),
23}
24
25type ErrorObject = Arc<dyn Any + Send + Sync + 'static>;
26type WaiterValue<V> = Option<Result<V, ErrorObject>>;
27type Waiter<V> = Arc<RwLock<WaiterValue<V>>>;
28
29pub struct ConcurrentInitializer<K, V, S = RandomState> {
30  waiters: crate::cht::SegmentedHashMap<Arc<K>, Waiter<V>, S>,
31}
32
33impl<K, V> ConcurrentInitializer<K, V>
34where
35  K: Eq + Hash,
36  V: Clone,
37{
38  pub fn new() -> Self {
39    Self::with_hasher(RandomState::new())
40  }
41}
42
43impl<K, V, S> ConcurrentInitializer<K, V, S>
44where
45  K: Eq + Hash,
46  V: Clone,
47  S: BuildHasher,
48{
49  pub fn with_hasher(build_hasher: S) -> Self {
50    Self {
51      waiters: cht::SegmentedHashMap::with_num_segments_and_hasher(
52        WAITER_MAP_NUM_SEGMENTS,
53        build_hasher,
54      ),
55    }
56  }
57
58  /// # Panics
59  /// Panics if the `init` closure has been panicked.
60  pub fn try_get_or_init<E>(
61    &self,
62    key: &Arc<K>,
63    // Closure to get an existing value from somewhere.
64    mut get: impl FnMut() -> Result<Option<V>, E>,
65    // Closure to initialize a new value.
66    init: impl FnOnce() -> Result<V, E>,
67  ) -> InitResult<V, E>
68  where
69    E: Send + Sync + 'static,
70  {
71    use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
72
73    use InitResult::*;
74
75    const MAX_RETRIES: usize = 200;
76    let mut retries = 0;
77
78    let (cht_key, hash) = self.cht_key_hash(key);
79
80    loop {
81      let waiter = Arc::new(RwLock::new(None));
82      let mut lock = waiter.write();
83
84      match self.try_insert_waiter(cht_key.clone(), hash, &waiter) {
85        None => {
86          // Our waiter was inserted.
87          // Check if the value has already been inserted by other thread.
88          match get() {
89            Ok(ok) => {
90              if let Some(value) = ok {
91                // Yes. Set the waiter value, remove our waiter, and return
92                // the existing value.
93                *lock = Some(Ok(value.clone()));
94                self.remove_waiter(cht_key, hash);
95                return InitResult::ReadExisting(value);
96              }
97            }
98            Err(err) => {
99              // Error. Set the waiter value, remove our waiter, and return
100              // the error.
101              let err: ErrorObject = Arc::new(err);
102              *lock = Some(Err(Arc::clone(&err)));
103              self.remove_waiter(cht_key, hash);
104              return InitErr(err.downcast().unwrap());
105            }
106          }
107
108          // The value still does not exist. Let's evaluate the init
109          // closure. Catching panic is safe here as we do not try to
110          // evaluate the closure again.
111          match catch_unwind(AssertUnwindSafe(init)) {
112            // Evaluated.
113            Ok(value) => {
114              let (waiter_val, init_res) = match value {
115                Ok(value) => (Some(Ok(value.clone())), InitResult::Initialized(value)),
116                Err(e) => {
117                  let err: ErrorObject = Arc::new(e);
118                  (Some(Err(Arc::clone(&err))), InitResult::InitErr(err.downcast().unwrap()))
119                }
120              };
121              *lock = waiter_val;
122              self.remove_waiter(cht_key, hash);
123              return init_res;
124            }
125            // Panicked.
126            Err(payload) => {
127              *lock = None;
128              // Remove the waiter so that others can retry.
129              self.remove_waiter(cht_key, hash);
130              resume_unwind(payload);
131            }
132          } // The write lock will be unlocked here.
133        }
134        Some(res) => {
135          // Somebody else's waiter already exists. Drop our write lock and
136          // wait for the read lock to become available.
137          std::mem::drop(lock);
138          match &*res.read() {
139            Some(Ok(value)) => return ReadExisting(value.clone()),
140            Some(Err(e)) => return InitErr(Arc::clone(e).downcast().unwrap()),
141            // None means somebody else's init closure has been panicked.
142            None => {
143              retries += 1;
144              if retries < MAX_RETRIES {
145                // Retry from the beginning.
146                continue;
147              } else {
148                panic!(
149                  "Too many retries. Tried to read the return value from the `init` \
150                                closure but failed {} times. Maybe the `init` kept panicking?",
151                  retries
152                );
153              }
154            }
155          }
156        }
157      }
158    }
159  }
160
161  #[inline]
162  fn remove_waiter(&self, cht_key: Arc<K>, hash: u64) {
163    self.waiters.remove(hash, |k| k == &cht_key);
164  }
165
166  #[inline]
167  fn try_insert_waiter(&self, cht_key: Arc<K>, hash: u64, waiter: &Waiter<V>) -> Option<Waiter<V>> {
168    let waiter = Arc::clone(waiter);
169    self.waiters.insert_if_not_present(cht_key, hash, waiter)
170  }
171
172  #[inline]
173  fn cht_key_hash(&self, key: &Arc<K>) -> (Arc<K>, u64) {
174    let cht_key = Arc::clone(key);
175    let hash = self.waiters.hash(&cht_key);
176    (cht_key, hash)
177  }
178}
179
180#[cfg(test)]
181mod tests {
182  use std::{
183    sync::atomic::{AtomicUsize, Ordering},
184    thread,
185  };
186
187  use super::*;
188
189  #[test]
190  fn test_concurrent() {
191    let initializer: Arc<ConcurrentInitializer<String, u64>> =
192      Arc::new(ConcurrentInitializer::new());
193    let store = Arc::new(AtomicUsize::new(0));
194
195    // Spawn four threads.
196    let threads: Vec<_> = (0..16_u8)
197      .map(|thread_id| {
198        let my_initializer = initializer.clone();
199        let my_store = store.clone();
200
201        thread::spawn(move || {
202          println!("Thread {} started.", thread_id);
203
204          // Try to insert and get the value for key1. Although all four
205          // threads will call `try_get_or_init` at the same time,
206          // the init closure must be called only once.
207          let value: InitResult<u64, std::io::Error> = my_initializer.try_get_or_init(
208            &Arc::new("key1".to_owned()),
209            || {
210              let size = my_store.load(Ordering::SeqCst) as u64;
211              if size > 0 {
212                return Ok(Some(size));
213              } else {
214                return Ok(None);
215              }
216            },
217            || {
218              println!("The init closure called by thread {}.", thread_id);
219              let size = std::fs::metadata("./Cargo.toml")?.len();
220              my_store.store(size as usize, Ordering::SeqCst);
221              Ok(size)
222            },
223          );
224
225          // Ensure the value exists now.
226          assert_matches!(value, InitResult::Initialized(_) | InitResult::ReadExisting(_));
227
228          println!("Thread {} got the value. (len: {:?})", thread_id, value);
229        })
230      })
231      .collect();
232
233    // Wait all threads to complete.
234    threads.into_iter().for_each(|t| t.join().expect("Thread failed"));
235  }
236}