1use anyhow::{Error, Result};
2use itertools::Itertools;
3use sql::{Column, Schema, Table, schema};
4use sqlx::PgConnection;
5use std::str::FromStr;
6
7#[allow(async_fn_in_trait)]
8pub trait FromPostgres: Sized {
9 async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Self>;
10}
11
12#[derive(sqlx::FromRow)]
13pub struct SchemaColumn {
14 pub table_name: String,
15 pub column_name: String,
16 #[allow(dead_code)]
17 pub ordinal_position: i32,
18 pub is_nullable: String,
19 pub data_type: String,
20 pub numeric_precision: Option<i32>,
21 pub numeric_scale: Option<i32>,
22 pub inner_type: Option<String>,
23}
24
25pub async fn query_schema_columns(
26 conn: &mut PgConnection,
27 schema_name: &str,
28) -> Result<Vec<SchemaColumn>> {
29 let s = include_str!("sql/query_columns.sql");
30 let result = sqlx::query_as::<_, SchemaColumn>(s)
31 .bind(schema_name)
32 .fetch_all(conn)
33 .await?;
34 Ok(result)
35}
36
37#[derive(sqlx::FromRow)]
38struct TableSchema {
39 #[allow(dead_code)]
40 pub table_schema: String,
41 pub table_name: String,
42}
43
44pub async fn query_table_names(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<String>> {
45 let s = include_str!("sql/query_tables.sql");
46 let result = sqlx::query_as::<_, TableSchema>(s)
47 .bind(schema_name)
48 .fetch_all(conn)
49 .await?;
50 Ok(result.into_iter().map(|t| t.table_name).collect())
51}
52
53#[derive(Debug, sqlx::FromRow)]
54pub struct ForeignKey {
55 pub table_schema: String,
56 pub constraint_name: String,
57 pub table_name: String,
58 pub column_name: String,
59 pub foreign_table_schema: String,
60 pub foreign_table_name: String,
61 pub foreign_column_name: String,
62}
63
64pub async fn query_constraints(
65 conn: &mut PgConnection,
66 schema_name: &str,
67) -> Result<Vec<ForeignKey>> {
68 let s = include_str!("sql/query_constraints.sql");
69 Ok(sqlx::query_as::<_, ForeignKey>(s)
70 .bind(schema_name)
71 .fetch_all(conn)
72 .await?)
73}
74
75#[derive(Debug, sqlx::FromRow)]
76pub struct Index {
77 pub schemaname: String,
78 pub tablename: String,
79 pub indexname: String,
80 pub indexdef: String,
81}
82
83pub async fn query_indices(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Index>> {
84 let s = include_str!("sql/query_indices.sql");
86 Ok(sqlx::query_as::<_, Index>(s)
87 .bind(schema_name)
88 .fetch_all(conn)
89 .await?)
90}
91
92#[derive(sqlx::FromRow)]
93pub struct Function {
94 pub routine_schema: String,
95 pub routine_name: String,
96 pub routine_type: String,
97 pub data_type: Option<String>,
98 pub routine_definition: Option<String>,
99}
100
101pub async fn query_functions(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Function>> {
102 let s = include_str!("sql/query_functions.sql");
103 Ok(sqlx::query_as::<_, Function>(s)
104 .bind(schema_name)
105 .fetch_all(conn)
106 .await?)
107}
108
109#[derive(sqlx::FromRow)]
110pub struct Trigger {
111 pub trigger_schema: String,
112 pub trigger_name: String,
113 pub event_manipulation: String,
114 pub event_object_table: String,
115 pub action_timing: String,
116 pub action_statement: String,
117}
118
119pub async fn query_triggers(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Trigger>> {
120 let s = include_str!("sql/query_triggers.sql");
121 Ok(sqlx::query_as::<_, Trigger>(s)
122 .bind(schema_name)
123 .fetch_all(conn)
124 .await?)
125}
126
127impl TryInto<Column> for SchemaColumn {
128 type Error = Error;
129
130 fn try_into(self) -> std::result::Result<Column, Self::Error> {
131 use schema::Type::*;
132 let nullable = self.is_nullable == "YES";
133 let typ = match self.data_type.as_str() {
134 "ARRAY" => {
135 let inner = schema::Type::from_str(
136 &self
137 .inner_type
138 .expect("Encounterd ARRAY with no inner type."),
139 )?;
140 Array(Box::new(inner))
141 }
142 "numeric" if self.numeric_precision.is_some() && self.numeric_scale.is_some() => {
143 Numeric(
144 self.numeric_precision.unwrap() as u8,
145 self.numeric_scale.unwrap() as u8,
146 )
147 }
148 z => schema::Type::from_str(z)?,
149 };
150 Ok(Column {
151 name: self.column_name.clone(),
152 typ,
153 nullable,
154 primary_key: false,
155 default: None,
156 constraint: None,
157 })
158 }
159}
160
161impl FromPostgres for Schema {
162 async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Schema> {
163 let column_schemas = query_schema_columns(conn, schema_name).await?;
164 let mut tables = column_schemas
165 .into_iter()
166 .chunk_by(|c| c.table_name.clone())
167 .into_iter()
168 .map(|(table_name, group)| {
169 let columns = group
170 .map(|c: SchemaColumn| c.try_into())
171 .collect::<Result<Vec<_>, Error>>()?;
172 Ok(Table {
173 schema: Some(schema_name.to_string()),
174 name: table_name,
175 columns,
176 indexes: vec![],
177 })
178 })
179 .collect::<Result<Vec<_>, Error>>()?;
180 let mut it_tables = tables.iter_mut().peekable();
181 let indices = query_indices(conn, schema_name).await?;
182 for index in indices {
183 while &index.tablename != &it_tables.peek().unwrap().name {
184 it_tables.next();
185 }
186 let t = it_tables.peek_mut().unwrap();
187 t.indexes.push(sql::Index {
188 name: index.indexname,
189 columns: Vec::new(),
190 });
191 }
192
193 let constraints = query_constraints(conn, schema_name).await?;
194 let mut it_tables = tables.iter_mut().peekable();
195 for fk in constraints {
196 while &fk.table_name != &it_tables.peek().unwrap().name {
197 it_tables.next();
198 }
199 let table = it_tables.peek_mut().unwrap();
200 let column = table
201 .columns
202 .iter_mut()
203 .find(|c| c.name == fk.column_name)
204 .expect("Constraint for unknown column.");
205 column.constraint = Some(schema::Constraint::ForeignKey(schema::ForeignKey {
206 table: fk.foreign_table_name,
207 columns: vec![fk.foreign_column_name],
208 }));
209 }
210
211 let table_names = query_table_names(conn, schema_name).await?;
213 let mut tables_it = tables.iter().peekable();
214 let mut empty_tables = Vec::new();
215 'outer: for name in table_names {
216 while let Some(table) = tables_it.peek() {
217 if &name == &table.name {
218 tables_it.next();
219 continue 'outer;
220 }
221 }
222 empty_tables.push(Table {
223 schema: Some(schema_name.to_string()),
224 name,
225 columns: vec![],
226 indexes: vec![],
227 })
228 }
229 Ok(Schema { tables })
230 }
231}
232
233#[cfg(test)]
234mod test {
235 use super::*;
236
237 #[test]
238 fn test_numeric() {
239 let c = SchemaColumn {
240 table_name: "foo".to_string(),
241 column_name: "bar".to_string(),
242 ordinal_position: 1,
243 is_nullable: "NO".to_string(),
244 data_type: "numeric".to_string(),
245 numeric_precision: Some(10),
246 numeric_scale: Some(2),
247 inner_type: None,
248 };
249 let column: Column = c.try_into().unwrap();
250 assert_eq!(column.typ, schema::Type::Numeric(10, 2));
251 }
252
253 #[test]
254 fn test_integer() {
255 let c = SchemaColumn {
256 table_name: "foo".to_string(),
257 column_name: "bar".to_string(),
258 ordinal_position: 1,
259 is_nullable: "NO".to_string(),
260 data_type: "integer".to_string(),
261 numeric_precision: Some(32),
262 numeric_scale: Some(0),
263 inner_type: None,
264 };
265 let column: Column = c.try_into().unwrap();
266 assert_eq!(column.typ, schema::Type::I32);
267 }
268}