#![forbid(unsafe_code)]
use std::sync::{Arc, Mutex, RwLock};
use arc_swap::ArcSwap;
pub trait ReadOptimized<T: Clone + Send + Sync>: Send + Sync {
fn load(&self) -> T;
fn store(&self, val: T);
}
pub struct ArcSwapStore<T> {
inner: ArcSwap<T>,
}
impl<T: Clone + Send + Sync> ArcSwapStore<T> {
pub fn new(val: T) -> Self {
Self {
inner: ArcSwap::from_pointee(val),
}
}
pub fn load_ref(&self) -> arc_swap::Guard<Arc<T>> {
self.inner.load()
}
}
impl<T: Clone + Send + Sync> ReadOptimized<T> for ArcSwapStore<T> {
#[inline]
fn load(&self) -> T {
let guard = self.inner.load();
T::clone(&guard)
}
#[inline]
fn store(&self, val: T) {
self.inner.store(Arc::new(val));
}
}
pub struct RwLockStore<T> {
inner: RwLock<T>,
}
impl<T: Clone + Send + Sync> RwLockStore<T> {
pub fn new(val: T) -> Self {
Self {
inner: RwLock::new(val),
}
}
}
impl<T: Clone + Send + Sync> ReadOptimized<T> for RwLockStore<T> {
#[inline]
fn load(&self) -> T {
self.inner.read().expect("RwLock poisoned").clone()
}
#[inline]
fn store(&self, val: T) {
*self.inner.write().expect("RwLock poisoned") = val;
}
}
pub struct MutexStore<T> {
inner: Mutex<T>,
}
impl<T: Clone + Send + Sync> MutexStore<T> {
pub fn new(val: T) -> Self {
Self {
inner: Mutex::new(val),
}
}
}
impl<T: Clone + Send + Sync> ReadOptimized<T> for MutexStore<T> {
#[inline]
fn load(&self) -> T {
self.inner.lock().expect("Mutex poisoned").clone()
}
#[inline]
fn store(&self, val: T) {
*self.inner.lock().expect("Mutex poisoned") = val;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Barrier;
use std::thread;
#[derive(Debug, Clone, PartialEq, Eq)]
struct Config {
name: String,
value: u64,
}
fn make_config(n: u64) -> Config {
Config {
name: format!("cfg-{n}"),
value: n,
}
}
#[test]
fn arcswap_load_returns_initial_value() {
let store = ArcSwapStore::new(42u64);
assert_eq!(store.load(), 42);
}
#[test]
fn arcswap_store_then_load() {
let store = ArcSwapStore::new(0u64);
store.store(99);
assert_eq!(store.load(), 99);
}
#[test]
fn arcswap_load_ref_borrows_without_clone() {
let store = ArcSwapStore::new(make_config(1));
let guard = store.load_ref();
assert_eq!(guard.name, "cfg-1");
assert_eq!(guard.value, 1);
}
#[test]
fn arcswap_multiple_stores_last_wins() {
let store = ArcSwapStore::new(0u64);
for i in 1..=100 {
store.store(i);
}
assert_eq!(store.load(), 100);
}
#[test]
fn arcswap_concurrent_reads_never_panic() {
let store = Arc::new(ArcSwapStore::new(make_config(0)));
let barrier = Arc::new(Barrier::new(8));
let handles: Vec<_> = (0..8)
.map(|_| {
let s = Arc::clone(&store);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
for _ in 0..1000 {
let _ = s.load();
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn arcswap_concurrent_read_write() {
let store = Arc::new(ArcSwapStore::new(0u64));
let barrier = Arc::new(Barrier::new(9));
let readers: Vec<_> = (0..8)
.map(|_| {
let s = Arc::clone(&store);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
let mut last = 0u64;
for _ in 0..10_000 {
let v = s.load();
assert!(v >= last, "stale read: got {v}, expected >= {last}");
last = v;
}
})
})
.collect();
let writer = {
let s = Arc::clone(&store);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
for i in 1..=10_000u64 {
s.store(i);
}
})
};
writer.join().unwrap();
for h in readers {
h.join().unwrap();
}
assert_eq!(store.load(), 10_000);
}
#[test]
fn rwlock_load_returns_initial_value() {
let store = RwLockStore::new(42u64);
assert_eq!(store.load(), 42);
}
#[test]
fn rwlock_store_then_load() {
let store = RwLockStore::new(0u64);
store.store(99);
assert_eq!(store.load(), 99);
}
#[test]
fn rwlock_concurrent_read_write() {
let store = Arc::new(RwLockStore::new(0u64));
let barrier = Arc::new(Barrier::new(5));
let readers: Vec<_> = (0..4)
.map(|_| {
let s = Arc::clone(&store);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
for _ in 0..5_000 {
let _ = s.load();
}
})
})
.collect();
let writer = {
let s = Arc::clone(&store);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
for i in 1..=5_000u64 {
s.store(i);
}
})
};
writer.join().unwrap();
for h in readers {
h.join().unwrap();
}
assert_eq!(store.load(), 5_000);
}
#[test]
fn mutex_load_returns_initial_value() {
let store = MutexStore::new(42u64);
assert_eq!(store.load(), 42);
}
#[test]
fn mutex_store_then_load() {
let store = MutexStore::new(0u64);
store.store(99);
assert_eq!(store.load(), 99);
}
#[test]
fn mutex_concurrent_read_write() {
let store = Arc::new(MutexStore::new(0u64));
let barrier = Arc::new(Barrier::new(5));
let readers: Vec<_> = (0..4)
.map(|_| {
let s = Arc::clone(&store);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
for _ in 0..5_000 {
let _ = s.load();
}
})
})
.collect();
let writer = {
let s = Arc::clone(&store);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
for i in 1..=5_000u64 {
s.store(i);
}
})
};
writer.join().unwrap();
for h in readers {
h.join().unwrap();
}
assert_eq!(store.load(), 5_000);
}
#[test]
fn trait_object_arcswap() {
let store: Box<dyn ReadOptimized<u64>> = Box::new(ArcSwapStore::new(10));
assert_eq!(store.load(), 10);
store.store(20);
assert_eq!(store.load(), 20);
}
#[test]
fn trait_object_rwlock() {
let store: Box<dyn ReadOptimized<u64>> = Box::new(RwLockStore::new(10));
assert_eq!(store.load(), 10);
store.store(20);
assert_eq!(store.load(), 20);
}
#[test]
fn trait_object_mutex() {
let store: Box<dyn ReadOptimized<u64>> = Box::new(MutexStore::new(10));
assert_eq!(store.load(), 10);
store.store(20);
assert_eq!(store.load(), 20);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct FakeCaps {
true_color: bool,
sync_output: bool,
mouse_sgr: bool,
}
#[test]
fn arcswap_with_copy_type() {
let caps = FakeCaps {
true_color: true,
sync_output: false,
mouse_sgr: true,
};
let store = ArcSwapStore::new(caps);
assert_eq!(store.load(), caps);
let updated = FakeCaps {
true_color: true,
sync_output: true,
mouse_sgr: true,
};
store.store(updated);
assert_eq!(store.load(), updated);
}
#[test]
fn concurrent_copy_type_reads() {
let caps = FakeCaps {
true_color: true,
sync_output: false,
mouse_sgr: true,
};
let store = Arc::new(ArcSwapStore::new(caps));
let barrier = Arc::new(Barrier::new(8));
let handles: Vec<_> = (0..8)
.map(|_| {
let s = Arc::clone(&store);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
for _ in 0..10_000 {
let v = s.load();
assert!(v.true_color);
assert!(v.mouse_sgr);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
}