use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Add, Sub, BitAnd, BitOr, BitXor};
use crate::core::smutex::SMutex;
pub struct Atomic<T> {
mutex: SMutex,
value: UnsafeCell<T>,
}
unsafe impl<T: Send> Sync for Atomic<T> {}
unsafe impl<T: Send> Send for Atomic<T> {}
impl<T> Atomic<T> {
pub fn new(value: T) -> Self {
Self {
mutex: SMutex::new(),
value: UnsafeCell::new(value),
}
}
pub fn into_inner(self) -> T {
self.value.into_inner()
}
pub fn get_mut(&mut self) -> &mut T {
self.value.get_mut()
}
pub fn with<F, R>(&self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
let _guard = self.mutex.lock_group();
unsafe { f(&*self.value.get()) }
}
pub fn with_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
let _guard = self.mutex.lock();
unsafe { f(&mut *self.value.get()) }
}
pub fn replace_with<F>(&self, f: F) -> T
where
F: FnOnce(&T) -> T,
{
let _guard = self.mutex.lock();
unsafe {
let old = std::ptr::read(self.value.get());
let new = f(&old);
std::ptr::write(self.value.get(), new);
old
}
}
pub fn update<F>(&self, f: F)
where
F: FnOnce(&mut T),
{
let _guard = self.mutex.lock();
unsafe { f(&mut *self.value.get()) }
}
pub fn try_update<F>(&self, f: F) -> Result<(), ()>
where
F: FnOnce(&mut T) -> bool,
{
let _guard = self.mutex.lock();
unsafe {
if f(&mut *self.value.get()) {
Ok(())
} else {
Err(())
}
}
}
}
impl<T: Clone> Atomic<T> {
pub fn load(&self) -> T {
let _guard = self.mutex.lock_group();
unsafe { (*self.value.get()).clone() }
}
pub fn store(&self, val: T) {
let _guard = self.mutex.lock();
unsafe {
*self.value.get() = val;
}
}
pub fn swap(&self, val: T) -> T {
let _guard = self.mutex.lock();
unsafe {
let old = (*self.value.get()).clone();
*self.value.get() = val;
old
}
}
pub fn take_default(&self) -> T
where
T: Default,
{
self.swap(T::default())
}
}
impl<T: Copy> Atomic<T> {
pub fn load_copy(&self) -> T {
let _guard = self.mutex.lock_group();
unsafe { *self.value.get() }
}
pub fn swap_copy(&self, val: T) -> T {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = val;
old
}
}
}
impl<T: Clone + PartialEq> Atomic<T> {
pub fn compare_exchange(&self, current: T, new: T) -> Result<T, T> {
let _guard = self.mutex.lock();
unsafe {
let val = (*self.value.get()).clone();
if val == current {
*self.value.get() = new;
Ok(val)
} else {
Err(val)
}
}
}
pub fn compare_exchange_weak(&self, current: T, new: T) -> Result<T, T> {
self.compare_exchange(current, new)
}
pub fn compare_and_swap(&self, current: T, new: T) -> T {
match self.compare_exchange(current, new) {
Ok(val) | Err(val) => val,
}
}
pub fn fetch_update<F>(&self, mut f: F) -> Result<T, T>
where
F: FnMut(&T) -> Option<T>,
{
let _guard = self.mutex.lock();
unsafe {
let old = (*self.value.get()).clone();
if let Some(new) = f(&old) {
*self.value.get() = new;
Ok(old)
} else {
Err(old)
}
}
}
}
impl<T: Copy + Add<Output = T>> Atomic<T> {
pub fn fetch_add(&self, val: T) -> T {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = old + val;
old
}
}
}
impl<T: Copy + Sub<Output = T>> Atomic<T> {
pub fn fetch_sub(&self, val: T) -> T {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = old - val;
old
}
}
}
impl<T: Copy + BitAnd<Output = T>> Atomic<T> {
pub fn fetch_and_bits(&self, val: T) -> T {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = old & val;
old
}
}
}
impl<T: Copy + BitOr<Output = T>> Atomic<T> {
pub fn fetch_or_bits(&self, val: T) -> T {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = old | val;
old
}
}
}
impl<T: Copy + BitXor<Output = T>> Atomic<T> {
pub fn fetch_xor_bits(&self, val: T) -> T {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = old ^ val;
old
}
}
}
impl<T: Copy + PartialOrd> Atomic<T> {
pub fn fetch_max(&self, val: T) -> T {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
if val > old {
*self.value.get() = val;
}
old
}
}
pub fn fetch_min(&self, val: T) -> T {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
if val < old {
*self.value.get() = val;
}
old
}
}
}
impl Atomic<bool> {
pub fn fetch_and(&self, val: bool) -> bool {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = old && val;
old
}
}
pub fn fetch_or(&self, val: bool) -> bool {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = old || val;
old
}
}
pub fn fetch_xor(&self, val: bool) -> bool {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = old != val;
old
}
}
pub fn fetch_nand(&self, val: bool) -> bool {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = !(old && val);
old
}
}
}
macro_rules! impl_nand {
($($t:ty),*) => {
$(
impl Atomic<$t> {
pub fn fetch_nand(&self, val: $t) -> $t {
let _guard = self.mutex.lock();
unsafe {
let old = *self.value.get();
*self.value.get() = !(old & val);
old
}
}
}
)*
};
}
impl_nand!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize);
impl Atomic<String> {
pub fn push_str(&self, s: &str) {
self.update(|string| string.push_str(s));
}
pub fn clear(&self) {
self.update(|string| string.clear());
}
pub fn len(&self) -> usize {
self.with(|string| string.len())
}
pub fn is_empty(&self) -> bool {
self.with(|string| string.is_empty())
}
}
impl<T: Clone> Atomic<Vec<T>> {
pub fn push(&self, value: T) {
self.update(|vec| vec.push(value));
}
pub fn pop(&self) -> Option<T> {
self.with_mut(|vec| vec.pop())
}
pub fn len(&self) -> usize {
self.with(|vec| vec.len())
}
pub fn is_empty(&self) -> bool {
self.with(|vec| vec.is_empty())
}
pub fn clear(&self) {
self.update(|vec| vec.clear());
}
pub fn extend<I>(&self, iter: I)
where
I: IntoIterator<Item = T>,
{
self.update(|vec| vec.extend(iter));
}
}
impl<T: Clone> Atomic<Option<T>> {
pub fn is_some(&self) -> bool {
self.with(|opt| opt.is_some())
}
pub fn is_none(&self) -> bool {
self.with(|opt| opt.is_none())
}
pub fn take(&self) -> Option<T> {
self.with_mut(|opt| opt.take())
}
pub fn replace(&self, value: T) -> Option<T> {
self.with_mut(|opt| opt.replace(value))
}
}
impl<T: Clone + fmt::Debug> fmt::Debug for Atomic<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let value = self.load();
f.debug_struct("Atomic")
.field("value", &value)
.finish()
}
}
impl<T: Default> Default for Atomic<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> From<T> for Atomic<T> {
fn from(value: T) -> Self {
Self::new(value)
}
}