use futures::future;
use futures::future::FutureExt;
use futures::future::Shared;
use futures::future::TryFutureExt;
use log::debug;
use log::error;
use log::info;
use std::clone::Clone;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::RwLock;
use thiserror::Error;
#[derive(Debug, Error, Clone)]
pub enum ExpiryGetError {
#[error("Poisoned lock: {0}")]
PoisonedLock(String),
#[error("not initialized")]
NotInitialized,
#[error("Inner update future failed: {0}")]
InnerFutureFailed(String),
#[error("Update failed: {0}")]
UpdateFailed(String),
}
pub type ExpiryFut<T> = Pin<Box<dyn Future<Output = Result<T, ExpiryGetError>> + Send>>;
macro_rules! poison_err_future {
($e:ident) => {{
error!("poisoned lock: {}", $e);
future::err(ExpiryGetError::PoisonedLock($e.to_string()).into())
.boxed()
.shared()
}};
}
pub trait Expiry {
fn valid(&self) -> bool;
}
#[derive(Clone)]
struct Fu<T: Expiry + Clone + Send + 'static> {
pub f: Shared<ExpiryFut<T>>,
}
pub trait Provider<T: Expiry + Clone + 'static> {
fn update(&self) -> ExpiryFut<T>;
}
pub struct RemoteStore<T: Expiry + Clone + Sync + Send + 'static, P: Provider<T> + 'static> {
provider: Arc<P>,
remote: Arc<RwLock<Fu<T>>>,
inflight: Arc<Mutex<bool>>,
}
impl<T: Expiry + Clone + Sync + Send + 'static, P: Provider<T> + 'static> Clone
for RemoteStore<T, P>
{
fn clone(&self) -> Self {
RemoteStore {
provider: Arc::clone(&self.provider),
remote: Arc::clone(&self.remote),
inflight: Arc::clone(&self.inflight),
}
}
}
impl<T: Expiry + Clone + Sync + Send + 'static, P: Provider<T> + Sync + Send + 'static>
RemoteStore<T, P>
{
#[allow(clippy::mutex_atomic)]
pub fn new(p: P) -> Self {
info!("initializing from remote");
let remote = Arc::new(RwLock::new(Fu {
f: FutureExt::shared(p.update()),
}));
RemoteStore {
provider: Arc::new(p),
remote,
inflight: Arc::new(Mutex::new(false)),
}
}
fn update(self) -> Shared<ExpiryFut<T>> {
debug!("lock for updating from remote");
match self.inflight.lock() {
Ok(mut lock) => {
if !*lock {
*lock = true;
info!("updating from remote");
match self.remote.write() {
Ok(mut r) => {
let unlock = Arc::clone(&self.inflight);
let f = self.provider.update().map(move |f| {
if let Ok(mut unlock) = unlock.lock() {
*unlock = false;
}
f
});
*r = Fu {
f: FutureExt::shared(f.boxed()),
};
}
Err(e) => return poison_err_future!(e),
}
} else {
debug!("update ongoing")
}
match self.remote.read() {
Ok(r) => r.f.clone(),
Err(e) => return poison_err_future!(e),
}
}
Err(e) => poison_err_future!(e),
}
}
fn get_or_update(self, t: T) -> ExpiryFut<T> {
debug!("get or update");
if t.valid() {
debug!("reusing cached data");
future::ok::<T, ExpiryGetError>(t).boxed()
} else {
self.update()
.map_err(|e| ExpiryGetError::InnerFutureFailed(e.to_string()))
.boxed()
}
}
pub fn get(&self) -> ExpiryFut<T> {
let s = (*self).clone();
match self.remote.read() {
Ok(ref f) => {
debug!("reading remote");
f.f.clone()
.and_then(move |item| s.get_or_update(item))
.boxed()
}
Err(e) => future::err::<T, ExpiryGetError>(ExpiryGetError::PoisonedLock(e.to_string()))
.boxed(),
}
}
}
#[cfg(test)]
mod test_timed {
extern crate chrono;
extern crate futures_timer;
use super::*;
use chrono::DateTime;
use chrono::Utc;
use futures::executor::block_on;
use futures_timer::Delay;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use std::thread;
use std::time::Duration;
struct P1 {
counter: Arc<AtomicI64>,
}
#[derive(Clone)]
struct E1 {
expire: DateTime<Utc>,
payload: String,
}
impl Expiry for E1 {
fn valid(&self) -> bool {
self.expire > Utc::now()
}
}
fn check_ok_and_expiry<T: Expiry + Clone + 'static>(t: Result<T, ExpiryGetError>) {
assert!(t.is_ok());
let t = t.unwrap();
assert!(t.valid());
}
impl Provider<E1> for P1 {
fn update(&self) -> ExpiryFut<E1> {
let c = Arc::clone(&self.counter);
Delay::new(Duration::from_millis(10))
.map(move |_| {
c.fetch_add(1, Ordering::SeqCst);
Ok::<E1, ExpiryGetError>(E1 {
expire: Utc::now() + chrono::Duration::milliseconds(200),
payload: String::from("foobar"),
})
})
.map_err(Into::into)
.into_future()
.boxed()
}
}
fn check_counter(counter: &Arc<AtomicI64>, should: i64) {
assert_eq!(counter.load(Ordering::SeqCst), should);
}
#[test]
fn threaded_with_ttl() {
let provider = P1 {
counter: Arc::new(AtomicI64::default()),
};
let counter = Arc::clone(&provider.counter);
let rs = RemoteStore::new(provider);
let c = block_on(rs.get());
check_ok_and_expiry(c);
let mut threads = vec![];
for _ in 0..10 {
let rs_c = rs.clone();
let child = thread::spawn(move || async move {
thread::sleep(Duration::from_millis(50));
let c = rs_c.get().await;
check_ok_and_expiry(c);
});
threads.push(child);
}
let c = block_on(rs.get());
check_ok_and_expiry(c);
assert!(threads.into_iter().map(|c| c.join()).all(|j| j.is_ok()));
check_counter(&counter, 1);
thread::sleep(Duration::from_millis(300));
check_counter(&counter, 1);
let rs_c = rs.clone();
let child = thread::spawn(move || {
let c = block_on(rs_c.get());
check_ok_and_expiry(c);
});
let c = block_on(rs.get());
check_ok_and_expiry(c);
assert!(child.join().is_ok());
check_counter(&counter, 2);
}
#[test]
fn many_threads() {
let rs = RemoteStore::new(P1 {
counter: Arc::new(AtomicI64::default()),
});
let c = block_on(rs.get());
check_ok_and_expiry(c);
let mut threads = vec![];
for i in 0..30 {
let rs_c = rs.clone();
let child = thread::spawn(move || {
thread::sleep(Duration::from_millis(i * 10));
let c = block_on(rs_c.get());
check_ok_and_expiry(c);
});
threads.push(child);
}
assert!(threads.into_iter().map(|c| c.join()).all(|j| j.is_ok()));
}
#[test]
fn many_threads_same_delay() {
let rs = RemoteStore::new(P1 {
counter: Arc::new(AtomicI64::default()),
});
let c = block_on(rs.get());
check_ok_and_expiry(c);
let mut threads = vec![];
for _ in 0..30 {
let rs_c = rs.clone();
let child = thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
let c = block_on(rs_c.get());
check_ok_and_expiry(c);
});
threads.push(child);
}
assert!(threads.into_iter().map(|c| c.join()).all(|j| j.is_ok()));
}
}
#[cfg(test)]
mod test_async {
use super::*;
use futures::future::join_all;
use futures_timer::Delay;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use std::time::Duration;
struct P2 {
counter: Arc<AtomicI64>,
valid_for: i64,
}
#[derive(Clone)]
struct E2 {
counter: Arc<AtomicI64>,
payload: usize,
valid_for: i64,
}
impl Expiry for E2 {
fn valid(&self) -> bool {
let c = self.counter.load(Ordering::SeqCst);
self.valid_for > c
}
}
impl Provider<E2> for P2 {
fn update(&self) -> ExpiryFut<E2> {
let c = Arc::clone(&self.counter);
let valid_for = self.valid_for;
Delay::new(Duration::from_millis(10))
.map(move |_| {
c.fetch_add(1, Ordering::SeqCst);
Ok::<E2, ExpiryGetError>(E2 {
counter: Arc::new(AtomicI64::new(0)),
payload: 0,
valid_for,
})
})
.map_err(Into::into)
.into_future()
.boxed()
}
}
fn check_and_increment(t: Result<E2, ExpiryGetError>) {
assert!(t.is_ok());
let t = t.unwrap();
(*t.counter).fetch_add(1, Ordering::SeqCst);
}
fn check_counter(counter: &Arc<AtomicI64>, should: i64) {
assert_eq!(counter.load(Ordering::SeqCst), should);
}
#[tokio::test]
async fn async_with_counter() {
let provider = P2 {
counter: Arc::new(AtomicI64::default()),
valid_for: 20,
};
let counter = Arc::clone(&provider.counter);
let rs = RemoteStore::new(provider);
let mut futs = vec![];
for _ in 0..19 {
let rs_c = rs.clone();
let fut = async move {
let c = rs_c.get().await;
check_and_increment(c);
};
futs.push(fut);
}
let c = rs.get().await;
check_and_increment(c);
join_all(futs).await;
check_counter(&counter, 1);
let mut futs = vec![];
for _ in 0..3 {
let rs_c = rs.clone();
let fut = async move {
let c = rs_c.get().await;
check_and_increment(c);
};
futs.push(fut);
}
join_all(futs).await;
check_counter(&counter, 2);
}
}
#[cfg(test)]
mod test_counted {
use super::*;
use futures::executor::block_on;
use futures_timer::Delay;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use std::thread;
use std::time::Duration;
struct P2 {
counter: Arc<AtomicI64>,
valid_for: i64,
}
#[derive(Clone)]
struct E2 {
counter: Arc<AtomicI64>,
payload: usize,
valid_for: i64,
}
impl Expiry for E2 {
fn valid(&self) -> bool {
let c = self.counter.load(Ordering::SeqCst);
self.valid_for > c
}
}
impl Provider<E2> for P2 {
fn update(&self) -> ExpiryFut<E2> {
let c = Arc::clone(&self.counter);
let valid_for = self.valid_for;
Delay::new(Duration::from_millis(10))
.map(move |_| {
c.fetch_add(1, Ordering::SeqCst);
Ok::<E2, ExpiryGetError>(E2 {
counter: Arc::new(AtomicI64::new(0)),
payload: 0,
valid_for,
})
})
.map_err(Into::into)
.into_future()
.boxed()
}
}
fn check_and_increment(t: Result<E2, ExpiryGetError>) {
assert!(t.is_ok());
let t = t.unwrap();
(*t.counter).fetch_add(1, Ordering::SeqCst);
}
fn check_counter(counter: &Arc<AtomicI64>, should: i64) {
assert_eq!(counter.load(Ordering::SeqCst), should);
}
#[test]
fn threaded_with_counter() {
let provider = P2 {
counter: Arc::new(AtomicI64::default()),
valid_for: 10,
};
let counter = Arc::clone(&provider.counter);
let rs = RemoteStore::new(provider);
let c = block_on(rs.get());
check_and_increment(c);
check_counter(&counter, 1);
let mut threads = vec![];
for _ in 0..8 {
let rs_c = rs.clone();
let child = thread::spawn(move || {
let c = block_on(rs_c.get());
check_and_increment(c);
});
threads.push(child);
}
let c = block_on(rs.get());
check_and_increment(c);
assert!(threads.into_iter().map(|c| c.join()).all(|j| j.is_ok()));
check_counter(&counter, 1);
let rs_c = rs.clone();
let child = thread::spawn(move || {
let c = block_on(rs_c.get());
check_and_increment(c);
});
let c = block_on(rs.get());
check_and_increment(c);
assert!(child.join().is_ok());
check_counter(&counter, 2);
}
}
#[cfg(test)]
mod test_failing {
use super::*;
use futures::executor::block_on;
use std::thread;
use std::time::Duration;
#[derive(Clone)]
struct P3 {}
#[derive(Clone)]
struct E3 {}
impl Expiry for E3 {
fn valid(&self) -> bool {
true
}
}
impl Provider<E3> for P3 {
fn update(&self) -> ExpiryFut<E3> {
future::err(ExpiryGetError::NotInitialized.into())
.into_future()
.boxed()
}
}
#[test]
fn many_threads_fail() {
let rs = RemoteStore::new(P3 {});
let _ = block_on(rs.get());
let mut threads = vec![];
for i in 0..30 {
let rs_c = rs.clone();
let child = thread::spawn(move || {
thread::sleep(Duration::from_millis(i * 10));
let c = block_on(rs_c.get());
assert!(c.is_err());
});
threads.push(child);
}
assert!(threads.into_iter().map(|c| c.join()).all(|j| j.is_ok()));
}
}