use std::future::Future;
use std::time::Duration;
use tokio::sync::watch;
use tracing::{debug, error};
use rust_tg_bot_raw::error::TelegramError;
pub struct NetworkLoopConfig<'a, A, E> {
pub action_cb: A,
pub on_err_cb: Option<E>,
pub description: &'a str,
pub interval: f64,
pub stop_rx: Option<watch::Receiver<bool>>,
pub is_running: Option<Box<dyn Fn() -> bool + Send + Sync + 'a>>,
pub max_retries: i32,
pub repeat_on_success: bool,
}
pub async fn network_retry_loop<'a, A, AF, E>(
config: NetworkLoopConfig<'a, A, E>,
) -> Result<(), TelegramError>
where
A: Fn() -> AF,
AF: Future<Output = Result<(), TelegramError>>,
E: Fn(&TelegramError),
{
let NetworkLoopConfig {
action_cb,
on_err_cb,
description,
interval,
mut stop_rx,
is_running,
max_retries,
repeat_on_success,
} = config;
let log_prefix = format!("Network Retry Loop ({description}):");
let effective_is_running = is_running.unwrap_or_else(|| Box::new(|| true));
debug!("{log_prefix} Starting");
let mut cur_interval = interval;
let mut retries: i32 = 0;
while effective_is_running() {
let action_result = match stop_rx.as_mut() {
Some(rx) => {
tokio::select! {
biased;
_ = wait_for_stop(rx) => {
debug!("{log_prefix} Cancelled via stop signal");
return Ok(());
}
res = action_cb() => res,
}
}
None => action_cb().await,
};
match action_result {
Ok(()) => {
if !repeat_on_success {
debug!("{log_prefix} Action succeeded. Stopping loop.");
return Ok(());
}
cur_interval = interval;
}
Err(TelegramError::RetryAfter { retry_after }) => {
let slack = Duration::from_millis(500);
cur_interval = (retry_after + slack).as_secs_f64();
if check_max_retries(retries, max_retries, &log_prefix) {
return Err(TelegramError::RetryAfter { retry_after });
}
}
Err(TelegramError::TimedOut(_)) => {
cur_interval = 0.0;
if check_max_retries(retries, max_retries, &log_prefix) {
return Err(TelegramError::TimedOut("timed out".into()));
}
}
Err(TelegramError::InvalidToken(msg)) => {
error!("{log_prefix} Invalid token. Aborting retry loop.");
return Err(TelegramError::InvalidToken(msg));
}
Err(ref e) => {
if let Some(ref cb) = on_err_cb {
cb(e);
}
if check_max_retries(retries, max_retries, &log_prefix) {
return Err(action_result.unwrap_err());
}
cur_interval = if cur_interval == 0.0 {
1.0
} else {
(1.5 * cur_interval).min(30.0)
};
}
}
retries += 1;
if cur_interval > 0.0 {
tokio::time::sleep(Duration::from_secs_f64(cur_interval)).await;
}
}
Ok(())
}
async fn wait_for_stop(rx: &mut watch::Receiver<bool>) {
while !*rx.borrow_and_update() {
if rx.changed().await.is_err() {
return;
}
}
}
fn check_max_retries(current: i32, max: i32, prefix: &str) -> bool {
if max < 0 || current < max {
debug!("{prefix} Failed run {current} of {max}. Retrying.",);
false
} else {
error!("{prefix} Failed run {current} of {max}. Aborting.",);
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn succeeds_on_first_try() {
let result = network_retry_loop(NetworkLoopConfig {
action_cb: || async { Ok(()) },
on_err_cb: None::<fn(&TelegramError)>,
description: "test",
interval: 0.0,
stop_rx: None,
is_running: None,
max_retries: 0,
repeat_on_success: false,
})
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn retries_and_succeeds() {
let counter = Arc::new(AtomicU32::new(0));
let c = counter.clone();
let result = network_retry_loop(NetworkLoopConfig {
action_cb: move || {
let c = c.clone();
async move {
let n = c.fetch_add(1, Ordering::SeqCst);
if n < 2 {
Err(TelegramError::Network("fail".into()))
} else {
Ok(())
}
}
},
on_err_cb: None::<fn(&TelegramError)>,
description: "retry-test",
interval: 0.0,
stop_rx: None,
is_running: None,
max_retries: -1, repeat_on_success: false,
})
.await;
assert!(result.is_ok());
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn aborts_after_max_retries() {
let result = network_retry_loop(NetworkLoopConfig {
action_cb: || async { Err::<(), _>(TelegramError::Network("always fail".into())) },
on_err_cb: None::<fn(&TelegramError)>,
description: "abort-test",
interval: 0.0,
stop_rx: None,
is_running: None,
max_retries: 2,
repeat_on_success: false,
})
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn invalid_token_aborts_immediately() {
let counter = Arc::new(AtomicU32::new(0));
let c = counter.clone();
let result = network_retry_loop(NetworkLoopConfig {
action_cb: move || {
let c = c.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(TelegramError::InvalidToken("bad".into()))
}
},
on_err_cb: None::<fn(&TelegramError)>,
description: "token-test",
interval: 0.0,
stop_rx: None,
is_running: None,
max_retries: -1,
repeat_on_success: false,
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn stop_signal_cancels_loop() {
let (tx, rx) = watch::channel(false);
let counter = Arc::new(AtomicU32::new(0));
let c = counter.clone();
let handle = tokio::spawn(async move {
network_retry_loop(NetworkLoopConfig {
action_cb: move || {
let c = c.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
}
},
on_err_cb: None::<fn(&TelegramError)>,
description: "stop-test",
interval: 0.01,
stop_rx: Some(rx),
is_running: None,
max_retries: -1,
repeat_on_success: true,
})
.await
});
tokio::time::sleep(Duration::from_millis(80)).await;
tx.send(true).unwrap();
let result = handle.await.unwrap();
assert!(result.is_ok());
assert!(counter.load(Ordering::SeqCst) >= 1);
}
}