sql_sqlx/
lib.rs

1use anyhow::{Error, Result};
2use itertools::Itertools;
3use sql::{Column, Schema, Table, schema};
4use sqlx::PgConnection;
5use std::str::FromStr;
6
7#[allow(async_fn_in_trait)]
8pub trait FromPostgres: Sized {
9    async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Self>;
10}
11
12#[derive(sqlx::FromRow)]
13struct SchemaColumn {
14    pub table_name: String,
15    pub column_name: String,
16    #[allow(dead_code)]
17    pub ordinal_position: i32,
18    pub is_nullable: String,
19    pub data_type: String,
20    pub numeric_precision: Option<i32>,
21    pub numeric_scale: Option<i32>,
22    pub inner_type: Option<String>,
23}
24
25async fn query_schema_columns(
26    conn: &mut PgConnection,
27    schema_name: &str,
28) -> Result<Vec<SchemaColumn>> {
29    let s = include_str!("sql/query_columns.sql");
30    let result = sqlx::query_as::<_, SchemaColumn>(s)
31        .bind(schema_name)
32        .fetch_all(conn)
33        .await?;
34    Ok(result)
35}
36
37#[derive(sqlx::FromRow)]
38struct TableSchema {
39    #[allow(dead_code)]
40    pub table_schema: String,
41    pub table_name: String,
42}
43
44async fn query_table_names(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<String>> {
45    let s = include_str!("sql/query_tables.sql");
46    let result = sqlx::query_as::<_, TableSchema>(s)
47        .bind(schema_name)
48        .fetch_all(conn)
49        .await?;
50    Ok(result.into_iter().map(|t| t.table_name).collect())
51}
52
53#[derive(sqlx::FromRow)]
54#[allow(dead_code)]
55struct ForeignKey {
56    pub table_schema: String,
57    pub constraint_name: String,
58    pub table_name: String,
59    pub column_name: String,
60    pub foreign_table_schema: String,
61    pub foreign_table_name: String,
62    pub foreign_column_name: String,
63}
64
65async fn query_constraints(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<ForeignKey>> {
66    let s = include_str!("sql/query_constraints.sql");
67    Ok(sqlx::query_as::<_, ForeignKey>(s)
68        .bind(schema_name)
69        .fetch_all(conn)
70        .await?)
71}
72
73#[derive(sqlx::FromRow)]
74pub struct Index {
75    pub schemaname: String,
76    pub tablename: String,
77    pub indexname: String,
78    pub indexdef: String,
79}
80
81pub async fn query_indices(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Index>> {
82    let s = include_str!("sql/query_indices.sql");
83    Ok(sqlx::query_as::<_, Index>(s)
84        .bind(schema_name)
85        .fetch_all(conn)
86        .await?)
87}
88
89#[derive(sqlx::FromRow)]
90pub struct Function {
91    pub routine_schema: String,
92    pub routine_name: String,
93    pub routine_type: String,
94    pub data_type: Option<String>,
95    pub routine_definition: Option<String>,
96}
97
98pub async fn query_functions(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Function>> {
99    let s = include_str!("sql/query_functions.sql");
100    Ok(sqlx::query_as::<_, Function>(s)
101        .bind(schema_name)
102        .fetch_all(conn)
103        .await?)
104}
105
106#[derive(sqlx::FromRow)]
107pub struct Trigger {
108    pub trigger_schema: String,
109    pub trigger_name: String,
110    pub event_manipulation: String,
111    pub event_object_table: String,
112    pub action_timing: String,
113    pub action_statement: String,
114}
115
116pub async fn query_triggers(conn: &mut PgConnection, schema_name: &str) -> Result<Vec<Trigger>> {
117    let s = include_str!("sql/query_triggers.sql");
118    Ok(sqlx::query_as::<_, Trigger>(s)
119        .bind(schema_name)
120        .fetch_all(conn)
121        .await?)
122}
123
124impl TryInto<Column> for SchemaColumn {
125    type Error = Error;
126
127    fn try_into(self) -> std::result::Result<Column, Self::Error> {
128        use schema::Type::*;
129        let nullable = self.is_nullable == "YES";
130        let typ = match self.data_type.as_str() {
131            "ARRAY" => {
132                let inner = schema::Type::from_str(
133                    &self
134                        .inner_type
135                        .expect("Encounterd ARRAY with no inner type."),
136                )?;
137                Array(Box::new(inner))
138            }
139            "numeric" if self.numeric_precision.is_some() && self.numeric_scale.is_some() => {
140                Numeric(
141                    self.numeric_precision.unwrap() as u8,
142                    self.numeric_scale.unwrap() as u8,
143                )
144            }
145            z => schema::Type::from_str(z)?,
146        };
147        Ok(Column {
148            name: self.column_name.clone(),
149            typ,
150            nullable,
151            primary_key: false,
152            default: None,
153            constraint: None,
154        })
155    }
156}
157
158impl FromPostgres for Schema {
159    async fn try_from_postgres(conn: &mut PgConnection, schema_name: &str) -> Result<Schema> {
160        let column_schemas = query_schema_columns(conn, schema_name).await?;
161        let mut tables = column_schemas
162            .into_iter()
163            .chunk_by(|c| c.table_name.clone())
164            .into_iter()
165            .map(|(table_name, group)| {
166                let columns = group
167                    .map(|c: SchemaColumn| c.try_into())
168                    .collect::<Result<Vec<_>, Error>>()?;
169                Ok(Table {
170                    schema: Some(schema_name.to_string()),
171                    name: table_name,
172                    columns,
173                    indexes: vec![],
174                })
175            })
176            .collect::<Result<Vec<_>, Error>>()?;
177
178        let constraints = query_constraints(conn, schema_name).await?;
179        for fk in constraints {
180            let table = tables
181                .iter_mut()
182                .find(|t| t.name == fk.table_name)
183                .expect("Constraint for unknown table.");
184            let column = table
185                .columns
186                .iter_mut()
187                .find(|c| c.name == fk.column_name)
188                .expect("Constraint for unknown column.");
189            column.constraint = Some(schema::Constraint::ForeignKey(schema::ForeignKey {
190                table: fk.foreign_table_name,
191                columns: vec![fk.foreign_column_name],
192            }));
193        }
194
195        // Degenerate case but you can have tables with no columns...
196        let table_names = query_table_names(conn, schema_name).await?;
197        for name in table_names {
198            if tables.iter().any(|t| t.name == name) {
199                continue;
200            }
201            tables.push(Table {
202                schema: Some(schema_name.to_string()),
203                name,
204                columns: vec![],
205                indexes: vec![],
206            })
207        }
208        Ok(Schema { tables })
209    }
210}
211
212#[cfg(test)]
213mod test {
214    use super::*;
215
216    #[test]
217    fn test_numeric() {
218        let c = SchemaColumn {
219            table_name: "foo".to_string(),
220            column_name: "bar".to_string(),
221            ordinal_position: 1,
222            is_nullable: "NO".to_string(),
223            data_type: "numeric".to_string(),
224            numeric_precision: Some(10),
225            numeric_scale: Some(2),
226            inner_type: None,
227        };
228        let column: Column = c.try_into().unwrap();
229        assert_eq!(column.typ, schema::Type::Numeric(10, 2));
230    }
231
232    #[test]
233    fn test_integer() {
234        let c = SchemaColumn {
235            table_name: "foo".to_string(),
236            column_name: "bar".to_string(),
237            ordinal_position: 1,
238            is_nullable: "NO".to_string(),
239            data_type: "integer".to_string(),
240            numeric_precision: Some(32),
241            numeric_scale: Some(0),
242            inner_type: None,
243        };
244        let column: Column = c.try_into().unwrap();
245        assert_eq!(column.typ, schema::Type::I32);
246    }
247}