Skip to main content

nu_utils/sync/
keyed_lazy_lock.rs

1use std::{
2    collections::HashMap,
3    hash::Hash,
4    sync::{LazyLock, OnceLock},
5};
6
7use parking_lot::RwLock;
8
9/// Lazily initializes values per key.
10///
11/// The first call to [`KeyedLazyLock::get`] for a key creates the value using `init`.
12/// Later calls return the same value.
13///
14/// Initialization for each key happens at most once.
15pub struct KeyedLazyLock<K, V, F = fn(&K) -> V> {
16    // Why `Box<OnceLock<V>>`
17    //
18    // Each key stores its own OnceLock. We allocate it in a Box so the address
19    // stays stable even if the HashMap grows and relocates entries.
20    //
21    // This lets us:
22    // 1. Grab a raw pointer to the OnceLock.
23    // 2. Drop the map lock.
24    // 3. Initialize the value outside the lock.
25    //
26    // Without the Box, the OnceLock could move during a HashMap resize,
27    // invalidating the pointer.
28    map: LazyLock<RwLock<HashMap<K, Box<OnceLock<V>>>>>,
29    init: F,
30}
31
32impl<K, V, F> KeyedLazyLock<K, V, F> {
33    pub const fn new(init: F) -> Self {
34        Self {
35            map: LazyLock::new(|| RwLock::new(HashMap::new())),
36            init,
37        }
38    }
39}
40
41impl<K, V, F> KeyedLazyLock<K, V, F>
42where
43    K: Eq + Hash + Clone,
44    F: Fn(&K) -> V,
45{
46    /// Returns the lazily initialized value for `key`.
47    ///
48    /// If the key has not been accessed before, `init(key)` will run exactly once.
49    /// Concurrent callers requesting the same key will wait for initialization.
50    ///
51    /// # Deadlocks
52    /// `init` must not call `get` with the same key.
53    pub fn get(&self, key: &K) -> &V {
54        // Fast path: try to find the cell with a read lock.
55        if let Some(cell_ptr) = self.try_get_cell_ptr(key) {
56            // SAFETY:
57            // - The pointer refers to a OnceLock stored inside a Box in the map.
58            // - Entries are never removed, so the Box lives until self is dropped.
59            // - Moving the Box inside the HashMap does not move the allocation.
60            let cell = unsafe { &*cell_ptr };
61
62            // init runs outside the map lock.
63            return cell.get_or_init(|| (self.init)(key));
64        }
65
66        // Slow path: insert the cell.
67        let cell_ptr = {
68            let mut write = self.map.write();
69
70            // Another thread may have inserted it already.
71            let cell_box = write
72                .entry(key.clone())
73                .or_insert_with(|| Box::new(OnceLock::new()));
74
75            // Grab pointer so we can drop the lock before initialization.
76            (&**cell_box) as *const OnceLock<V>
77        };
78
79        // SAFETY: same reasoning as above.
80        let cell = unsafe { &*cell_ptr };
81        cell.get_or_init(|| (self.init)(key))
82    }
83
84    #[inline]
85    fn try_get_cell_ptr(&self, key: &K) -> Option<*const OnceLock<V>> {
86        let read = self.map.read();
87        read.get(key)
88            .map(|cell_box| (&**cell_box) as *const OnceLock<V>)
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::KeyedLazyLock;
95    use std::sync::{
96        Arc, Barrier,
97        atomic::{AtomicUsize, Ordering},
98    };
99
100    #[test]
101    fn initializes_once_per_key() {
102        let counter = AtomicUsize::new(0);
103        let lock = KeyedLazyLock::new(|_: &String| {
104            counter.fetch_add(1, Ordering::SeqCst);
105            42
106        });
107
108        let key = String::from("alpha");
109        let first = lock.get(&key);
110        let second = lock.get(&key);
111
112        assert_eq!(*first, 42);
113        assert!(std::ptr::eq(first, second));
114        assert_eq!(counter.load(Ordering::SeqCst), 1);
115    }
116
117    #[test]
118    fn initializes_once_with_concurrent_callers() {
119        let counter = Arc::new(AtomicUsize::new(0));
120        let lock = Arc::new(KeyedLazyLock::new({
121            let counter = Arc::clone(&counter);
122            move |_: &String| {
123                counter.fetch_add(1, Ordering::SeqCst);
124                7
125            }
126        }));
127
128        let barrier = Arc::new(Barrier::new(8));
129        let mut handles = Vec::new();
130
131        for _ in 0..8 {
132            let lock = Arc::clone(&lock);
133            let barrier = Arc::clone(&barrier);
134            handles.push(std::thread::spawn(move || {
135                barrier.wait();
136                let key = String::from("shared");
137                let value = lock.get(&key);
138                assert_eq!(*value, 7);
139            }));
140        }
141
142        for handle in handles {
143            match handle.join() {
144                Ok(()) => {}
145                Err(_) => panic!("thread panicked"),
146            }
147        }
148
149        assert_eq!(counter.load(Ordering::SeqCst), 1);
150    }
151
152    #[test]
153    fn initializes_each_key_separately() {
154        let counter = AtomicUsize::new(0);
155        let lock = KeyedLazyLock::new(|_: &String| {
156            counter.fetch_add(1, Ordering::SeqCst);
157            1
158        });
159
160        let keys = ["a", "b", "c", "d"]
161            .into_iter()
162            .map(String::from)
163            .collect::<Vec<_>>();
164        for key in &keys {
165            let value = lock.get(key);
166            assert_eq!(*value, 1);
167        }
168
169        assert_eq!(counter.load(Ordering::SeqCst), keys.len());
170    }
171
172    #[test]
173    fn retains_value_address_after_rehash() {
174        let lock = KeyedLazyLock::new(|key: &String| key.len());
175        let seed = String::from("seed");
176        let first = lock.get(&seed) as *const usize;
177
178        for index in 0..1500 {
179            let key = format!("key-{index}");
180            let _ = lock.get(&key);
181        }
182
183        let second = lock.get(&seed) as *const usize;
184        assert!(std::ptr::eq(first, second));
185    }
186}