use std::future::Future;
use std::sync::Arc;
use parking_lot::Mutex;
use tokio::runtime::Handle;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio::time::MissedTickBehavior;
struct RefreshRuntime {
stop_tx: Option<oneshot::Sender<()>>,
join: JoinHandle<()>,
}
#[derive(Clone)]
pub struct PeriodicRefresher {
runtime: Arc<Mutex<Option<RefreshRuntime>>>,
}
impl PeriodicRefresher {
pub fn new() -> Self {
Self {
runtime: Arc::new(Mutex::new(None)),
}
}
pub fn start<F, Fut>(&self, interval: std::time::Duration, refresh_fn: F) -> Result<(), String>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
if interval.is_zero() {
return Err("interval must be non-zero".into());
}
let handle = Handle::try_current().map_err(|e| e.to_string())?;
let mut guard = self.runtime.lock();
if guard.as_ref().is_some_and(|rt| !rt.join.is_finished()) {
return Err("periodic refresh already running".into());
}
let (stop_tx, stop_rx) = oneshot::channel();
let join = handle.spawn(periodic_loop(interval, stop_rx, refresh_fn));
*guard = Some(RefreshRuntime {
stop_tx: Some(stop_tx),
join,
});
Ok(())
}
pub async fn stop(&self) -> bool {
let runtime = {
let mut guard = self.runtime.lock();
guard.take()
};
let Some(mut runtime) = runtime else {
return false;
};
if let Some(stop_tx) = runtime.stop_tx.take() {
let _ = stop_tx.send(());
}
let _ = runtime.join.await;
true
}
pub fn is_running(&self) -> bool {
let mut guard = self.runtime.lock();
if guard.as_ref().is_some_and(|rt| rt.join.is_finished()) {
*guard = None;
return false;
}
guard.is_some()
}
}
impl Default for PeriodicRefresher {
fn default() -> Self {
Self::new()
}
}
async fn periodic_loop<F, Fut>(
interval: std::time::Duration,
mut stop_rx: oneshot::Receiver<()>,
refresh_fn: F,
) where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send,
{
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
ticker.tick().await;
loop {
tokio::select! {
_ = &mut stop_rx => break,
_ = ticker.tick() => {
refresh_fn().await;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
#[tokio::test]
async fn start_stop_basic() {
let refresher = PeriodicRefresher::new();
let counter = Arc::new(AtomicU32::new(0));
let c = counter.clone();
refresher
.start(Duration::from_millis(10), move || {
let c = c.clone();
async move {
c.fetch_add(1, Ordering::Relaxed);
}
})
.unwrap();
assert!(refresher.is_running());
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(refresher.stop().await);
assert!(!refresher.is_running());
assert!(counter.load(Ordering::Relaxed) > 0);
}
#[tokio::test]
async fn zero_interval_rejected() {
let refresher = PeriodicRefresher::new();
let err = refresher.start(Duration::ZERO, || async {}).unwrap_err();
assert!(err.contains("non-zero"));
}
#[tokio::test]
async fn double_start_rejected() {
let refresher = PeriodicRefresher::new();
refresher
.start(Duration::from_secs(60), || async {})
.unwrap();
let err = refresher
.start(Duration::from_secs(60), || async {})
.unwrap_err();
assert!(err.contains("already running"));
refresher.stop().await;
}
#[tokio::test]
async fn stop_when_not_running() {
let refresher = PeriodicRefresher::new();
assert!(!refresher.stop().await);
}
#[tokio::test]
async fn can_restart_after_stop() {
let refresher = PeriodicRefresher::new();
refresher
.start(Duration::from_millis(10), || async {})
.unwrap();
refresher.stop().await;
refresher
.start(Duration::from_millis(10), || async {})
.unwrap();
assert!(refresher.is_running());
refresher.stop().await;
}
}