database_mcp_sql/
timeout.rs1use std::time::{Duration, Instant};
8
9use crate::SqlError;
10
11pub async fn execute_with_timeout<T>(
28 timeout_secs: Option<u64>,
29 sql: &str,
30 fut: impl Future<Output = Result<T, sqlx::Error>>,
31) -> Result<T, SqlError> {
32 match timeout_secs.filter(|&t| t > 0) {
33 Some(secs) => {
34 let start = Instant::now();
35 tokio::time::timeout(Duration::from_secs(secs), fut)
36 .await
37 .map_err(|_| SqlError::QueryTimeout {
38 elapsed_secs: start.elapsed().as_secs_f64(),
39 sql: sql.to_string(),
40 })?
41 .map_err(|e| SqlError::Query(e.to_string()))
42 }
43 None => fut.await.map_err(|e| SqlError::Query(e.to_string())),
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50
51 #[tokio::test]
52 async fn fast_query_succeeds_with_timeout() {
53 let result = execute_with_timeout(Some(5), "SELECT 1", async { Ok(42) }).await;
54 assert_eq!(result.expect("should succeed"), 42);
55 }
56
57 #[tokio::test]
58 async fn query_error_propagates_as_app_error() {
59 let result: Result<i32, SqlError> = execute_with_timeout(Some(5), "BAD SQL", async {
60 Err(sqlx::Error::Configuration("syntax error".into()))
61 })
62 .await;
63 let err = result.expect_err("should fail");
64 assert!(
65 matches!(err, SqlError::Query(ref msg) if msg.contains("syntax error")),
66 "unexpected error: {err}"
67 );
68 }
69
70 #[tokio::test]
71 async fn slow_query_times_out() {
72 let result: Result<i32, SqlError> = execute_with_timeout(Some(1), "SELECT SLEEP(60)", async {
73 tokio::time::sleep(Duration::from_mins(1)).await;
74 Ok(0)
75 })
76 .await;
77 let err = result.expect_err("should time out");
78 match err {
79 SqlError::QueryTimeout { elapsed_secs, sql } => {
80 assert!(elapsed_secs >= 0.9, "elapsed too small: {elapsed_secs}");
81 assert_eq!(sql, "SELECT SLEEP(60)");
82 }
83 other => panic!("expected QueryTimeout, got: {other}"),
84 }
85 }
86
87 #[tokio::test]
88 async fn none_timeout_runs_without_limit() {
89 let result = execute_with_timeout(None, "SELECT 1", async { Ok(1) }).await;
90 assert_eq!(result.expect("should succeed"), 1);
91 }
92
93 #[tokio::test]
94 async fn zero_timeout_disables_limit() {
95 let result = execute_with_timeout(Some(0), "SELECT 1", async { Ok(1) }).await;
96 assert_eq!(result.expect("should succeed"), 1);
97 }
98}