sql_sqlx/
lib.rs

1use anyhow::{Error, Result};
2use itertools::Itertools;
3use sql::{Column, Expr, Generated, GenerationTime, GenerationValue, Schema, Table, schema};
4use sqlx::PgConnection;
5use std::str::FromStr;
6
7#[allow(async_fn_in_trait)]
8pub trait FromPostgres: Sized {
9    async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Self>;
10}
11
12#[derive(sqlx::FromRow)]
13pub struct ColumnRow {
14    pub table_name: String,
15    pub column_name: String,
16    #[allow(dead_code)]
17    pub ordinal_position: i32,
18    pub is_nullable: String,
19    pub data_type: String,
20    pub numeric_precision: Option<i32>,
21    pub numeric_scale: Option<i32>,
22    pub inner_type: Option<String>,
23    pub primary_key: bool,
24    pub generation_time: Option<String>,
25    pub generation_expression: Option<String>,
26    pub identity_generation: Option<String>,
27}
28
29pub async fn query_schema_columns(
30    conn: &mut PgConnection,
31    schema_name: &str,
32) -> Result<Vec<ColumnRow>> {
33    let s = include_str!("sql/query_columns.sql");
34    let result = sqlx::query_as::<_, ColumnRow>(s)
35        .bind(schema_name)
36        .fetch_all(conn)
37        .await?;
38    Ok(result)
39}
40
41#[derive(sqlx::FromRow)]
42struct TableSchema {
43    #[allow(dead_code)]
44    pub table_schema: String,
45    pub table_name: String,
46}
47
48pub async fn query_table_names(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<String>> {
49    let s = include_str!("sql/query_tables.sql");
50    let result = sqlx::query_as::<_, TableSchema>(s)
51        .bind(schema_name)
52        .fetch_all(conn)
53        .await?;
54    Ok(result.into_iter().map(|t| t.table_name).collect())
55}
56
57#[derive(Debug, sqlx::FromRow)]
58pub struct ForeignKey {
59    pub table_schema: String,
60    pub constraint_name: String,
61    pub table_name: String,
62    pub column_name: String,
63    pub foreign_table_schema: String,
64    pub foreign_table_name: String,
65    pub foreign_column_name: String,
66}
67
68pub async fn query_constraints(
69    conn: &mut PgConnection,
70    schema_name: &str,
71) -> Result<Vec<ForeignKey>> {
72    let s = include_str!("sql/query_constraints.sql");
73    Ok(sqlx::query_as::<_, ForeignKey>(s)
74        .bind(schema_name)
75        .fetch_all(conn)
76        .await?)
77}
78
79#[derive(Debug, sqlx::FromRow)]
80pub struct Index {
81    pub schema: String,
82    pub table: String,
83    pub name: String,
84    // text of the index definition
85    pub statement: String,
86    pub unique: bool,
87    pub kind: String,
88    pub columns: Vec<String>,
89}
90
91pub async fn query_indices(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Index>> {
92    // because of pg_tables join, this only returns indices for tables, not views/mat views
93    let s = include_str!("sql/query_indices.sql");
94    Ok(sqlx::query_as::<_, Index>(s)
95        .bind(schema_name)
96        .fetch_all(conn)
97        .await?)
98}
99
100#[derive(Debug, sqlx::FromRow)]
101pub struct Function {
102    pub routine_schema: String,
103    pub routine_name: String,
104    pub routine_type: String,
105    pub data_type: Option<String>,
106    pub routine_definition: Option<String>,
107}
108
109pub async fn query_functions(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Function>> {
110    let s = include_str!("sql/query_functions.sql");
111    Ok(sqlx::query_as::<_, Function>(s)
112        .bind(schema_name)
113        .fetch_all(conn)
114        .await?)
115}
116
117#[derive(sqlx::FromRow)]
118pub struct Trigger {
119    pub trigger_schema: String,
120    pub trigger_name: String,
121    pub event_manipulation: String,
122    pub event_object_table: String,
123    pub action_timing: String,
124    pub action_statement: String,
125}
126
127pub async fn query_triggers(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Trigger>> {
128    let s = include_str!("sql/query_triggers.sql");
129    Ok(sqlx::query_as::<_, Trigger>(s)
130        .bind(schema_name)
131        .fetch_all(conn)
132        .await?)
133}
134
135fn parse_generated(
136    time: Option<String>,
137    expr: Option<String>,
138    identity: Option<String>,
139) -> Option<Generated> {
140    let time = time?;
141    let time = match time.as_str() {
142        "ALWAYS" => GenerationTime::Always,
143        "BY DEFAULT" => GenerationTime::ByDefault,
144        _ => return None,
145    };
146    let value = if let Some(identity) = identity
147        && identity == "YES"
148    {
149        GenerationValue::Identity
150    } else if let Some(expr) = expr {
151        GenerationValue::Expr(Expr::Raw(expr))
152    } else {
153        return None;
154    };
155    Some(Generated { time, value })
156}
157
158impl TryInto<Column> for ColumnRow {
159    type Error = Error;
160
161    fn try_into(self) -> std::result::Result<Column, Self::Error> {
162        use schema::Type::*;
163        let nullable = self.is_nullable == "YES";
164        let typ = match self.data_type.as_str() {
165            "ARRAY" => {
166                let inner = schema::Type::from_str(
167                    &self
168                        .inner_type
169                        .expect("Encounterd ARRAY with no inner type."),
170                )?;
171                Array(Box::new(inner))
172            }
173            "numeric" if self.numeric_precision.is_some() && self.numeric_scale.is_some() => {
174                Numeric(
175                    self.numeric_precision.unwrap() as u8,
176                    self.numeric_scale.unwrap() as u8,
177                )
178            }
179            z => schema::Type::from_str(z)?,
180        };
181        let generated = parse_generated(
182            self.generation_time,
183            self.generation_expression,
184            self.identity_generation,
185        );
186        Ok(Column {
187            name: self.column_name.clone(),
188            typ,
189            nullable,
190            primary_key: false,
191            default: None,
192            constraint: None,
193            generated,
194        })
195    }
196}
197
198impl FromPostgres for Schema {
199    async fn try_from_postgres(conn: &mut PgConnection, schema: &str) -> Result<Schema> {
200        let column_schemas = query_schema_columns(conn, schema).await?;
201        let mut tables = column_schemas
202            .into_iter()
203            .chunk_by(|c| c.table_name.clone())
204            .into_iter()
205            .map(|(table_name, group)| {
206                let columns = group
207                    .map(|c: ColumnRow| c.try_into())
208                    .collect::<Result<Vec<_>, Error>>()?;
209                Ok(Table {
210                    schema: Some(schema.to_string()),
211                    name: table_name,
212                    columns,
213                })
214            })
215            .collect::<Result<Vec<_>, Error>>()?;
216        let constraints = query_constraints(conn, schema).await?;
217        let mut it_tables = tables.iter_mut().peekable();
218        for fk in constraints {
219            while &fk.table_name != &it_tables.peek().unwrap().name {
220                it_tables.next();
221            }
222            let table = it_tables.peek_mut().unwrap();
223            let column = table
224                .columns
225                .iter_mut()
226                .find(|c| c.name == fk.column_name)
227                .expect("Constraint for unknown column.");
228            column.constraint = Some(schema::Constraint::ForeignKey(schema::ForeignKey {
229                table: fk.foreign_table_name,
230                columns: vec![fk.foreign_column_name],
231            }));
232        }
233
234        // Degenerate case but you can have tables with no columns...
235        let table_names = query_table_names(conn, schema).await?;
236        let mut tables_it = tables.iter().peekable();
237        let mut empty_tables = Vec::new();
238        'outer: for name in table_names {
239            while let Some(table) = tables_it.peek() {
240                if &name == &table.name {
241                    tables_it.next();
242                    continue 'outer;
243                }
244            }
245            empty_tables.push(Table {
246                schema: Some(schema.to_string()),
247                name,
248                columns: vec![],
249            })
250        }
251        Ok(Schema { tables })
252    }
253}
254
255#[cfg(test)]
256mod test {
257    use super::*;
258
259    #[test]
260    fn test_numeric() {
261        let c = ColumnRow {
262            table_name: "foo".to_string(),
263            column_name: "bar".to_string(),
264            ordinal_position: 1,
265            is_nullable: "NO".to_string(),
266            data_type: "numeric".to_string(),
267            numeric_precision: Some(10),
268            numeric_scale: Some(2),
269            inner_type: None,
270            primary_key: false,
271        };
272        let column: Column = c.try_into().unwrap();
273        assert_eq!(column.typ, schema::Type::Numeric(10, 2));
274    }
275
276    #[test]
277    fn test_integer() {
278        let c = ColumnRow {
279            table_name: "foo".to_string(),
280            column_name: "bar".to_string(),
281            ordinal_position: 1,
282            is_nullable: "NO".to_string(),
283            data_type: "integer".to_string(),
284            numeric_precision: Some(32),
285            numeric_scale: Some(0),
286            inner_type: None,
287            primary_key: false,
288        };
289        let column: Column = c.try_into().unwrap();
290        assert_eq!(column.typ, schema::Type::I32);
291    }
292}