use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;
pub trait ConfigCell<T: 'static>: Clone + 'static {
fn new(value: T) -> Self;
fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R;
fn update<E>(&self, f: impl FnOnce(&mut T) -> Result<(), E>) -> Result<(), E>;
}
pub struct LocalConfigCell<T>(Rc<RefCell<Rc<T>>>);
impl<T> Clone for LocalConfigCell<T> {
fn clone(&self) -> Self {
Self(Rc::clone(&self.0))
}
}
impl<T: Clone + 'static> ConfigCell<T> for LocalConfigCell<T> {
fn new(value: T) -> Self {
Self(Rc::new(RefCell::new(Rc::new(value))))
}
#[inline]
fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
let guard = self.0.borrow();
f(&guard)
}
fn update<E>(&self, f: impl FnOnce(&mut T) -> Result<(), E>) -> Result<(), E> {
let current = Rc::clone(&self.0.borrow());
let mut next = (*current).clone();
f(&mut next)?;
*self.0.borrow_mut() = Rc::new(next);
Ok(())
}
}
pub struct ArcSwapConfigCell<T>(Arc<ArcSwapConfigInner<T>>);
struct ArcSwapConfigInner<T> {
value: arc_swap::ArcSwap<T>,
writer: parking_lot::Mutex<()>,
}
impl<T> Clone for ArcSwapConfigCell<T> {
fn clone(&self) -> Self {
Self(Arc::clone(&self.0))
}
}
impl<T: Clone + 'static> ConfigCell<T> for ArcSwapConfigCell<T> {
fn new(value: T) -> Self {
Self(Arc::new(ArcSwapConfigInner {
value: arc_swap::ArcSwap::from_pointee(value),
writer: parking_lot::Mutex::new(()),
}))
}
#[inline]
fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
let guard = self.0.value.load();
f(&guard)
}
fn update<E>(&self, f: impl FnOnce(&mut T) -> Result<(), E>) -> Result<(), E> {
let _writer = self.0.writer.lock();
let mut next = (**self.0.value.load()).clone();
f(&mut next)?;
self.0.value.store(Arc::new(next));
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_update_visible<Cell: ConfigCell<i32>>() {
let cell = Cell::new(1);
cell.update::<()>(|v| {
*v = 42;
Ok(())
})
.expect("update succeeds");
assert_eq!(cell.with(|v| *v), 42);
}
fn assert_err_keeps_prior<Cell: ConfigCell<i32>>() {
let cell = Cell::new(7);
let result = cell.update(|v| {
*v = 99;
Err("rejected")
});
assert_eq!(result, Err("rejected"));
assert_eq!(cell.with(|v| *v), 7);
}
fn assert_clone_shares_underlying<Cell: ConfigCell<i32>>() {
let cell = Cell::new(0);
let other = cell.clone();
cell.update::<()>(|v| {
*v = 5;
Ok(())
})
.expect("update succeeds");
assert_eq!(other.with(|v| *v), 5);
}
#[test]
fn local_update_visible_through_with() {
assert_update_visible::<LocalConfigCell<i32>>();
}
#[test]
fn local_err_keeps_prior_value() {
assert_err_keeps_prior::<LocalConfigCell<i32>>();
}
#[test]
fn local_clone_shares_underlying() {
assert_clone_shares_underlying::<LocalConfigCell<i32>>();
}
#[test]
fn arc_swap_update_visible_through_with() {
assert_update_visible::<ArcSwapConfigCell<i32>>();
}
#[test]
fn arc_swap_err_keeps_prior_value() {
assert_err_keeps_prior::<ArcSwapConfigCell<i32>>();
}
#[test]
fn arc_swap_clone_shares_underlying() {
assert_clone_shares_underlying::<ArcSwapConfigCell<i32>>();
}
#[test]
fn arc_swap_panic_in_update_keeps_prior_value() {
let cell = ArcSwapConfigCell::new(3);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
cell.update::<()>(|v| {
*v = 100;
panic!("boom");
})
.expect("unreachable: closure panics");
}));
assert!(result.is_err());
cell.update::<()>(|v| {
*v = 4;
Ok(())
})
.expect("writer lock is released during unwind");
assert_eq!(cell.with(|v| *v), 4);
}
#[test]
fn arc_swap_concurrent_store_load_no_torn_read() {
use std::sync::Arc;
use std::thread;
#[derive(Clone)]
struct Paired {
high: u64,
low: u64,
}
const ROUNDS: u64 = 5_000;
let cell = Arc::new(ArcSwapConfigCell::new(Paired { high: 0, low: 0 }));
thread::scope(|scope| {
let writer = Arc::clone(&cell);
scope.spawn(move || {
for round in 1..=ROUNDS {
writer
.update::<()>(|v| {
v.high = round;
v.low = round;
Ok(())
})
.expect("update succeeds");
}
});
for _ in 0..4 {
let reader = Arc::clone(&cell);
scope.spawn(move || {
for _ in 0..ROUNDS {
let torn = reader.with(|v| v.high != v.low);
assert!(!torn, "observed a torn read");
}
});
}
});
assert_eq!(cell.with(|v| (v.high, v.low)), (ROUNDS, ROUNDS));
}
#[test]
fn arc_swap_concurrent_disjoint_updates_are_not_lost() {
use std::sync::{Arc, Barrier};
use std::thread;
#[derive(Clone, Default, PartialEq, Eq, Debug)]
struct Pair {
left: bool,
right: bool,
}
let cell = Arc::new(ArcSwapConfigCell::new(Pair::default()));
let writer = cell.0.writer.lock();
let start = Arc::new(Barrier::new(3));
thread::scope(|scope| {
let left_cell = Arc::clone(&cell);
let left_start = Arc::clone(&start);
scope.spawn(move || {
left_start.wait();
left_cell
.update::<()>(|value| {
value.left = true;
Ok(())
})
.expect("left update succeeds");
});
let right_cell = Arc::clone(&cell);
let right_start = Arc::clone(&start);
scope.spawn(move || {
right_start.wait();
right_cell
.update::<()>(|value| {
value.right = true;
Ok(())
})
.expect("right update succeeds");
});
start.wait();
drop(writer);
});
assert_eq!(
cell.with(Clone::clone),
Pair {
left: true,
right: true,
}
);
}
}