Skip to main content

database_mcp_sql/
timeout.rs

1//! Query-level timeout wrapper for SQL operations.
2//!
3//! Provides [`execute_with_timeout`] which wraps any async query future
4//! with an optional `tokio::time::timeout` guard.  All backend crates
5//! use this single function instead of duplicating timeout logic.
6
7use std::fmt::Display;
8use std::time::{Duration, Instant};
9
10use database_mcp_server::AppError;
11
12/// Executes `fut` with an optional query timeout.
13///
14/// When `timeout_secs` is `Some(n)` where `n > 0`, the future is wrapped
15/// with [`tokio::time::timeout`].  On expiry the future is dropped
16/// (cancelling the in-flight query) and [`AppError::QueryTimeout`] is
17/// returned with the wall-clock elapsed time and the original SQL text.
18///
19/// When `timeout_secs` is `None` or `Some(0)`, the future runs without
20/// any timeout.
21///
22/// # Errors
23///
24/// * [`AppError::QueryTimeout`] — the query exceeded the configured
25///   timeout.
26/// * [`AppError::Query`] — the underlying query failed for a
27///   non-timeout reason (e.g. syntax error, connection loss).
28pub async fn execute_with_timeout<T, E: Display>(
29    timeout_secs: Option<u64>,
30    sql: &str,
31    fut: impl Future<Output = Result<T, E>>,
32) -> Result<T, AppError> {
33    match timeout_secs.filter(|&t| t > 0) {
34        Some(secs) => {
35            let start = Instant::now();
36            tokio::time::timeout(Duration::from_secs(secs), fut)
37                .await
38                .map_err(|_| AppError::QueryTimeout {
39                    elapsed_secs: start.elapsed().as_secs_f64(),
40                    sql: sql.to_string(),
41                })?
42                .map_err(|e| AppError::Query(e.to_string()))
43        }
44        None => fut.await.map_err(|e| AppError::Query(e.to_string())),
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51    use std::fmt;
52
53    #[derive(Debug)]
54    struct TestError(String);
55
56    impl fmt::Display for TestError {
57        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58            write!(f, "{}", self.0)
59        }
60    }
61
62    #[tokio::test]
63    async fn fast_query_succeeds_with_timeout() {
64        let result: Result<i32, AppError> =
65            execute_with_timeout(Some(5), "SELECT 1", async { Ok::<_, TestError>(42) }).await;
66        assert_eq!(result.expect("should succeed"), 42);
67    }
68
69    #[tokio::test]
70    async fn query_error_propagates_as_app_error() {
71        let result: Result<i32, AppError> = execute_with_timeout(Some(5), "BAD SQL", async {
72            Err::<i32, _>(TestError("syntax error".into()))
73        })
74        .await;
75        let err = result.expect_err("should fail");
76        assert!(
77            matches!(err, AppError::Query(ref msg) if msg.contains("syntax error")),
78            "unexpected error: {err}"
79        );
80    }
81
82    #[tokio::test]
83    async fn slow_query_times_out() {
84        let result: Result<i32, AppError> = execute_with_timeout(Some(1), "SELECT SLEEP(60)", async {
85            tokio::time::sleep(Duration::from_secs(60)).await;
86            Ok::<_, TestError>(0)
87        })
88        .await;
89        let err = result.expect_err("should time out");
90        match err {
91            AppError::QueryTimeout { elapsed_secs, sql } => {
92                assert!(elapsed_secs >= 0.9, "elapsed too small: {elapsed_secs}");
93                assert_eq!(sql, "SELECT SLEEP(60)");
94            }
95            other => panic!("expected QueryTimeout, got: {other}"),
96        }
97    }
98
99    #[tokio::test]
100    async fn none_timeout_runs_without_limit() {
101        let result: Result<i32, AppError> =
102            execute_with_timeout(None, "SELECT 1", async { Ok::<_, TestError>(1) }).await;
103        assert_eq!(result.expect("should succeed"), 1);
104    }
105
106    #[tokio::test]
107    async fn zero_timeout_disables_limit() {
108        let result: Result<i32, AppError> =
109            execute_with_timeout(Some(0), "SELECT 1", async { Ok::<_, TestError>(1) }).await;
110        assert_eq!(result.expect("should succeed"), 1);
111    }
112}