use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use arc_swap::{ArcSwap, Guard};
pub struct CacheValue<T: Sized> {
version: usize,
value: T,
}
impl<T> CacheValue<T> {
pub fn get_ref(&self) -> &T {
&self.value
}
}
pub struct Cache<T> {
value: ArcSwap<CacheValue<T>>,
is_updating: AtomicBool,
}
pub type CacheValueType<T> = Guard<Arc<CacheValue<T>>>;
impl<T> Cache<T> {
pub fn new(value: T, version: usize) -> Self {
Cache {
value: ArcSwap::new(CacheValue::<T> { version, value }.into()),
is_updating: AtomicBool::new(false),
}
}
fn finish_update(&self) {
self.is_updating.store(false, Ordering::SeqCst);
}
pub fn value(
&self,
version: usize,
f: impl FnOnce() -> T,
) -> Result<CacheValueType<T>, impl FnOnce() -> T> {
let v = self.value.load();
match v.version.cmp(&version) {
std::cmp::Ordering::Equal => Ok(v),
std::cmp::Ordering::Greater => Err(f), std::cmp::Ordering::Less => {
drop(v);
match self.is_updating.compare_exchange(
false,
true,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => {
let v = self.value.load();
match v.version.cmp(&version) {
std::cmp::Ordering::Equal => {
self.finish_update();
Ok(v)
}
std::cmp::Ordering::Greater => {
self.finish_update();
Err(f)
}
std::cmp::Ordering::Less => {
drop(v);
self.value.store(
CacheValue {
value: f(),
version,
}
.into(),
);
let v = self.value.load(); self.finish_update();
Ok(v)
}
}
}
Err(_) => Err(f),
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::{sync::Arc, time::Duration};
use super::Cache;
#[test]
fn test_cache() {
let cache = Cache::<String>::new("0".to_string(), 0);
assert_eq!(
cache
.value(0, || { "1".to_string() })
.as_ref()
.map(|v| v.get_ref().as_str())
.unwrap_or(""),
"0"
);
assert_eq!(
cache
.value(1, || { "1".to_string() })
.as_ref()
.map(|v| v.get_ref().as_str())
.unwrap_or(""),
"1"
);
assert!(cache.value(0, || { "2".to_string() }).is_err());
let cache = Arc::new(cache);
let cache2 = cache.clone();
std::thread::spawn(move || {
let res = cache2.value(2, || {
std::thread::sleep(Duration::from_secs(5));
"2".to_string()
});
assert_eq!(
res.as_ref().map(|v| v.get_ref().as_str()).unwrap_or(""),
"2"
);
});
std::thread::sleep(Duration::from_secs(1));
while cache.value(2, || "".to_string()).is_err() {
std::thread::sleep(Duration::from_secs(1));
}
assert_eq!(
cache
.value(2, || { "".to_string() })
.as_ref()
.map(|v| v.get_ref().as_str())
.unwrap_or(""),
"2"
);
}
}