take_once/
lib.rs

1//! A thread-safe container for one-time storage and one-time consumption of a value.
2//!
3//! This module provides the [`TakeOnce`] type, which enables safe storage
4//! and consumption of a value in concurrent contexts. It ensures that:
5//!
6//! * A value can be stored exactly once
7//! * A stored value can be taken out exactly once
8//! * All operations are thread-safe
9//!
10//! This is similar to [`std::sync::OnceLock`], but with different semantics regarding
11//! value consumption. `TakeOnce` allows the stored value to be taken out (moved) exactly once,
12//! whereas `OnceLock` allows the value to be accessed in place multiple times.
13//!
14//! # Example
15//!
16#![cfg_attr(feature = "_shuttle", doc = "```ignore")]
17#![cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
18//! use take_once::TakeOnce;
19//!
20//! let cell = TakeOnce::new();
21//!
22//! // Store a value
23//! assert_eq!(cell.store(42), Ok(()));
24//!
25//! // Subsequent stores do not consume the value.
26//! assert_eq!(cell.store(24), Err(24));
27//!
28//! // Take the stored value
29//! assert_eq!(cell.take(), Some(42));
30//!
31//! // Value can only be taken once
32//! assert_eq!(cell.take(), None);
33//! ```
34//!
35//! # Thread Safety
36//!
37//! `TakeOnce<T>` is both [`Send`] and [`Sync`] when `T: Send`, making it suitable for
38//! sharing across thread boundaries. All operations are atomic and properly synchronized.
39//!
40#![cfg_attr(feature = "_shuttle", doc = "```ignore")]
41#![cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
42//! # use std::sync::Arc;
43//! # use std::thread;
44//! # use take_once::TakeOnce;
45//! let shared = Arc::new(TakeOnce::new());
46//! let shared2 = shared.clone();
47//!
48//! // Store in one thread
49//! thread::spawn(move || {
50//!     shared.store(42).unwrap();
51//! }).join().unwrap();
52//!
53//! // Take in another thread
54//! thread::spawn(move || {
55//!     if let Some(value) = shared2.take() {
56//!         println!("Got value: {}", value);
57//!     }
58//! });
59
60use std::marker::PhantomData;
61
62#[cfg(feature = "_shuttle")]
63pub(crate) use shuttle::sync::{
64    atomic::{AtomicPtr, Ordering},
65    Once,
66};
67
68#[cfg(not(feature = "_shuttle"))]
69pub(crate) use std::sync::{
70    atomic::{AtomicPtr, Ordering},
71    Once,
72};
73
74/// A thread-safe container that allows storing a value once and taking it exactly once.
75/// This is useful in scenarios where:
76///
77/// * A value needs to be initialized lazily but only once
78/// * The initialized value should be consumable (moved out) rather than just accessible.
79///   See [`std::sync::OnceLock`] for a similar use case where the value can be accessed in place.
80/// * Multiple threads might attempt to initialize or consume the value, but only one should succeed
81///
82/// # Thread Safety
83///
84/// `TakeOnce<T>` implements `Send` and `Sync` when `T: Send`, making it safe to share
85/// across thread boundaries. All operations are atomic and properly synchronized.
86///
87/// # Memory Management
88///
89/// The stored value is heap-allocated and properly cleaned up when the `TakeOnce` is dropped.
90///
91/// # Examples
92///
93/// Basic usage:
94#[cfg_attr(feature = "_shuttle", doc = "```ignore")]
95#[cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
96/// use take_once::TakeOnce;
97///
98/// let cell = TakeOnce::new();
99///
100/// // Initial store succeeds
101/// assert_eq!(cell.store(42), Ok(()));
102///
103/// // Subsequent stores return the provided value
104/// assert_eq!(cell.store(24), Err(24));
105///
106/// // Take the value
107/// assert_eq!(cell.take(), Some(42));
108///
109/// // Can't take twice
110/// assert_eq!(cell.take(), None);
111/// ```
112///
113/// Concurrent usage:
114#[cfg_attr(feature = "_shuttle", doc = "```ignore")]
115#[cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
116/// use std::sync::Arc;
117/// use std::thread;
118/// use take_once::TakeOnce;
119///
120/// let shared = Arc::new(TakeOnce::new());
121/// let threads: Vec<_> = (0..3)
122///     .map(|i| {
123///         let shared = shared.clone();
124///         thread::spawn(move || {
125///             // Only one thread will successfully store
126///             shared.store(i)
127///         })
128///     })
129///     .collect();
130/// ```
131#[derive(Debug)]
132pub struct TakeOnce<T> {
133    once: Once,
134    // Whether or not the value is initialized is tracked by `once.is_completed()`.
135    value: AtomicPtr<T>,
136    _marker: PhantomData<T>,
137}
138
139impl<T> TakeOnce<T> {
140    /// Creates a new empty cell.
141    #[inline]
142    #[must_use]
143    pub const fn new() -> TakeOnce<T> {
144        TakeOnce {
145            once: Once::new(),
146            value: AtomicPtr::new(std::ptr::null_mut()),
147            _marker: PhantomData,
148        }
149    }
150
151    /// Create and store a cell in a single operation
152    ///
153    #[cfg_attr(feature = "_shuttle", doc = "```ignore")]
154    #[cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
155    /// use take_once::TakeOnce;
156    ///
157    /// let initialized = TakeOnce::new_with(true);
158    /// assert_eq!(initialized.store(false), Err(false));
159    /// assert_eq!(initialized.take(), Some(true));
160    /// ```
161    #[must_use]
162    pub fn new_with(val: T) -> TakeOnce<T> {
163        let cell = TakeOnce::new();
164        let _ = cell.store(val);
165        cell
166    }
167
168    /// Stores a value into this `TakeOnce` if it has not been initialized.
169    ///
170    /// If the `TakeOnce` has already been initialized, the value is returned as [`Err`].
171    /// Otherwise, the value is stored and [`Ok`] is returned.
172    ///
173    /// This method allocates memory on the heap to store the value.
174    #[inline]
175    pub fn store(&self, val: T) -> Result<(), T> {
176        let mut val = Some(val);
177        self.once.call_once(|| {
178            let val = val.take().unwrap();
179            let ptr = Box::into_raw(Box::new(val));
180            self.value.store(ptr, Ordering::Release);
181        });
182
183        val.map_or(Ok(()), Err)
184    }
185
186    /// Takes the value out of this `TakeOnce`, if it has been initialized.
187    ///
188    /// If the `TakeOnce` has not been initialized, or if the value has already been taken,
189    /// this method returns `None`.
190    #[inline]
191    #[must_use]
192    pub fn take(&self) -> Option<T> {
193        if self.once.is_completed() {
194            let ptr = self.value.swap(std::ptr::null_mut(), Ordering::Acquire);
195            if ptr.is_null() {
196                None
197            } else {
198                // SAFETY: `self.value` is initialized (since `self.once.is_completed()`)
199                // and has not been taken before (since `ptr` is not null).
200                Some(*unsafe { Box::from_raw(ptr) })
201            }
202        } else {
203            None
204        }
205    }
206
207    /// Returns true if the value has been initialized, regardless of whether it has been taken.
208    /// In other words, this returns true if `store` has been called at least once.
209    #[inline]
210    #[must_use]
211    pub fn is_completed(&self) -> bool {
212        self.once.is_completed()
213    }
214}
215
216impl<T> Default for TakeOnce<T> {
217    fn default() -> Self {
218        Self::new()
219    }
220}
221
222impl<T> Drop for TakeOnce<T> {
223    fn drop(&mut self) {
224        if self.once.is_completed() {
225            let ptr = self.value.swap(std::ptr::null_mut(), Ordering::Acquire);
226            if !ptr.is_null() {
227                // SAFETY: `self.value` is initialized (since `self.once.is_completed()`)
228                // and has not been taken before (since `ptr` is not null).
229                drop(unsafe { Box::from_raw(ptr) });
230            }
231        }
232    }
233}
234
235// SAFETY: `TakeOnce` is `Send` iff `T` is `Send`.
236unsafe impl<T: Send> Send for TakeOnce<T> {}
237// SAFETY: `TakeOnce` does not allow shared access to the inner value.
238unsafe impl<T: Send> Sync for TakeOnce<T> {}
239
240// #[cfg(all(test, feature = "_shuttle"))]
241#[cfg(test)]
242mod tests {
243    use super::TakeOnce;
244    use shuttle::sync::Arc;
245    use shuttle::thread;
246
247    const SHUTTLE_ITERS: usize = 10_000;
248
249    #[test]
250    fn concurrent_store_operations() {
251        shuttle::check_random(
252            || {
253                let once_take = Arc::new(TakeOnce::new());
254                let num_threads = 6;
255                let threads: Vec<_> = (0..num_threads)
256                    .map(|i| {
257                        let once_take = once_take.clone();
258                        thread::spawn(move || once_take.store(i))
259                    })
260                    .collect();
261
262                let results: Vec<_> = threads.into_iter().map(|t| t.join().unwrap()).collect();
263
264                // Exactly one store should succeed
265                assert_eq!(results.iter().filter(|r| r.is_ok()).count(), 1);
266                // All other stores should return Err
267                assert_eq!(
268                    results.iter().filter(|r| r.is_err()).count(),
269                    num_threads - 1
270                );
271            },
272            SHUTTLE_ITERS,
273        );
274    }
275
276    #[test]
277    fn concurrent_take_operations() {
278        shuttle::check_random(
279            || {
280                let once_take = Arc::new(TakeOnce::new());
281
282                // First store
283                assert_eq!(once_take.store(42), Ok(()));
284
285                // Concurrent takes
286                let threads: Vec<_> = (0..3)
287                    .map(|_| {
288                        let once_take = once_take.clone();
289                        thread::spawn(move || once_take.take())
290                    })
291                    .collect();
292
293                let results: Vec<_> = threads.into_iter().map(|t| t.join().unwrap()).collect();
294
295                // Exactly one take should succeed
296                assert_eq!(results.iter().filter(|r| r.is_some()).count(), 1);
297                // The successful take should return 42
298                assert!(results.iter().any(|r| r == &Some(42)));
299                // All other takes should return None
300                assert_eq!(results.iter().filter(|r| r.is_none()).count(), 2);
301            },
302            SHUTTLE_ITERS,
303        );
304    }
305
306    #[test]
307    fn mixed_store_take_operations() {
308        shuttle::check_random(
309            || {
310                let once_take = Arc::new(TakeOnce::new());
311
312                // Alternate between store and take
313                let num_threads = 6;
314                let threads: Vec<_> = (0..num_threads)
315                    .map(|i| {
316                        let once_take = once_take.clone();
317                        thread::spawn(move || {
318                            if i % 2 == 0 {
319                                once_take.store(i)
320                            } else {
321                                once_take.take().map_or(Err(i), |_| Ok(()))
322                            }
323                        })
324                    })
325                    .collect();
326
327                let results = threads
328                    .into_iter()
329                    .map(|t| t.join().unwrap())
330                    .collect::<Vec<_>>();
331
332                // At least one operation should succeed
333                assert!(results.iter().any(|r| r.is_ok()));
334            },
335            SHUTTLE_ITERS,
336        );
337    }
338
339    #[test]
340    fn completion_status_consistency() {
341        shuttle::check_random(
342            || {
343                let once_take = Arc::new(TakeOnce::new());
344                let _once_take2 = once_take.clone();
345
346                assert!(!once_take.is_completed());
347
348                let t1 = thread::spawn(move || {
349                    once_take.store(42).unwrap();
350                    once_take.is_completed()
351                });
352
353                let completed_after_store = t1.join().unwrap();
354                assert!(completed_after_store);
355            },
356            SHUTTLE_ITERS,
357        );
358    }
359
360    #[test]
361    fn store_take_ordering() {
362        shuttle::check_random(
363            || {
364                let once_take = Arc::new(TakeOnce::new());
365                let once_take2 = once_take.clone();
366                let once_take3 = once_take.clone();
367
368                let t1 = thread::spawn(move || {
369                    once_take.store(42).unwrap();
370                    // This release store should be visible to the take thread
371                    true
372                });
373
374                let t2 = thread::spawn(move || {
375                    // This acquire load should see the store if completed
376                    if once_take2.is_completed() {
377                        once_take2.take()
378                    } else {
379                        None
380                    }
381                });
382
383                let t3 = thread::spawn(move || {
384                    // Should never see a partial state
385                    if once_take3.is_completed() {
386                        assert!(once_take3.take().is_some() || once_take3.take().is_none());
387                    }
388                });
389
390                assert!(t1.join().unwrap());
391                t2.join().unwrap();
392                t3.join().unwrap();
393            },
394            SHUTTLE_ITERS,
395        );
396    }
397
398    #[test]
399    fn drop_consistency() {
400        shuttle::check_random(
401            || {
402                let once_take = Arc::new(TakeOnce::new());
403                let once_take2 = once_take.clone();
404
405                // Store a value that implements Drop
406                #[derive(Debug, PartialEq)]
407                struct DropTest(i32);
408                static DROPPED: shuttle::sync::Once = shuttle::sync::Once::new();
409                impl Drop for DropTest {
410                    fn drop(&mut self) {
411                        let mut called = false;
412                        DROPPED.call_once(|| called = true);
413                        assert!(called);
414                    }
415                }
416
417                once_take.store(DropTest(42)).unwrap();
418
419                let t = thread::spawn(move || {
420                    // Take the value in another thread
421                    once_take2.take()
422                });
423
424                // The value should only be dropped once.
425                let taken = t.join().unwrap();
426                if let Some(val) = taken {
427                    assert_eq!(val.0, 42);
428                    drop(val);
429                }
430
431                assert!(DROPPED.is_completed());
432            },
433            SHUTTLE_ITERS,
434        );
435    }
436}