Skip to main content

dbmcp_sql/
timeout.rs

1//! Query-level timeout wrapper for SQL operations.
2//!
3//! Provides [`execute_with_timeout`] which runs a query operation under
4//! an optional `tokio::time::timeout` guard.  All backend crates use this
5//! single function instead of duplicating timeout logic.
6
7use std::time::{Duration, Instant};
8
9use sqlx::SqlStr;
10
11use crate::SqlError;
12
13/// Runs a query operation with an optional query timeout.
14///
15/// `op` is handed the [`SqlStr`] to execute and returns the in-flight
16/// query future. When `timeout_secs` is `Some(n)` with `n > 0`, that
17/// future is wrapped with [`tokio::time::timeout`]; on expiry it is
18/// dropped (cancelling the in-flight query) and [`SqlError::QueryTimeout`]
19/// is returned with the elapsed time and the SQL text. `None` or
20/// `Some(0)` runs without a timeout.
21///
22/// The SQL text is only copied into the error on the timeout path, so the
23/// success path adds no allocation.
24///
25/// # Errors
26///
27/// * [`SqlError::QueryTimeout`] — the query exceeded the configured
28///   timeout.
29/// * [`SqlError::Query`] — the underlying query failed for a
30///   non-timeout reason (e.g. syntax error, connection loss).
31pub async fn execute_with_timeout<T>(
32    timeout_secs: Option<u64>,
33    sql: SqlStr,
34    op: impl AsyncFnOnce(SqlStr) -> Result<T, sqlx::Error>,
35) -> Result<T, SqlError> {
36    let result = match timeout_secs {
37        Some(secs) if secs > 0 => {
38            let start = Instant::now();
39            let err_sql = sql.clone();
40            tokio::time::timeout(Duration::from_secs(secs), op(sql))
41                .await
42                .map_err(|_| SqlError::QueryTimeout {
43                    elapsed_secs: start.elapsed().as_secs_f64(),
44                    sql: err_sql.as_str().to_owned(),
45                })?
46        }
47        _ => op(sql).await,
48    };
49    result.map_err(|e| SqlError::Query(e.to_string()))
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    #[tokio::test]
57    async fn fast_query_succeeds_with_timeout() {
58        let result = execute_with_timeout(Some(5), SqlStr::from_static("SELECT 1"), |_| async { Ok(42) }).await;
59        assert_eq!(result.expect("should succeed"), 42);
60    }
61
62    #[tokio::test]
63    async fn query_error_propagates_as_app_error() {
64        let result: Result<i32, SqlError> = execute_with_timeout(Some(5), SqlStr::from_static("BAD SQL"), |_| async {
65            Err(sqlx::Error::Configuration("syntax error".into()))
66        })
67        .await;
68        let err = result.expect_err("should fail");
69        assert!(
70            matches!(err, SqlError::Query(ref msg) if msg.contains("syntax error")),
71            "unexpected error: {err}"
72        );
73    }
74
75    #[tokio::test]
76    async fn slow_query_times_out() {
77        let result: Result<i32, SqlError> =
78            execute_with_timeout(Some(1), SqlStr::from_static("SELECT SLEEP(60)"), |_| async {
79                tokio::time::sleep(Duration::from_mins(1)).await;
80                Ok(0)
81            })
82            .await;
83        let err = result.expect_err("should time out");
84        match err {
85            SqlError::QueryTimeout { elapsed_secs, sql } => {
86                assert!(elapsed_secs >= 0.9, "elapsed too small: {elapsed_secs}");
87                assert_eq!(sql, "SELECT SLEEP(60)");
88            }
89            other => panic!("expected QueryTimeout, got: {other}"),
90        }
91    }
92
93    #[tokio::test]
94    async fn none_timeout_runs_without_limit() {
95        let result = execute_with_timeout(None, SqlStr::from_static("SELECT 1"), |_| async { Ok(1) }).await;
96        assert_eq!(result.expect("should succeed"), 1);
97    }
98
99    #[tokio::test]
100    async fn zero_timeout_disables_limit() {
101        let result = execute_with_timeout(Some(0), SqlStr::from_static("SELECT 1"), |_| async { Ok(1) }).await;
102        assert_eq!(result.expect("should succeed"), 1);
103    }
104}