Skip to main content

database_mcp_sql/
connection.rs

1//! Connection abstraction shared across database backends.
2//!
3//! Defines [`Connection`] — the single trait every backend implements.
4//! Backends provide pool resolution and timeout config; default method
5//! implementations handle query execution.
6
7use crate::SqlError;
8use serde_json::Value;
9use sqlx::{Decode, Executor, Row, Type};
10use sqlx_to_json::{QueryResult as _, RowExt};
11
12use crate::timeout::execute_with_timeout;
13
14/// Unified query surface every backend tool handler uses.
15///
16/// Backends supply three required items — [`DB`](Connection::DB),
17/// [`pool`](Connection::pool), and [`query_timeout`](Connection::query_timeout)
18/// — and receive default implementations for query execution.
19///
20/// # Errors
21///
22/// Query methods may return:
23///
24/// - [`SqlError::InvalidIdentifier`] — `database` failed identifier validation.
25/// - [`SqlError::Connection`] — the underlying driver failed.
26/// - [`SqlError::QueryTimeout`] — the query exceeded the configured timeout.
27#[allow(async_fn_in_trait)]
28pub trait Connection: Send + Sync
29where
30    for<'c> &'c mut <Self::DB as sqlx::Database>::Connection: Executor<'c, Database = Self::DB>,
31    usize: sqlx::ColumnIndex<<Self::DB as sqlx::Database>::Row>,
32    <Self::DB as sqlx::Database>::Row: RowExt,
33    <Self::DB as sqlx::Database>::QueryResult: sqlx_to_json::QueryResult,
34{
35    /// The sqlx database driver type (e.g. `sqlx::MySql`).
36    type DB: sqlx::Database;
37
38    /// Resolves the connection pool for the given target database.
39    ///
40    /// # Errors
41    ///
42    /// - [`SqlError::InvalidIdentifier`] — `target` failed validation.
43    async fn pool(&self, target: Option<&str>) -> Result<sqlx::Pool<Self::DB>, SqlError>;
44
45    /// Returns the configured query timeout in seconds, if any.
46    fn query_timeout(&self) -> Option<u64>;
47
48    /// Runs a statement that returns no meaningful rows.
49    ///
50    /// # Errors
51    ///
52    /// See trait-level documentation.
53    async fn execute(&self, query: &str, database: Option<&str>) -> Result<u64, SqlError> {
54        let pool = self.pool(database).await?;
55        execute_with_timeout(self.query_timeout(), query, async {
56            Ok(pool.execute(query).await?.rows_affected())
57        })
58        .await
59    }
60
61    /// Runs a statement and collects every result row as JSON.
62    ///
63    /// # Errors
64    ///
65    /// See trait-level documentation.
66    async fn fetch_json(&self, query: &str, database: Option<&str>) -> Result<Vec<Value>, SqlError> {
67        let pool = self.pool(database).await?;
68        execute_with_timeout(self.query_timeout(), query, async {
69            Ok(pool.fetch_all(query).await?.iter().map(RowExt::to_json).collect())
70        })
71        .await
72    }
73
74    /// Runs a query and extracts column 0 from the first row, if any.
75    ///
76    /// Returns `None` for both "no row returned" and "row where column 0
77    /// is NULL" (decode errors are caught, not propagated).
78    ///
79    /// # Errors
80    ///
81    /// See trait-level documentation.
82    async fn fetch_optional<T>(&self, query: &str, database: Option<&str>) -> Result<Option<T>, SqlError>
83    where
84        T: for<'r> Decode<'r, Self::DB> + Type<Self::DB> + Send + Unpin,
85    {
86        let pool = self.pool(database).await?;
87        execute_with_timeout(self.query_timeout(), query, async {
88            Ok(pool.fetch_optional(query).await?.and_then(|r| r.try_get(0usize).ok()))
89        })
90        .await
91    }
92
93    /// Runs a query and extracts the first column of every row.
94    ///
95    /// # Errors
96    ///
97    /// See trait-level documentation.
98    async fn fetch_scalar<T>(&self, query: &str, database: Option<&str>) -> Result<Vec<T>, SqlError>
99    where
100        T: for<'r> Decode<'r, Self::DB> + Type<Self::DB> + Send + Unpin,
101    {
102        let pool = self.pool(database).await?;
103        execute_with_timeout(self.query_timeout(), query, async {
104            let rows = pool.fetch_all(query).await?;
105            rows.iter().map(|r| r.try_get(0usize)).collect()
106        })
107        .await
108    }
109}