use crate::config::{DatabaseBackend, SqlxConfig};
use crate::error::SqlxResult;
use crate::pool::SqlxPool;
use crate::row::SqlxRow;
use crate::types::quote_identifier;
use prax_query::QueryResult;
use prax_query::filter::FilterValue;
use prax_query::traits::{BoxFuture, Model, QueryEngine};
use sqlx::Row;
use std::sync::Arc;
use tracing::debug;
#[derive(Clone)]
pub struct SqlxEngine {
pool: Arc<SqlxPool>,
backend: DatabaseBackend,
}
impl SqlxEngine {
pub async fn new(config: SqlxConfig) -> SqlxResult<Self> {
let backend = config.backend;
let pool = SqlxPool::connect(&config).await?;
Ok(Self {
pool: Arc::new(pool),
backend,
})
}
pub fn from_pool(pool: SqlxPool) -> Self {
let backend = pool.backend();
Self {
pool: Arc::new(pool),
backend,
}
}
pub fn backend(&self) -> DatabaseBackend {
self.backend
}
pub fn pool(&self) -> &SqlxPool {
&self.pool
}
pub async fn close(&self) {
self.pool.close().await;
}
pub async fn raw_query_many(
&self,
sql: &str,
params: &[FilterValue],
) -> SqlxResult<Vec<SqlxRow>> {
debug!(sql = %sql, "Executing raw_query_many");
match &*self.pool {
#[cfg(feature = "postgres")]
SqlxPool::Postgres(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_pg_param(query, param);
}
let rows = query.fetch_all(pool).await?;
Ok(rows.into_iter().map(SqlxRow::Postgres).collect())
}
#[cfg(feature = "mysql")]
SqlxPool::MySql(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_mysql_param(query, param);
}
let rows = query.fetch_all(pool).await?;
Ok(rows.into_iter().map(SqlxRow::MySql).collect())
}
#[cfg(feature = "sqlite")]
SqlxPool::Sqlite(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_sqlite_param(query, param);
}
let rows = query.fetch_all(pool).await?;
Ok(rows.into_iter().map(SqlxRow::Sqlite).collect())
}
}
}
pub async fn raw_query_one(&self, sql: &str, params: &[FilterValue]) -> SqlxResult<SqlxRow> {
debug!(sql = %sql, "Executing raw_query_one");
match &*self.pool {
#[cfg(feature = "postgres")]
SqlxPool::Postgres(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_pg_param(query, param);
}
let row = query.fetch_one(pool).await?;
Ok(SqlxRow::Postgres(row))
}
#[cfg(feature = "mysql")]
SqlxPool::MySql(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_mysql_param(query, param);
}
let row = query.fetch_one(pool).await?;
Ok(SqlxRow::MySql(row))
}
#[cfg(feature = "sqlite")]
SqlxPool::Sqlite(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_sqlite_param(query, param);
}
let row = query.fetch_one(pool).await?;
Ok(SqlxRow::Sqlite(row))
}
}
}
pub async fn raw_query_optional(
&self,
sql: &str,
params: &[FilterValue],
) -> SqlxResult<Option<SqlxRow>> {
debug!(sql = %sql, "Executing raw_query_optional");
match &*self.pool {
#[cfg(feature = "postgres")]
SqlxPool::Postgres(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_pg_param(query, param);
}
let row = query.fetch_optional(pool).await?;
Ok(row.map(SqlxRow::Postgres))
}
#[cfg(feature = "mysql")]
SqlxPool::MySql(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_mysql_param(query, param);
}
let row = query.fetch_optional(pool).await?;
Ok(row.map(SqlxRow::MySql))
}
#[cfg(feature = "sqlite")]
SqlxPool::Sqlite(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_sqlite_param(query, param);
}
let row = query.fetch_optional(pool).await?;
Ok(row.map(SqlxRow::Sqlite))
}
}
}
pub async fn raw_execute(&self, sql: &str, params: &[FilterValue]) -> SqlxResult<u64> {
debug!(sql = %sql, "Executing raw_execute");
match &*self.pool {
#[cfg(feature = "postgres")]
SqlxPool::Postgres(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_pg_param(query, param);
}
let result = query.execute(pool).await?;
Ok(result.rows_affected())
}
#[cfg(feature = "mysql")]
SqlxPool::MySql(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_mysql_param(query, param);
}
let result = query.execute(pool).await?;
Ok(result.rows_affected())
}
#[cfg(feature = "sqlite")]
SqlxPool::Sqlite(pool) => {
let mut query = sqlx::query(sql);
for param in params {
query = bind_sqlite_param(query, param);
}
let result = query.execute(pool).await?;
Ok(result.rows_affected())
}
}
}
pub async fn count_table(&self, table: &str, filter: Option<&str>) -> SqlxResult<u64> {
let table = quote_identifier(self.backend, table);
let sql = match filter {
Some(f) => format!("SELECT COUNT(*) as count FROM {} WHERE {}", table, f),
None => format!("SELECT COUNT(*) as count FROM {}", table),
};
let row = self.raw_query_one(&sql, &[]).await?;
match row {
#[cfg(feature = "postgres")]
SqlxRow::Postgres(r) => Ok(r.try_get::<i64, _>("count")? as u64),
#[cfg(feature = "mysql")]
SqlxRow::MySql(r) => Ok(r.try_get::<i64, _>("count")? as u64),
#[cfg(feature = "sqlite")]
SqlxRow::Sqlite(r) => Ok(r.try_get::<i64, _>("count")? as u64),
}
}
}
#[cfg(feature = "postgres")]
fn bind_pg_param<'q>(
query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
value: &'q FilterValue,
) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
match value {
FilterValue::String(s) => query.bind(s.as_str()),
FilterValue::Int(i) => query.bind(*i),
FilterValue::Float(f) => query.bind(*f),
FilterValue::Bool(b) => query.bind(*b),
FilterValue::Null => query.bind(Option::<String>::None),
FilterValue::Json(j) => query.bind(j.clone()),
FilterValue::List(arr) => {
let json = serde_json::to_value(arr).unwrap_or(serde_json::Value::Null);
query.bind(json)
}
}
}
#[cfg(feature = "mysql")]
fn bind_mysql_param<'q>(
query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
value: &'q FilterValue,
) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
match value {
FilterValue::String(s) => query.bind(s.as_str()),
FilterValue::Int(i) => query.bind(*i),
FilterValue::Float(f) => query.bind(*f),
FilterValue::Bool(b) => query.bind(*b),
FilterValue::Null => query.bind(Option::<String>::None),
FilterValue::Json(j) => query.bind(j.to_string()),
FilterValue::List(arr) => {
let json = serde_json::to_string(arr).unwrap_or_default();
query.bind(json)
}
}
}
#[cfg(feature = "sqlite")]
fn bind_sqlite_param<'q>(
query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
value: &'q FilterValue,
) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
match value {
FilterValue::String(s) => query.bind(s.as_str()),
FilterValue::Int(i) => query.bind(*i),
FilterValue::Float(f) => query.bind(*f),
FilterValue::Bool(b) => query.bind(*b),
FilterValue::Null => query.bind(Option::<String>::None),
FilterValue::Json(j) => query.bind(j.to_string()),
FilterValue::List(arr) => {
let json = serde_json::to_string(arr).unwrap_or_default();
query.bind(json)
}
}
}
impl QueryEngine for SqlxEngine {
fn dialect(&self) -> &dyn prax_query::dialect::SqlDialect {
match self.backend {
DatabaseBackend::Postgres => &prax_query::dialect::Postgres,
DatabaseBackend::MySql => &prax_query::dialect::Mysql,
DatabaseBackend::Sqlite => &prax_query::dialect::Sqlite,
}
}
fn query_many<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Vec<T>>> {
let sql = sql.to_string();
Box::pin(async move {
debug!(sql = %sql, "Executing query_many via QueryEngine");
let rows = self
.raw_query_many(&sql, ¶ms)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()))?;
rows.iter()
.map(|r| {
let rr = crate::row_ref::SqlxRowRef::from_sqlx(r).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})?;
T::from_row(&rr).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})
})
.collect()
})
}
fn query_one<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<T>> {
let sql = sql.to_string();
Box::pin(async move {
debug!(sql = %sql, "Executing query_one via QueryEngine");
let row = self.raw_query_one(&sql, ¶ms).await.map_err(|e| {
let msg = e.to_string();
if msg.contains("no rows") {
prax_query::QueryError::not_found(T::MODEL_NAME)
} else {
prax_query::QueryError::database(msg)
}
})?;
let rr = crate::row_ref::SqlxRowRef::from_sqlx(&row).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})?;
T::from_row(&rr).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})
})
}
fn query_optional<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Option<T>>> {
let sql = sql.to_string();
Box::pin(async move {
debug!(sql = %sql, "Executing query_optional via QueryEngine");
let row = self
.raw_query_optional(&sql, ¶ms)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()))?;
match row {
Some(r) => {
let rr = crate::row_ref::SqlxRowRef::from_sqlx(&r).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})?;
T::from_row(&rr).map(Some).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})
}
None => Ok(None),
}
})
}
fn execute_insert<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<T>> {
let sql = sql.to_string();
Box::pin(async move {
debug!(sql = %sql, "Executing execute_insert via QueryEngine");
let row = self
.raw_query_one(&sql, ¶ms)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()))?;
let rr = crate::row_ref::SqlxRowRef::from_sqlx(&row).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})?;
T::from_row(&rr).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})
})
}
fn execute_update<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Vec<T>>> {
let sql = sql.to_string();
Box::pin(async move {
debug!(sql = %sql, "Executing execute_update via QueryEngine");
let rows = self
.raw_query_many(&sql, ¶ms)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()))?;
rows.iter()
.map(|r| {
let rr = crate::row_ref::SqlxRowRef::from_sqlx(r).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})?;
T::from_row(&rr).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::deserialization(msg).with_source(e)
})
})
.collect()
})
}
fn execute_delete(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<u64>> {
let sql = sql.to_string();
Box::pin(async move {
debug!(sql = %sql, "Executing execute_delete via QueryEngine");
let affected = self
.raw_execute(&sql, ¶ms)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()))?;
Ok(affected)
})
}
fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
let sql = sql.to_string();
Box::pin(async move {
debug!(sql = %sql, "Executing execute_raw via QueryEngine");
let affected = self
.raw_execute(&sql, ¶ms)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()))?;
Ok(affected)
})
}
fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
let sql = sql.to_string();
Box::pin(async move {
debug!(sql = %sql, "Executing count via QueryEngine");
let row = self
.raw_query_one(&sql, ¶ms)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()))?;
let count = match row {
#[cfg(feature = "postgres")]
SqlxRow::Postgres(r) => r
.try_get::<i64, _>(0)
.map_err(|e| prax_query::QueryError::database(e.to_string()))?
as u64,
#[cfg(feature = "mysql")]
SqlxRow::MySql(r) => r
.try_get::<i64, _>(0)
.map_err(|e| prax_query::QueryError::database(e.to_string()))?
as u64,
#[cfg(feature = "sqlite")]
SqlxRow::Sqlite(r) => r
.try_get::<i64, _>(0)
.map_err(|e| prax_query::QueryError::database(e.to_string()))?
as u64,
};
Ok(count)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::placeholder;
#[test]
fn test_placeholder_generation() {
assert_eq!(placeholder(DatabaseBackend::Postgres, 1), "$1");
assert_eq!(placeholder(DatabaseBackend::MySql, 1), "?");
assert_eq!(placeholder(DatabaseBackend::Sqlite, 1), "?");
}
#[test]
fn test_quote_identifier() {
assert_eq!(
quote_identifier(DatabaseBackend::Postgres, "users"),
"\"users\""
);
assert_eq!(quote_identifier(DatabaseBackend::MySql, "users"), "`users`");
}
}