use arc_swap::ArcSwap;
use std::sync::Arc;
#[derive(Debug)]
pub struct AtomicValue<T> {
inner: ArcSwap<T>,
}
impl<T> AtomicValue<T> {
pub fn new(value: T) -> Self {
Self {
inner: ArcSwap::from_pointee(value),
}
}
pub fn load(&self) -> Arc<T> {
self.inner.load_full()
}
pub fn store(&self, value: T) {
self.inner.store(Arc::new(value));
}
pub fn swap(&self, value: T) -> Arc<T> {
self.inner.swap(Arc::new(value))
}
pub fn store_arc(&self, value: Arc<T>) {
self.inner.store(value);
}
pub fn swap_arc(&self, value: Arc<T>) -> Arc<T> {
self.inner.swap(value)
}
}
impl<T: Clone> AtomicValue<T> {
pub fn update<F>(&self, f: F) -> T
where
F: Fn(&T) -> T,
{
use std::cell::Cell;
let last_new: Cell<Option<Arc<T>>> = Cell::new(None);
self.inner.rcu(|current| {
let new_arc = Arc::new(f(current.as_ref()));
last_new.set(Some(Arc::clone(&new_arc)));
new_arc
});
match last_new.into_inner() {
Some(arc) => (*arc).clone(),
None => unreachable!("rcu always invokes the closure at least once"),
}
}
pub fn get(&self) -> T {
(*self.load()).clone()
}
}
impl<T: Default> Default for AtomicValue<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> Clone for AtomicValue<T> {
fn clone(&self) -> Self {
Self {
inner: ArcSwap::new(self.inner.load_full()),
}
}
}
impl<T> From<T> for AtomicValue<T> {
fn from(value: T) -> Self {
Self::new(value)
}
}
impl<T> From<Arc<T>> for AtomicValue<T> {
fn from(value: Arc<T>) -> Self {
Self {
inner: ArcSwap::new(value),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_new_and_load() {
let value: AtomicValue<i32> = AtomicValue::new(42);
assert_eq!(*value.load(), 42);
}
#[test]
fn test_store() {
let value: AtomicValue<i32> = AtomicValue::new(42);
value.store(100);
assert_eq!(*value.load(), 100);
}
#[test]
fn test_swap() {
let value: AtomicValue<i32> = AtomicValue::new(42);
let old = value.swap(100);
assert_eq!(*old, 42);
assert_eq!(*value.load(), 100);
}
#[test]
fn test_store_arc() {
let value: AtomicValue<i32> = AtomicValue::new(42);
value.store_arc(Arc::new(100));
assert_eq!(*value.load(), 100);
}
#[test]
fn test_swap_arc() {
let value: AtomicValue<i32> = AtomicValue::new(42);
let old = value.swap_arc(Arc::new(100));
assert_eq!(*old, 42);
assert_eq!(*value.load(), 100);
}
#[test]
fn test_update() {
let value: AtomicValue<i32> = AtomicValue::new(42);
let new_value = value.update(|v| v + 1);
assert_eq!(new_value, 43);
assert_eq!(*value.load(), 43);
}
#[test]
fn test_get() {
let value: AtomicValue<String> = AtomicValue::new("hello".to_string());
let cloned: String = value.get();
assert_eq!(cloned, "hello");
}
#[test]
fn test_default() {
let value: AtomicValue<i32> = AtomicValue::default();
assert_eq!(*value.load(), 0);
let value: AtomicValue<String> = AtomicValue::default();
assert_eq!(*value.load(), "");
}
#[test]
fn test_clone() {
let value1: AtomicValue<i32> = AtomicValue::new(42);
let value2 = value1.clone();
assert_eq!(*value1.load(), 42);
assert_eq!(*value2.load(), 42);
value1.store(100);
assert_eq!(*value1.load(), 100);
assert_eq!(*value2.load(), 42);
}
#[test]
fn test_from_value() {
let value: AtomicValue<i32> = AtomicValue::from(42);
assert_eq!(*value.load(), 42);
}
#[test]
fn test_from_arc() {
let arc = Arc::new(42);
let value: AtomicValue<i32> = AtomicValue::from(arc);
assert_eq!(*value.load(), 42);
}
#[test]
fn test_with_string() {
let value: AtomicValue<String> = AtomicValue::new("initial".to_string());
assert_eq!(value.load().as_ref(), "initial");
value.store("updated".to_string());
assert_eq!(value.load().as_ref(), "updated");
}
#[test]
fn test_with_struct() {
#[derive(Debug, Clone, PartialEq)]
struct Config {
host: String,
port: u16,
}
let config = Config {
host: "localhost".to_string(),
port: 8080,
};
let value: AtomicValue<Config> = AtomicValue::new(config);
assert_eq!(value.load().host, "localhost");
assert_eq!(value.load().port, 8080);
value.store(Config {
host: "0.0.0.0".to_string(),
port: 9090,
});
assert_eq!(value.load().host, "0.0.0.0");
assert_eq!(value.load().port, 9090);
}
#[test]
fn test_concurrent_reads() {
let value: Arc<AtomicValue<i32>> = Arc::new(AtomicValue::new(42));
let handles: Vec<_> = (0..10)
.map(|_| {
let value = Arc::clone(&value);
thread::spawn(move || {
for _ in 0..1000 {
let _ = value.load();
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(*value.load(), 42);
}
#[test]
fn test_concurrent_writes() {
let value: Arc<AtomicValue<i32>> = Arc::new(AtomicValue::new(0));
let handles: Vec<_> = (0..10)
.map(|i| {
let value = Arc::clone(&value);
thread::spawn(move || {
for _ in 0..100 {
value.store(i);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let final_value = *value.load();
assert!((0..10).contains(&final_value));
}
#[test]
fn test_concurrent_updates() {
let value: Arc<AtomicValue<i32>> = Arc::new(AtomicValue::new(0));
let handles: Vec<_> = (0..10)
.map(|_| {
let value = Arc::clone(&value);
thread::spawn(move || {
for _ in 0..100 {
value.update(|v| v + 1);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let final_value = *value.load();
assert_eq!(final_value, 1000);
}
}