rowdy-db 0.8.2

A fast, modern, and rowdy TUI database management tool written in Rust.
use async_trait::async_trait;
use sqlx::{
    postgres::{PgPool, PgRow},
    Column as SqlxColumn, Row as SqlxRow, TypeInfo, ValueRef,
};
use crate::db::error::DbError;
use crate::db::traits::SqlClient;
use crate::db::types::{Column, ColumnSchema, DbQueryResult, ForeignKey, Row, TableKind, TableObject, Value};

pub struct PostgresConnector {
    pool: Option<PgPool>,
}

impl PostgresConnector {
    pub fn new() -> Self {
        Self { pool: None }
    }

    fn pool(&self) -> Result<&PgPool, DbError> {
        self.pool.as_ref().ok_or(DbError::NotConnected)
    }
}

#[async_trait]
impl SqlClient for PostgresConnector {
    async fn connect(&mut self, url: &str) -> Result<(), DbError> {
        let pool = PgPool::connect(url)
            .await
            .map_err(|e| DbError::ConnectionFailed(e.to_string()))?;
        self.pool = Some(pool);
        Ok(())
    }

    async fn disconnect(&mut self) -> Result<(), DbError> {
        if let Some(pool) = self.pool.take() {
            pool.close().await;
        }
        Ok(())
    }

    async fn execute(&self, query: &str) -> Result<u64, DbError> {
        let result = sqlx::query(query)
            .execute(self.pool()?)
            .await
            .map_err(|e| DbError::QueryFailed(e.to_string()))?;
        Ok(result.rows_affected())
    }

    async fn fetch_all(&self, query: &str) -> Result<DbQueryResult, DbError> {
        let rows: Vec<PgRow> = sqlx::query(query)
            .fetch_all(self.pool()?)
            .await
            .map_err(|e| DbError::QueryFailed(e.to_string()))?;

        if rows.is_empty() {
            return Ok(DbQueryResult {
                columns: vec![],
                rows: vec![],
                rows_affected: 0,
            });
        }

        let columns: Vec<Column> = rows[0]
            .columns()
            .iter()
            .map(|c| Column {
                name: c.name().to_string(),
                type_name: c.type_info().name().to_string(),
            })
            .collect();

        let mapped_rows: Vec<Row> = rows
            .iter()
            .map(|r| Row {
                values: (0..r.len()).map(|i| pg_value(r, i)).collect(),
            })
            .collect();

        let count = mapped_rows.len() as u64;
        Ok(DbQueryResult {
            columns,
            rows: mapped_rows,
            rows_affected: count,
        })
    }

    async fn get_tables(&self) -> Result<Vec<String>, DbError> {
        let rows: Vec<PgRow> = sqlx::query(
            "SELECT table_name FROM information_schema.tables \
             WHERE table_schema = 'public' AND table_type = 'BASE TABLE' \
             ORDER BY table_name",
        )
        .fetch_all(self.pool()?)
        .await
        .map_err(|e| DbError::QueryFailed(e.to_string()))?;

        Ok(rows
            .iter()
            .map(|r| r.try_get::<String, _>(0).unwrap_or_default())
            .collect())
    }

    async fn get_table_objects(&self) -> Result<Vec<TableObject>, DbError> {
        let rows: Vec<PgRow> = sqlx::query(
            "SELECT table_name, table_type FROM information_schema.tables \
             WHERE table_schema = 'public' AND table_type IN ('BASE TABLE', 'VIEW') \
             ORDER BY table_name",
        )
        .fetch_all(self.pool()?)
        .await
        .map_err(|e| DbError::QueryFailed(e.to_string()))?;

        Ok(rows.iter().map(|r| {
            let name = r.try_get::<String, _>(0).unwrap_or_default();
            let type_str = r.try_get::<String, _>(1).unwrap_or_default();
            let kind = if type_str == "VIEW" { TableKind::View } else { TableKind::Table };
            TableObject { name, kind }
        }).collect())
    }

