1use log::{error, info, warn};
6use mongodb::{bson::doc, Client};
7use std::time::Duration;
8use tokio::time;
9
10const DEFAULT_INTERVAL: &str = "12h";
11const MAX_RETRIES: u32 = 3;
12const RETRY_DELAY: Duration = Duration::from_secs(5);
13
14pub fn parse_interval(interval: &str) -> Result<Duration, String> {
19 let s = interval.trim();
20 if s.is_empty() {
21 return Err("Interval must not be empty".into());
22 }
23
24 let (digits, unit) = s.split_at(s.len() - 1);
25 let value: u64 = digits
26 .parse()
27 .map_err(|_| format!("Invalid interval \"{interval}\". Use e.g. \"12h\", \"30m\", or \"60s\"."))?;
28
29 match unit.to_ascii_lowercase().as_str() {
30 "h" => Ok(Duration::from_secs(value * 3600)),
31 "m" => Ok(Duration::from_secs(value * 60)),
32 "s" => Ok(Duration::from_secs(value)),
33 _ => Err(format!(
34 "Invalid interval \"{interval}\". Use e.g. \"12h\", \"30m\", or \"60s\"."
35 )),
36 }
37}
38
39async fn ping_with_retry(client: &Client) {
41 let db = client.database("admin");
42 for attempt in 1..=MAX_RETRIES {
43 match db.run_command(doc! { "ping": 1 }, None).await {
44 Ok(result) => {
45 info!("[mongo-keepalive] Ping successful: {:?}", result);
46 return;
47 }
48 Err(err) => {
49 warn!(
50 "[mongo-keepalive] Ping attempt {}/{} failed: {}",
51 attempt, MAX_RETRIES, err
52 );
53 if attempt < MAX_RETRIES {
54 time::sleep(RETRY_DELAY).await;
55 }
56 }
57 }
58 }
59 error!(
60 "[mongo-keepalive] All {} ping attempts failed. Will retry at next interval.",
61 MAX_RETRIES
62 );
63}
64
65pub struct KeepAliveHandle {
67 handle: tokio::task::JoinHandle<()>,
68}
69
70impl KeepAliveHandle {
71 pub fn stop(self) {
73 self.handle.abort();
74 info!("[mongo-keepalive] Stopped.");
75 }
76}
77
78pub async fn start_keep_alive(uri: &str, interval: &str) -> KeepAliveHandle {
88 let interval = if interval.is_empty() {
89 DEFAULT_INTERVAL
90 } else {
91 interval
92 };
93
94 let dur = parse_interval(interval).expect("Invalid interval");
95 info!(
96 "[mongo-keepalive] Starting with interval {} ({:?})",
97 interval, dur
98 );
99
100 let client = Client::with_uri_str(uri)
101 .await
102 .expect("[mongo-keepalive] Failed to connect to MongoDB");
103
104 ping_with_retry(&client).await;
106
107 let handle = tokio::spawn(async move {
108 let mut ticker = time::interval(dur);
109 ticker.tick().await;
111
112 loop {
113 ticker.tick().await;
114 ping_with_retry(&client).await;
115 }
116 });
117
118 KeepAliveHandle { handle }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[test]
126 fn test_parse_interval() {
127 assert_eq!(parse_interval("12h").unwrap(), Duration::from_secs(43200));
128 assert_eq!(parse_interval("30m").unwrap(), Duration::from_secs(1800));
129 assert_eq!(parse_interval("60s").unwrap(), Duration::from_secs(60));
130 assert!(parse_interval("abc").is_err());
131 assert!(parse_interval("").is_err());
132 }
133}