Skip to main content

reflect_db/
postgres.rs

1use async_trait::async_trait;
2use sqlx::{PgPool, Row};
3use crate::error::ReflectError;
4use crate::executor::{Executor, TableInfo};
5use crate::metadata::{Column, ForeignKey, Index, PrimaryKey, SqlType};
6
7pub struct PostgresExecutor {
8    pool: PgPool,
9}
10
11impl PostgresExecutor {
12    pub fn new(pool: PgPool) -> Self {
13        Self { pool }
14    }
15}
16
17#[async_trait]
18impl Executor for PostgresExecutor {
19    async fn fetch_tables(
20        &self,
21        schema: Option<&str>,
22        include_views: bool,
23    ) -> Result<Vec<TableInfo>, ReflectError> {
24        let schema_name = schema.unwrap_or("public");
25        let table_types = if include_views {
26            "('BASE TABLE', 'VIEW')"
27        } else {
28            "('BASE TABLE')"
29        };
30
31        let query_str = format!(
32            "{} {}",
33            "SELECT table_name, table_type FROM information_schema.tables WHERE table_schema = $1 AND table_type IN",
34            table_types
35        );
36
37        let rows = sqlx::query(&query_str)
38            .bind(schema_name)
39            .fetch_all(&self.pool)
40            .await?;
41
42        let mut infos = Vec::new();
43        for row in rows {
44            let name: String = row.try_get("table_name")?;
45            let ttype: String = row.try_get("table_type")?;
46            infos.push(TableInfo {
47                name,
48                is_view: ttype == "VIEW",
49            });
50        }
51        Ok(infos)
52    }
53
54    async fn fetch_columns(
55        &self,
56        table: &str,
57        schema: Option<&str>,
58    ) -> Result<Vec<Column>, ReflectError> {
59        let schema_name = schema.unwrap_or("public");
60        
61        // This is a simplified fetch querying information_schema
62        let rows = sqlx::query(
63            "SELECT column_name, data_type, is_nullable, column_default, character_maximum_length 
64             FROM information_schema.columns 
65             WHERE table_schema = $1 AND table_name = $2 
66             ORDER BY ordinal_position"
67        )
68        .bind(schema_name)
69        .bind(table)
70        .fetch_all(&self.pool)
71        .await?;
72
73        let mut columns = Vec::new();
74        for row in rows {
75            let name: String = row.try_get("column_name")?;
76            let data_type_str: String = row.try_get("data_type")?;
77            let is_nullable_str: String = row.try_get("is_nullable")?;
78            let default: Option<String> = row.try_get("column_default")?;
79            let char_max: Option<i32> = row.try_get("character_maximum_length")?;
80
81            let nullable = is_nullable_str == "YES";
82            let dt = data_type_str.to_uppercase();
83
84            let data_type = match dt.as_str() {
85                "INTEGER" | "INT" | "INT4" => SqlType::Integer,
86                "BIGINT" | "INT8" => SqlType::BigInt,
87                "REAL" | "FLOAT4" => SqlType::Float,
88                "DOUBLE PRECISION" | "FLOAT8" => SqlType::Double,
89                "BOOLEAN" | "BOOL" => SqlType::Boolean,
90                "TEXT" => SqlType::Text,
91                "CHARACTER VARYING" | "VARCHAR" => {
92                    SqlType::Varchar(char_max.map(|v| v as u32))
93                }
94                "DATE" => SqlType::Date,
95                "TIMESTAMP WITHOUT TIME ZONE" | "TIMESTAMP WITH TIME ZONE" => SqlType::Timestamp,
96                "JSON" | "JSONB" => SqlType::Json,
97                "UUID" => SqlType::Uuid,
98                _ => SqlType::Custom(dt.clone()),
99            };
100
101            columns.push(Column {
102                name,
103                data_type,
104                nullable,
105                default,
106            });
107        }
108
109        Ok(columns)
110    }
111
112    async fn fetch_primary_key(
113        &self,
114        table: &str,
115        schema: Option<&str>,
116    ) -> Result<Option<PrimaryKey>, ReflectError> {
117        let schema_name = schema.unwrap_or("public");
118        let rows = sqlx::query(
119            "SELECT a.attname AS column_name
120             FROM pg_index i
121             JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
122             JOIN pg_class t ON t.oid = i.indrelid
123             JOIN pg_namespace n ON n.oid = t.relnamespace
124             WHERE n.nspname = $1 AND t.relname = $2 AND i.indisprimary"
125        )
126        .bind(schema_name)
127        .bind(table)
128        .fetch_all(&self.pool)
129        .await?;
130
131        if rows.is_empty() {
132            return Ok(None);
133        }
134
135        let mut columns = Vec::new();
136        for row in rows {
137            let col: String = row.try_get("column_name")?;
138            columns.push(col);
139        }
140
141        Ok(Some(PrimaryKey { columns }))
142    }
143
144    async fn fetch_foreign_keys(
145        &self,
146        table: &str,
147        schema: Option<&str>,
148    ) -> Result<Vec<ForeignKey>, ReflectError> {
149        let schema_name = schema.unwrap_or("public");
150        
151        let query = "
152            SELECT
153                kcu.column_name,
154                ccu.table_name AS foreign_table_name,
155                ccu.column_name AS foreign_column_name
156            FROM 
157                information_schema.table_constraints AS tc 
158                JOIN information_schema.key_column_usage AS kcu
159                  ON tc.constraint_name = kcu.constraint_name
160                  AND tc.table_schema = kcu.table_schema
161                JOIN information_schema.constraint_column_usage AS ccu
162                  ON ccu.constraint_name = tc.constraint_name
163                  AND ccu.table_schema = tc.table_schema
164            WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = $1 AND tc.table_schema = $2;
165        ";
166
167        let rows = sqlx::query(query)
168            .bind(table)
169            .bind(schema_name)
170            .fetch_all(&self.pool)
171            .await?;
172
173        let mut fks = Vec::new();
174        for row in rows {
175            let column: String = row.try_get("column_name")?;
176            let referenced_table: String = row.try_get("foreign_table_name")?;
177            let referenced_column: String = row.try_get("foreign_column_name")?;
178
179            fks.push(ForeignKey {
180                column,
181                referenced_table,
182                referenced_column,
183            });
184        }
185
186        Ok(fks)
187    }
188
189    async fn fetch_indexes(
190        &self,
191        table: &str,
192        schema: Option<&str>,
193    ) -> Result<Vec<Index>, ReflectError> {
194        let schema_name = schema.unwrap_or("public");
195        let rows = sqlx::query(
196            "SELECT
197                i.relname AS index_name,
198                ix.indisunique AS is_unique,
199                a.attname AS column_name
200             FROM pg_index ix
201             JOIN pg_class t ON t.oid = ix.indrelid
202             JOIN pg_class i ON i.oid = ix.indexrelid
203             JOIN pg_namespace n ON n.oid = t.relnamespace
204             JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
205             WHERE n.nspname = $1 AND t.relname = $2 AND NOT ix.indisprimary"
206        )
207        .bind(schema_name)
208        .bind(table)
209        .fetch_all(&self.pool)
210        .await?;
211
212        let mut idx_map: std::collections::HashMap<String, Index> = std::collections::HashMap::new();
213
214        for row in rows {
215            let idx_name: String = row.try_get("index_name")?;
216            let is_unique: bool = row.try_get("is_unique")?;
217            let col_name: String = row.try_get("column_name")?;
218
219            let eg = idx_map.entry(idx_name.clone()).or_insert(Index {
220                name: idx_name,
221                columns: Vec::new(),
222                unique: is_unique,
223            });
224            eg.columns.push(col_name);
225        }
226
227        Ok(idx_map.into_values().collect())
228    }
229}