1use std::time::{Duration, Instant};
8
9use sqlx::SqlStr;
10
11use crate::SqlError;
12
13pub async fn execute_with_timeout<T>(
32 timeout_secs: Option<u64>,
33 sql: SqlStr,
34 op: impl AsyncFnOnce(SqlStr) -> Result<T, sqlx::Error>,
35) -> Result<T, SqlError> {
36 let result = match timeout_secs {
37 Some(secs) if secs > 0 => {
38 let start = Instant::now();
39 let err_sql = sql.clone();
40 tokio::time::timeout(Duration::from_secs(secs), op(sql))
41 .await
42 .map_err(|_| SqlError::QueryTimeout {
43 elapsed_secs: start.elapsed().as_secs_f64(),
44 sql: err_sql.as_str().to_owned(),
45 })?
46 }
47 _ => op(sql).await,
48 };
49 result.map_err(|e| SqlError::Query(e.to_string()))
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55
56 #[tokio::test]
57 async fn fast_query_succeeds_with_timeout() {
58 let result = execute_with_timeout(Some(5), SqlStr::from_static("SELECT 1"), |_| async { Ok(42) }).await;
59 assert_eq!(result.expect("should succeed"), 42);
60 }
61
62 #[tokio::test]
63 async fn query_error_propagates_as_app_error() {
64 let result: Result<i32, SqlError> = execute_with_timeout(Some(5), SqlStr::from_static("BAD SQL"), |_| async {
65 Err(sqlx::Error::Configuration("syntax error".into()))
66 })
67 .await;
68 let err = result.expect_err("should fail");
69 assert!(
70 matches!(err, SqlError::Query(ref msg) if msg.contains("syntax error")),
71 "unexpected error: {err}"
72 );
73 }
74
75 #[tokio::test]
76 async fn slow_query_times_out() {
77 let result: Result<i32, SqlError> =
78 execute_with_timeout(Some(1), SqlStr::from_static("SELECT SLEEP(60)"), |_| async {
79 tokio::time::sleep(Duration::from_mins(1)).await;
80 Ok(0)
81 })
82 .await;
83 let err = result.expect_err("should time out");
84 match err {
85 SqlError::QueryTimeout { elapsed_secs, sql } => {
86 assert!(elapsed_secs >= 0.9, "elapsed too small: {elapsed_secs}");
87 assert_eq!(sql, "SELECT SLEEP(60)");
88 }
89 other => panic!("expected QueryTimeout, got: {other}"),
90 }
91 }
92
93 #[tokio::test]
94 async fn none_timeout_runs_without_limit() {
95 let result = execute_with_timeout(None, SqlStr::from_static("SELECT 1"), |_| async { Ok(1) }).await;
96 assert_eq!(result.expect("should succeed"), 1);
97 }
98
99 #[tokio::test]
100 async fn zero_timeout_disables_limit() {
101 let result = execute_with_timeout(Some(0), SqlStr::from_static("SELECT 1"), |_| async { Ok(1) }).await;
102 assert_eq!(result.expect("should succeed"), 1);
103 }
104}