atomic_take/lib.rs
1#![no_std]
2#![allow(clippy::bool_comparison)]
3//! This crate allows you to store a value that you can later take out atomically. As this
4//! crate uses atomics, no locking is involved in taking the value out.
5//!
6//! As an example, you could store the [`Sender`] of an oneshot channel in an
7//! [`AtomicTake`], which would allow you to notify the first time a closure is called.
8//!
9//! ```
10//! use atomic_take::AtomicTake;
11//! use tokio::sync::oneshot;
12//!
13//! let (send, mut recv) = oneshot::channel();
14//!
15//! let take = AtomicTake::new(send);
16//! let closure = move || {
17//! if let Some(send) = take.take() {
18//! // Notify the first time this closure is called.
19//! send.send(()).unwrap();
20//! }
21//! };
22//!
23//! closure();
24//! assert!(recv.try_recv().is_ok());
25//!
26//! closure(); // This does nothing.
27//! ```
28//!
29//! Additionally the closure above can be called concurrently from many threads. For
30//! example, if you put the `AtomicTake` in an [`Arc`], you can share it between several
31//! threads and receive a message from the first thread to run.
32//!
33//! ```
34//! use std::thread;
35//! use std::sync::Arc;
36//! use atomic_take::AtomicTake;
37//! use tokio::sync::oneshot;
38//!
39//! let (send, mut recv) = oneshot::channel();
40//!
41//! // Use an Arc to share the AtomicTake between several threads.
42//! let take = Arc::new(AtomicTake::new(send));
43//!
44//! // Spawn three threads and try to send a message from each.
45//! let mut handles = Vec::new();
46//! for i in 0..3 {
47//! let take_clone = Arc::clone(&take);
48//! let join_handle = thread::spawn(move || {
49//!
50//! // Check if this thread is first and send a message if so.
51//! if let Some(send) = take_clone.take() {
52//! // Send the index of the thread.
53//! send.send(i).unwrap();
54//! }
55//!
56//! });
57//! handles.push(join_handle);
58//! }
59//! // Wait for all three threads to finish.
60//! for handle in handles {
61//! handle.join().unwrap();
62//! }
63//!
64//! // After all the threads finished, try to send again.
65//! if let Some(send) = take.take() {
66//! // This will definitely not happen.
67//! send.send(100).unwrap();
68//! }
69//!
70//! // Confirm that one of the first three threads got to send the message first.
71//! assert!(recv.try_recv().unwrap() < 3);
72//! ```
73//!
74//! This crate does not require the standard library.
75//!
76//! [`Sender`]: https://docs.rs/tokio/latest/tokio/sync/oneshot/struct.Sender.html
77//! [`AtomicTake`]: struct.AtomicTake.html
78//! [`Arc`]: https://doc.rust-lang.org/std/sync/struct.Arc.html
79
80use core::cell::Cell;
81use core::marker::PhantomData;
82use core::mem::{self, MaybeUninit};
83use core::ptr;
84use core::sync::atomic::{AtomicBool, Ordering};
85
86use core::fmt;
87
88type PhantomUnsync = PhantomData<Cell<u8>>;
89
90/// A container with an atomic take operation.
91pub struct AtomicTake<T> {
92 taken: AtomicBool,
93 value: MaybeUninit<T>,
94 _unsync: PhantomUnsync,
95}
96
97impl<T> AtomicTake<T> {
98 /// Create a new `AtomicTake` with the given value.
99 pub const fn new(value: T) -> Self {
100 AtomicTake {
101 taken: AtomicBool::new(false),
102 value: MaybeUninit::new(value),
103 _unsync: PhantomData,
104 }
105 }
106 /// Create an empty `AtomicTake` that contains no value.
107 pub const fn empty() -> Self {
108 AtomicTake {
109 taken: AtomicBool::new(true),
110 value: MaybeUninit::uninit(),
111 _unsync: PhantomData,
112 }
113 }
114 /// Takes out the value from this `AtomicTake`. It is guaranteed that exactly one
115 /// caller will receive the value and all others will receive `None`.
116 pub fn take(&self) -> Option<T> {
117 if self.taken.swap(true, Ordering::Relaxed) == false {
118 unsafe { Some(ptr::read(self.value.as_ptr())) }
119 } else {
120 None
121 }
122 }
123 /// This methods does the same as `take`, but does not use an atomic swap.
124 ///
125 /// This is safe because you cannot call this method without unique access to the
126 /// `AtomicTake`, so no other threads will try to take it concurrently.
127 pub fn take_mut(&mut self) -> Option<T> {
128 if mem::replace(self.taken.get_mut(), true) == false {
129 unsafe { Some(ptr::read(self.value.as_ptr())) }
130 } else {
131 None
132 }
133 }
134
135 /// Check whether the value is taken. Note that if this returns `false`, then this
136 /// is immediately stale if another thread could be concurrently trying to take it.
137 pub fn is_taken(&self) -> bool {
138 self.taken.load(Ordering::Relaxed)
139 }
140
141 /// Insert a new value into the `AtomicTake` and return the previous value.
142 ///
143 /// This function requires unique access to ensure no other threads accesses the
144 /// `AtomicTake` concurrently, as this operation cannot be performed atomically
145 /// without a lock.
146 pub fn insert(&mut self, value: T) -> Option<T> {
147 let previous = self.take_mut();
148
149 unsafe {
150 ptr::write(self.value.as_mut_ptr(), value);
151 *self.taken.get_mut() = false;
152 }
153
154 // Could also be written as below, but this avoids running the destructor.
155 // *self = AtomicTake::new(value);
156
157 previous
158 }
159}
160
161impl<T> Drop for AtomicTake<T> {
162 fn drop(&mut self) {
163 if !*self.taken.get_mut() {
164 unsafe {
165 ptr::drop_in_place(self.value.as_mut_ptr());
166 }
167 }
168 }
169}
170
171// As this api can only be used to move values between threads, Sync is not needed.
172unsafe impl<T: Send> Sync for AtomicTake<T> {}
173
174impl<T> From<T> for AtomicTake<T> {
175 fn from(t: T) -> AtomicTake<T> {
176 AtomicTake::new(t)
177 }
178}
179
180impl<T> fmt::Debug for AtomicTake<T> {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 if self.is_taken() {
183 write!(f, "Empty AtomicTake")
184 } else {
185 write!(f, "Non-Empty AtomicTake")
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use crate::AtomicTake;
193
194 struct CountDrops {
195 counter: *mut u32,
196 }
197 impl Drop for CountDrops {
198 fn drop(&mut self) {
199 unsafe {
200 *self.counter += 1;
201 }
202 }
203 }
204
205 struct PanicOnDrop;
206 impl Drop for PanicOnDrop {
207 fn drop(&mut self) {
208 panic!("Panic on drop called.");
209 }
210 }
211
212 #[test]
213 fn drop_calls_drop() {
214 let mut counter = 0;
215
216 let take = AtomicTake::new(CountDrops {
217 counter: &mut counter,
218 });
219 drop(take);
220
221 assert_eq!(counter, 1);
222 }
223
224 #[test]
225 fn taken_not_dropped_twice() {
226 let mut counter = 0;
227
228 let take = AtomicTake::new(CountDrops {
229 counter: &mut counter,
230 });
231 take.take();
232
233 assert_eq!(counter, 1);
234
235 drop(take);
236
237 assert_eq!(counter, 1);
238 }
239
240 #[test]
241 fn taken_mut_not_dropped_twice() {
242 let mut counter = 0;
243
244 let mut take = AtomicTake::new(CountDrops {
245 counter: &mut counter,
246 });
247 take.take_mut();
248
249 assert_eq!(counter, 1);
250
251 drop(take);
252
253 assert_eq!(counter, 1);
254 }
255
256 #[test]
257 fn insert_dropped_once() {
258 let mut counter1 = 0;
259 let mut counter2 = 0;
260
261 let mut take = AtomicTake::new(CountDrops {
262 counter: &mut counter1,
263 });
264 assert!(!take.is_taken());
265 take.insert(CountDrops {
266 counter: &mut counter2,
267 });
268 assert!(!take.is_taken());
269 drop(take);
270
271 assert_eq!(counter1, 1);
272 assert_eq!(counter2, 1);
273 }
274
275 #[test]
276 fn insert_take() {
277 let mut counter1 = 0;
278 let mut counter2 = 0;
279
280 let mut take = AtomicTake::new(CountDrops {
281 counter: &mut counter1,
282 });
283 take.insert(CountDrops {
284 counter: &mut counter2,
285 });
286
287 assert!(!take.is_taken());
288
289 assert_eq!(counter1, 1);
290 assert_eq!(counter2, 0);
291
292 drop(take);
293
294 assert_eq!(counter1, 1);
295 assert_eq!(counter2, 1);
296 }
297
298 #[test]
299 fn empty_no_drop() {
300 let take: AtomicTake<PanicOnDrop> = AtomicTake::empty();
301 assert!(take.is_taken());
302 drop(take);
303 }
304
305 #[test]
306 fn empty_insert() {
307 let mut take = AtomicTake::empty();
308
309 assert!(take.is_taken());
310
311 let mut counter = 0;
312
313 take.insert(CountDrops {
314 counter: &mut counter,
315 });
316
317 assert!(!take.is_taken());
318
319 drop(take);
320
321 assert_eq!(counter, 1);
322 }
323}