miden_utils_sync/racy_lock.rs
1use alloc::boxed::Box;
2use core::{
3 fmt,
4 ops::Deref,
5 ptr,
6 sync::atomic::{AtomicPtr, Ordering},
7};
8
9/// Thread-safe, non-blocking, lazily evaluated lock with the same interface
10/// as [`std::sync::LazyLock`].
11///
12/// Concurrent threads will race to set the value atomically, and memory allocated by losing threads
13/// will be dropped immediately after they fail to set the pointer.
14///
15/// The underlying implementation is based on `once_cell::race::OnceBox` which relies on
16/// [`core::sync::atomic::AtomicPtr`] to ensure that the data race results in a single successful
17/// write to the relevant pointer, namely the first write.
18/// See <https://github.com/matklad/once_cell/blob/v1.19.0/src/race.rs#L294>.
19///
20/// Performs lazy evaluation and can be used for statics.
21pub struct RacyLock<T, F = fn() -> T> {
22 inner: AtomicPtr<T>,
23 f: F,
24}
25
26#[cfg(all(loom, test))]
27mod unsound_demo {
28 use alloc::boxed::Box;
29 use core::{
30 cell::RefCell,
31 ptr,
32 sync::atomic::{AtomicPtr, Ordering},
33 };
34
35 use loom::{hint, model::Builder, sync::Arc, thread};
36
37 // Deliberately unsound lock that ignores `T` in Sync/Send bounds to demonstrate the failure.
38 struct BadLock<T, F: Fn() -> T> {
39 inner: AtomicPtr<T>,
40 f: F,
41 }
42
43 impl<T, F: Fn() -> T> BadLock<T, F> {
44 pub const fn new(f: F) -> Self {
45 Self {
46 inner: AtomicPtr::new(ptr::null_mut()),
47 f,
48 }
49 }
50
51 pub fn force(&self) -> &T {
52 let mut p = self.inner.load(Ordering::Acquire);
53 if p.is_null() {
54 let v = (self.f)();
55 p = Box::into_raw(Box::new(v));
56 if let Err(old) = self.inner.compare_exchange(
57 ptr::null_mut(),
58 p,
59 Ordering::AcqRel,
60 Ordering::Acquire,
61 ) {
62 // Another thread won; drop our allocation and use the existing pointer
63 drop(unsafe { Box::from_raw(p) });
64 p = old;
65 }
66 }
67 unsafe { &*p }
68 }
69 }
70
71 impl<T, F: Fn() -> T> Drop for BadLock<T, F> {
72 fn drop(&mut self) {
73 let p = *self.inner.get_mut();
74 if !p.is_null() {
75 drop(unsafe { Box::from_raw(p) });
76 }
77 }
78 }
79
80 // UNSOUND: `Sync` and `Send` do not depend on `T`.
81 unsafe impl<T, F: Fn() -> T + Sync> Sync for BadLock<T, F> {}
82 unsafe impl<T, F: Fn() -> T + Send> Send for BadLock<T, F> {}
83
84 // This test demonstrates the failure mode: sharing `&RefCell<_>` across threads via
85 // an unsound `Sync` impl allows concurrent `borrow_mut`, which panics at runtime.
86 #[test]
87 #[should_panic]
88 fn bad_sync_loom_allows_cross_thread_refcell_borrow_mut_panic() {
89 let mut builder = Builder::default();
90 builder.max_duration = Some(std::time::Duration::from_secs(10));
91 builder.check(|| {
92 let lock = Arc::new(BadLock::new(|| RefCell::new(0u32)));
93 let l1 = lock.clone();
94 let l2 = lock.clone();
95
96 let t1 = thread::spawn(move || {
97 let c1 = l1.force();
98 let _g1 = c1.borrow_mut();
99 // Keep the mutable borrow alive to maximize overlap
100 for _ in 0..100 {
101 hint::spin_loop();
102 }
103 });
104
105 let t2 = thread::spawn(move || {
106 let c2 = l2.force();
107 // This will panic in schedules where t1 holds the mutable borrow
108 let _g2 = c2.borrow_mut();
109 });
110
111 let _ = t1.join();
112 let _ = t2.join();
113 });
114 }
115}
116
117impl<T, F> RacyLock<T, F>
118where
119 F: Fn() -> T,
120{
121 /// Creates a new lazy, racy value with the given initializing function.
122 pub const fn new(f: F) -> Self {
123 Self {
124 inner: AtomicPtr::new(ptr::null_mut()),
125 f,
126 }
127 }
128
129 /// Forces the evaluation of the locked value and returns a reference to
130 /// the result. This is equivalent to the [`Self::deref`].
131 ///
132 /// There is no blocking involved in this operation. Instead, concurrent
133 /// threads will race to set the underlying pointer. Memory allocated by
134 /// losing threads will be dropped immediately after they fail to set the pointer.
135 ///
136 /// This function's interface is designed around [`std::sync::LazyLock::force`] but
137 /// the implementation is derived from `once_cell::race::OnceBox::get_or_try_init`.
138 pub fn force(this: &RacyLock<T, F>) -> &T {
139 let mut ptr = this.inner.load(Ordering::Acquire);
140
141 // Pointer is not yet set, attempt to set it ourselves.
142 if ptr.is_null() {
143 // Execute the initialization function and allocate.
144 let val = (this.f)();
145 ptr = Box::into_raw(Box::new(val));
146
147 // Attempt atomic store.
148 let exchange = this.inner.compare_exchange(
149 ptr::null_mut(),
150 ptr,
151 Ordering::AcqRel,
152 Ordering::Acquire,
153 );
154
155 // Pointer already set, load.
156 if let Err(old) = exchange {
157 drop(unsafe { Box::from_raw(ptr) });
158 ptr = old;
159 }
160 }
161
162 unsafe { &*ptr }
163 }
164}
165
166impl<T: Default> Default for RacyLock<T> {
167 /// Creates a new lock that will evaluate the underlying value based on `T::default`.
168 #[inline]
169 fn default() -> RacyLock<T> {
170 RacyLock::new(T::default)
171 }
172}
173
174impl<T, F> fmt::Debug for RacyLock<T, F>
175where
176 T: fmt::Debug,
177 F: Fn() -> T,
178{
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 write!(f, "RacyLock({:?})", self.inner.load(Ordering::Relaxed))
181 }
182}
183
184impl<T, F> Deref for RacyLock<T, F>
185where
186 F: Fn() -> T,
187{
188 type Target = T;
189
190 /// Either sets or retrieves the value, and dereferences it.
191 ///
192 /// See [`Self::force`] for more details.
193 #[inline]
194 fn deref(&self) -> &T {
195 RacyLock::force(self)
196 }
197}
198
199impl<T, F> Drop for RacyLock<T, F> {
200 /// Drops the underlying pointer.
201 fn drop(&mut self) {
202 let ptr = *self.inner.get_mut();
203 if !ptr.is_null() {
204 // SAFETY: for any given value of `ptr`, we are guaranteed to have at most a single
205 // instance of `RacyLock` holding that value. Hence, synchronizing threads
206 // in `drop()` is not necessary, and we are guaranteed never to double-free.
207 // In short, since `RacyLock` doesn't implement `Clone`, the only scenario
208 // where there can be multiple instances of `RacyLock` across multiple threads
209 // referring to the same `ptr` value is when `RacyLock` is used in a static variable.
210 drop(unsafe { Box::from_raw(ptr) });
211 }
212 }
213}
214
215// Ensure `RacyLock` only implements auto-traits when it is sound to do so.
216// `Send` requires ability to move the owned initializer and the (possibly
217// newly allocated) `T` across threads safely.
218unsafe impl<T: Send, F: Send> Send for RacyLock<T, F> {}
219
220// `Sync` requires that shared access through `&self` is safe, which implies
221// both the stored `T` and the initializer `F` can be shared across threads.
222unsafe impl<T: Send + Sync, F: Send> Sync for RacyLock<T, F> {}
223
224#[cfg(test)]
225mod tests {
226 use alloc::vec::Vec;
227
228 use super::*;
229
230 #[test]
231 fn deref_default() {
232 // Lock a copy type and validate default value.
233 let lock: RacyLock<i32> = RacyLock::default();
234 assert_eq!(*lock, 0);
235 }
236
237 #[test]
238 fn deref_copy() {
239 // Lock a copy type and validate value.
240 let lock = RacyLock::new(|| 42);
241 assert_eq!(*lock, 42);
242 }
243
244 #[test]
245 fn deref_clone() {
246 // Lock a no copy type.
247 let lock = RacyLock::new(|| Vec::from([1, 2, 3]));
248
249 // Use the value so that the compiler forces us to clone.
250 let mut v = lock.clone();
251 v.push(4);
252
253 // Validate the value.
254 assert_eq!(v, Vec::from([1, 2, 3, 4]));
255 }
256
257 #[test]
258 fn deref_static() {
259 // Create a static lock.
260 static VEC: RacyLock<Vec<i32>> = RacyLock::new(|| Vec::from([1, 2, 3]));
261
262 // Validate that the address of the value does not change.
263 let addr = &*VEC as *const Vec<i32>;
264 for _ in 0..5 {
265 assert_eq!(*VEC, [1, 2, 3]);
266 assert_eq!(addr, &(*VEC) as *const Vec<i32>)
267 }
268 }
269
270 #[test]
271 fn type_inference() {
272 // Check that we can infer `T` from closure's type.
273 let _ = RacyLock::new(|| ());
274 }
275
276 #[test]
277 fn is_sync_send() {
278 fn assert_traits<T: Send + Sync>() {}
279 assert_traits::<RacyLock<Vec<i32>>>();
280 }
281
282 #[test]
283 fn is_send() {
284 fn assert_send<T: Send>() {}
285 assert_send::<RacyLock<i32>>();
286 }
287
288 #[test]
289 fn is_sync() {
290 fn assert_sync<T: Sync>() {}
291 assert_sync::<RacyLock<i32>>();
292 }
293}