nu_utils/sync/
keyed_lazy_lock.rs1use std::{
2 collections::HashMap,
3 hash::Hash,
4 sync::{LazyLock, OnceLock},
5};
6
7use parking_lot::RwLock;
8
9pub struct KeyedLazyLock<K, V, F = fn(&K) -> V> {
16 map: LazyLock<RwLock<HashMap<K, Box<OnceLock<V>>>>>,
29 init: F,
30}
31
32impl<K, V, F> KeyedLazyLock<K, V, F> {
33 pub const fn new(init: F) -> Self {
34 Self {
35 map: LazyLock::new(|| RwLock::new(HashMap::new())),
36 init,
37 }
38 }
39}
40
41impl<K, V, F> KeyedLazyLock<K, V, F>
42where
43 K: Eq + Hash + Clone,
44 F: Fn(&K) -> V,
45{
46 pub fn get(&self, key: &K) -> &V {
54 if let Some(cell_ptr) = self.try_get_cell_ptr(key) {
56 let cell = unsafe { &*cell_ptr };
61
62 return cell.get_or_init(|| (self.init)(key));
64 }
65
66 let cell_ptr = {
68 let mut write = self.map.write();
69
70 let cell_box = write
72 .entry(key.clone())
73 .or_insert_with(|| Box::new(OnceLock::new()));
74
75 (&**cell_box) as *const OnceLock<V>
77 };
78
79 let cell = unsafe { &*cell_ptr };
81 cell.get_or_init(|| (self.init)(key))
82 }
83
84 #[inline]
85 fn try_get_cell_ptr(&self, key: &K) -> Option<*const OnceLock<V>> {
86 let read = self.map.read();
87 read.get(key)
88 .map(|cell_box| (&**cell_box) as *const OnceLock<V>)
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::KeyedLazyLock;
95 use std::sync::{
96 Arc, Barrier,
97 atomic::{AtomicUsize, Ordering},
98 };
99
100 #[test]
101 fn initializes_once_per_key() {
102 let counter = AtomicUsize::new(0);
103 let lock = KeyedLazyLock::new(|_: &String| {
104 counter.fetch_add(1, Ordering::SeqCst);
105 42
106 });
107
108 let key = String::from("alpha");
109 let first = lock.get(&key);
110 let second = lock.get(&key);
111
112 assert_eq!(*first, 42);
113 assert!(std::ptr::eq(first, second));
114 assert_eq!(counter.load(Ordering::SeqCst), 1);
115 }
116
117 #[test]
118 fn initializes_once_with_concurrent_callers() {
119 let counter = Arc::new(AtomicUsize::new(0));
120 let lock = Arc::new(KeyedLazyLock::new({
121 let counter = Arc::clone(&counter);
122 move |_: &String| {
123 counter.fetch_add(1, Ordering::SeqCst);
124 7
125 }
126 }));
127
128 let barrier = Arc::new(Barrier::new(8));
129 let mut handles = Vec::new();
130
131 for _ in 0..8 {
132 let lock = Arc::clone(&lock);
133 let barrier = Arc::clone(&barrier);
134 handles.push(std::thread::spawn(move || {
135 barrier.wait();
136 let key = String::from("shared");
137 let value = lock.get(&key);
138 assert_eq!(*value, 7);
139 }));
140 }
141
142 for handle in handles {
143 match handle.join() {
144 Ok(()) => {}
145 Err(_) => panic!("thread panicked"),
146 }
147 }
148
149 assert_eq!(counter.load(Ordering::SeqCst), 1);
150 }
151
152 #[test]
153 fn initializes_each_key_separately() {
154 let counter = AtomicUsize::new(0);
155 let lock = KeyedLazyLock::new(|_: &String| {
156 counter.fetch_add(1, Ordering::SeqCst);
157 1
158 });
159
160 let keys = ["a", "b", "c", "d"]
161 .into_iter()
162 .map(String::from)
163 .collect::<Vec<_>>();
164 for key in &keys {
165 let value = lock.get(key);
166 assert_eq!(*value, 1);
167 }
168
169 assert_eq!(counter.load(Ordering::SeqCst), keys.len());
170 }
171
172 #[test]
173 fn retains_value_address_after_rehash() {
174 let lock = KeyedLazyLock::new(|key: &String| key.len());
175 let seed = String::from("seed");
176 let first = lock.get(&seed) as *const usize;
177
178 for index in 0..1500 {
179 let key = format!("key-{index}");
180 let _ = lock.get(&key);
181 }
182
183 let second = lock.get(&seed) as *const usize;
184 assert!(std::ptr::eq(first, second));
185 }
186}