Skip to main content

mongo_keepalive/
lib.rs

1//! # mongo-keepalive
2//!
3//! Keep MongoDB Atlas free-tier clusters alive by periodically sending a `ping` command.
4
5use log::{error, info, warn};
6use mongodb::{bson::doc, Client};
7use std::time::Duration;
8use tokio::time;
9
10const DEFAULT_INTERVAL: &str = "12h";
11const MAX_RETRIES: u32 = 3;
12const RETRY_DELAY: Duration = Duration::from_secs(5);
13
14/// Parse an interval string like `"12h"`, `"30m"`, or `"60s"` into a [`Duration`].
15///
16/// # Errors
17/// Returns an error if the format is invalid.
18pub fn parse_interval(interval: &str) -> Result<Duration, String> {
19    let s = interval.trim();
20    if s.is_empty() {
21        return Err("Interval must not be empty".into());
22    }
23
24    let (digits, unit) = s.split_at(s.len() - 1);
25    let value: u64 = digits
26        .parse()
27        .map_err(|_| format!("Invalid interval \"{interval}\". Use e.g. \"12h\", \"30m\", or \"60s\"."))?;
28
29    match unit.to_ascii_lowercase().as_str() {
30        "h" => Ok(Duration::from_secs(value * 3600)),
31        "m" => Ok(Duration::from_secs(value * 60)),
32        "s" => Ok(Duration::from_secs(value)),
33        _ => Err(format!(
34            "Invalid interval \"{interval}\". Use e.g. \"12h\", \"30m\", or \"60s\"."
35        )),
36    }
37}
38
39/// Send a ping command with retry logic.
40async fn ping_with_retry(client: &Client) {
41    let db = client.database("admin");
42    for attempt in 1..=MAX_RETRIES {
43        match db.run_command(doc! { "ping": 1 }, None).await {
44            Ok(result) => {
45                info!("[mongo-keepalive] Ping successful: {:?}", result);
46                return;
47            }
48            Err(err) => {
49                warn!(
50                    "[mongo-keepalive] Ping attempt {}/{} failed: {}",
51                    attempt, MAX_RETRIES, err
52                );
53                if attempt < MAX_RETRIES {
54                    time::sleep(RETRY_DELAY).await;
55                }
56            }
57        }
58    }
59    error!(
60        "[mongo-keepalive] All {} ping attempts failed. Will retry at next interval.",
61        MAX_RETRIES
62    );
63}
64
65/// A handle to the keep-alive background task. Dropping it cancels the task.
66pub struct KeepAliveHandle {
67    handle: tokio::task::JoinHandle<()>,
68}
69
70impl KeepAliveHandle {
71    /// Stop the keep-alive loop.
72    pub fn stop(self) {
73        self.handle.abort();
74        info!("[mongo-keepalive] Stopped.");
75    }
76}
77
78/// Start the keep-alive loop.
79///
80/// Connects to MongoDB using `uri`, sends an initial ping, then repeats at the
81/// given `interval` (default `"12h"`).
82///
83/// Returns a [`KeepAliveHandle`] that can be used to stop the loop.
84///
85/// # Panics
86/// Panics if the MongoDB connection cannot be established.
87pub async fn start_keep_alive(uri: &str, interval: &str) -> KeepAliveHandle {
88    let interval = if interval.is_empty() {
89        DEFAULT_INTERVAL
90    } else {
91        interval
92    };
93
94    let dur = parse_interval(interval).expect("Invalid interval");
95    info!(
96        "[mongo-keepalive] Starting with interval {} ({:?})",
97        interval, dur
98    );
99
100    let client = Client::with_uri_str(uri)
101        .await
102        .expect("[mongo-keepalive] Failed to connect to MongoDB");
103
104    // Initial ping
105    ping_with_retry(&client).await;
106
107    let handle = tokio::spawn(async move {
108        let mut ticker = time::interval(dur);
109        // Skip the first tick (already pinged above)
110        ticker.tick().await;
111
112        loop {
113            ticker.tick().await;
114            ping_with_retry(&client).await;
115        }
116    });
117
118    KeepAliveHandle { handle }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn test_parse_interval() {
127        assert_eq!(parse_interval("12h").unwrap(), Duration::from_secs(43200));
128        assert_eq!(parse_interval("30m").unwrap(), Duration::from_secs(1800));
129        assert_eq!(parse_interval("60s").unwrap(), Duration::from_secs(60));
130        assert!(parse_interval("abc").is_err());
131        assert!(parse_interval("").is_err());
132    }
133}