1use anyhow::{Error, Result};
2use itertools::Itertools;
3use sql::{Column, Expr, Generated, GenerationTime, GenerationValue, Schema, Table, schema};
4use sqlx::PgConnection;
5use std::str::FromStr;
6
7#[allow(async_fn_in_trait)]
8pub trait FromPostgres: Sized {
9 async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Self>;
10}
11
12#[derive(sqlx::FromRow)]
13pub struct ColumnRow {
14 pub table_name: String,
15 pub column_name: String,
16 #[allow(dead_code)]
17 pub ordinal_position: i32,
18 pub is_nullable: String,
19 pub data_type: String,
20 pub numeric_precision: Option<i32>,
21 pub numeric_scale: Option<i32>,
22 pub inner_type: Option<String>,
23 pub primary_key: bool,
24 pub generation_time: Option<String>,
25 pub generation_expression: Option<String>,
26 pub identity_generation: Option<String>,
27}
28
29pub async fn query_schema_columns(
30 conn: &mut PgConnection,
31 schema_name: &str,
32) -> Result<Vec<ColumnRow>> {
33 let s = include_str!("sql/query_columns.sql");
34 let result = sqlx::query_as::<_, ColumnRow>(s)
35 .bind(schema_name)
36 .fetch_all(conn)
37 .await?;
38 Ok(result)
39}
40
41#[derive(sqlx::FromRow)]
42struct TableSchema {
43 #[allow(dead_code)]
44 pub table_schema: String,
45 pub table_name: String,
46}
47
48pub async fn query_table_names(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<String>> {
49 let s = include_str!("sql/query_tables.sql");
50 let result = sqlx::query_as::<_, TableSchema>(s)
51 .bind(schema_name)
52 .fetch_all(conn)
53 .await?;
54 Ok(result.into_iter().map(|t| t.table_name).collect())
55}
56
57#[derive(Debug, sqlx::FromRow)]
58pub struct ForeignKey {
59 pub table_schema: String,
60 pub constraint_name: String,
61 pub table_name: String,
62 pub column_name: String,
63 pub foreign_table_schema: String,
64 pub foreign_table_name: String,
65 pub foreign_column_name: String,
66}
67
68pub async fn query_constraints(
69 conn: &mut PgConnection,
70 schema_name: &str,
71) -> Result<Vec<ForeignKey>> {
72 let s = include_str!("sql/query_constraints.sql");
73 Ok(sqlx::query_as::<_, ForeignKey>(s)
74 .bind(schema_name)
75 .fetch_all(conn)
76 .await?)
77}
78
79#[derive(Debug, sqlx::FromRow)]
80pub struct Index {
81 pub schema: String,
82 pub table: String,
83 pub name: String,
84 pub statement: String,
86 pub unique: bool,
87 pub kind: String,
88 pub columns: Vec<String>,
89}
90
91pub async fn query_indices(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Index>> {
92 let s = include_str!("sql/query_indices.sql");
94 Ok(sqlx::query_as::<_, Index>(s)
95 .bind(schema_name)
96 .fetch_all(conn)
97 .await?)
98}
99
100#[derive(Debug, sqlx::FromRow)]
101pub struct Function {
102 pub routine_schema: String,
103 pub routine_name: String,
104 pub routine_type: String,
105 pub data_type: Option<String>,
106 pub routine_definition: Option<String>,
107}
108
109pub async fn query_functions(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Function>> {
110 let s = include_str!("sql/query_functions.sql");
111 Ok(sqlx::query_as::<_, Function>(s)
112 .bind(schema_name)
113 .fetch_all(conn)
114 .await?)
115}
116
117#[derive(sqlx::FromRow)]
118pub struct Trigger {
119 pub trigger_schema: String,
120 pub trigger_name: String,
121 pub event_manipulation: String,
122 pub event_object_table: String,
123 pub action_timing: String,
124 pub action_statement: String,
125}
126
127pub async fn query_triggers(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Trigger>> {
128 let s = include_str!("sql/query_triggers.sql");
129 Ok(sqlx::query_as::<_, Trigger>(s)
130 .bind(schema_name)
131 .fetch_all(conn)
132 .await?)
133}
134
135fn parse_generated(
136 time: Option<String>,
137 expr: Option<String>,
138 identity: Option<String>,
139) -> Option<Generated> {
140 let time = time?;
141 let time = match time.as_str() {
142 "ALWAYS" => GenerationTime::Always,
143 "BY DEFAULT" => GenerationTime::ByDefault,
144 _ => return None,
145 };
146 let value = if let Some(identity) = identity
147 && identity == "YES"
148 {
149 GenerationValue::Identity
150 } else if let Some(expr) = expr {
151 GenerationValue::Expr(Expr::Raw(expr))
152 } else {
153 return None;
154 };
155 Some(Generated { time, value })
156}
157
158impl TryInto<Column> for ColumnRow {
159 type Error = Error;
160
161 fn try_into(self) -> std::result::Result<Column, Self::Error> {
162 use schema::Type::*;
163 let nullable = self.is_nullable == "YES";
164 let typ = match self.data_type.as_str() {
165 "ARRAY" => {
166 let inner = schema::Type::from_str(
167 &self
168 .inner_type
169 .expect("Encounterd ARRAY with no inner type."),
170 )?;
171 Array(Box::new(inner))
172 }
173 "numeric" if self.numeric_precision.is_some() && self.numeric_scale.is_some() => {
174 Numeric(
175 self.numeric_precision.unwrap() as u8,
176 self.numeric_scale.unwrap() as u8,
177 )
178 }
179 z => schema::Type::from_str(z)?,
180 };
181 let generated = parse_generated(
182 self.generation_time,
183 self.generation_expression,
184 self.identity_generation,
185 );
186 Ok(Column {
187 name: self.column_name.clone(),
188 typ,
189 nullable,
190 primary_key: false,
191 default: None,
192 constraint: None,
193 generated,
194 })
195 }
196}
197
198impl FromPostgres for Schema {
199 async fn try_from_postgres(conn: &mut PgConnection, schema: &str) -> Result<Schema> {
200 let column_schemas = query_schema_columns(conn, schema).await?;
201 let mut tables = column_schemas
202 .into_iter()
203 .chunk_by(|c| c.table_name.clone())
204 .into_iter()
205 .map(|(table_name, group)| {
206 let columns = group
207 .map(|c: ColumnRow| c.try_into())
208 .collect::<Result<Vec<_>, Error>>()?;
209 Ok(Table {
210 schema: Some(schema.to_string()),
211 name: table_name,
212 columns,
213 })
214 })
215 .collect::<Result<Vec<_>, Error>>()?;
216 let constraints = query_constraints(conn, schema).await?;
217 let mut it_tables = tables.iter_mut().peekable();
218 for fk in constraints {
219 while &fk.table_name != &it_tables.peek().unwrap().name {
220 it_tables.next();
221 }
222 let table = it_tables.peek_mut().unwrap();
223 let column = table
224 .columns
225 .iter_mut()
226 .find(|c| c.name == fk.column_name)
227 .expect("Constraint for unknown column.");
228 column.constraint = Some(schema::Constraint::ForeignKey(schema::ForeignKey {
229 table: fk.foreign_table_name,
230 columns: vec![fk.foreign_column_name],
231 }));
232 }
233
234 let table_names = query_table_names(conn, schema).await?;
236 let mut tables_it = tables.iter().peekable();
237 let mut empty_tables = Vec::new();
238 'outer: for name in table_names {
239 while let Some(table) = tables_it.peek() {
240 if &name == &table.name {
241 tables_it.next();
242 continue 'outer;
243 }
244 }
245 empty_tables.push(Table {
246 schema: Some(schema.to_string()),
247 name,
248 columns: vec![],
249 })
250 }
251 Ok(Schema { tables })
252 }
253}
254
255#[cfg(test)]
256mod test {
257 use super::*;
258
259 #[test]
260 fn test_numeric() {
261 let c = ColumnRow {
262 table_name: "foo".to_string(),
263 column_name: "bar".to_string(),
264 ordinal_position: 1,
265 is_nullable: "NO".to_string(),
266 data_type: "numeric".to_string(),
267 numeric_precision: Some(10),
268 numeric_scale: Some(2),
269 inner_type: None,
270 primary_key: false,
271 };
272 let column: Column = c.try_into().unwrap();
273 assert_eq!(column.typ, schema::Type::Numeric(10, 2));
274 }
275
276 #[test]
277 fn test_integer() {
278 let c = ColumnRow {
279 table_name: "foo".to_string(),
280 column_name: "bar".to_string(),
281 ordinal_position: 1,
282 is_nullable: "NO".to_string(),
283 data_type: "integer".to_string(),
284 numeric_precision: Some(32),
285 numeric_scale: Some(0),
286 inner_type: None,
287 primary_key: false,
288 };
289 let column: Column = c.try_into().unwrap();
290 assert_eq!(column.typ, schema::Type::I32);
291 }
292}