Skip to main content

dbmcp_sql/
connection.rs

1//! Connection abstraction shared across database backends.
2//!
3//! Defines [`Connection`] — the single trait every backend implements.
4//! Backends provide pool resolution and timeout config; default method
5//! implementations handle query execution.
6
7use crate::SqlError;
8use serde_json::Value;
9use sqlx::{AssertSqlSafe, Decode, Execute, Executor, FromRow, Row, SqlSafeStr, SqlStr, Type};
10use sqlx_json::{QueryResult as _, RowExt};
11
12use crate::timeout::execute_with_timeout;
13
14/// Splits a query into its SQL text and bound arguments for safe execution.
15///
16/// Lets [`Connection`] query methods accept either a bindless `&str` — wrapped
17/// as [`sqlx::AssertSqlSafe`] and run through the unprepared text protocol — or a
18/// parameterized `sqlx::query(..).bind(..)` value, without callers writing the
19/// wrapper at every call site. Callers remain responsible for ensuring bindless
20/// strings carry no injection (via read-only validation and identifier quoting).
21///
22/// The returned [`SqlStr`] owns its text, so no input borrow escapes. The
23/// `(SqlStr, Option<_>)` pair is itself an [`sqlx::Execute`] value: a `None`
24/// argument set routes through the unprepared text protocol, `Some` through a
25/// prepared statement.
26pub trait IntoSafeQuery<DB: sqlx::Database> {
27    /// Returns the SQL text and the bound arguments, if any.
28    ///
29    /// # Errors
30    ///
31    /// [`SqlError::Query`] — extracting bound arguments failed.
32    fn into_sql_and_args(self) -> Result<(SqlStr, Option<DB::Arguments>), SqlError>;
33}
34
35impl<DB: sqlx::Database> IntoSafeQuery<DB> for &str {
36    fn into_sql_and_args(self) -> Result<(SqlStr, Option<DB::Arguments>), SqlError> {
37        Ok((AssertSqlSafe(self).into_sql_str(), None))
38    }
39}
40
41impl<DB: sqlx::Database, A> IntoSafeQuery<DB> for sqlx::query::Query<'_, DB, A>
42where
43    A: Send + sqlx::IntoArguments<DB>,
44{
45    fn into_sql_and_args(mut self) -> Result<(SqlStr, Option<DB::Arguments>), SqlError> {
46        let arguments = self.take_arguments().map_err(|e| SqlError::Query(e.to_string()))?;
47        Ok((self.sql(), arguments))
48    }
49}
50
51/// Unified query surface every backend tool handler uses.
52///
53/// Backends supply three required items — [`DB`](Connection::DB),
54/// [`pool`](Connection::pool), and [`query_timeout`](Connection::query_timeout)
55/// — and receive default implementations for query execution.
56///
57/// Query methods accept any [`IntoSafeQuery`] value: a bindless `&str` (run
58/// through the unprepared text protocol, required for statements like `MySQL`
59/// `USE`) or a parameterized `sqlx::query(sql).bind(...)` value (run as a
60/// prepared statement).
61///
62/// # Errors
63///
64/// Query methods may return:
65///
66/// - [`SqlError::InvalidIdentifier`] — `database` failed identifier validation.
67/// - [`SqlError::Connection`] — the underlying driver failed.
68/// - [`SqlError::QueryTimeout`] — the query exceeded the configured timeout.
69#[allow(async_fn_in_trait)]
70pub trait Connection: Send + Sync
71where
72    for<'c> &'c mut <Self::DB as sqlx::Database>::Connection: Executor<'c, Database = Self::DB>,
73    usize: sqlx::ColumnIndex<<Self::DB as sqlx::Database>::Row>,
74    <Self::DB as sqlx::Database>::Row: RowExt,
75    <Self::DB as sqlx::Database>::QueryResult: sqlx_json::QueryResult,
76{
77    /// The sqlx database driver type (e.g. `sqlx::MySql`).
78    type DB: sqlx::Database;
79
80    /// Resolves the connection pool for the given target database.
81    ///
82    /// # Errors
83    ///
84    /// - [`SqlError::InvalidIdentifier`] — `target` failed validation.
85    async fn pool(&self, target: Option<&str>) -> Result<sqlx::Pool<Self::DB>, SqlError>;
86
87    /// Returns the configured query timeout in seconds, if any.
88    fn query_timeout(&self) -> Option<u64>;
89
90    /// Runs a statement that returns no meaningful rows.
91    ///
92    /// # Errors
93    ///
94    /// See trait-level documentation.
95    async fn execute<Q>(&self, query: Q, database: Option<&str>) -> Result<u64, SqlError>
96    where
97        Q: IntoSafeQuery<Self::DB>,
98    {
99        let (sql, arguments) = query.into_sql_and_args()?;
100        let pool = self.pool(database).await?;
101        execute_with_timeout(self.query_timeout(), sql, |sql| async move {
102            Ok(pool.execute((sql, arguments)).await?.rows_affected())
103        })
104        .await
105    }
106
107    /// Runs a statement and collects every result row as JSON.
108    ///
109    /// # Errors
110    ///
111    /// See trait-level documentation.
112    async fn fetch_json<Q>(&self, query: Q, database: Option<&str>) -> Result<Vec<Value>, SqlError>
113    where
114        Q: IntoSafeQuery<Self::DB>,
115    {
116        let (sql, arguments) = query.into_sql_and_args()?;
117        let pool = self.pool(database).await?;
118        execute_with_timeout(self.query_timeout(), sql, |sql| async move {
119            let rows = pool.fetch_all((sql, arguments)).await?;
120            Ok(rows.iter().map(RowExt::to_json).collect())
121        })
122        .await
123    }
124
125    /// Runs a query and extracts column 0 from the first row, if any.
126    ///
127    /// Returns `None` for both "no row returned" and "row where column 0
128    /// is NULL" (decode errors are caught, not propagated).
129    ///
130    /// # Errors
131    ///
132    /// See trait-level documentation.
133    async fn fetch_optional<Q, T>(&self, query: Q, database: Option<&str>) -> Result<Option<T>, SqlError>
134    where
135        Q: IntoSafeQuery<Self::DB>,
136        T: for<'r> Decode<'r, Self::DB> + Type<Self::DB> + Send + Unpin,
137    {
138        let (sql, arguments) = query.into_sql_and_args()?;
139        let pool = self.pool(database).await?;
140        execute_with_timeout(self.query_timeout(), sql, |sql| async move {
141            let row = pool.fetch_optional((sql, arguments)).await?;
142            Ok(row.and_then(|r| r.try_get(0usize).ok()))
143        })
144        .await
145    }
146
147    /// Runs a query and extracts the first column of every row.
148    ///
149    /// # Errors
150    ///
151    /// See trait-level documentation.
152    async fn fetch_scalar<Q, T>(&self, query: Q, database: Option<&str>) -> Result<Vec<T>, SqlError>
153    where
154        Q: IntoSafeQuery<Self::DB>,
155        T: for<'r> Decode<'r, Self::DB> + Type<Self::DB> + Send + Unpin,
156    {
157        let (sql, arguments) = query.into_sql_and_args()?;
158        let pool = self.pool(database).await?;
159        execute_with_timeout(self.query_timeout(), sql, |sql| async move {
160            let rows = pool.fetch_all((sql, arguments)).await?;
161            rows.iter().map(|r| r.try_get(0usize)).collect()
162        })
163        .await
164    }
165
166    /// Runs a query and decodes every row into `T` via [`sqlx::FromRow`].
167    ///
168    /// # Errors
169    ///
170    /// See trait-level documentation. Row decode failures (column type
171    /// mismatch, malformed JSON inside a [`sqlx::types::Json`] column, etc.)
172    /// surface as [`SqlError::Query`].
173    async fn fetch<Q, T>(&self, query: Q, database: Option<&str>) -> Result<Vec<T>, SqlError>
174    where
175        Q: IntoSafeQuery<Self::DB>,
176        T: for<'r> FromRow<'r, <Self::DB as sqlx::Database>::Row> + Send + Unpin,
177    {
178        let (sql, arguments) = query.into_sql_and_args()?;
179        let pool = self.pool(database).await?;
180        execute_with_timeout(self.query_timeout(), sql, |sql| async move {
181            let rows = pool.fetch_all((sql, arguments)).await?;
182            rows.iter().map(T::from_row).collect()
183        })
184        .await
185    }
186}