use std::{sync::Arc, time::Duration};
use rand::{thread_rng, Rng};
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
#[cfg(not(target_arch = "wasm32"))]
use tokio::time::sleep;
#[cfg(target_arch = "wasm32")]
use wasmtimer::tokio::sleep;
use crate::{
background::{AsyncRuntime, BackgroundRuntime},
configuration_fetcher::ConfigurationFetcher,
configuration_store::ConfigurationStore,
Error,
};
#[derive(Debug, Clone)]
pub struct ConfigurationPollerConfig {
pub interval: Duration,
pub jitter: Duration,
}
impl ConfigurationPollerConfig {
pub const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(30);
pub const DEFAULT_POLL_JITTER: Duration = Duration::from_secs(3);
pub fn new() -> ConfigurationPollerConfig {
ConfigurationPollerConfig::default()
}
pub fn with_interval(mut self, interval: Duration) -> ConfigurationPollerConfig {
self.interval = interval;
self
}
pub fn with_jitter(mut self, jitter: Duration) -> ConfigurationPollerConfig {
self.jitter = jitter;
self
}
}
impl Default for ConfigurationPollerConfig {
fn default() -> ConfigurationPollerConfig {
ConfigurationPollerConfig {
interval: ConfigurationPollerConfig::DEFAULT_POLL_INTERVAL,
jitter: ConfigurationPollerConfig::DEFAULT_POLL_JITTER,
}
}
}
#[derive(Debug)]
pub struct ConfigurationPoller {
status: watch::Receiver<Option<Result<(), crate::Error>>>,
cancellation_token: CancellationToken,
}
impl ConfigurationPoller {
pub async fn wait_for_configuration(&self) -> Result<(), crate::Error> {
let mut status_rx = self.status.clone();
let status = status_rx
.wait_for(|status| status.is_some())
.await
.map_err(|_| Error::PollerThreadPanicked)?;
status
.as_ref()
.cloned()
.expect("option should always be Some because it's checked in .wait_for()")
}
pub fn stop(&self) {
self.cancellation_token.cancel();
}
}
pub fn start_configuration_poller<AR: AsyncRuntime>(
runtime: &BackgroundRuntime<AR>,
fetcher: ConfigurationFetcher,
store: Arc<ConfigurationStore>,
config: ConfigurationPollerConfig,
) -> ConfigurationPoller {
#[cfg(not(target_arch = "wasm32"))]
let spawn = |f| runtime.spawn_untracked(f);
#[cfg(target_arch = "wasm32")]
let spawn = wasm_bindgen_futures::spawn_local;
let (status_tx, status_rx) = watch::channel(None);
let cancellation_token = runtime.cancellation_token();
log::info!(target: "eppo", "starting configuration poller");
spawn({
let cancellation_token = cancellation_token.clone();
async move {
cancellation_token
.run_until_cancelled(configuration_poller(fetcher, store, config, status_tx))
.await;
}
});
ConfigurationPoller {
status: status_rx,
cancellation_token,
}
}
async fn configuration_poller(
mut fetcher: ConfigurationFetcher,
store: Arc<ConfigurationStore>,
config: ConfigurationPollerConfig,
status: watch::Sender<Option<Result<(), crate::Error>>>,
) {
let update_status = move |next: Result<(), crate::Error>| {
status.send_if_modified(|value| {
let update = value.as_ref().is_none()
|| value
.as_ref()
.is_some_and(|prev| prev.is_ok() != next.is_ok());
if update {
*value = Some(next);
}
update
});
};
loop {
match fetcher.fetch_configuration().await {
Ok(configuration) => {
store.set_configuration(Arc::new(configuration));
update_status(Ok(()));
}
Err(err @ (Error::Unauthorized | Error::InvalidBaseUrl(_))) => {
update_status(Err(Error::from(err)));
return;
}
_ => {
}
}
let timeout = jitter(config.interval, config.jitter);
sleep(timeout).await;
}
}
fn jitter(interval: Duration, jitter: Duration) -> Duration {
Duration::saturating_sub(interval, thread_rng().gen_range(Duration::ZERO..=jitter))
}
#[cfg(test)]
mod jitter_tests {
use std::time::Duration;
#[test]
fn jitter_is_subtractive() {
let interval = Duration::from_secs(30);
let jitter = Duration::from_secs(30);
let result = super::jitter(interval, jitter);
assert!(result <= interval, "{result:?} must be <= {interval:?}");
}
#[test]
fn jitter_truncates_to_zero() {
let interval = Duration::ZERO;
let jitter = Duration::from_secs(30);
let result = super::jitter(interval, jitter);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn jitter_works_with_zero_jitter() {
let interval = Duration::from_secs(30);
let jitter = Duration::ZERO;
let result = super::jitter(interval, jitter);
assert_eq!(result, Duration::from_secs(30));
}
}