1use anyhow::{Error, Result};
2use itertools::Itertools;
3use sql::{Column, Schema, Table, schema};
4use sqlx::PgConnection;
5use std::str::FromStr;
6
7#[allow(async_fn_in_trait)]
8pub trait FromPostgres: Sized {
9 async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Self>;
10}
11
12#[derive(sqlx::FromRow)]
13struct SchemaColumn {
14 pub table_name: String,
15 pub column_name: String,
16 #[allow(dead_code)]
17 pub ordinal_position: i32,
18 pub is_nullable: String,
19 pub data_type: String,
20 pub numeric_precision: Option<i32>,
21 pub numeric_scale: Option<i32>,
22 pub inner_type: Option<String>,
23}
24
25async fn query_schema_columns(
26 conn: &mut PgConnection,
27 schema_name: &str,
28) -> Result<Vec<SchemaColumn>> {
29 let s = include_str!("sql/query_columns.sql");
30 let result = sqlx::query_as::<_, SchemaColumn>(s)
31 .bind(schema_name)
32 .fetch_all(conn)
33 .await?;
34 Ok(result)
35}
36
37#[derive(sqlx::FromRow)]
38struct TableSchema {
39 #[allow(dead_code)]
40 pub table_schema: String,
41 pub table_name: String,
42}
43
44async fn query_table_names(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<String>> {
45 let s = include_str!("sql/query_tables.sql");
46 let result = sqlx::query_as::<_, TableSchema>(s)
47 .bind(schema_name)
48 .fetch_all(conn)
49 .await?;
50 Ok(result.into_iter().map(|t| t.table_name).collect())
51}
52
53#[derive(sqlx::FromRow)]
54#[allow(dead_code)]
55struct ForeignKey {
56 pub table_schema: String,
57 pub constraint_name: String,
58 pub table_name: String,
59 pub column_name: String,
60 pub foreign_table_schema: String,
61 pub foreign_table_name: String,
62 pub foreign_column_name: String,
63}
64
65async fn query_constraints(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<ForeignKey>> {
66 let s = include_str!("sql/query_constraints.sql");
67 Ok(sqlx::query_as::<_, ForeignKey>(s)
68 .bind(schema_name)
69 .fetch_all(conn)
70 .await?)
71}
72
73#[derive(sqlx::FromRow)]
74pub struct Index {
75 pub schemaname: String,
76 pub tablename: String,
77 pub indexname: String,
78 pub indexdef: String,
79}
80
81pub async fn query_indices(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Index>> {
82 let s = include_str!("sql/query_indices.sql");
83 Ok(sqlx::query_as::<_, Index>(s)
84 .bind(schema_name)
85 .fetch_all(conn)
86 .await?)
87}
88
89#[derive(sqlx::FromRow)]
90pub struct Function {
91 pub routine_schema: String,
92 pub routine_name: String,
93 pub routine_type: String,
94 pub data_type: Option<String>,
95 pub routine_definition: Option<String>,
96}
97
98pub async fn query_functions(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Function>> {
99 let s = include_str!("sql/query_functions.sql");
100 Ok(sqlx::query_as::<_, Function>(s)
101 .bind(schema_name)
102 .fetch_all(conn)
103 .await?)
104}
105
106#[derive(sqlx::FromRow)]
107pub struct Trigger {
108 pub trigger_schema: String,
109 pub trigger_name: String,
110 pub event_manipulation: String,
111 pub event_object_table: String,
112 pub action_timing: String,
113 pub action_statement: String,
114}
115
116pub async fn query_triggers(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Trigger>> {
117 let s = include_str!("sql/query_triggers.sql");
118 Ok(sqlx::query_as::<_, Trigger>(s)
119 .bind(schema_name)
120 .fetch_all(conn)
121 .await?)
122}
123
124impl TryInto<Column> for SchemaColumn {
125 type Error = Error;
126
127 fn try_into(self) -> std::result::Result<Column, Self::Error> {
128 use schema::Type::*;
129 let nullable = self.is_nullable == "YES";
130 let typ = match self.data_type.as_str() {
131 "ARRAY" => {
132 let inner = schema::Type::from_str(
133 &self
134 .inner_type
135 .expect("Encounterd ARRAY with no inner type."),
136 )?;
137 Array(Box::new(inner))
138 }
139 "numeric" if self.numeric_precision.is_some() && self.numeric_scale.is_some() => {
140 Numeric(
141 self.numeric_precision.unwrap() as u8,
142 self.numeric_scale.unwrap() as u8,
143 )
144 }
145 z => schema::Type::from_str(z)?,
146 };
147 Ok(Column {
148 name: self.column_name.clone(),
149 typ,
150 nullable,
151 primary_key: false,
152 default: None,
153 constraint: None,
154 })
155 }
156}
157
158impl FromPostgres for Schema {
159 async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Schema> {
160 let column_schemas = query_schema_columns(conn, schema_name).await?;
161 let mut tables = column_schemas
162 .into_iter()
163 .chunk_by(|c| c.table_name.clone())
164 .into_iter()
165 .map(|(table_name, group)| {
166 let columns = group
167 .map(|c: SchemaColumn| c.try_into())
168 .collect::<Result<Vec<_>, Error>>()?;
169 Ok(Table {
170 schema: Some(schema_name.to_string()),
171 name: table_name,
172 columns,
173 indexes: vec![],
174 })
175 })
176 .collect::<Result<Vec<_>, Error>>()?;
177
178 let constraints = query_constraints(conn, schema_name).await?;
179 for fk in constraints {
180 let table = tables
181 .iter_mut()
182 .find(|t| t.name == fk.table_name)
183 .expect("Constraint for unknown table.");
184 let column = table
185 .columns
186 .iter_mut()
187 .find(|c| c.name == fk.column_name)
188 .expect("Constraint for unknown column.");
189 column.constraint = Some(schema::Constraint::ForeignKey(schema::ForeignKey {
190 table: fk.foreign_table_name,
191 columns: vec![fk.foreign_column_name],
192 }));
193 }
194
195 let table_names = query_table_names(conn, schema_name).await?;
197 for name in table_names {
198 if tables.iter().any(|t| t.name == name) {
199 continue;
200 }
201 tables.push(Table {
202 schema: Some(schema_name.to_string()),
203 name,
204 columns: vec![],
205 indexes: vec![],
206 })
207 }
208 Ok(Schema { tables })
209 }
210}
211
212#[cfg(test)]
213mod test {
214 use super::*;
215
216 #[test]
217 fn test_numeric() {
218 let c = SchemaColumn {
219 table_name: "foo".to_string(),
220 column_name: "bar".to_string(),
221 ordinal_position: 1,
222 is_nullable: "NO".to_string(),
223 data_type: "numeric".to_string(),
224 numeric_precision: Some(10),
225 numeric_scale: Some(2),
226 inner_type: None,
227 };
228 let column: Column = c.try_into().unwrap();
229 assert_eq!(column.typ, schema::Type::Numeric(10, 2));
230 }
231
232 #[test]
233 fn test_integer() {
234 let c = SchemaColumn {
235 table_name: "foo".to_string(),
236 column_name: "bar".to_string(),
237 ordinal_position: 1,
238 is_nullable: "NO".to_string(),
239 data_type: "integer".to_string(),
240 numeric_precision: Some(32),
241 numeric_scale: Some(0),
242 inner_type: None,
243 };
244 let column: Column = c.try_into().unwrap();
245 assert_eq!(column.typ, schema::Type::I32);
246 }
247}