1use async_trait::async_trait;
2use sqlx::{PgPool, Row};
3use crate::error::ReflectError;
4use crate::executor::{Executor, TableInfo};
5use crate::metadata::{Column, ForeignKey, Index, PrimaryKey, SqlType};
6
7pub struct PostgresExecutor {
8 pool: PgPool,
9}
10
11impl PostgresExecutor {
12 pub fn new(pool: PgPool) -> Self {
13 Self { pool }
14 }
15}
16
17#[async_trait]
18impl Executor for PostgresExecutor {
19 async fn fetch_tables(
20 &self,
21 schema: Option<&str>,
22 include_views: bool,
23 ) -> Result<Vec<TableInfo>, ReflectError> {
24 let schema_name = schema.unwrap_or("public");
25 let table_types = if include_views {
26 "('BASE TABLE', 'VIEW')"
27 } else {
28 "('BASE TABLE')"
29 };
30
31 let query_str = format!(
32 "{} {}",
33 "SELECT table_name, table_type FROM information_schema.tables WHERE table_schema = $1 AND table_type IN",
34 table_types
35 );
36
37 let rows = sqlx::query(&query_str)
38 .bind(schema_name)
39 .fetch_all(&self.pool)
40 .await?;
41
42 let mut infos = Vec::new();
43 for row in rows {
44 let name: String = row.try_get("table_name")?;
45 let ttype: String = row.try_get("table_type")?;
46 infos.push(TableInfo {
47 name,
48 is_view: ttype == "VIEW",
49 });
50 }
51 Ok(infos)
52 }
53
54 async fn fetch_columns(
55 &self,
56 table: &str,
57 schema: Option<&str>,
58 ) -> Result<Vec<Column>, ReflectError> {
59 let schema_name = schema.unwrap_or("public");
60
61 let rows = sqlx::query(
63 "SELECT column_name, data_type, is_nullable, column_default, character_maximum_length
64 FROM information_schema.columns
65 WHERE table_schema = $1 AND table_name = $2
66 ORDER BY ordinal_position"
67 )
68 .bind(schema_name)
69 .bind(table)
70 .fetch_all(&self.pool)
71 .await?;
72
73 let mut columns = Vec::new();
74 for row in rows {
75 let name: String = row.try_get("column_name")?;
76 let data_type_str: String = row.try_get("data_type")?;
77 let is_nullable_str: String = row.try_get("is_nullable")?;
78 let default: Option<String> = row.try_get("column_default")?;
79 let char_max: Option<i32> = row.try_get("character_maximum_length")?;
80
81 let nullable = is_nullable_str == "YES";
82 let dt = data_type_str.to_uppercase();
83
84 let data_type = match dt.as_str() {
85 "INTEGER" | "INT" | "INT4" => SqlType::Integer,
86 "BIGINT" | "INT8" => SqlType::BigInt,
87 "REAL" | "FLOAT4" => SqlType::Float,
88 "DOUBLE PRECISION" | "FLOAT8" => SqlType::Double,
89 "BOOLEAN" | "BOOL" => SqlType::Boolean,
90 "TEXT" => SqlType::Text,
91 "CHARACTER VARYING" | "VARCHAR" => {
92 SqlType::Varchar(char_max.map(|v| v as u32))
93 }
94 "DATE" => SqlType::Date,
95 "TIMESTAMP WITHOUT TIME ZONE" | "TIMESTAMP WITH TIME ZONE" => SqlType::Timestamp,
96 "JSON" | "JSONB" => SqlType::Json,
97 "UUID" => SqlType::Uuid,
98 _ => SqlType::Custom(dt.clone()),
99 };
100
101 columns.push(Column {
102 name,
103 data_type,
104 nullable,
105 default,
106 });
107 }
108
109 Ok(columns)
110 }
111
112 async fn fetch_primary_key(
113 &self,
114 table: &str,
115 schema: Option<&str>,
116 ) -> Result<Option<PrimaryKey>, ReflectError> {
117 let schema_name = schema.unwrap_or("public");
118 let rows = sqlx::query(
119 "SELECT a.attname AS column_name
120 FROM pg_index i
121 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
122 JOIN pg_class t ON t.oid = i.indrelid
123 JOIN pg_namespace n ON n.oid = t.relnamespace
124 WHERE n.nspname = $1 AND t.relname = $2 AND i.indisprimary"
125 )
126 .bind(schema_name)
127 .bind(table)
128 .fetch_all(&self.pool)
129 .await?;
130
131 if rows.is_empty() {
132 return Ok(None);
133 }
134
135 let mut columns = Vec::new();
136 for row in rows {
137 let col: String = row.try_get("column_name")?;
138 columns.push(col);
139 }
140
141 Ok(Some(PrimaryKey { columns }))
142 }
143
144 async fn fetch_foreign_keys(
145 &self,
146 table: &str,
147 schema: Option<&str>,
148 ) -> Result<Vec<ForeignKey>, ReflectError> {
149 let schema_name = schema.unwrap_or("public");
150
151 let query = "
152 SELECT
153 kcu.column_name,
154 ccu.table_name AS foreign_table_name,
155 ccu.column_name AS foreign_column_name
156 FROM
157 information_schema.table_constraints AS tc
158 JOIN information_schema.key_column_usage AS kcu
159 ON tc.constraint_name = kcu.constraint_name
160 AND tc.table_schema = kcu.table_schema
161 JOIN information_schema.constraint_column_usage AS ccu
162 ON ccu.constraint_name = tc.constraint_name
163 AND ccu.table_schema = tc.table_schema
164 WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = $1 AND tc.table_schema = $2;
165 ";
166
167 let rows = sqlx::query(query)
168 .bind(table)
169 .bind(schema_name)
170 .fetch_all(&self.pool)
171 .await?;
172
173 let mut fks = Vec::new();
174 for row in rows {
175 let column: String = row.try_get("column_name")?;
176 let referenced_table: String = row.try_get("foreign_table_name")?;
177 let referenced_column: String = row.try_get("foreign_column_name")?;
178
179 fks.push(ForeignKey {
180 column,
181 referenced_table,
182 referenced_column,
183 });
184 }
185
186 Ok(fks)
187 }
188
189 async fn fetch_indexes(
190 &self,
191 table: &str,
192 schema: Option<&str>,
193 ) -> Result<Vec<Index>, ReflectError> {
194 let schema_name = schema.unwrap_or("public");
195 let rows = sqlx::query(
196 "SELECT
197 i.relname AS index_name,
198 ix.indisunique AS is_unique,
199 a.attname AS column_name
200 FROM pg_index ix
201 JOIN pg_class t ON t.oid = ix.indrelid
202 JOIN pg_class i ON i.oid = ix.indexrelid
203 JOIN pg_namespace n ON n.oid = t.relnamespace
204 JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
205 WHERE n.nspname = $1 AND t.relname = $2 AND NOT ix.indisprimary"
206 )
207 .bind(schema_name)
208 .bind(table)
209 .fetch_all(&self.pool)
210 .await?;
211
212 let mut idx_map: std::collections::HashMap<String, Index> = std::collections::HashMap::new();
213
214 for row in rows {
215 let idx_name: String = row.try_get("index_name")?;
216 let is_unique: bool = row.try_get("is_unique")?;
217 let col_name: String = row.try_get("column_name")?;
218
219 let eg = idx_map.entry(idx_name.clone()).or_insert(Index {
220 name: idx_name,
221 columns: Vec::new(),
222 unique: is_unique,
223 });
224 eg.columns.push(col_name);
225 }
226
227 Ok(idx_map.into_values().collect())
228 }
229}