use async_trait::async_trait;
use sqlx::postgres::PgPool;
use sqlx::{Column as SqlxColumn, Row};
use crate::core::DatabaseAdapter;
use crate::core::error::{DbError, DbResult};
use crate::core::models::*;
pub struct PostgresAdapter {
pool: PgPool,
}
impl PostgresAdapter {
pub async fn connect(connection_string: &str) -> DbResult<Self> {
let pool = PgPool::connect(connection_string)
.await
.map_err(|e| DbError::ConnectionFailed(e.to_string()))?;
Ok(Self { pool })
}
}
#[async_trait]
impl DatabaseAdapter for PostgresAdapter {
fn name(&self) -> &str {
"PostgreSQL"
}
fn db_type(&self) -> DatabaseType {
DatabaseType::PostgreSQL
}
async fn get_schemas(&self) -> DbResult<Vec<Schema>> {
let rows = sqlx::query(
"SELECT schema_name FROM information_schema.schemata \
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') \
ORDER BY schema_name",
)
.fetch_all(&self.pool)
.await
.map_err(|e| DbError::QueryFailed(e.to_string()))?;
Ok(rows
.iter()
.map(|r| Schema {
name: r.get("schema_name"),
})
.collect())
}
async fn get_tables(&self, schema: &str) -> DbResult<Vec<Table>> {
let rows = sqlx::query(
"SELECT table_name FROM information_schema.tables \
WHERE table_schema = $1 AND table_type = 'BASE TABLE' \
ORDER BY table_name",
)
.bind(schema)
.fetch_all(&self.pool)
.await
.map_err(|e| DbError::QueryFailed(e.to_string()))?;
Ok(rows
.iter()
.map(|r| Table {
name: r.get("table_name"),
schema: schema.to_string(),
})
.collect())
}
async fn get_views(&self, schema: &str) -> DbResult<Vec<View>> {
let rows = sqlx::query(
"SELECT table_name FROM information_schema.views \
WHERE table_schema = $1 \
ORDER BY table_name",
)
.bind(schema)
.fetch_all(&self.pool)
.await
.map_err(|e| DbError::QueryFailed(e.to_string()))?;
Ok(rows
.iter()
.map(|r| View {
name: r.get("table_name"),
schema: schema.to_string(),
valid: true,
})
.collect())
}
async fn get_procedures(&self, schema: &str) -> DbResult<Vec<Procedure>> {
let rows = sqlx::query(
"SELECT routine_name FROM information_schema.routines \
WHERE routine_schema = $1 AND routine_type = 'PROCEDURE' \
ORDER BY routine_name",
)
.bind(schema)
.fetch_all(&self.pool)
.await
.map_err(|e| DbError::QueryFailed(e.to_string()))?;
Ok(rows
.iter()
.map(|r| Procedure {
name: r.get("routine_name"),
schema: schema.to_string(),
valid: true,
})
.collect())
}
async fn get_functions(&self, schema: &str) -> DbResult<Vec<Function>> {
let rows = sqlx::query(
"SELECT routine_name FROM information_schema.routines \
WHERE routine_schema = $1 AND routine_type = 'FUNCTION' \
ORDER BY routine_name",
)
.bind(schema)
.fetch_all(&self.pool)
.await
.map_err(|e| DbError::QueryFailed(e.to_string()))?;
Ok(rows
.iter()
.map(|r| Function {
name: r.get("routine_name"),
schema: schema.to_string(),
valid: true,
})
.collect())
}
async fn get_columns(&self, schema: &str, table: &str) -> DbResult<Vec<Column>> {
let rows = sqlx::query(
"SELECT c.column_name, c.data_type, c.is_nullable, \
CASE WHEN tc.constraint_type = 'PRIMARY KEY' THEN true ELSE false END as is_pk \
FROM information_schema.columns c \
LEFT JOIN information_schema.key_column_usage kcu \
ON c.table_schema = kcu.table_schema \
AND c.table_name = kcu.table_name \
AND c.column_name = kcu.column_name \
LEFT JOIN information_schema.table_constraints tc \
ON kcu.constraint_name = tc.constraint_name \
AND kcu.table_schema = tc.table_schema \
AND tc.constraint_type = 'PRIMARY KEY' \
WHERE c.table_schema = $1 AND c.table_name = $2 \
ORDER BY c.ordinal_position",
)
.bind(schema)
.bind(table)
.fetch_all(&self.pool)
.await
.map_err(|e| DbError::QueryFailed(e.to_string()))?;
Ok(rows
.iter()
.map(|r| {
let nullable_str: String = r.get("is_nullable");
Column {
name: r.get("column_name"),
data_type: r.get("data_type"),
nullable: nullable_str == "YES",
is_primary_key: r.get::<bool, _>("is_pk"),
}
})
.collect())
}
async fn execute(&self, query: &str) -> DbResult<QueryResult> {
let rows = sqlx::query(query)
.fetch_all(&self.pool)
.await
.map_err(|e| DbError::QueryFailed(e.to_string()))?;
if rows.is_empty() {
return Ok(QueryResult {
columns: vec![],
rows: vec![],
elapsed: None,
});
}
let columns: Vec<String> = rows[0]
.columns()
.iter()
.map(|c| c.name().to_string())
.collect();
let data: Vec<Vec<String>> = rows
.iter()
.map(|row| {
columns
.iter()
.enumerate()
.map(|(i, _)| {
row.try_get::<String, _>(i)
.or_else(|_| row.try_get::<i64, _>(i).map(|v| v.to_string()))
.or_else(|_| row.try_get::<f64, _>(i).map(|v| v.to_string()))
.or_else(|_| row.try_get::<bool, _>(i).map(|v| v.to_string()))
.unwrap_or_else(|_| "NULL".to_string())
})
.collect()
})
.collect();
Ok(QueryResult {
columns,
rows: data,
elapsed: None,
})
}
}