#![no_std]
use core::cell::Cell;
use core::marker::PhantomData;
use core::mem::{self, ManuallyDrop};
use core::ptr;
use core::sync::atomic::{Ordering, AtomicBool};
type PhantomUnsync = PhantomData<Cell<u8>>;
pub struct AtomicTake<T> {
taken: AtomicBool,
value: ManuallyDrop<T>,
_unsync: PhantomUnsync,
}
impl<T> AtomicTake<T> {
pub fn new(value: T) -> Self {
AtomicTake {
taken: AtomicBool::new(false),
value: ManuallyDrop::new(value),
_unsync: PhantomData,
}
}
pub fn take(&self) -> Option<T> {
if self.taken.swap(true, Ordering::Relaxed) == false {
unsafe {
Some(ptr::read(&*self.value))
}
} else {
None
}
}
pub fn take_mut(&mut self) -> Option<T> {
if mem::replace(self.taken.get_mut(), true) == false {
unsafe {
Some(ptr::read(&*self.value))
}
} else {
None
}
}
pub fn insert(&mut self, value: T) -> Option<T> {
let previous = self.take_mut();
self.value = ManuallyDrop::new(value);
*self.taken.get_mut() = false;
previous
}
}
impl<T> Drop for AtomicTake<T> {
fn drop(&mut self) {
if !*self.taken.get_mut() {
unsafe {
ManuallyDrop::drop(&mut self.value);
}
}
}
}
unsafe impl<T: Send> Sync for AtomicTake<T> {}
#[cfg(test)]
mod tests {
use crate::AtomicTake;
struct CountDrops {
counter: *mut u32,
}
impl Drop for CountDrops {
fn drop(&mut self) {
unsafe {
*self.counter += 1;
}
}
}
#[test]
fn drop_calls_drop() {
let mut counter = 0;
let take = AtomicTake::new(CountDrops {
counter: &mut counter,
});
drop(take);
assert_eq!(counter, 1);
}
#[test]
fn taken_not_dropped_twice() {
let mut counter = 0;
let take = AtomicTake::new(CountDrops {
counter: &mut counter,
});
take.take();
assert_eq!(counter, 1);
drop(take);
assert_eq!(counter, 1);
}
#[test]
fn taken_mut_not_dropped_twice() {
let mut counter = 0;
let mut take = AtomicTake::new(CountDrops {
counter: &mut counter,
});
take.take_mut();
assert_eq!(counter, 1);
drop(take);
assert_eq!(counter, 1);
}
#[test]
fn insert_dropped_once() {
let mut counter1 = 0;
let mut counter2 = 0;
let mut take = AtomicTake::new(CountDrops {
counter: &mut counter1,
});
take.insert(CountDrops {
counter: &mut counter2,
});
drop(take);
assert_eq!(counter1, 1);
assert_eq!(counter2, 1);
}
#[test]
fn insert_take() {
let mut counter1 = 0;
let mut counter2 = 0;
let mut take = AtomicTake::new(CountDrops {
counter: &mut counter1,
});
take.insert(CountDrops {
counter: &mut counter2,
});
assert_eq!(counter1, 1);
assert_eq!(counter2, 0);
drop(take);
assert_eq!(counter1, 1);
assert_eq!(counter2, 1);
}
}