database_mcp_sql/connection.rs
1//! Connection abstraction shared across database backends.
2//!
3//! Defines [`Connection`] — the single trait every backend implements.
4//! Backends provide pool resolution and timeout config; default method
5//! implementations handle query execution.
6
7use crate::SqlError;
8use serde_json::Value;
9use sqlx::{Decode, Executor, Row, Type};
10use sqlx_to_json::{QueryResult as _, RowExt};
11
12use crate::timeout::execute_with_timeout;
13
14/// Unified query surface every backend tool handler uses.
15///
16/// Backends supply three required items — [`DB`](Connection::DB),
17/// [`pool`](Connection::pool), and [`query_timeout`](Connection::query_timeout)
18/// — and receive default implementations for query execution.
19///
20/// # Errors
21///
22/// Query methods may return:
23///
24/// - [`SqlError::InvalidIdentifier`] — `database` failed identifier validation.
25/// - [`SqlError::Connection`] — the underlying driver failed.
26/// - [`SqlError::QueryTimeout`] — the query exceeded the configured timeout.
27#[allow(async_fn_in_trait)]
28pub trait Connection: Send + Sync
29where
30 for<'c> &'c mut <Self::DB as sqlx::Database>::Connection: Executor<'c, Database = Self::DB>,
31 usize: sqlx::ColumnIndex<<Self::DB as sqlx::Database>::Row>,
32 <Self::DB as sqlx::Database>::Row: RowExt,
33 <Self::DB as sqlx::Database>::QueryResult: sqlx_to_json::QueryResult,
34{
35 /// The sqlx database driver type (e.g. `sqlx::MySql`).
36 type DB: sqlx::Database;
37
38 /// Resolves the connection pool for the given target database.
39 ///
40 /// # Errors
41 ///
42 /// - [`SqlError::InvalidIdentifier`] — `target` failed validation.
43 async fn pool(&self, target: Option<&str>) -> Result<sqlx::Pool<Self::DB>, SqlError>;
44
45 /// Returns the configured query timeout in seconds, if any.
46 fn query_timeout(&self) -> Option<u64>;
47
48 /// Runs a statement that returns no meaningful rows.
49 ///
50 /// # Errors
51 ///
52 /// See trait-level documentation.
53 async fn execute(&self, query: &str, database: Option<&str>) -> Result<u64, SqlError> {
54 let pool = self.pool(database).await?;
55 execute_with_timeout(self.query_timeout(), query, async {
56 Ok(pool.execute(query).await?.rows_affected())
57 })
58 .await
59 }
60
61 /// Runs a statement and collects every result row as JSON.
62 ///
63 /// # Errors
64 ///
65 /// See trait-level documentation.
66 async fn fetch_json(&self, query: &str, database: Option<&str>) -> Result<Vec<Value>, SqlError> {
67 let pool = self.pool(database).await?;
68 execute_with_timeout(self.query_timeout(), query, async {
69 Ok(pool.fetch_all(query).await?.iter().map(RowExt::to_json).collect())
70 })
71 .await
72 }
73
74 /// Runs a query and extracts column 0 from the first row, if any.
75 ///
76 /// Returns `None` for both "no row returned" and "row where column 0
77 /// is NULL" (decode errors are caught, not propagated).
78 ///
79 /// # Errors
80 ///
81 /// See trait-level documentation.
82 async fn fetch_optional<T>(&self, query: &str, database: Option<&str>) -> Result<Option<T>, SqlError>
83 where
84 T: for<'r> Decode<'r, Self::DB> + Type<Self::DB> + Send + Unpin,
85 {
86 let pool = self.pool(database).await?;
87 execute_with_timeout(self.query_timeout(), query, async {
88 Ok(pool.fetch_optional(query).await?.and_then(|r| r.try_get(0usize).ok()))
89 })
90 .await
91 }
92
93 /// Runs a query and extracts the first column of every row.
94 ///
95 /// # Errors
96 ///
97 /// See trait-level documentation.
98 async fn fetch_scalar<T>(&self, query: &str, database: Option<&str>) -> Result<Vec<T>, SqlError>
99 where
100 T: for<'r> Decode<'r, Self::DB> + Type<Self::DB> + Send + Unpin,
101 {
102 let pool = self.pool(database).await?;
103 execute_with_timeout(self.query_timeout(), query, async {
104 let rows = pool.fetch_all(query).await?;
105 rows.iter().map(|r| r.try_get(0usize)).collect()
106 })
107 .await
108 }
109}