    async fn get_schema(&self, table: &str) -> Result<Vec<ColumnSchema>, DbError> {
        let pool = self.pool()?;
        let safe = table.replace('\'', "").replace('"', "");
        // Correlated subqueries avoid JOIN fan-out and the array_position()
        // type-resolution edge-cases that silently broke the previous query.
        let query = format!(
            r#"SELECT
                a.attname::text AS column_name,
                pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
                (NOT a.attnotnull) AS is_nullable,
                EXISTS(
                    SELECT 1 FROM pg_constraint c
                    WHERE c.conrelid = a.attrelid
                      AND c.contype  = 'p'
                      AND a.attnum   = ANY(c.conkey)
                ) AS is_pk,
                (
                    SELECT rc.relname::text
                    FROM pg_constraint c
                    JOIN pg_class rc ON rc.oid = c.confrelid
                    WHERE c.conrelid = a.attrelid
                      AND c.contype  = 'f'
                      AND a.attnum   = ANY(c.conkey)
                    LIMIT 1
                ) AS foreign_table_name,
                (
                    SELECT ra.attname::text
                    FROM pg_constraint c
                    JOIN pg_attribute ra ON ra.attrelid = c.confrelid
                                        AND ra.attnum   = c.confkey[1]
                    WHERE c.conrelid = a.attrelid
                      AND c.contype  = 'f'
                      AND a.attnum   = ANY(c.conkey)
                    LIMIT 1
                ) AS foreign_column_name
            FROM pg_attribute a
            JOIN pg_class     cl ON cl.oid = a.attrelid
            JOIN pg_namespace n  ON n.oid  = cl.relnamespace
            WHERE n.nspname  = 'public'
              AND cl.relname = '{safe}'
              AND a.attnum   > 0
              AND NOT a.attisdropped
            ORDER BY a.attnum"#
        );

        let rows: Vec<PgRow> = sqlx::query(&query)
            .fetch_all(pool)
            .await
            .map_err(|e| DbError::QueryFailed(e.to_string()))?;

        let mut schema = vec![];
        for row in &rows {
            let name: String = row.try_get("column_name").unwrap_or_default();
            let type_name: String = row.try_get("data_type").unwrap_or_default();
            let is_nullable: bool = row.try_get("is_nullable").unwrap_or(true);
            let is_pk: bool = row.try_get("is_pk").unwrap_or(false);
            let fk_table: Option<String> = row.try_get("foreign_table_name").unwrap_or(None);
            let fk_col: Option<String> = row.try_get("foreign_column_name").unwrap_or(None);
            let fk = match (fk_table, fk_col) {
                (Some(t), Some(c)) if !t.is_empty() => Some(ForeignKey { table: t, column: c }),
                _ => None,
            };
            schema.push(ColumnSchema { name, type_name, is_pk, is_nullable, fk });
        }
        Ok(schema)
    }
}

fn pg_value(row: &PgRow, index: usize) -> Value {
    use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
    let raw = row.try_get_raw(index).unwrap();
    if raw.is_null() {
        return Value::Null;
    }
    let tn = raw.type_info().name().to_string();
    let marker = || Value::Text(format!("<?{tn}>"));
    match tn.as_str() {
        "BOOL"           => row.try_get::<bool, _>(index).map(Value::Bool).unwrap_or_else(|_| marker()),
        // Each integer width requires its own Rust type in sqlx
        "INT2"           => row.try_get::<i16, _>(index).map(|v| Value::Int(v as i64)).unwrap_or_else(|_| marker()),
        "INT4"           => row.try_get::<i32, _>(index).map(|v| Value::Int(v as i64)).unwrap_or_else(|_| marker()),
        "INT8"           => row.try_get::<i64, _>(index).map(Value::Int).unwrap_or_else(|_| marker()),
        "FLOAT4"         => row.try_get::<f32, _>(index).map(|v| Value::Float(v as f64)).unwrap_or_else(|_| marker()),
        "FLOAT8"         => row.try_get::<f64, _>(index).map(Value::Float).unwrap_or_else(|_| marker()),
        "NUMERIC"        => row.try_get::<bigdecimal::BigDecimal, _>(index)
                               .map(|d| Value::Text(crate::db::types::format_decimal(d)))
                               .unwrap_or_else(|_| marker()),
        "BYTEA"          => row.try_get::<Vec<u8>, _>(index).map(Value::Bytes).unwrap_or_else(|_| marker()),
        // Dates and times
        "DATE"           => row.try_get::<NaiveDate, _>(index).map(|d| Value::Text(d.to_string())).unwrap_or_else(|_| marker()),
        "TIME"           => row.try_get::<NaiveTime, _>(index).map(|t| Value::Text(t.to_string())).unwrap_or_else(|_| marker()),
        "TIMESTAMP"      => row.try_get::<NaiveDateTime, _>(index).map(|dt| Value::Text(dt.to_string())).unwrap_or_else(|_| marker()),
        "TIMESTAMPTZ"    => row.try_get::<DateTime<Utc>, _>(index).map(|dt| Value::Text(dt.to_rfc3339())).unwrap_or_else(|_| marker()),
        // UUID
        "UUID"           => row.try_get::<uuid::Uuid, _>(index).map(|u| Value::Text(u.to_string())).unwrap_or_else(|_| marker()),
        // JSON / JSONB — decoded via serde_json then serialised back to a compact string
        "JSON" | "JSONB" => row.try_get::<serde_json::Value, _>(index).map(|v| Value::Text(v.to_string())).unwrap_or_else(|_| marker()),
        // Arrays: OID names start with '_', e.g. _TEXT, _INT4, _UUID
        // Try Vec<String> first (text arrays), then Vec<i64> (int arrays), else marker
        s if s.starts_with('_') => {
            if let Ok(v) = row.try_get::<Vec<String>, _>(index) {
                return Value::Text(format!("[{}]", v.join(", ")));
            }
            if let Ok(v) = row.try_get::<Vec<i64>, _>(index) {
                return Value::Text(format!("[{}]", v.iter().map(|n| n.to_string()).collect::<Vec<_>>().join(", ")));
            }
            if let Ok(v) = row.try_get::<Vec<bool>, _>(index) {
                return Value::Text(format!("[{}]", v.iter().map(|b| b.to_string()).collect::<Vec<_>>().join(", ")));
            }
            marker()
        }
        // Text-like types (TEXT, VARCHAR, CHAR, BPCHAR, NAME, XML, INTERVAL, INET, CIDR, MACADDR…)
        _ => row.try_get::<String, _>(index).map(Value::Text).unwrap_or_else(|_| marker()),
    }
}