axum_sql_viewer/database/
postgres.rs1use crate::database::traits::{DatabaseError, DatabaseProvider};
4use crate::schema::{
5 ColumnInfo, CountResponse, ForeignKey, IndexInfo, QueryResult, RowQuery, RowsResponse,
6 SortOrder, TableInfo, TableSchema,
7};
8use async_trait::async_trait;
9use sqlx::{postgres::PgRow, Column, PgPool, Row, TypeInfo};
10use std::collections::HashMap;
11
12pub struct PostgresProvider {
14 pool: PgPool,
15}
16
17impl PostgresProvider {
18 pub fn new(pool: PgPool) -> Self {
24 Self { pool }
25 }
26
27 fn quote_identifier(identifier: &str) -> String {
29 format!("\"{}\"", identifier.replace("\"", "\"\""))
30 }
31
32 fn row_to_json(row: &PgRow) -> Result<serde_json::Value, DatabaseError> {
34 let mut map = serde_json::Map::new();
35
36 for column in row.columns() {
37 let column_name = column.name();
38 let type_info = column.type_info();
39 let type_name = type_info.name();
40
41 let value: serde_json::Value = match type_name {
42 "BOOL" => {
43 let val: Option<bool> = row.try_get(column_name)?;
44 val.map(serde_json::Value::Bool).unwrap_or(serde_json::Value::Null)
45 }
46 "INT2" | "SMALLINT" | "SMALLSERIAL" => {
47 let val: Option<i16> = row.try_get(column_name)?;
48 val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
49 }
50 "INT4" | "INT" | "INTEGER" | "SERIAL" => {
51 let val: Option<i32> = row.try_get(column_name)?;
52 val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
53 }
54 "INT8" | "BIGINT" | "BIGSERIAL" => {
55 let val: Option<i64> = row.try_get(column_name)?;
56 val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
57 }
58 "FLOAT4" | "REAL" => {
59 let val: Option<f32> = row.try_get(column_name)?;
60 val.and_then(|v| serde_json::Number::from_f64(v as f64))
61 .map(serde_json::Value::Number)
62 .unwrap_or(serde_json::Value::Null)
63 }
64 "FLOAT8" | "DOUBLE PRECISION" => {
65 let val: Option<f64> = row.try_get(column_name)?;
66 val.and_then(serde_json::Number::from_f64)
67 .map(serde_json::Value::Number)
68 .unwrap_or(serde_json::Value::Null)
69 }
70 "TEXT" | "VARCHAR" | "CHAR" | "NAME" | "BPCHAR" => {
71 let val: Option<String> = row.try_get(column_name)?;
72 val.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null)
73 }
74 "BYTEA" => {
75 let val: Option<Vec<u8>> = row.try_get(column_name)?;
76 val.map(|bytes| {
77 serde_json::Value::String(format!("[BLOB: {} bytes]", bytes.len()))
78 }).unwrap_or(serde_json::Value::Null)
79 }
80 "TIMESTAMP" | "TIMESTAMPTZ" | "TIMESTAMP WITHOUT TIME ZONE" | "TIMESTAMP WITH TIME ZONE"
81 | "DATE" | "TIME" | "TIME WITHOUT TIME ZONE" => {
82 let val: Option<String> = row.try_get(column_name).ok().flatten();
84 val.map(serde_json::Value::String)
85 .unwrap_or(serde_json::Value::Null)
86 }
87 "JSON" | "JSONB" => {
88 let val: Option<serde_json::Value> = row.try_get(column_name)?;
89 val.unwrap_or(serde_json::Value::Null)
90 }
91 "UUID" => {
92 let val: Option<String> = row.try_get(column_name).ok().flatten();
94 val.map(serde_json::Value::String)
95 .unwrap_or(serde_json::Value::Null)
96 }
97 "NUMERIC" | "DECIMAL" => {
98 let val: Option<String> = row.try_get(column_name).ok().flatten();
100 val.map(serde_json::Value::String)
101 .unwrap_or(serde_json::Value::Null)
102 }
103 _ => {
104 let val: Option<String> = row.try_get(column_name).ok().flatten();
106 val.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null)
107 }
108 };
109
110 map.insert(column_name.to_string(), value);
111 }
112
113 Ok(serde_json::Value::Object(map))
114 }
115
116 fn build_where_clause(filters: &HashMap<String, String>, parameter_offset: i32) -> (String, Vec<String>) {
118 if filters.is_empty() {
119 return (String::new(), vec![]);
120 }
121
122 let mut conditions = Vec::new();
123 let mut values = Vec::new();
124 let mut param_index = parameter_offset;
125
126 for (column, filter_value) in filters {
127 let quoted_column = Self::quote_identifier(column);
128
129 if filter_value.contains('%') {
130 conditions.push(format!("{} LIKE ${}", quoted_column, param_index));
131 } else {
132 conditions.push(format!("{} = ${}", quoted_column, param_index));
133 }
134
135 values.push(filter_value.clone());
136 param_index += 1;
137 }
138
139 let where_clause = format!(" WHERE {}", conditions.join(" AND "));
140 (where_clause, values)
141 }
142}
143
144#[async_trait]
145impl DatabaseProvider for PostgresProvider {
146 async fn list_tables(&self) -> Result<Vec<TableInfo>, DatabaseError> {
147 let query = r#"
148 SELECT table_name
149 FROM information_schema.tables
150 WHERE table_schema = 'public'
151 AND table_type = 'BASE TABLE'
152 ORDER BY table_name
153 "#;
154
155 let rows = sqlx::query(query)
156 .fetch_all(&self.pool)
157 .await?;
158
159 let tables = rows
160 .iter()
161 .map(|row| {
162 let name: String = row.try_get("table_name")?;
163 Ok(TableInfo {
164 name,
165 row_count: None,
166 })
167 })
168 .collect::<Result<Vec<_>, sqlx::Error>>()?;
169
170 Ok(tables)
171 }
172
173 async fn get_table_schema(&self, table: &str) -> Result<TableSchema, DatabaseError> {
174 let column_query = r#"
176 SELECT
177 column_name,
178 data_type,
179 is_nullable,
180 column_default,
181 udt_name
182 FROM information_schema.columns
183 WHERE table_schema = 'public'
184 AND table_name = $1
185 ORDER BY ordinal_position
186 "#;
187
188 let column_rows = sqlx::query(column_query)
189 .bind(table)
190 .fetch_all(&self.pool)
191 .await?;
192
193 if column_rows.is_empty() {
194 return Err(DatabaseError::TableNotFound(table.to_string()));
195 }
196
197 let pk_query = r#"
199 SELECT kcu.column_name
200 FROM information_schema.table_constraints tc
201 JOIN information_schema.key_column_usage kcu
202 ON tc.constraint_name = kcu.constraint_name
203 AND tc.table_schema = kcu.table_schema
204 WHERE tc.table_schema = 'public'
205 AND tc.table_name = $1
206 AND tc.constraint_type = 'PRIMARY KEY'
207 ORDER BY kcu.ordinal_position
208 "#;
209
210 let pk_rows = sqlx::query(pk_query)
211 .bind(table)
212 .fetch_all(&self.pool)
213 .await?;
214
215 let primary_key_columns: Vec<String> = pk_rows
216 .iter()
217 .map(|row| row.try_get::<String, _>("column_name"))
218 .collect::<Result<Vec<_>, _>>()?;
219
220 let primary_key = if primary_key_columns.is_empty() {
221 None
222 } else {
223 Some(primary_key_columns.clone())
224 };
225
226 let fk_query = r#"
228 SELECT
229 kcu.column_name,
230 ccu.table_name AS references_table,
231 ccu.column_name AS references_column
232 FROM information_schema.table_constraints tc
233 JOIN information_schema.key_column_usage kcu
234 ON tc.constraint_name = kcu.constraint_name
235 AND tc.table_schema = kcu.table_schema
236 JOIN information_schema.constraint_column_usage ccu
237 ON ccu.constraint_name = tc.constraint_name
238 AND ccu.table_schema = tc.table_schema
239 WHERE tc.table_schema = 'public'
240 AND tc.table_name = $1
241 AND tc.constraint_type = 'FOREIGN KEY'
242 "#;
243
244 let fk_rows = sqlx::query(fk_query)
245 .bind(table)
246 .fetch_all(&self.pool)
247 .await?;
248
249 let foreign_keys: Vec<ForeignKey> = fk_rows
250 .iter()
251 .map(|row| {
252 Ok(ForeignKey {
253 column: row.try_get("column_name")?,
254 references_table: row.try_get("references_table")?,
255 references_column: row.try_get("references_column")?,
256 })
257 })
258 .collect::<Result<Vec<_>, sqlx::Error>>()?;
259
260 let index_query = r#"
262 SELECT
263 i.indexname AS index_name,
264 i.indexdef AS index_definition
265 FROM pg_indexes i
266 WHERE i.schemaname = 'public'
267 AND i.tablename = $1
268 AND i.indexname NOT IN (
269 SELECT constraint_name
270 FROM information_schema.table_constraints
271 WHERE table_schema = 'public'
272 AND table_name = $1
273 AND constraint_type = 'PRIMARY KEY'
274 )
275 "#;
276
277 let index_rows = sqlx::query(index_query)
278 .bind(table)
279 .fetch_all(&self.pool)
280 .await?;
281
282 let indexes: Vec<IndexInfo> = index_rows
283 .iter()
284 .map(|row| {
285 let index_name: String = row.try_get("index_name")?;
286 let index_definition: String = row.try_get("index_definition")?;
287
288 let columns = vec![]; let unique = index_definition.to_uppercase().contains("UNIQUE");
293
294 Ok(IndexInfo {
295 name: index_name,
296 columns,
297 unique,
298 })
299 })
300 .collect::<Result<Vec<_>, sqlx::Error>>()?;
301
302 let columns: Vec<ColumnInfo> = column_rows
304 .iter()
305 .map(|row| {
306 let column_name: String = row.try_get("column_name")?;
307 let data_type: String = row.try_get("data_type")?;
308 let is_nullable: String = row.try_get("is_nullable")?;
309 let column_default: Option<String> = row.try_get("column_default")?;
310
311 Ok(ColumnInfo {
312 name: column_name.clone(),
313 data_type,
314 nullable: is_nullable == "YES",
315 default_value: column_default,
316 is_primary_key: primary_key_columns.contains(&column_name),
317 })
318 })
319 .collect::<Result<Vec<_>, sqlx::Error>>()?;
320
321 Ok(TableSchema {
322 name: table.to_string(),
323 columns,
324 primary_key,
325 foreign_keys,
326 indexes,
327 })
328 }
329
330 async fn get_rows(&self, table: &str, query: RowQuery) -> Result<RowsResponse, DatabaseError> {
331 let schema = self.get_table_schema(table).await?;
333 let column_names: Vec<String> = schema.columns.iter().map(|c| c.name.clone()).collect();
334
335 let quoted_table = Self::quote_identifier(table);
337 let mut sql = format!("SELECT * FROM {}", quoted_table);
338
339 let (where_clause, filter_values) = Self::build_where_clause(&query.filters, 1);
341 sql.push_str(&where_clause);
342
343 if let Some(sort_column) = &query.sort_by {
345 if !column_names.contains(sort_column) {
347 return Err(DatabaseError::InvalidColumn(sort_column.clone()));
348 }
349
350 let quoted_sort = Self::quote_identifier(sort_column);
351 let sort_direction = match query.sort_order {
352 Some(SortOrder::Descending) => "DESC",
353 _ => "ASC",
354 };
355 sql.push_str(&format!(" ORDER BY {} {}", quoted_sort, sort_direction));
356 }
357
358 let limit = query.limit.min(500); sql.push_str(&format!(" LIMIT {} OFFSET {}", limit, query.offset));
361
362 let mut query_builder = sqlx::query(&sql);
364 for value in &filter_values {
365 query_builder = query_builder.bind(value);
366 }
367
368 let rows = query_builder.fetch_all(&self.pool).await?;
369
370 let json_rows: Vec<serde_json::Value> = rows
372 .iter()
373 .map(Self::row_to_json)
374 .collect::<Result<Vec<_>, _>>()?;
375
376 let count_result = self.count_rows(table, &query).await?;
378 let total = count_result.count;
379
380 let has_more = query.offset + (json_rows.len() as u64) < total;
381
382 Ok(RowsResponse {
383 rows: json_rows,
384 columns: column_names,
385 total,
386 offset: query.offset,
387 limit,
388 has_more,
389 })
390 }
391
392 async fn count_rows(&self, table: &str, query: &RowQuery) -> Result<CountResponse, DatabaseError> {
393 let quoted_table = Self::quote_identifier(table);
394 let mut sql = format!("SELECT COUNT(*) as count FROM {}", quoted_table);
395
396 let (where_clause, filter_values) = Self::build_where_clause(&query.filters, 1);
398 sql.push_str(&where_clause);
399
400 let mut query_builder = sqlx::query(&sql);
402 for value in &filter_values {
403 query_builder = query_builder.bind(value);
404 }
405
406 let row = query_builder.fetch_one(&self.pool).await?;
407 let count: i64 = row.try_get("count")?;
408
409 Ok(CountResponse {
410 count: count as u64,
411 })
412 }
413
414 async fn execute_query(&self, sql: &str) -> Result<QueryResult, DatabaseError> {
415 let start_time = std::time::Instant::now();
416
417 let result = sqlx::query(sql).fetch_all(&self.pool).await;
419
420 let execution_time_milliseconds = start_time.elapsed().as_millis() as u64;
421
422 match result {
423 Ok(rows) => {
424 if rows.is_empty() {
425 Ok(QueryResult {
428 columns: vec![],
429 rows: vec![],
430 affected_rows: 0,
431 execution_time_milliseconds,
432 error: None,
433 })
434 } else {
435 let columns: Vec<String> = rows[0]
437 .columns()
438 .iter()
439 .map(|col| col.name().to_string())
440 .collect();
441
442 let json_rows: Vec<serde_json::Value> = rows
443 .iter()
444 .map(Self::row_to_json)
445 .collect::<Result<Vec<_>, _>>()?;
446
447 let max_rows = 10000;
449 if json_rows.len() > max_rows {
450 return Err(DatabaseError::TooManyRows(max_rows as u64));
451 }
452
453 Ok(QueryResult {
454 columns,
455 rows: json_rows,
456 affected_rows: 0,
457 execution_time_milliseconds,
458 error: None,
459 })
460 }
461 }
462 Err(error) => {
463 Ok(QueryResult {
465 columns: vec![],
466 rows: vec![],
467 affected_rows: 0,
468 execution_time_milliseconds,
469 error: Some(error.to_string()),
470 })
471 }
472 }
473 }
474}