atomic_try_update/
once.rs

1//! A wait-free alternative to `std::sync::OnceLock`, with helper methods that make it easier to
2//! correctly register state at startup.
3use std::{error::Error, fmt::Display, ptr::null_mut};
4
5use num_enum::{IntoPrimitive, TryFromPrimitive};
6
7use crate::{
8    atomic_try_update,
9    bits::{Align8, FlagPtr},
10    Atom,
11};
12
13#[derive(IntoPrimitive, TryFromPrimitive)]
14#[repr(usize)]
15enum Lifecycle {
16    NotSet = 0,
17    Setting,
18    Set,
19    Dead,
20}
21
22/// Not exposed in external API.  We panic on the field `UseAfterFreeBug`, and map
23/// everything else to `OnceLockFreeError` before returning it to callers.
24enum OnceLockFreeInternalError {
25    AlreadySet,
26    AttemptToReadWhenUnset,
27    AttemptToSetConcurrently,
28    UseAfterFreeBug,
29    UnpreparedForSet,
30}
31
32#[derive(Debug, PartialEq, Eq)]
33pub enum OnceLockFreeError {
34    AlreadySet,
35    AttemptToReadWhenUnset,
36    AttemptToSetConcurrently,
37    UnpreparedForSet,
38}
39
40impl Error for OnceLockFreeError {}
41
42impl Display for OnceLockFreeError {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{self:?}")
45    }
46}
47
48fn panic_on_memory_bug(err: OnceLockFreeInternalError) -> OnceLockFreeError {
49    match err {
50        OnceLockFreeInternalError::AlreadySet => OnceLockFreeError::AlreadySet,
51        OnceLockFreeInternalError::AttemptToReadWhenUnset => {
52            OnceLockFreeError::AttemptToReadWhenUnset
53        }
54        OnceLockFreeInternalError::AttemptToSetConcurrently => {
55            OnceLockFreeError::AttemptToSetConcurrently
56        }
57        OnceLockFreeInternalError::UseAfterFreeBug => {
58            panic!("Encountered use-after-free in OnceLockFree");
59        }
60        OnceLockFreeInternalError::UnpreparedForSet => OnceLockFreeError::UnpreparedForSet,
61    }
62}
63
64#[derive(Default)]
65struct OnceLockFreeState<T> {
66    flag_ptr: FlagPtr<Align8<T>>,
67}
68
69/// A wait-free alternative to `std::sync::OnceLock`
70///
71/// This includes a few special purpose helper methods for various use cases.  The main advanatge
72/// of these helpers is that they map unexpected states into `OnceLockFreeError` values.
73///
74/// The helper methods are designed to be used in pairs:
75///
76/// If you need to wait until a value has been registered, use `get_poll` to read it,
77/// and `set` to set it.
78///
79/// If you need want to set a value exactly once, wait until everying is set, and then later read
80/// the value you have a few options.
81///
82/// If you need to memoize the result, use `get_or_prepare_to_set()` to check to see if the value
83/// has been set, and then use `set_prepared()` to install the value.  Do this in a way that
84/// guarantees that callers will not race to set the value.  After all the sets have completed, you
85/// can use `get()` or `get_or_prepare_to_set()` to read values that must be present.
86///
87/// If you want to guarantee that no setters succeed after the first `get()`, and don't guarantee that
88/// all values are set by the time initialization completes, use `get_or_seal()`.
89pub struct OnceLockFree<T> {
90    inner: Atom<OnceLockFreeState<T>, u64>,
91}
92
93impl<'a, T> OnceLockFree<T> {
94    /// Creates a new empty cell.
95    pub fn new() -> Self {
96        Default::default()
97    }
98
99    pub fn get_or_prepare_to_set(&'a self) -> Result<Option<&'a T>, OnceLockFreeError> {
100        unsafe {
101            Ok(
102                atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
103                    Ok(Lifecycle::NotSet) => {
104                        s.flag_ptr.set_flag(Lifecycle::Setting.into());
105                        (true, Ok(None))
106                    }
107                    Ok(Lifecycle::Setting) => (
108                        false,
109                        Err(OnceLockFreeInternalError::AttemptToSetConcurrently),
110                    ),
111                    Ok(Lifecycle::Set) => {
112                        let ptr = s.flag_ptr.get_ptr();
113                        (false, Ok(if ptr.is_null() { None } else { Some(ptr) }))
114                    }
115                    Ok(Lifecycle::Dead) => (false, Err(OnceLockFreeInternalError::UseAfterFreeBug)),
116                    Err(_) => {
117                        panic!("torn read?")
118                    }
119                })
120                .map_err(panic_on_memory_bug)?
121                .map(|ptr| &(*ptr).inner),
122            )
123        }
124    }
125
126    /// Gets the reference to the underlying value.
127    ///
128    /// Unlike OnceCell and OnceLock, which return an ``Option<T>``, this returns
129    /// an Error if the value has not yet been set.  There are a few other
130    /// variants of get() that are appropriate for other use cases.
131    pub fn get(&'a self) -> Result<&'a T, OnceLockFreeError> {
132        match self.get_or_seal()? {
133            Some(t) => Ok(t),
134            None => Err(OnceLockFreeInternalError::AttemptToReadWhenUnset),
135        }
136        .map_err(panic_on_memory_bug)
137    }
138
139    /// Gets the reference to the underlying value, or None if the value has
140    /// not been set yet.
141    pub fn get_poll(&'a self) -> Option<&'a T> {
142        unsafe {
143            atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
144                Ok(Lifecycle::Set) => {
145                    let ptr = s.flag_ptr.get_ptr();
146                    (false, if ptr.is_null() { None } else { Some(ptr) })
147                }
148                _ => (false, None),
149            })
150            .map(|ptr| &(*ptr).inner)
151        }
152    }
153
154    /// Gets the reference to the underyling value or "seals" self so that it
155    /// can never be set to a value moving forward.
156    ///
157    /// Returns error if another thread concurrently prepares self, and during shutdown.
158    pub fn get_or_seal(&'a self) -> Result<Option<&'a T>, OnceLockFreeError> {
159        unsafe {
160            Ok(
161                atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
162                    Ok(Lifecycle::NotSet) => {
163                        s.flag_ptr.set_flag(Lifecycle::Set.into());
164                        s.flag_ptr.set_ptr(null_mut());
165                        (true, Ok(None))
166                    }
167                    Ok(Lifecycle::Setting) => (
168                        false,
169                        Err(OnceLockFreeInternalError::AttemptToSetConcurrently),
170                    ),
171                    Ok(Lifecycle::Set) => {
172                        let ptr = s.flag_ptr.get_ptr();
173                        (false, Ok(if ptr.is_null() { None } else { Some(ptr) }))
174                    }
175                    Ok(Lifecycle::Dead) => (false, Err(OnceLockFreeInternalError::UseAfterFreeBug)),
176                    Err(_) => {
177                        panic!("torn read?")
178                    }
179                })
180                .map_err(panic_on_memory_bug)?
181                .map(|ptr| (&(*ptr).inner)),
182            )
183        }
184    }
185    /// set the value after a call to get_or_prepare_to_set returned None.  This is done in
186    /// two phases so that racing sets are more likely to be noticed, and to help callers
187    /// improve error messages when that happens.
188    ///
189    /// TODO: Add an Error state, and transition into it when racing get_or_prepare_to_set
190    ///       calls occur.
191    ///
192    /// Returns error if already set, or if we haven't been prepared
193    pub fn set_prepared(&'a self, val: T) -> Result<&'a T, OnceLockFreeError> {
194        // This ensures the ptr is 8-byte aligned (or more), so that flag_ptr can steal
195        // the three least significant bits
196        let ptr: *mut Align8<T> = Box::into_raw(Box::new(val.into()));
197        unsafe {
198            atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
199                Ok(Lifecycle::NotSet) => (false, Err(OnceLockFreeInternalError::UnpreparedForSet)),
200                Ok(Lifecycle::Setting) => {
201                    s.flag_ptr.set_flag(Lifecycle::Set.into());
202                    s.flag_ptr.set_ptr(ptr);
203                    (true, Ok(()))
204                }
205                Ok(Lifecycle::Set) => (false, Err(OnceLockFreeInternalError::AlreadySet)),
206                Ok(Lifecycle::Dead) => (false, Err(OnceLockFreeInternalError::UseAfterFreeBug)),
207                Err(_) => {
208                    panic!("torn read?")
209                }
210            })
211            .map_err(panic_on_memory_bug)?;
212            Ok(&(*ptr).inner)
213        }
214    }
215    /// Set this to the provided value.  Wait free.
216    ///
217    /// Returns Error if we've been prepared, or set already, and a reference to the stored val on success.
218    pub fn set(&'a self, val: T) -> Result<&'a T, OnceLockFreeError> {
219        let ptr: *mut Align8<T> = Box::into_raw(Box::new(val.into()));
220        unsafe {
221            atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
222                Ok(Lifecycle::NotSet) => {
223                    s.flag_ptr.set_flag(Lifecycle::Set.into());
224                    s.flag_ptr.set_ptr(ptr);
225                    (true, Ok(()))
226                }
227                Ok(Lifecycle::Setting) => (
228                    false,
229                    Err(OnceLockFreeInternalError::AttemptToSetConcurrently),
230                ),
231                Ok(Lifecycle::Set) => (false, Err(OnceLockFreeInternalError::AlreadySet)),
232                Ok(Lifecycle::Dead) => (false, Err(OnceLockFreeInternalError::UseAfterFreeBug)),
233                Err(_) => {
234                    panic!("torn read?")
235                }
236            })
237            .map_err(panic_on_memory_bug)?;
238            Ok(&(*ptr).inner)
239        }
240    }
241}
242
243impl<T> Default for OnceLockFree<T> {
244    fn default() -> Self {
245        Self {
246            inner: Default::default(),
247        }
248    }
249}
250
251impl<T> Drop for OnceLockFree<T> {
252    fn drop(&mut self) {
253        unsafe {
254            match atomic_try_update(&self.inner, |s| {
255                match s.flag_ptr.get_flag().try_into() {
256                    Ok(Lifecycle::NotSet) => {
257                        s.flag_ptr.set_flag(Lifecycle::Dead.into());
258                        (true, Ok(None))
259                    }
260                    Ok(Lifecycle::Setting) => {
261                        s.flag_ptr.set_flag(Lifecycle::Dead.into());
262                        (true, Ok(None))
263                    }
264                    Ok(Lifecycle::Set) => {
265                        s.flag_ptr.set_flag(Lifecycle::Dead.into());
266                        let ptr = s.flag_ptr.get_ptr();
267                        (
268                            true,
269                            if ptr.is_null() {
270                                Ok(None)
271                            } else {
272                                Ok(Some(ptr))
273                            },
274                        )
275                    }
276                    Ok(Lifecycle::Dead) => {
277                        // TODO: report double free (as a panic outside the atomic_try_update)
278                        (false, Err(OnceLockFreeInternalError::UseAfterFreeBug))
279                        // don't want to double free!
280                    }
281                    Err(_) => {
282                        (true, Ok(None)) // CAS from torn read should fail.
283                    }
284                }
285            })
286            .map_err(panic_on_memory_bug)
287            .unwrap()
288            {
289                None => (),
290                Some(ptr) => {
291                    let _drop = Box::from_raw(ptr);
292                }
293            };
294        }
295    }
296}