use crate::client::backend::{
BackendError, BackendResult, BackendType, DatabaseBackend, HealthStatus, QueryLanguage,
QueryResult, TranslatedQuery,
};
use async_trait::async_trait;
use serde_json::{Map, Number, Value};
use sqlx::{Column, PgPool, Row, postgres::PgRow};
pub struct PostgresBackend {
pool: PgPool,
}
impl PostgresBackend {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub async fn from_connection_string(connection: &str) -> BackendResult<Self> {
let pool = PgPool::connect(connection)
.await
.map_err(|err| BackendError::Generic(err.to_string()))?;
Ok(Self { pool })
}
}
#[async_trait]
impl DatabaseBackend for PostgresBackend {
async fn execute_query(&self, query: TranslatedQuery) -> BackendResult<QueryResult> {
if !matches!(query.language, QueryLanguage::Sql) {
return Err(BackendError::Generic(
"Postgres backend only supports SQL".to_string(),
));
}
let rows = sqlx::query(&query.sql)
.fetch_all(&self.pool)
.await
.map_err(|err| BackendError::Generic(err.to_string()))?;
let mut data = Vec::new();
let column_names = rows
.first()
.map(|row| {
row.columns()
.iter()
.map(|col| col.name().to_string())
.collect()
})
.unwrap_or_default();
for row in &rows {
data.push(row_to_value(row)?);
}
Ok(QueryResult::new(data, column_names))
}
async fn health_check(&self) -> BackendResult<HealthStatus> {
match self.pool.acquire().await {
Ok(_) => Ok(HealthStatus::Healthy),
Err(err) => Err(BackendError::Generic(err.to_string())),
}
}
fn backend_type(&self) -> BackendType {
BackendType::PostgreSQL
}
fn supports_sql(&self) -> bool {
true
}
fn supports_cql(&self) -> bool {
false
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn row_to_value(row: &PgRow) -> BackendResult<Value> {
let mut object = Map::new();
for column in row.columns() {
let value = match read_column_value(row, column) {
Ok(value) => value,
Err(err) => {
return Err(BackendError::Generic(format!(
"failed to decode column {}: {err}",
column.name()
)));
}
};
object.insert(column.name().to_string(), value);
}
Ok(Value::Object(object))
}
fn read_column_value(row: &PgRow, column: &impl Column) -> BackendResult<Value> {
let name = column.name();
if let Ok(json_value) = row.try_get::<serde_json::Value, _>(name) {
return Ok(json_value);
}
if let Ok(text) = row.try_get::<String, _>(name) {
return Ok(Value::String(text));
}
if let Ok(i) = row.try_get::<i64, _>(name) {
return Ok(Value::Number(Number::from(i)));
}
if let Ok(f) = row.try_get::<f64, _>(name)
&& let Some(number) = Number::from_f64(f)
{
return Ok(Value::Number(number));
}
if let Ok(b) = row.try_get::<bool, _>(name) {
return Ok(Value::Bool(b));
}
if let Ok(bytes) = row.try_get::<Vec<u8>, _>(name) {
return Ok(Value::String(String::from_utf8_lossy(&bytes).to_string()));
}
Ok(Value::String("<binary>".to_string()))
}