sql_sqlx/
lib.rs

1use anyhow::{Error, Result};
2use itertools::Itertools;
3use sql::{Column, Schema, Table, schema};
4use sqlx::PgConnection;
5use std::str::FromStr;
6
7#[allow(async_fn_in_trait)]
8pub trait FromPostgres: Sized {
9    async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Self>;
10}
11
12#[derive(sqlx::FromRow)]
13pub struct SchemaColumn {
14    pub table_name: String,
15    pub column_name: String,
16    #[allow(dead_code)]
17    pub ordinal_position: i32,
18    pub is_nullable: String,
19    pub data_type: String,
20    pub numeric_precision: Option<i32>,
21    pub numeric_scale: Option<i32>,
22    pub inner_type: Option<String>,
23}
24
25pub async fn query_schema_columns(
26    conn: &mut PgConnection,
27    schema_name: &str,
28) -> Result<Vec<SchemaColumn>> {
29    let s = include_str!("sql/query_columns.sql");
30    let result = sqlx::query_as::<_, SchemaColumn>(s)
31        .bind(schema_name)
32        .fetch_all(conn)
33        .await?;
34    Ok(result)
35}
36
37#[derive(sqlx::FromRow)]
38struct TableSchema {
39    #[allow(dead_code)]
40    pub table_schema: String,
41    pub table_name: String,
42}
43
44pub async fn query_table_names(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<String>> {
45    let s = include_str!("sql/query_tables.sql");
46    let result = sqlx::query_as::<_, TableSchema>(s)
47        .bind(schema_name)
48        .fetch_all(conn)
49        .await?;
50    Ok(result.into_iter().map(|t| t.table_name).collect())
51}
52
53#[derive(Debug, sqlx::FromRow)]
54pub struct ForeignKey {
55    pub table_schema: String,
56    pub constraint_name: String,
57    pub table_name: String,
58    pub column_name: String,
59    pub foreign_table_schema: String,
60    pub foreign_table_name: String,
61    pub foreign_column_name: String,
62}
63
64pub async fn query_constraints(
65    conn: &mut PgConnection,
66    schema_name: &str,
67) -> Result<Vec<ForeignKey>> {
68    let s = include_str!("sql/query_constraints.sql");
69    Ok(sqlx::query_as::<_, ForeignKey>(s)
70        .bind(schema_name)
71        .fetch_all(conn)
72        .await?)
73}
74
75#[derive(Debug, sqlx::FromRow)]
76pub struct Index {
77    pub schemaname: String,
78    pub tablename: String,
79    pub indexname: String,
80    pub indexdef: String,
81}
82
83pub async fn query_indices(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Index>> {
84    // because of pg_tables join, this only returns indices for tables, not views/mat views
85    let s = include_str!("sql/query_indices.sql");
86    Ok(sqlx::query_as::<_, Index>(s)
87        .bind(schema_name)
88        .fetch_all(conn)
89        .await?)
90}
91
92#[derive(sqlx::FromRow)]
93pub struct Function {
94    pub routine_schema: String,
95    pub routine_name: String,
96    pub routine_type: String,
97    pub data_type: Option<String>,
98    pub routine_definition: Option<String>,
99}
100
101pub async fn query_functions(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Function>> {
102    let s = include_str!("sql/query_functions.sql");
103    Ok(sqlx::query_as::<_, Function>(s)
104        .bind(schema_name)
105        .fetch_all(conn)
106        .await?)
107}
108
109#[derive(sqlx::FromRow)]
110pub struct Trigger {
111    pub trigger_schema: String,
112    pub trigger_name: String,
113    pub event_manipulation: String,
114    pub event_object_table: String,
115    pub action_timing: String,
116    pub action_statement: String,
117}
118
119pub async fn query_triggers(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Trigger>> {
120    let s = include_str!("sql/query_triggers.sql");
121    Ok(sqlx::query_as::<_, Trigger>(s)
122        .bind(schema_name)
123        .fetch_all(conn)
124        .await?)
125}
126
127impl TryInto<Column> for SchemaColumn {
128    type Error = Error;
129
130    fn try_into(self) -> std::result::Result<Column, Self::Error> {
131        use schema::Type::*;
132        let nullable = self.is_nullable == "YES";
133        let typ = match self.data_type.as_str() {
134            "ARRAY" => {
135                let inner = schema::Type::from_str(
136                    &self
137                        .inner_type
138                        .expect("Encounterd ARRAY with no inner type."),
139                )?;
140                Array(Box::new(inner))
141            }
142            "numeric" if self.numeric_precision.is_some() && self.numeric_scale.is_some() => {
143                Numeric(
144                    self.numeric_precision.unwrap() as u8,
145                    self.numeric_scale.unwrap() as u8,
146                )
147            }
148            z => schema::Type::from_str(z)?,
149        };
150        Ok(Column {
151            name: self.column_name.clone(),
152            typ,
153            nullable,
154            primary_key: false,
155            default: None,
156            constraint: None,
157        })
158    }
159}
160
161impl FromPostgres for Schema {
162    async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Schema> {
163        let column_schemas = query_schema_columns(conn, schema_name).await?;
164        let mut tables = column_schemas
165            .into_iter()
166            .chunk_by(|c| c.table_name.clone())
167            .into_iter()
168            .map(|(table_name, group)| {
169                let columns = group
170                    .map(|c: SchemaColumn| c.try_into())
171                    .collect::<Result<Vec<_>, Error>>()?;
172                Ok(Table {
173                    schema: Some(schema_name.to_string()),
174                    name: table_name,
175                    columns,
176                    indexes: vec![],
177                })
178            })
179            .collect::<Result<Vec<_>, Error>>()?;
180        let mut it_tables = tables.iter_mut().peekable();
181        let indices = query_indices(conn, schema_name).await?;
182        for index in indices {
183            while &index.tablename != &it_tables.peek().unwrap().name {
184                it_tables.next();
185            }
186            let t = it_tables.peek_mut().unwrap();
187            t.indexes.push(sql::Index {
188                name: index.indexname,
189                columns: Vec::new(),
190            });
191        }
192
193        let constraints = query_constraints(conn, schema_name).await?;
194        let mut it_tables = tables.iter_mut().peekable();
195        for fk in constraints {
196            while &fk.table_name != &it_tables.peek().unwrap().name {
197                it_tables.next();
198            }
199            let table = it_tables.peek_mut().unwrap();
200            let column = table
201                .columns
202                .iter_mut()
203                .find(|c| c.name == fk.column_name)
204                .expect("Constraint for unknown column.");
205            column.constraint = Some(schema::Constraint::ForeignKey(schema::ForeignKey {
206                table: fk.foreign_table_name,
207                columns: vec![fk.foreign_column_name],
208            }));
209        }
210
211        // Degenerate case but you can have tables with no columns...
212        let table_names = query_table_names(conn, schema_name).await?;
213        let mut tables_it = tables.iter().peekable();
214        let mut empty_tables = Vec::new();
215        'outer: for name in table_names {
216            while let Some(table) = tables_it.peek() {
217                if &name == &table.name {
218                    tables_it.next();
219                    continue 'outer;
220                }
221            }
222            empty_tables.push(Table {
223                schema: Some(schema_name.to_string()),
224                name,
225                columns: vec![],
226                indexes: vec![],
227            })
228        }
229        Ok(Schema { tables })
230    }
231}
232
233#[cfg(test)]
234mod test {
235    use super::*;
236
237    #[test]
238    fn test_numeric() {
239        let c = SchemaColumn {
240            table_name: "foo".to_string(),
241            column_name: "bar".to_string(),
242            ordinal_position: 1,
243            is_nullable: "NO".to_string(),
244            data_type: "numeric".to_string(),
245            numeric_precision: Some(10),
246            numeric_scale: Some(2),
247            inner_type: None,
248        };
249        let column: Column = c.try_into().unwrap();
250        assert_eq!(column.typ, schema::Type::Numeric(10, 2));
251    }
252
253    #[test]
254    fn test_integer() {
255        let c = SchemaColumn {
256            table_name: "foo".to_string(),
257            column_name: "bar".to_string(),
258            ordinal_position: 1,
259            is_nullable: "NO".to_string(),
260            data_type: "integer".to_string(),
261            numeric_precision: Some(32),
262            numeric_scale: Some(0),
263            inner_type: None,
264        };
265        let column: Column = c.try_into().unwrap();
266        assert_eq!(column.typ, schema::Type::I32);
267    }
268}