use arc_swap::ArcSwap;
use std::sync::Arc;
pub struct RouterCache<T> {
inner: ArcSwap<T>,
}
impl<T> RouterCache<T> {
pub fn new(initial: T) -> Self {
Self {
inner: ArcSwap::from_pointee(initial),
}
}
pub fn load(&self) -> Arc<T> {
self.inner.load_full()
}
pub fn store(&self, new_router: T) {
self.inner.store(Arc::new(new_router));
}
#[allow(dead_code)]
pub fn with_current<R>(&self, f: impl FnOnce(&T) -> R) -> R {
let current = self.inner.load();
f(&*current)
}
#[allow(dead_code)]
pub fn update_if<P>(&self, predicate: P, new_router: T) -> bool
where
P: FnOnce(&T) -> bool,
{
let current = self.inner.load();
if predicate(&*current) {
self.store(new_router);
true
} else {
false
}
}
#[allow(dead_code, clippy::needless_pass_by_value)] pub fn compare_and_swap(&self, expected: Arc<T>, new_router: T) -> Result<(), Arc<T>>
where
T: PartialEq,
{
let new_arc = Arc::new(new_router);
let result = self.inner.compare_and_swap(&expected, new_arc);
if Arc::ptr_eq(&*result, &expected) {
Ok(())
} else {
Err(result.clone())
}
}
}
impl<T> Clone for RouterCache<T> {
fn clone(&self) -> Self {
Self {
inner: ArcSwap::new(self.inner.load_full()),
}
}
}
impl<T> Default for RouterCache<T>
where
T: Default,
{
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> std::fmt::Debug for RouterCache<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let current = self.inner.load();
f.debug_struct("RouterCache")
.field("current", &*current)
.finish()
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Duration;
#[derive(Debug, Clone, PartialEq)]
struct TestRouter {
id: usize,
name: String,
}
impl TestRouter {
fn new(id: usize, name: &str) -> Self {
Self {
id,
name: name.to_owned(),
}
}
}
#[test]
fn test_basic_load_store() {
let initial = TestRouter::new(1, "initial");
let cache = RouterCache::new(initial.clone());
let loaded = cache.load();
assert_eq!(*loaded, initial);
let new_router = TestRouter::new(2, "updated");
cache.store(new_router.clone());
let loaded = cache.load();
assert_eq!(*loaded, new_router);
}
#[test]
fn test_concurrent_reads() {
let initial = TestRouter::new(1, "concurrent_test");
let cache = Arc::new(RouterCache::new(initial));
let mut handles = vec![];
for i in 0..10 {
let cache_clone = Arc::clone(&cache);
let handle = thread::spawn(move || {
for _ in 0..100 {
let router = cache_clone.load();
assert_eq!(router.name, "concurrent_test");
thread::sleep(Duration::from_micros(i * 10));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_concurrent_read_write() {
let initial = TestRouter::new(0, "router_0");
let cache = Arc::new(RouterCache::new(initial));
let update_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..5 {
let cache_clone = Arc::clone(&cache);
let handle = thread::spawn(move || {
for _ in 0..50 {
let router = cache_clone.load();
assert!(router.name.starts_with("router_"));
thread::sleep(Duration::from_micros(10));
}
});
handles.push(handle);
}
for _ in 0..2 {
let cache_clone = Arc::clone(&cache);
let count_clone = Arc::clone(&update_count);
let handle = thread::spawn(move || {
for _ in 0..10 {
let id = count_clone.fetch_add(1, Ordering::SeqCst);
let new_router = TestRouter::new(id, &format!("router_{id}"));
cache_clone.store(new_router);
thread::sleep(Duration::from_millis(1));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let final_router = cache.load();
assert!(final_router.name.starts_with("router_"));
}
#[test]
fn test_with_current() {
let initial = TestRouter::new(42, "test_with_current");
let cache = RouterCache::new(initial);
let result = cache.with_current(|router| format!("{}_{}", router.name, router.id));
assert_eq!(result, "test_with_current_42");
}
#[test]
fn test_update_if() {
let initial = TestRouter::new(1, "conditional");
let cache = RouterCache::new(initial);
let new_router1 = TestRouter::new(2, "updated");
let updated = cache.update_if(|r| r.id == 1, new_router1.clone());
assert!(updated);
let current = cache.load();
assert_eq!(*current, new_router1);
let new_router2 = TestRouter::new(3, "should_not_update");
let updated = cache.update_if(|r| r.id == 5, new_router2);
assert!(!updated);
let current = cache.load();
assert_eq!(*current, new_router1); }
#[test]
fn test_compare_and_swap() {
let initial = TestRouter::new(1, "cas_test");
let cache = RouterCache::new(initial);
let current = cache.load();
let new_router = TestRouter::new(2, "cas_updated");
let result = cache.compare_and_swap(current, new_router.clone());
assert!(result.is_ok());
let updated = cache.load();
assert_eq!(*updated, new_router);
let wrong_expected = Arc::new(TestRouter::new(99, "wrong"));
let another_router = TestRouter::new(3, "cas_failed");
let result = cache.compare_and_swap(wrong_expected, another_router);
assert!(result.is_err());
let current = cache.load();
assert_eq!(*current, new_router);
}
#[test]
fn test_clone_and_debug() {
let original = TestRouter::new(1, "clone_test");
let cache1 = RouterCache::new(original.clone());
let cache2 = cache1.clone();
assert_eq!(*cache1.load(), *cache2.load());
let new_router = TestRouter::new(2, "updated");
cache1.store(new_router.clone());
assert_eq!(*cache1.load(), new_router);
assert_eq!(*cache2.load(), original);
let debug_str = format!("{cache1:?}");
assert!(debug_str.contains("RouterCache"));
assert!(debug_str.contains("updated"));
}
}