database_mcp_sql/
timeout.rs1use std::fmt::Display;
8use std::time::{Duration, Instant};
9
10use database_mcp_server::AppError;
11
12pub async fn execute_with_timeout<T, E: Display>(
29 timeout_secs: Option<u64>,
30 sql: &str,
31 fut: impl Future<Output = Result<T, E>>,
32) -> Result<T, AppError> {
33 match timeout_secs.filter(|&t| t > 0) {
34 Some(secs) => {
35 let start = Instant::now();
36 tokio::time::timeout(Duration::from_secs(secs), fut)
37 .await
38 .map_err(|_| AppError::QueryTimeout {
39 elapsed_secs: start.elapsed().as_secs_f64(),
40 sql: sql.to_string(),
41 })?
42 .map_err(|e| AppError::Query(e.to_string()))
43 }
44 None => fut.await.map_err(|e| AppError::Query(e.to_string())),
45 }
46}
47
48#[cfg(test)]
49mod tests {
50 use super::*;
51 use std::fmt;
52
53 #[derive(Debug)]
54 struct TestError(String);
55
56 impl fmt::Display for TestError {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 write!(f, "{}", self.0)
59 }
60 }
61
62 #[tokio::test]
63 async fn fast_query_succeeds_with_timeout() {
64 let result: Result<i32, AppError> =
65 execute_with_timeout(Some(5), "SELECT 1", async { Ok::<_, TestError>(42) }).await;
66 assert_eq!(result.expect("should succeed"), 42);
67 }
68
69 #[tokio::test]
70 async fn query_error_propagates_as_app_error() {
71 let result: Result<i32, AppError> = execute_with_timeout(Some(5), "BAD SQL", async {
72 Err::<i32, _>(TestError("syntax error".into()))
73 })
74 .await;
75 let err = result.expect_err("should fail");
76 assert!(
77 matches!(err, AppError::Query(ref msg) if msg.contains("syntax error")),
78 "unexpected error: {err}"
79 );
80 }
81
82 #[tokio::test]
83 async fn slow_query_times_out() {
84 let result: Result<i32, AppError> = execute_with_timeout(Some(1), "SELECT SLEEP(60)", async {
85 tokio::time::sleep(Duration::from_secs(60)).await;
86 Ok::<_, TestError>(0)
87 })
88 .await;
89 let err = result.expect_err("should time out");
90 match err {
91 AppError::QueryTimeout { elapsed_secs, sql } => {
92 assert!(elapsed_secs >= 0.9, "elapsed too small: {elapsed_secs}");
93 assert_eq!(sql, "SELECT SLEEP(60)");
94 }
95 other => panic!("expected QueryTimeout, got: {other}"),
96 }
97 }
98
99 #[tokio::test]
100 async fn none_timeout_runs_without_limit() {
101 let result: Result<i32, AppError> =
102 execute_with_timeout(None, "SELECT 1", async { Ok::<_, TestError>(1) }).await;
103 assert_eq!(result.expect("should succeed"), 1);
104 }
105
106 #[tokio::test]
107 async fn zero_timeout_disables_limit() {
108 let result: Result<i32, AppError> =
109 execute_with_timeout(Some(0), "SELECT 1", async { Ok::<_, TestError>(1) }).await;
110 assert_eq!(result.expect("should succeed"), 1);
111 }
112}