axum_sql_viewer/database/
postgres.rs

1//! PostgreSQL database provider implementation
2
3use crate::database::traits::{DatabaseError, DatabaseProvider};
4use crate::schema::{
5    ColumnInfo, CountResponse, ForeignKey, IndexInfo, QueryResult, RowQuery, RowsResponse,
6    SortOrder, TableInfo, TableSchema,
7};
8use async_trait::async_trait;
9use sqlx::{postgres::PgRow, Column, PgPool, Row, TypeInfo};
10use std::collections::HashMap;
11
12/// PostgreSQL database provider
13pub struct PostgresProvider {
14    pool: PgPool,
15}
16
17impl PostgresProvider {
18    /// Create a new PostgreSQL provider
19    ///
20    /// # Arguments
21    ///
22    /// * `pool` - PostgreSQL connection pool
23    pub fn new(pool: PgPool) -> Self {
24        Self { pool }
25    }
26
27    /// Quote an identifier to prevent SQL injection
28    fn quote_identifier(identifier: &str) -> String {
29        format!("\"{}\"", identifier.replace("\"", "\"\""))
30    }
31
32    /// Convert a PostgreSQL row to a JSON object
33    fn row_to_json(row: &PgRow) -> Result<serde_json::Value, DatabaseError> {
34        let mut map = serde_json::Map::new();
35
36        for column in row.columns() {
37            let column_name = column.name();
38            let type_info = column.type_info();
39            let type_name = type_info.name();
40
41            let value: serde_json::Value = match type_name {
42                "BOOL" => {
43                    let val: Option<bool> = row.try_get(column_name)?;
44                    val.map(serde_json::Value::Bool).unwrap_or(serde_json::Value::Null)
45                }
46                "INT2" | "SMALLINT" | "SMALLSERIAL" => {
47                    let val: Option<i16> = row.try_get(column_name)?;
48                    val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
49                }
50                "INT4" | "INT" | "INTEGER" | "SERIAL" => {
51                    let val: Option<i32> = row.try_get(column_name)?;
52                    val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
53                }
54                "INT8" | "BIGINT" | "BIGSERIAL" => {
55                    let val: Option<i64> = row.try_get(column_name)?;
56                    val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
57                }
58                "FLOAT4" | "REAL" => {
59                    let val: Option<f32> = row.try_get(column_name)?;
60                    val.and_then(|v| serde_json::Number::from_f64(v as f64))
61                        .map(serde_json::Value::Number)
62                        .unwrap_or(serde_json::Value::Null)
63                }
64                "FLOAT8" | "DOUBLE PRECISION" => {
65                    let val: Option<f64> = row.try_get(column_name)?;
66                    val.and_then(serde_json::Number::from_f64)
67                        .map(serde_json::Value::Number)
68                        .unwrap_or(serde_json::Value::Null)
69                }
70                "TEXT" | "VARCHAR" | "CHAR" | "NAME" | "BPCHAR" => {
71                    let val: Option<String> = row.try_get(column_name)?;
72                    val.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null)
73                }
74                "BYTEA" => {
75                    let val: Option<Vec<u8>> = row.try_get(column_name)?;
76                    val.map(|bytes| {
77                        serde_json::Value::String(format!("[BLOB: {} bytes]", bytes.len()))
78                    }).unwrap_or(serde_json::Value::Null)
79                }
80                "TIMESTAMP" | "TIMESTAMPTZ" | "TIMESTAMP WITHOUT TIME ZONE" | "TIMESTAMP WITH TIME ZONE"
81                | "DATE" | "TIME" | "TIME WITHOUT TIME ZONE" => {
82                    // Try to get as string representation
83                    let val: Option<String> = row.try_get(column_name).ok().flatten();
84                    val.map(serde_json::Value::String)
85                        .unwrap_or(serde_json::Value::Null)
86                }
87                "JSON" | "JSONB" => {
88                    let val: Option<serde_json::Value> = row.try_get(column_name)?;
89                    val.unwrap_or(serde_json::Value::Null)
90                }
91                "UUID" => {
92                    // Try to get as string representation
93                    let val: Option<String> = row.try_get(column_name).ok().flatten();
94                    val.map(serde_json::Value::String)
95                        .unwrap_or(serde_json::Value::Null)
96                }
97                "NUMERIC" | "DECIMAL" => {
98                    // Try to get as string to preserve precision
99                    let val: Option<String> = row.try_get(column_name).ok().flatten();
100                    val.map(serde_json::Value::String)
101                        .unwrap_or(serde_json::Value::Null)
102                }
103                _ => {
104                    // Fallback: try to get as string
105                    let val: Option<String> = row.try_get(column_name).ok().flatten();
106                    val.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null)
107                }
108            };
109
110            map.insert(column_name.to_string(), value);
111        }
112
113        Ok(serde_json::Value::Object(map))
114    }
115
116    /// Build a WHERE clause from filters
117    fn build_where_clause(filters: &HashMap<String, String>, parameter_offset: i32) -> (String, Vec<String>) {
118        if filters.is_empty() {
119            return (String::new(), vec![]);
120        }
121
122        let mut conditions = Vec::new();
123        let mut values = Vec::new();
124        let mut param_index = parameter_offset;
125
126        for (column, filter_value) in filters {
127            let quoted_column = Self::quote_identifier(column);
128
129            if filter_value.contains('%') {
130                conditions.push(format!("{} LIKE ${}", quoted_column, param_index));
131            } else {
132                conditions.push(format!("{} = ${}", quoted_column, param_index));
133            }
134
135            values.push(filter_value.clone());
136            param_index += 1;
137        }
138
139        let where_clause = format!(" WHERE {}", conditions.join(" AND "));
140        (where_clause, values)
141    }
142}
143
144#[async_trait]
145impl DatabaseProvider for PostgresProvider {
146    async fn list_tables(&self) -> Result<Vec<TableInfo>, DatabaseError> {
147        let query = r#"
148            SELECT table_name
149            FROM information_schema.tables
150            WHERE table_schema = 'public'
151              AND table_type = 'BASE TABLE'
152            ORDER BY table_name
153        "#;
154
155        let rows = sqlx::query(query)
156            .fetch_all(&self.pool)
157            .await?;
158
159        let tables = rows
160            .iter()
161            .map(|row| {
162                let name: String = row.try_get("table_name")?;
163                Ok(TableInfo {
164                    name,
165                    row_count: None,
166                })
167            })
168            .collect::<Result<Vec<_>, sqlx::Error>>()?;
169
170        Ok(tables)
171    }
172
173    async fn get_table_schema(&self, table: &str) -> Result<TableSchema, DatabaseError> {
174        // Get column information
175        let column_query = r#"
176            SELECT
177                column_name,
178                data_type,
179                is_nullable,
180                column_default,
181                udt_name
182            FROM information_schema.columns
183            WHERE table_schema = 'public'
184              AND table_name = $1
185            ORDER BY ordinal_position
186        "#;
187
188        let column_rows = sqlx::query(column_query)
189            .bind(table)
190            .fetch_all(&self.pool)
191            .await?;
192
193        if column_rows.is_empty() {
194            return Err(DatabaseError::TableNotFound(table.to_string()));
195        }
196
197        // Get primary key columns
198        let pk_query = r#"
199            SELECT kcu.column_name
200            FROM information_schema.table_constraints tc
201            JOIN information_schema.key_column_usage kcu
202              ON tc.constraint_name = kcu.constraint_name
203              AND tc.table_schema = kcu.table_schema
204            WHERE tc.table_schema = 'public'
205              AND tc.table_name = $1
206              AND tc.constraint_type = 'PRIMARY KEY'
207            ORDER BY kcu.ordinal_position
208        "#;
209
210        let pk_rows = sqlx::query(pk_query)
211            .bind(table)
212            .fetch_all(&self.pool)
213            .await?;
214
215        let primary_key_columns: Vec<String> = pk_rows
216            .iter()
217            .map(|row| row.try_get::<String, _>("column_name"))
218            .collect::<Result<Vec<_>, _>>()?;
219
220        let primary_key = if primary_key_columns.is_empty() {
221            None
222        } else {
223            Some(primary_key_columns.clone())
224        };
225
226        // Get foreign keys
227        let fk_query = r#"
228            SELECT
229                kcu.column_name,
230                ccu.table_name AS references_table,
231                ccu.column_name AS references_column
232            FROM information_schema.table_constraints tc
233            JOIN information_schema.key_column_usage kcu
234              ON tc.constraint_name = kcu.constraint_name
235              AND tc.table_schema = kcu.table_schema
236            JOIN information_schema.constraint_column_usage ccu
237              ON ccu.constraint_name = tc.constraint_name
238              AND ccu.table_schema = tc.table_schema
239            WHERE tc.table_schema = 'public'
240              AND tc.table_name = $1
241              AND tc.constraint_type = 'FOREIGN KEY'
242        "#;
243
244        let fk_rows = sqlx::query(fk_query)
245            .bind(table)
246            .fetch_all(&self.pool)
247            .await?;
248
249        let foreign_keys: Vec<ForeignKey> = fk_rows
250            .iter()
251            .map(|row| {
252                Ok(ForeignKey {
253                    column: row.try_get("column_name")?,
254                    references_table: row.try_get("references_table")?,
255                    references_column: row.try_get("references_column")?,
256                })
257            })
258            .collect::<Result<Vec<_>, sqlx::Error>>()?;
259
260        // Get indexes
261        let index_query = r#"
262            SELECT
263                i.indexname AS index_name,
264                i.indexdef AS index_definition
265            FROM pg_indexes i
266            WHERE i.schemaname = 'public'
267              AND i.tablename = $1
268              AND i.indexname NOT IN (
269                SELECT constraint_name
270                FROM information_schema.table_constraints
271                WHERE table_schema = 'public'
272                  AND table_name = $1
273                  AND constraint_type = 'PRIMARY KEY'
274              )
275        "#;
276
277        let index_rows = sqlx::query(index_query)
278            .bind(table)
279            .fetch_all(&self.pool)
280            .await?;
281
282        let indexes: Vec<IndexInfo> = index_rows
283            .iter()
284            .map(|row| {
285                let index_name: String = row.try_get("index_name")?;
286                let index_definition: String = row.try_get("index_definition")?;
287
288                // Parse column names from index definition (simplified)
289                // This is a basic implementation - could be enhanced
290                let columns = vec![]; // Would need proper parsing of index_definition
291
292                let unique = index_definition.to_uppercase().contains("UNIQUE");
293
294                Ok(IndexInfo {
295                    name: index_name,
296                    columns,
297                    unique,
298                })
299            })
300            .collect::<Result<Vec<_>, sqlx::Error>>()?;
301
302        // Build column info
303        let columns: Vec<ColumnInfo> = column_rows
304            .iter()
305            .map(|row| {
306                let column_name: String = row.try_get("column_name")?;
307                let data_type: String = row.try_get("data_type")?;
308                let is_nullable: String = row.try_get("is_nullable")?;
309                let column_default: Option<String> = row.try_get("column_default")?;
310
311                Ok(ColumnInfo {
312                    name: column_name.clone(),
313                    data_type,
314                    nullable: is_nullable == "YES",
315                    default_value: column_default,
316                    is_primary_key: primary_key_columns.contains(&column_name),
317                })
318            })
319            .collect::<Result<Vec<_>, sqlx::Error>>()?;
320
321        Ok(TableSchema {
322            name: table.to_string(),
323            columns,
324            primary_key,
325            foreign_keys,
326            indexes,
327        })
328    }
329
330    async fn get_rows(&self, table: &str, query: RowQuery) -> Result<RowsResponse, DatabaseError> {
331        // Validate table exists and get columns
332        let schema = self.get_table_schema(table).await?;
333        let column_names: Vec<String> = schema.columns.iter().map(|c| c.name.clone()).collect();
334
335        // Build base query
336        let quoted_table = Self::quote_identifier(table);
337        let mut sql = format!("SELECT * FROM {}", quoted_table);
338
339        // Add WHERE clause for filters
340        let (where_clause, filter_values) = Self::build_where_clause(&query.filters, 1);
341        sql.push_str(&where_clause);
342
343        // Add ORDER BY clause
344        if let Some(sort_column) = &query.sort_by {
345            // Validate sort column exists
346            if !column_names.contains(sort_column) {
347                return Err(DatabaseError::InvalidColumn(sort_column.clone()));
348            }
349
350            let quoted_sort = Self::quote_identifier(sort_column);
351            let sort_direction = match query.sort_order {
352                Some(SortOrder::Descending) => "DESC",
353                _ => "ASC",
354            };
355            sql.push_str(&format!(" ORDER BY {} {}", quoted_sort, sort_direction));
356        }
357
358        // Add LIMIT and OFFSET
359        let limit = query.limit.min(500); // Cap at 500 as per spec
360        sql.push_str(&format!(" LIMIT {} OFFSET {}", limit, query.offset));
361
362        // Execute query
363        let mut query_builder = sqlx::query(&sql);
364        for value in &filter_values {
365            query_builder = query_builder.bind(value);
366        }
367
368        let rows = query_builder.fetch_all(&self.pool).await?;
369
370        // Convert rows to JSON
371        let json_rows: Vec<serde_json::Value> = rows
372            .iter()
373            .map(Self::row_to_json)
374            .collect::<Result<Vec<_>, _>>()?;
375
376        // Get total count
377        let count_result = self.count_rows(table, &query).await?;
378        let total = count_result.count;
379
380        let has_more = query.offset + (json_rows.len() as u64) < total;
381
382        Ok(RowsResponse {
383            rows: json_rows,
384            columns: column_names,
385            total,
386            offset: query.offset,
387            limit,
388            has_more,
389        })
390    }
391
392    async fn count_rows(&self, table: &str, query: &RowQuery) -> Result<CountResponse, DatabaseError> {
393        let quoted_table = Self::quote_identifier(table);
394        let mut sql = format!("SELECT COUNT(*) as count FROM {}", quoted_table);
395
396        // Add WHERE clause for filters
397        let (where_clause, filter_values) = Self::build_where_clause(&query.filters, 1);
398        sql.push_str(&where_clause);
399
400        // Execute query
401        let mut query_builder = sqlx::query(&sql);
402        for value in &filter_values {
403            query_builder = query_builder.bind(value);
404        }
405
406        let row = query_builder.fetch_one(&self.pool).await?;
407        let count: i64 = row.try_get("count")?;
408
409        Ok(CountResponse {
410            count: count as u64,
411        })
412    }
413
414    async fn execute_query(&self, sql: &str) -> Result<QueryResult, DatabaseError> {
415        let start_time = std::time::Instant::now();
416
417        // Try to execute as a query that returns rows (SELECT)
418        let result = sqlx::query(sql).fetch_all(&self.pool).await;
419
420        let execution_time_milliseconds = start_time.elapsed().as_millis() as u64;
421
422        match result {
423            Ok(rows) => {
424                if rows.is_empty() {
425                    // Could be a DML query (INSERT/UPDATE/DELETE) or SELECT with no results
426                    // Try to get affected rows count
427                    Ok(QueryResult {
428                        columns: vec![],
429                        rows: vec![],
430                        affected_rows: 0,
431                        execution_time_milliseconds,
432                        error: None,
433                    })
434                } else {
435                    // SELECT query with results
436                    let columns: Vec<String> = rows[0]
437                        .columns()
438                        .iter()
439                        .map(|col| col.name().to_string())
440                        .collect();
441
442                    let json_rows: Vec<serde_json::Value> = rows
443                        .iter()
444                        .map(Self::row_to_json)
445                        .collect::<Result<Vec<_>, _>>()?;
446
447                    // Apply row limit
448                    let max_rows = 10000;
449                    if json_rows.len() > max_rows {
450                        return Err(DatabaseError::TooManyRows(max_rows as u64));
451                    }
452
453                    Ok(QueryResult {
454                        columns,
455                        rows: json_rows,
456                        affected_rows: 0,
457                        execution_time_milliseconds,
458                        error: None,
459                    })
460                }
461            }
462            Err(error) => {
463                // Return error in result
464                Ok(QueryResult {
465                    columns: vec![],
466                    rows: vec![],
467                    affected_rows: 0,
468                    execution_time_milliseconds,
469                    error: Some(error.to_string()),
470                })
471            }
472        }
473    }
474}