use std::{pin::Pin, sync::Arc};
use arc_swap::ArcSwap;
use crate::{
BoxedError,
platform::{Duration, Instant, MaybeSendFuture, MaybeSendSync},
};
pub(crate) trait RefreshFactory<V>: MaybeSendSync {
fn call(&self) -> Pin<Box<dyn MaybeSendFuture<Output = Result<V, BoxedError>>>>;
}
impl<V, F> RefreshFactory<V> for F
where
F: Fn() -> Pin<Box<dyn MaybeSendFuture<Output = Result<V, BoxedError>>>> + MaybeSendSync,
{
fn call(&self) -> Pin<Box<dyn MaybeSendFuture<Output = Result<V, BoxedError>>>> {
self()
}
}
pub(crate) struct Refreshable<V> {
value: ArcSwap<V>,
factory: Box<dyn RefreshFactory<V>>,
refresh_lock: tokio::sync::Mutex<()>,
}
impl<V: std::fmt::Debug> std::fmt::Debug for Refreshable<V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Refreshable")
.field("value", &self.value)
.finish_non_exhaustive()
}
}
#[bon::bon]
impl<V: std::fmt::Debug + MaybeSendSync + 'static> Refreshable<V> {
#[builder]
pub(crate) async fn new(
factory: impl Fn() -> Pin<Box<dyn MaybeSendFuture<Output = Result<V, BoxedError>>>>
+ MaybeSendSync
+ 'static,
) -> Result<Self, BoxedError> {
let initial = factory().await?;
Ok(Self {
value: ArcSwap::from_pointee(initial),
factory: Box::new(factory),
refresh_lock: tokio::sync::Mutex::new(()),
})
}
pub(crate) async fn refresh(&self) -> Result<bool, BoxedError> {
let cur = self.value.load_full();
let _lock = self.refresh_lock.lock().await;
if !Arc::ptr_eq(&self.value.load_full(), &cur) {
return Ok(false);
}
let new_value = self.factory.call().await?;
self.value.store(Arc::new(new_value));
Ok(true)
}
pub(crate) fn load(&self) -> arc_swap::Guard<Arc<V>> {
self.value.load()
}
pub(crate) fn load_full(&self) -> Arc<V> {
self.value.load_full()
}
}
#[allow(clippy::struct_field_names)]
struct RefreshTimestamps {
last_refreshed: Instant,
last_failed_refresh: Option<Instant>,
last_refresh_attempt: Option<Instant>,
}
pub(crate) struct ScheduledRefreshable<V> {
inner: Refreshable<V>,
ttl: Duration,
failure_backoff: Duration,
min_refresh_interval: Duration,
timestamps: std::sync::Mutex<RefreshTimestamps>,
}
impl<V: std::fmt::Debug> std::fmt::Debug for ScheduledRefreshable<V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScheduledRefreshable")
.field("inner", &self.inner)
.field("ttl", &self.ttl)
.field("failure_backoff", &self.failure_backoff)
.finish_non_exhaustive()
}
}
#[bon::bon]
impl<V: std::fmt::Debug + MaybeSendSync + 'static> ScheduledRefreshable<V> {
#[builder]
pub(crate) async fn new(
factory: impl Fn() -> Pin<Box<dyn MaybeSendFuture<Output = Result<V, BoxedError>>>>
+ MaybeSendSync
+ 'static,
#[builder(default = Duration::from_hours(1))]
ttl: Duration,
#[builder(default = Duration::from_secs(30))]
failure_backoff: Duration,
#[builder(default = Duration::from_mins(1))]
min_refresh_interval: Duration,
) -> Result<Self, BoxedError> {
let inner = Refreshable::builder().factory(factory).build().await?;
Ok(Self {
inner,
ttl,
failure_backoff,
min_refresh_interval,
timestamps: std::sync::Mutex::new(RefreshTimestamps {
last_refreshed: Instant::now(),
last_failed_refresh: None,
last_refresh_attempt: None,
}),
})
}
fn should_refresh(&self) -> bool {
let now = Instant::now();
let ts = self
.timestamps
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if ts
.last_refresh_attempt
.and_then(|t| now.checked_duration_since(t))
.is_some_and(|elapsed| elapsed < self.min_refresh_interval)
{
return false;
}
if now
.checked_duration_since(ts.last_refreshed)
.is_some_and(|elapsed| elapsed < self.ttl)
{
return false;
}
if ts
.last_failed_refresh
.and_then(|t| now.checked_duration_since(t))
.is_some_and(|elapsed| elapsed < self.failure_backoff)
{
return false;
}
true
}
fn record_refresh(&self, success: bool) {
let now = Instant::now();
let mut ts = self
.timestamps
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
ts.last_refresh_attempt = Some(now);
if success {
ts.last_refreshed = now;
ts.last_failed_refresh = None;
} else {
ts.last_failed_refresh = Some(now);
}
}
pub(crate) async fn try_refresh(&self) -> bool {
if !self.should_refresh() {
return false;
}
let success = self.inner.refresh().await.is_ok();
self.record_refresh(success);
success
}
pub(crate) async fn refresh(&self) -> Result<bool, BoxedError> {
let result = self.inner.refresh().await;
self.record_refresh(result.is_ok());
result
}
pub(crate) fn load(&self) -> arc_swap::Guard<Arc<V>> {
self.inner.load()
}
pub(crate) fn load_full(&self) -> Arc<V> {
self.inner.load_full()
}
}