axum_sql_viewer/database/
postgres.rs1use crate::database::traits::{DatabaseError, DatabaseProvider};
4use crate::schema::{
5 ColumnInfo, CountResponse, ForeignKey, IndexInfo, QueryResult, RowQuery, RowsResponse,
6 SortOrder, TableInfo, TableSchema,
7};
8use async_trait::async_trait;
9use sqlx::{postgres::PgRow, Column, PgPool, Row, TypeInfo};
10use std::collections::HashMap;
11
12pub struct PostgresProvider {
14 pool: PgPool,
15}
16
17impl PostgresProvider {
18 pub fn new(pool: PgPool) -> Self {
24 Self { pool }
25 }
26
27 fn quote_identifier(identifier: &str) -> String {
29 format!("\"{}\"", identifier.replace("\"", "\"\""))
30 }
31
32 fn row_to_json(row: &PgRow) -> Result<serde_json::Value, DatabaseError> {
34 let mut map = serde_json::Map::new();
35
36 for column in row.columns() {
37 let column_name = column.name();
38 let type_info = column.type_info();
39 let type_name = type_info.name();
40
41 let value: serde_json::Value = match type_name {
42 "BOOL" => {
43 let val: Option<bool> = row.try_get(column_name)?;
44 val.map(serde_json::Value::Bool).unwrap_or(serde_json::Value::Null)
45 }
46 "INT2" | "SMALLINT" | "SMALLSERIAL" => {
47 let val: Option<i16> = row.try_get(column_name)?;
48 val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
49 }
50 "INT4" | "INT" | "INTEGER" | "SERIAL" => {
51 let val: Option<i32> = row.try_get(column_name)?;
52 val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
53 }
54 "INT8" | "BIGINT" | "BIGSERIAL" => {
55 let val: Option<i64> = row.try_get(column_name)?;
56 val.map(|v| serde_json::Value::Number(v.into())).unwrap_or(serde_json::Value::Null)
57 }
58 "FLOAT4" | "REAL" => {
59 let val: Option<f32> = row.try_get(column_name)?;
60 val.and_then(|v| serde_json::Number::from_f64(v as f64))
61 .map(serde_json::Value::Number)
62 .unwrap_or(serde_json::Value::Null)
63 }
64 "FLOAT8" | "DOUBLE PRECISION" => {
65 let val: Option<f64> = row.try_get(column_name)?;
66 val.and_then(serde_json::Number::from_f64)
67 .map(serde_json::Value::Number)
68 .unwrap_or(serde_json::Value::Null)
69 }
70 "TEXT" | "VARCHAR" | "CHAR" | "NAME" | "BPCHAR" => {
71 let val: Option<String> = row.try_get(column_name)?;
72 val.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null)
73 }
74 "BYTEA" => {
75 let val: Option<Vec<u8>> = row.try_get(column_name)?;
76 val.map(|bytes| {
77 serde_json::Value::String(format!("[BLOB: {} bytes]", bytes.len()))
78 }).unwrap_or(serde_json::Value::Null)
79 }
80 "TIMESTAMP" | "TIMESTAMPTZ" | "TIMESTAMP WITHOUT TIME ZONE" | "TIMESTAMP WITH TIME ZONE"
81 | "DATE" | "TIME" | "TIME WITHOUT TIME ZONE" => {
82 let val: Option<String> = row.try_get(column_name).ok().flatten();
84 val.map(serde_json::Value::String)
85 .unwrap_or(serde_json::Value::Null)
86 }
87 "JSON" | "JSONB" => {
88 let val: Option<serde_json::Value> = row.try_get(column_name)?;
89 val.unwrap_or(serde_json::Value::Null)
90 }
91 "UUID" => {
92 let val: Option<String> = row.try_get(column_name).ok().flatten();
94 val.map(serde_json::Value::String)
95 .unwrap_or(serde_json::Value::Null)
96 }
97 "NUMERIC" | "DECIMAL" => {
98 let val: Option<String> = row.try_get(column_name).ok().flatten();
100 val.map(serde_json::Value::String)
101 .unwrap_or(serde_json::Value::Null)
102 }
103 _ => {
104 let val: Option<String> = row.try_get(column_name).ok().flatten();
106 val.map(serde_json::Value::String).unwrap_or(serde_json::Value::Null)
107 }
108 };
109
110 map.insert(column_name.to_string(), value);
111 }
112
113 Ok(serde_json::Value::Object(map))
114 }
115
116 fn build_where_clause(filters: &HashMap<String, String>, parameter_offset: i32) -> (String, Vec<String>) {
118 if filters.is_empty() {
119 return (String::new(), vec![]);
120 }
121
122 let mut conditions = Vec::new();
123 let mut values = Vec::new();
124 let mut param_index = parameter_offset;
125
126 for (column, filter_value) in filters {
127 let quoted_column = Self::quote_identifier(column);
128
129 if filter_value.contains('%') {
130 conditions.push(format!("{} LIKE ${}", quoted_column, param_index));
131 } else {
132 conditions.push(format!("{} = ${}", quoted_column, param_index));
133 }
134
135 values.push(filter_value.clone());
136 param_index += 1;
137 }
138
139 let where_clause = format!(" WHERE {}", conditions.join(" AND "));
140 (where_clause, values)
141 }
142}
143
144#[async_trait]
145impl DatabaseProvider for PostgresProvider {
146 async fn list_tables(&self) -> Result<Vec<TableInfo>, DatabaseError> {
147 let query = r#"
148 SELECT table_name
149 FROM information_schema.tables
150 WHERE table_schema = 'public'
151 AND table_type = 'BASE TABLE'
152 ORDER BY table_name
153 "#;
154
155 let rows = sqlx::query(query)
156 .fetch_all(&self.pool)
157 .await?;
158
159 let mut tables = Vec::new();
160 for row in rows {
161 let name: String = row.try_get("table_name")?;
162
163 let count_query = format!(
165 "SELECT COUNT(*) as count FROM {}",
166 Self::quote_identifier(&name)
167 );
168 let row_count: Option<u64> = sqlx::query_scalar(&count_query)
169 .fetch_one(&self.pool)
170 .await
171 .ok()
172 .map(|count: i64| count as u64);
173
174 tables.push(TableInfo { name, row_count });
175 }
176
177 Ok(tables)
178 }
179
180 async fn get_table_schema(&self, table: &str) -> Result<TableSchema, DatabaseError> {
181 let column_query = r#"
183 SELECT
184 column_name,
185 data_type,
186 is_nullable,
187 column_default,
188 udt_name
189 FROM information_schema.columns
190 WHERE table_schema = 'public'
191 AND table_name = $1
192 ORDER BY ordinal_position
193 "#;
194
195 let column_rows = sqlx::query(column_query)
196 .bind(table)
197 .fetch_all(&self.pool)
198 .await?;
199
200 if column_rows.is_empty() {
201 return Err(DatabaseError::TableNotFound(table.to_string()));
202 }
203
204 let pk_query = r#"
206 SELECT kcu.column_name
207 FROM information_schema.table_constraints tc
208 JOIN information_schema.key_column_usage kcu
209 ON tc.constraint_name = kcu.constraint_name
210 AND tc.table_schema = kcu.table_schema
211 WHERE tc.table_schema = 'public'
212 AND tc.table_name = $1
213 AND tc.constraint_type = 'PRIMARY KEY'
214 ORDER BY kcu.ordinal_position
215 "#;
216
217 let pk_rows = sqlx::query(pk_query)
218 .bind(table)
219 .fetch_all(&self.pool)
220 .await?;
221
222 let primary_key_columns: Vec<String> = pk_rows
223 .iter()
224 .map(|row| row.try_get::<String, _>("column_name"))
225 .collect::<Result<Vec<_>, _>>()?;
226
227 let primary_key = if primary_key_columns.is_empty() {
228 None
229 } else {
230 Some(primary_key_columns.clone())
231 };
232
233 let fk_query = r#"
235 SELECT
236 kcu.column_name,
237 ccu.table_name AS references_table,
238 ccu.column_name AS references_column
239 FROM information_schema.table_constraints tc
240 JOIN information_schema.key_column_usage kcu
241 ON tc.constraint_name = kcu.constraint_name
242 AND tc.table_schema = kcu.table_schema
243 JOIN information_schema.constraint_column_usage ccu
244 ON ccu.constraint_name = tc.constraint_name
245 AND ccu.table_schema = tc.table_schema
246 WHERE tc.table_schema = 'public'
247 AND tc.table_name = $1
248 AND tc.constraint_type = 'FOREIGN KEY'
249 "#;
250
251 let fk_rows = sqlx::query(fk_query)
252 .bind(table)
253 .fetch_all(&self.pool)
254 .await?;
255
256 let foreign_keys: Vec<ForeignKey> = fk_rows
257 .iter()
258 .map(|row| {
259 Ok(ForeignKey {
260 column: row.try_get("column_name")?,
261 references_table: row.try_get("references_table")?,
262 references_column: row.try_get("references_column")?,
263 })
264 })
265 .collect::<Result<Vec<_>, sqlx::Error>>()?;
266
267 let index_query = r#"
269 SELECT
270 i.indexname AS index_name,
271 i.indexdef AS index_definition
272 FROM pg_indexes i
273 WHERE i.schemaname = 'public'
274 AND i.tablename = $1
275 AND i.indexname NOT IN (
276 SELECT constraint_name
277 FROM information_schema.table_constraints
278 WHERE table_schema = 'public'
279 AND table_name = $1
280 AND constraint_type = 'PRIMARY KEY'
281 )
282 "#;
283
284 let index_rows = sqlx::query(index_query)
285 .bind(table)
286 .fetch_all(&self.pool)
287 .await?;
288
289 let indexes: Vec<IndexInfo> = index_rows
290 .iter()
291 .map(|row| {
292 let index_name: String = row.try_get("index_name")?;
293 let index_definition: String = row.try_get("index_definition")?;
294
295 let columns = vec![]; let unique = index_definition.to_uppercase().contains("UNIQUE");
300
301 Ok(IndexInfo {
302 name: index_name,
303 columns,
304 unique,
305 })
306 })
307 .collect::<Result<Vec<_>, sqlx::Error>>()?;
308
309 let columns: Vec<ColumnInfo> = column_rows
311 .iter()
312 .map(|row| {
313 let column_name: String = row.try_get("column_name")?;
314 let data_type: String = row.try_get("data_type")?;
315 let is_nullable: String = row.try_get("is_nullable")?;
316 let column_default: Option<String> = row.try_get("column_default")?;
317
318 Ok(ColumnInfo {
319 name: column_name.clone(),
320 data_type,
321 nullable: is_nullable == "YES",
322 default_value: column_default,
323 is_primary_key: primary_key_columns.contains(&column_name),
324 })
325 })
326 .collect::<Result<Vec<_>, sqlx::Error>>()?;
327
328 Ok(TableSchema {
329 name: table.to_string(),
330 columns,
331 primary_key,
332 foreign_keys,
333 indexes,
334 })
335 }
336
337 async fn get_rows(&self, table: &str, query: RowQuery) -> Result<RowsResponse, DatabaseError> {
338 let schema = self.get_table_schema(table).await?;
340 let column_names: Vec<String> = schema.columns.iter().map(|c| c.name.clone()).collect();
341
342 let quoted_table = Self::quote_identifier(table);
344 let mut sql = format!("SELECT * FROM {}", quoted_table);
345
346 let (where_clause, filter_values) = Self::build_where_clause(&query.filters, 1);
348 sql.push_str(&where_clause);
349
350 if let Some(sort_column) = &query.sort_by {
352 if !column_names.contains(sort_column) {
354 return Err(DatabaseError::InvalidColumn(sort_column.clone()));
355 }
356
357 let quoted_sort = Self::quote_identifier(sort_column);
358 let sort_direction = match query.sort_order {
359 Some(SortOrder::Descending) => "DESC",
360 _ => "ASC",
361 };
362 sql.push_str(&format!(" ORDER BY {} {}", quoted_sort, sort_direction));
363 }
364
365 let limit = query.limit.min(500); sql.push_str(&format!(" LIMIT {} OFFSET {}", limit, query.offset));
368
369 let mut query_builder = sqlx::query(&sql);
371 for value in &filter_values {
372 query_builder = query_builder.bind(value);
373 }
374
375 let rows = query_builder.fetch_all(&self.pool).await?;
376
377 let json_rows: Vec<serde_json::Value> = rows
379 .iter()
380 .map(Self::row_to_json)
381 .collect::<Result<Vec<_>, _>>()?;
382
383 let count_result = self.count_rows(table, &query).await?;
385 let total = count_result.count;
386
387 let has_more = query.offset + (json_rows.len() as u64) < total;
388
389 Ok(RowsResponse {
390 rows: json_rows,
391 columns: column_names,
392 total,
393 offset: query.offset,
394 limit,
395 has_more,
396 })
397 }
398
399 async fn count_rows(&self, table: &str, query: &RowQuery) -> Result<CountResponse, DatabaseError> {
400 let quoted_table = Self::quote_identifier(table);
401 let mut sql = format!("SELECT COUNT(*) as count FROM {}", quoted_table);
402
403 let (where_clause, filter_values) = Self::build_where_clause(&query.filters, 1);
405 sql.push_str(&where_clause);
406
407 let mut query_builder = sqlx::query(&sql);
409 for value in &filter_values {
410 query_builder = query_builder.bind(value);
411 }
412
413 let row = query_builder.fetch_one(&self.pool).await?;
414 let count: i64 = row.try_get("count")?;
415
416 Ok(CountResponse {
417 count: count as u64,
418 })
419 }
420
421 async fn execute_query(&self, sql: &str) -> Result<QueryResult, DatabaseError> {
422 let start_time = std::time::Instant::now();
423
424 let result = sqlx::query(sql).fetch_all(&self.pool).await;
426
427 let execution_time_milliseconds = start_time.elapsed().as_millis() as u64;
428
429 match result {
430 Ok(rows) => {
431 if rows.is_empty() {
432 Ok(QueryResult {
435 columns: vec![],
436 rows: vec![],
437 affected_rows: 0,
438 execution_time_milliseconds,
439 error: None,
440 })
441 } else {
442 let columns: Vec<String> = rows[0]
444 .columns()
445 .iter()
446 .map(|col| col.name().to_string())
447 .collect();
448
449 let json_rows: Vec<serde_json::Value> = rows
450 .iter()
451 .map(Self::row_to_json)
452 .collect::<Result<Vec<_>, _>>()?;
453
454 let max_rows = 10000;
456 if json_rows.len() > max_rows {
457 return Err(DatabaseError::TooManyRows(max_rows as u64));
458 }
459
460 Ok(QueryResult {
461 columns,
462 rows: json_rows,
463 affected_rows: 0,
464 execution_time_milliseconds,
465 error: None,
466 })
467 }
468 }
469 Err(error) => {
470 Ok(QueryResult {
472 columns: vec![],
473 rows: vec![],
474 affected_rows: 0,
475 execution_time_milliseconds,
476 error: Some(error.to_string()),
477 })
478 }
479 }
480 }
481}