use database_mcp_server::AppError;
use serde_json::Value;
use sqlx::Executor;
use sqlx_to_json::QueryResult as _;
use crate::identifier;
use crate::timeout::execute_with_timeout;
#[allow(async_fn_in_trait)]
pub trait Connection: Send + Sync {
type DB: sqlx::Database;
const IDENTIFIER_QUOTE: char;
async fn pool(&self, target: Option<&str>) -> Result<sqlx::Pool<Self::DB>, AppError>;
fn query_timeout(&self) -> Option<u64>;
async fn execute(&self, query: &str, database: Option<&str>) -> Result<u64, AppError>
where
for<'c> &'c mut <Self::DB as sqlx::Database>::Connection: Executor<'c, Database = Self::DB>,
<Self::DB as sqlx::Database>::QueryResult: sqlx_to_json::QueryResult,
{
let pool = self.pool(database).await?;
let sql = query.to_owned();
execute_with_timeout(self.query_timeout(), query, async move {
let mut conn = pool.acquire().await?;
let result = (&mut *conn).execute(sql.as_str()).await?;
Ok::<_, sqlx::Error>(result.rows_affected())
})
.await
}
async fn fetch(&self, query: &str, database: Option<&str>) -> Result<Vec<Value>, AppError>
where
for<'c> &'c mut <Self::DB as sqlx::Database>::Connection: Executor<'c, Database = Self::DB>,
<Self::DB as sqlx::Database>::Row: sqlx_to_json::RowExt,
{
let pool = self.pool(database).await?;
let sql = query.to_owned();
execute_with_timeout(self.query_timeout(), query, async move {
let mut conn = pool.acquire().await?;
let rows = (&mut *conn).fetch_all(sql.as_str()).await?;
Ok::<_, sqlx::Error>(rows.iter().map(sqlx_to_json::RowExt::to_json).collect())
})
.await
}
async fn fetch_optional(&self, query: &str, database: Option<&str>) -> Result<Option<Value>, AppError>
where
for<'c> &'c mut <Self::DB as sqlx::Database>::Connection: Executor<'c, Database = Self::DB>,
<Self::DB as sqlx::Database>::Row: sqlx_to_json::RowExt,
{
let pool = self.pool(database).await?;
let sql = query.to_owned();
execute_with_timeout(self.query_timeout(), query, async move {
let mut conn = pool.acquire().await?;
let row = (&mut *conn).fetch_optional(sql.as_str()).await?;
Ok::<_, sqlx::Error>(row.as_ref().map(sqlx_to_json::RowExt::to_json))
})
.await
}
fn quote_identifier(&self, name: &str) -> String {
identifier::quote_identifier(name, Self::IDENTIFIER_QUOTE)
}
fn quote_string(&self, value: &str) -> String {
identifier::quote_string(value)
}
}