1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
#![no_std]
//! This crate allows you to store a value that you can later take out atomically. As this
//! crate uses atomics, no locking is involved in taking the value out.
//!
//! As an example, you could store the [`Sender`] of an oneshot channel in an
//! [`AtomicTake`], which would allow you to notify the first time a closure is called.
//!
//! ```
//! use atomic_take::AtomicTake;
//! use futures::sync::oneshot;
//!
//! let (send, mut recv) = oneshot::channel();
//!
//! let take = AtomicTake::new(send);
//! let closure = move || {
//!     if let Some(send) = take.take() {
//!         // Notify the first time this closure is called.
//!         send.send(()).unwrap();
//!     }
//! };
//!
//! closure();
//! assert_eq!(recv.try_recv().unwrap(), Some(()));
//!
//! closure(); // This does nothing.
//! ```
//!
//! Additionally the closure above can be called concurrently from many threads. For
//! example, if you put the `AtomicTake` in an [`Arc`], you can share it between several
//! threads and receive a message from the first thread to run.
//!
//! ```
//! use std::thread;
//! use std::sync::Arc;
//! use atomic_take::AtomicTake;
//! use futures::sync::oneshot;
//!
//! let (send, mut recv) = oneshot::channel();
//!
//! // Use an Arc to share the AtomicTake between several threads.
//! let take = Arc::new(AtomicTake::new(send));
//!
//! // Spawn three threads and try to send a message from each.
//! let mut handles = Vec::new();
//! for i in 0..3 {
//!     let take_clone = Arc::clone(&take);
//!     let join_handle = thread::spawn(move || {
//!
//!         // Check if this thread is first and send a message if so.
//!         if let Some(send) = take_clone.take() {
//!             // Send the index of the thread.
//!             send.send(i).unwrap();
//!         }
//!
//!     });
//!     handles.push(join_handle);
//! }
//! // Wait for all three threads to finish.
//! for handle in handles {
//!     handle.join().unwrap();
//! }
//!
//! // After all the threads finished, try to send again.
//! if let Some(send) = take.take() {
//!     // This will definitely not happen.
//!     send.send(100).unwrap();
//! }
//!
//! // Confirm that one of the first three threads got to send the message first.
//! assert!(recv.try_recv().unwrap().unwrap() < 3);
//! ```
//!
//! This crate does not require the standard library.
//!
//! [`Sender`]: https://docs.rs/futures/0.1.29/futures/sync/oneshot/struct.Sender.html
//! [`AtomicTake`]: struct.AtomicTake.html
//! [`Arc`]: https://doc.rust-lang.org/std/sync/struct.Arc.html

use core::cell::Cell;
use core::marker::PhantomData;
use core::mem::{self, ManuallyDrop};
use core::ptr;
use core::sync::atomic::{Ordering, AtomicBool};

type PhantomUnsync = PhantomData<Cell<u8>>;

/// A container with an atomic take operation.
pub struct AtomicTake<T> {
    taken: AtomicBool,
    value: ManuallyDrop<T>,
    _unsync: PhantomUnsync,
}

impl<T> AtomicTake<T> {
    /// Create a new `AtomicTake` with the given value.
    pub fn new(value: T) -> Self {
        AtomicTake {
            taken: AtomicBool::new(false),
            value: ManuallyDrop::new(value),
            _unsync: PhantomData,
        }
    }
    /// Takes out the value from this `AtomicTake`. It is guaranteed that exactly one
    /// caller will receive the value and all others will receive `None`.
    pub fn take(&self) -> Option<T> {
        if self.taken.swap(true, Ordering::Relaxed) == false {
            unsafe {
                Some(ptr::read(&*self.value))
            }
        } else {
            None
        }
    }
    /// This methods does the same as `take`, but does not use an atomic swap.
    ///
    /// This is safe because you cannot call this method without unique access to the
    /// `AtomicTake`, so no other threads will try to take it concurrently.
    pub fn take_mut(&mut self) -> Option<T> {
        if mem::replace(self.taken.get_mut(), true) == false {
            unsafe {
                Some(ptr::read(&*self.value))
            }
        } else {
            None
        }
    }

    /// Insert a new value into the `AtomicTake` and return the previous value.
    ///
    /// This function requires unique access to ensure no other threads accesses the
    /// `AtomicTake` concurrently, as this operation cannot be performed atomically
    /// without a lock.
    pub fn insert(&mut self, value: T) -> Option<T> {
        let previous = self.take_mut();

        self.value = ManuallyDrop::new(value);
        *self.taken.get_mut() = false;

        // Could also be written as below, but this avoids running the destructor.
        // *self = AtomicTake::new(value);

        previous
    }
}

impl<T> Drop for AtomicTake<T> {
    fn drop(&mut self) {
        if !*self.taken.get_mut() {
            unsafe {
                ManuallyDrop::drop(&mut self.value);
            }
        }
    }
}

// As this api can only be used to move values between threads, Sync is not needed.
unsafe impl<T: Send> Sync for AtomicTake<T> {}

#[cfg(test)]
mod tests {
    use crate::AtomicTake;

    struct CountDrops {
        counter: *mut u32,
    }
    impl Drop for CountDrops {
        fn drop(&mut self) {
            unsafe {
                *self.counter += 1;
            }
        }
    }

    #[test]
    fn drop_calls_drop() {
        let mut counter = 0;

        let take = AtomicTake::new(CountDrops {
            counter: &mut counter,
        });
        drop(take);

        assert_eq!(counter, 1);
    }

    #[test]
    fn taken_not_dropped_twice() {
        let mut counter = 0;

        let take = AtomicTake::new(CountDrops {
            counter: &mut counter,
        });
        take.take();

        assert_eq!(counter, 1);

        drop(take);

        assert_eq!(counter, 1);
    }

    #[test]
    fn taken_mut_not_dropped_twice() {
        let mut counter = 0;

        let mut take = AtomicTake::new(CountDrops {
            counter: &mut counter,
        });
        take.take_mut();

        assert_eq!(counter, 1);

        drop(take);

        assert_eq!(counter, 1);
    }

    #[test]
    fn insert_dropped_once() {
        let mut counter1 = 0;
        let mut counter2 = 0;

        let mut take = AtomicTake::new(CountDrops {
            counter: &mut counter1,
        });
        take.insert(CountDrops {
            counter: &mut counter2,
        });
        drop(take);

        assert_eq!(counter1, 1);
        assert_eq!(counter2, 1);
    }

    #[test]
    fn insert_take() {
        let mut counter1 = 0;
        let mut counter2 = 0;

        let mut take = AtomicTake::new(CountDrops {
            counter: &mut counter1,
        });
        take.insert(CountDrops {
            counter: &mut counter2,
        });

        assert_eq!(counter1, 1);
        assert_eq!(counter2, 0);

        drop(take);

        assert_eq!(counter1, 1);
        assert_eq!(counter2, 1);
    }
}