Skip to main content

database_mcp_sql/
timeout.rs

1//! Query-level timeout wrapper for SQL operations.
2//!
3//! Provides [`execute_with_timeout`] which wraps any async query future
4//! with an optional `tokio::time::timeout` guard.  All backend crates
5//! use this single function instead of duplicating timeout logic.
6
7use std::time::{Duration, Instant};
8
9use crate::SqlError;
10
11/// Executes `fut` with an optional query timeout.
12///
13/// When `timeout_secs` is `Some(n)` where `n > 0`, the future is wrapped
14/// with [`tokio::time::timeout`].  On expiry the future is dropped
15/// (cancelling the in-flight query) and [`SqlError::QueryTimeout`] is
16/// returned with the wall-clock elapsed time and the original SQL text.
17///
18/// When `timeout_secs` is `None` or `Some(0)`, the future runs without
19/// any timeout.
20///
21/// # Errors
22///
23/// * [`SqlError::QueryTimeout`] — the query exceeded the configured
24///   timeout.
25/// * [`SqlError::Query`] — the underlying query failed for a
26///   non-timeout reason (e.g. syntax error, connection loss).
27pub async fn execute_with_timeout<T>(
28    timeout_secs: Option<u64>,
29    sql: &str,
30    fut: impl Future<Output = Result<T, sqlx::Error>>,
31) -> Result<T, SqlError> {
32    match timeout_secs.filter(|&t| t > 0) {
33        Some(secs) => {
34            let start = Instant::now();
35            tokio::time::timeout(Duration::from_secs(secs), fut)
36                .await
37                .map_err(|_| SqlError::QueryTimeout {
38                    elapsed_secs: start.elapsed().as_secs_f64(),
39                    sql: sql.to_string(),
40                })?
41                .map_err(|e| SqlError::Query(e.to_string()))
42        }
43        None => fut.await.map_err(|e| SqlError::Query(e.to_string())),
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    #[tokio::test]
52    async fn fast_query_succeeds_with_timeout() {
53        let result = execute_with_timeout(Some(5), "SELECT 1", async { Ok(42) }).await;
54        assert_eq!(result.expect("should succeed"), 42);
55    }
56
57    #[tokio::test]
58    async fn query_error_propagates_as_app_error() {
59        let result: Result<i32, SqlError> = execute_with_timeout(Some(5), "BAD SQL", async {
60            Err(sqlx::Error::Configuration("syntax error".into()))
61        })
62        .await;
63        let err = result.expect_err("should fail");
64        assert!(
65            matches!(err, SqlError::Query(ref msg) if msg.contains("syntax error")),
66            "unexpected error: {err}"
67        );
68    }
69
70    #[tokio::test]
71    async fn slow_query_times_out() {
72        let result: Result<i32, SqlError> = execute_with_timeout(Some(1), "SELECT SLEEP(60)", async {
73            tokio::time::sleep(Duration::from_mins(1)).await;
74            Ok(0)
75        })
76        .await;
77        let err = result.expect_err("should time out");
78        match err {
79            SqlError::QueryTimeout { elapsed_secs, sql } => {
80                assert!(elapsed_secs >= 0.9, "elapsed too small: {elapsed_secs}");
81                assert_eq!(sql, "SELECT SLEEP(60)");
82            }
83            other => panic!("expected QueryTimeout, got: {other}"),
84        }
85    }
86
87    #[tokio::test]
88    async fn none_timeout_runs_without_limit() {
89        let result = execute_with_timeout(None, "SELECT 1", async { Ok(1) }).await;
90        assert_eq!(result.expect("should succeed"), 1);
91    }
92
93    #[tokio::test]
94    async fn zero_timeout_disables_limit() {
95        let result = execute_with_timeout(Some(0), "SELECT 1", async { Ok(1) }).await;
96        assert_eq!(result.expect("should succeed"), 1);
97    }
98}