axum_sql_viewer/database/
postgres.rs

1//! PostgreSQL database provider implementation
2
3use crate::database::traits::{DatabaseError, DatabaseProvider};
4use crate::schema::{
5    ColumnInfo, CountResponse, ForeignKey, IndexInfo, QueryResult, RowQuery, RowsResponse,
6    SortOrder, TableInfo, TableSchema,
7};
8use async_trait::async_trait;
9use sqlx::{postgres::PgRow, Column, PgPool, Row, TypeInfo};
10use std::collections::HashMap;
11
12/// PostgreSQL database provider
13pub struct PostgresProvider {
14    pool: PgPool,
15}
16
17impl PostgresProvider {
18    /// Create a new PostgreSQL provider
19    ///
20    /// # Arguments
21    ///
22    /// * `pool` - PostgreSQL connection pool
23    pub fn new(pool: PgPool) -> Self {
24        Self { pool }
25    }
26
27    /// Quote an identifier to prevent SQL injection
28    fn quote_identifier(identifier: &str) -> String {
29        format!("\"{}\"", identifier.replace("\"", "\"\""))
30    }
31
32    /// Convert a PostgreSQL row to a JSON object
33    fn row_to_json(row: &PgRow) -> Result<serde_json::Value, DatabaseError> {
34        let mut map = serde_json::Map::new();
35
36        for column in row.columns() {
37            let column_name = column.name();
38            let type_info = column.type_info();
39            let type_name = type_info.name();
40
41            let value: serde_json::Value = match type_name {
42                "BOOL" => {
43                    let val: Option<bool> = row.try_get(column_name)?;
44                    val.map(serde_json::Value::Bool).unwrap_or(serde_json::Value::Null)
45                }
46                "INT2" | "SMALLINT" | "SMALLSERIAL" => {
47                    let val: Option<i16> = row.try_get(column_name)?;
48                    val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
49                }
50                "INT4" | "INT" | "INTEGER" | "SERIAL" => {
51                    let val: Option<i32> = row.try_get(column_name)?;
52                    val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
53                }
54                "INT8" | "BIGINT" | "BIGSERIAL" => {
55                    let val: Option<i64> = row.try_get(column_name)?;
56                    val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
57                }
58                "FLOAT4" | "REAL" => {
59                    let val: Option<f32> = row.try_get(column_name)?;
60                    val.and_then(|v| serde_json::Number::from_f64(v as f64))
61                        .map(serde_json::Value::Number)
62                        .unwrap_or(serde_json::Value::Null)
63                }
64                "FLOAT8" | "DOUBLE PRECISION" => {
65                    let val: Option<f64> = row.try_get(column_name)?;
66                    val.and_then(serde_json::Number::from_f64)
67                        .map(serde_json::Value::Number)
68                        .unwrap_or(serde_json::Value::Null)
69                }
70                "TEXT" | "VARCHAR" | "CHAR" | "NAME" | "BPCHAR" => {
71                    let val: Option<String> = row.try_get(column_name)?;
72                    val.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null)
73                }
74                "BYTEA" => {
75                    let val: Option<Vec<u8>> = row.try_get(column_name)?;
76                    val.map(|bytes| {
77                        serde_json::Value::String(format!("[BLOB: {} bytes]", bytes.len()))
78                    }).unwrap_or(serde_json::Value::Null)
79                }
80                "TIMESTAMP" | "TIMESTAMPTZ" | "TIMESTAMP WITHOUT TIME ZONE" | "TIMESTAMP WITH TIME ZONE"
81                | "DATE" | "TIME" | "TIME WITHOUT TIME ZONE" => {
82                    // Try to get as string representation
83                    let val: Option<String> = row.try_get(column_name).ok().flatten();
84                    val.map(serde_json::Value::String)
85                        .unwrap_or(serde_json::Value::Null)
86                }
87                "JSON" | "JSONB" => {
88                    let val: Option<serde_json::Value> = row.try_get(column_name)?;
89                    val.unwrap_or(serde_json::Value::Null)
90                }
91                "UUID" => {
92                    // Try to get as string representation
93                    let val: Option<String> = row.try_get(column_name).ok().flatten();
94                    val.map(serde_json::Value::String)
95                        .unwrap_or(serde_json::Value::Null)
96                }
97                "NUMERIC" | "DECIMAL" => {
98                    // Try to get as string to preserve precision
99                    let val: Option<String> = row.try_get(column_name).ok().flatten();
100                    val.map(serde_json::Value::String)
101                        .unwrap_or(serde_json::Value::Null)
102                }
103                _ => {
104                    // Fallback: try to get as string
105                    let val: Option<String> = row.try_get(column_name).ok().flatten();
106                    val.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null)
107                }
108            };
109
110            map.insert(column_name.to_string(), value);
111        }
112
113        Ok(serde_json::Value::Object(map))
114    }
115
116    /// Build a WHERE clause from filters
117    fn build_where_clause(filters: &HashMap<String, String>, parameter_offset: i32) -> (String, Vec<String>) {
118        if filters.is_empty() {
119            return (String::new(), vec![]);
120        }
121
122        let mut conditions = Vec::new();
123        let mut values = Vec::new();
124        let mut param_index = parameter_offset;
125
126        for (column, filter_value) in filters {
127            let quoted_column = Self::quote_identifier(column);
128
129            if filter_value.contains('%') {
130                conditions.push(format!("{} LIKE ${}", quoted_column, param_index));
131            } else {
132                conditions.push(format!("{} = ${}", quoted_column, param_index));
133            }
134
135            values.push(filter_value.clone());
136            param_index += 1;
137        }
138
139        let where_clause = format!(" WHERE {}", conditions.join(" AND "));
140        (where_clause, values)
141    }
142}
143
144#[async_trait]
145impl DatabaseProvider for PostgresProvider {
146    async fn list_tables(&self) -> Result<Vec<TableInfo>, DatabaseError> {
147        let query = r#"
148            SELECT table_name
149            FROM information_schema.tables
150            WHERE table_schema = 'public'
151              AND table_type = 'BASE TABLE'
152            ORDER BY table_name
153        "#;
154
155        let rows = sqlx::query(query)
156            .fetch_all(&self.pool)
157            .await?;
158
159        let mut tables = Vec::new();
160        for row in rows {
161            let name: String = row.try_get("table_name")?;
162
163            // Get row count for each table
164            let count_query = format!(
165                "SELECT COUNT(*) as count FROM {}",
166                Self::quote_identifier(&name)
167            );
168            let row_count: Option<u64> = sqlx::query_scalar(&count_query)
169                .fetch_one(&self.pool)
170                .await
171                .ok()
172                .map(|count: i64| count as u64);
173
174            tables.push(TableInfo { name, row_count });
175        }
176
177        Ok(tables)
178    }
179
180    async fn get_table_schema(&self, table: &str) -> Result<TableSchema, DatabaseError> {
181        // Get column information
182        let column_query = r#"
183            SELECT
184                column_name,
185                data_type,
186                is_nullable,
187                column_default,
188                udt_name
189            FROM information_schema.columns
190            WHERE table_schema = 'public'
191              AND table_name = $1
192            ORDER BY ordinal_position
193        "#;
194
195        let column_rows = sqlx::query(column_query)
196            .bind(table)
197            .fetch_all(&self.pool)
198            .await?;
199
200        if column_rows.is_empty() {
201            return Err(DatabaseError::TableNotFound(table.to_string()));
202        }
203
204        // Get primary key columns
205        let pk_query = r#"
206            SELECT kcu.column_name
207            FROM information_schema.table_constraints tc
208            JOIN information_schema.key_column_usage kcu
209              ON tc.constraint_name = kcu.constraint_name
210              AND tc.table_schema = kcu.table_schema
211            WHERE tc.table_schema = 'public'
212              AND tc.table_name = $1
213              AND tc.constraint_type = 'PRIMARY KEY'
214            ORDER BY kcu.ordinal_position
215        "#;
216
217        let pk_rows = sqlx::query(pk_query)
218            .bind(table)
219            .fetch_all(&self.pool)
220            .await?;
221
222        let primary_key_columns: Vec<String> = pk_rows
223            .iter()
224            .map(|row| row.try_get::<String, _>("column_name"))
225            .collect::<Result<Vec<_>, _>>()?;
226
227        let primary_key = if primary_key_columns.is_empty() {
228            None
229        } else {
230            Some(primary_key_columns.clone())
231        };
232
233        // Get foreign keys
234        let fk_query = r#"
235            SELECT
236                kcu.column_name,
237                ccu.table_name AS references_table,
238                ccu.column_name AS references_column
239            FROM information_schema.table_constraints tc
240            JOIN information_schema.key_column_usage kcu
241              ON tc.constraint_name = kcu.constraint_name
242              AND tc.table_schema = kcu.table_schema
243            JOIN information_schema.constraint_column_usage ccu
244              ON ccu.constraint_name = tc.constraint_name
245              AND ccu.table_schema = tc.table_schema
246            WHERE tc.table_schema = 'public'
247              AND tc.table_name = $1
248              AND tc.constraint_type = 'FOREIGN KEY'
249        "#;
250
251        let fk_rows = sqlx::query(fk_query)
252            .bind(table)
253            .fetch_all(&self.pool)
254            .await?;
255
256        let foreign_keys: Vec<ForeignKey> = fk_rows
257            .iter()
258            .map(|row| {
259                Ok(ForeignKey {
260                    column: row.try_get("column_name")?,
261                    references_table: row.try_get("references_table")?,
262                    references_column: row.try_get("references_column")?,
263                })
264            })
265            .collect::<Result<Vec<_>, sqlx::Error>>()?;
266
267        // Get indexes
268        let index_query = r#"
269            SELECT
270                i.indexname AS index_name,
271                i.indexdef AS index_definition
272            FROM pg_indexes i
273            WHERE i.schemaname = 'public'
274              AND i.tablename = $1
275              AND i.indexname NOT IN (
276                SELECT constraint_name
277                FROM information_schema.table_constraints
278                WHERE table_schema = 'public'
279                  AND table_name = $1
280                  AND constraint_type = 'PRIMARY KEY'
281              )
282        "#;
283
284        let index_rows = sqlx::query(index_query)
285            .bind(table)
286            .fetch_all(&self.pool)
287            .await?;
288
289        let indexes: Vec<IndexInfo> = index_rows
290            .iter()
291            .map(|row| {
292                let index_name: String = row.try_get("index_name")?;
293                let index_definition: String = row.try_get("index_definition")?;
294
295                // Parse column names from index definition (simplified)
296                // This is a basic implementation - could be enhanced
297                let columns = vec![]; // Would need proper parsing of index_definition
298
299                let unique = index_definition.to_uppercase().contains("UNIQUE");
300
301                Ok(IndexInfo {
302                    name: index_name,
303                    columns,
304                    unique,
305                })
306            })
307            .collect::<Result<Vec<_>, sqlx::Error>>()?;
308
309        // Build column info
310        let columns: Vec<ColumnInfo> = column_rows
311            .iter()
312            .map(|row| {
313                let column_name: String = row.try_get("column_name")?;
314                let data_type: String = row.try_get("data_type")?;
315                let is_nullable: String = row.try_get("is_nullable")?;
316                let column_default: Option<String> = row.try_get("column_default")?;
317
318                Ok(ColumnInfo {
319                    name: column_name.clone(),
320                    data_type,
321                    nullable: is_nullable == "YES",
322                    default_value: column_default,
323                    is_primary_key: primary_key_columns.contains(&column_name),
324                })
325            })
326            .collect::<Result<Vec<_>, sqlx::Error>>()?;
327
328        Ok(TableSchema {
329            name: table.to_string(),
330            columns,
331            primary_key,
332            foreign_keys,
333            indexes,
334        })
335    }
336
337    async fn get_rows(&self, table: &str, query: RowQuery) -> Result<RowsResponse, DatabaseError> {
338        // Validate table exists and get columns
339        let schema = self.get_table_schema(table).await?;
340        let column_names: Vec<String> = schema.columns.iter().map(|c| c.name.clone()).collect();
341
342        // Build base query
343        let quoted_table = Self::quote_identifier(table);
344        let mut sql = format!("SELECT * FROM {}", quoted_table);
345
346        // Add WHERE clause for filters
347        let (where_clause, filter_values) = Self::build_where_clause(&query.filters, 1);
348        sql.push_str(&where_clause);
349
350        // Add ORDER BY clause
351        if let Some(sort_column) = &query.sort_by {
352            // Validate sort column exists
353            if !column_names.contains(sort_column) {
354                return Err(DatabaseError::InvalidColumn(sort_column.clone()));
355            }
356
357            let quoted_sort = Self::quote_identifier(sort_column);
358            let sort_direction = match query.sort_order {
359                Some(SortOrder::Descending) => "DESC",
360                _ => "ASC",
361            };
362            sql.push_str(&format!(" ORDER BY {} {}", quoted_sort, sort_direction));
363        }
364
365        // Add LIMIT and OFFSET
366        let limit = query.limit.min(500); // Cap at 500 as per spec
367        sql.push_str(&format!(" LIMIT {} OFFSET {}", limit, query.offset));
368
369        // Execute query
370        let mut query_builder = sqlx::query(&sql);
371        for value in &filter_values {
372            query_builder = query_builder.bind(value);
373        }
374
375        let rows = query_builder.fetch_all(&self.pool).await?;
376
377        // Convert rows to JSON
378        let json_rows: Vec<serde_json::Value> = rows
379            .iter()
380            .map(Self::row_to_json)
381            .collect::<Result<Vec<_>, _>>()?;
382
383        // Get total count
384        let count_result = self.count_rows(table, &query).await?;
385        let total = count_result.count;
386
387        let has_more = query.offset + (json_rows.len() as u64) < total;
388
389        Ok(RowsResponse {
390            rows: json_rows,
391            columns: column_names,
392            total,
393            offset: query.offset,
394            limit,
395            has_more,
396        })
397    }
398
399    async fn count_rows(&self, table: &str, query: &RowQuery) -> Result<CountResponse, DatabaseError> {
400        let quoted_table = Self::quote_identifier(table);
401        let mut sql = format!("SELECT COUNT(*) as count FROM {}", quoted_table);
402
403        // Add WHERE clause for filters
404        let (where_clause, filter_values) = Self::build_where_clause(&query.filters, 1);
405        sql.push_str(&where_clause);
406
407        // Execute query
408        let mut query_builder = sqlx::query(&sql);
409        for value in &filter_values {
410            query_builder = query_builder.bind(value);
411        }
412
413        let row = query_builder.fetch_one(&self.pool).await?;
414        let count: i64 = row.try_get("count")?;
415
416        Ok(CountResponse {
417            count: count as u64,
418        })
419    }
420
421    async fn execute_query(&self, sql: &str) -> Result<QueryResult, DatabaseError> {
422        let start_time = std::time::Instant::now();
423
424        // Try to execute as a query that returns rows (SELECT)
425        let result = sqlx::query(sql).fetch_all(&self.pool).await;
426
427        let execution_time_milliseconds = start_time.elapsed().as_millis() as u64;
428
429        match result {
430            Ok(rows) => {
431                if rows.is_empty() {
432                    // Could be a DML query (INSERT/UPDATE/DELETE) or SELECT with no results
433                    // Try to get affected rows count
434                    Ok(QueryResult {
435                        columns: vec![],
436                        rows: vec![],
437                        affected_rows: 0,
438                        execution_time_milliseconds,
439                        error: None,
440                    })
441                } else {
442                    // SELECT query with results
443                    let columns: Vec<String> = rows[0]
444                        .columns()
445                        .iter()
446                        .map(|col| col.name().to_string())
447                        .collect();
448
449                    let json_rows: Vec<serde_json::Value> = rows
450                        .iter()
451                        .map(Self::row_to_json)
452                        .collect::<Result<Vec<_>, _>>()?;
453
454                    // Apply row limit
455                    let max_rows = 10000;
456                    if json_rows.len() > max_rows {
457                        return Err(DatabaseError::TooManyRows(max_rows as u64));
458                    }
459
460                    Ok(QueryResult {
461                        columns,
462                        rows: json_rows,
463                        affected_rows: 0,
464                        execution_time_milliseconds,
465                        error: None,
466                    })
467                }
468            }
469            Err(error) => {
470                // Return error in result
471                Ok(QueryResult {
472                    columns: vec![],
473                    rows: vec![],
474                    affected_rows: 0,
475                    execution_time_milliseconds,
476                    error: Some(error.to_string()),
477                })
478            }
479        }
480    }
481}