Skip to main content

sqlx_gen/introspect/
postgres.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::PgPool;
5
6use super::{ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9    pool: &PgPool,
10    schemas: &[String],
11    include_views: bool,
12) -> Result<SchemaInfo> {
13    let tables = fetch_tables(pool, schemas).await?;
14    let mut views = if include_views {
15        fetch_views(pool, schemas).await?
16    } else {
17        Vec::new()
18    };
19
20    if !views.is_empty() {
21        let nullability_info = fetch_view_column_nullability(pool, schemas).await?;
22        resolve_view_nullability(&mut views, &nullability_info);
23    }
24
25    let enums = fetch_enums(pool, schemas).await?;
26    let composite_types = fetch_composite_types(pool, schemas).await?;
27    let domains = fetch_domains(pool, schemas).await?;
28
29    Ok(SchemaInfo {
30        tables,
31        views,
32        enums,
33        composite_types,
34        domains,
35    })
36}
37
38async fn fetch_tables(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
39    let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, bool, Option<String>)>(
40        r#"
41        SELECT
42            c.table_schema,
43            c.table_name,
44            c.column_name,
45            c.data_type,
46            COALESCE(c.udt_name, c.data_type) as udt_name,
47            c.is_nullable,
48            c.ordinal_position,
49            CASE WHEN kcu.column_name IS NOT NULL THEN true ELSE false END AS is_primary_key,
50            c.column_default
51        FROM information_schema.columns c
52        JOIN information_schema.tables t
53            ON t.table_schema = c.table_schema
54            AND t.table_name = c.table_name
55            AND t.table_type = 'BASE TABLE'
56        LEFT JOIN information_schema.table_constraints tc
57            ON tc.table_schema = c.table_schema
58            AND tc.table_name = c.table_name
59            AND tc.constraint_type = 'PRIMARY KEY'
60        LEFT JOIN information_schema.key_column_usage kcu
61            ON kcu.constraint_name = tc.constraint_name
62            AND kcu.constraint_schema = tc.constraint_schema
63            AND kcu.column_name = c.column_name
64        WHERE c.table_schema = ANY($1)
65        ORDER BY c.table_schema, c.table_name, c.ordinal_position
66        "#,
67    )
68    .bind(schemas)
69    .fetch_all(pool)
70    .await?;
71
72    let mut tables: Vec<TableInfo> = Vec::new();
73    let mut current_key: Option<(String, String)> = None;
74
75    for (schema, table, col_name, data_type, udt_name, nullable, ordinal, is_pk, column_default) in rows {
76        let key = (schema.clone(), table.clone());
77        if current_key.as_ref() != Some(&key) {
78            current_key = Some(key);
79            tables.push(TableInfo {
80                schema_name: schema.clone(),
81                name: table.clone(),
82                columns: Vec::new(),
83            });
84        }
85        tables.last_mut().unwrap().columns.push(ColumnInfo {
86            name: col_name,
87            data_type,
88            udt_name,
89            is_nullable: nullable == "YES",
90            is_primary_key: is_pk,
91            ordinal_position: ordinal,
92            schema_name: schema,
93            column_default,
94        });
95    }
96
97    Ok(tables)
98}
99
100async fn fetch_views(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
101    let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, Option<String>)>(
102        r#"
103        SELECT
104            c.table_schema,
105            c.table_name,
106            c.column_name,
107            c.data_type,
108            COALESCE(c.udt_name, c.data_type) as udt_name,
109            c.is_nullable,
110            c.ordinal_position,
111            c.column_default
112        FROM information_schema.columns c
113        JOIN information_schema.tables t
114            ON t.table_schema = c.table_schema
115            AND t.table_name = c.table_name
116            AND t.table_type = 'VIEW'
117        WHERE c.table_schema = ANY($1)
118        ORDER BY c.table_schema, c.table_name, c.ordinal_position
119        "#,
120    )
121    .bind(schemas)
122    .fetch_all(pool)
123    .await?;
124
125    let mut views: Vec<TableInfo> = Vec::new();
126    let mut current_key: Option<(String, String)> = None;
127
128    for (schema, table, col_name, data_type, udt_name, nullable, ordinal, column_default) in rows {
129        let key = (schema.clone(), table.clone());
130        if current_key.as_ref() != Some(&key) {
131            current_key = Some(key);
132            views.push(TableInfo {
133                schema_name: schema.clone(),
134                name: table.clone(),
135                columns: Vec::new(),
136            });
137        }
138        views.last_mut().unwrap().columns.push(ColumnInfo {
139            name: col_name,
140            data_type,
141            udt_name,
142            is_nullable: nullable == "YES",
143            is_primary_key: false,
144            ordinal_position: ordinal,
145            schema_name: schema,
146            column_default,
147        });
148    }
149
150    Ok(views)
151}
152
153struct ViewColumnNullability {
154    view_schema: String,
155    view_name: String,
156    source_column_name: String,
157    source_not_null: bool,
158}
159
160async fn fetch_view_column_nullability(
161    pool: &PgPool,
162    schemas: &[String],
163) -> Result<Vec<ViewColumnNullability>> {
164    let rows = sqlx::query_as::<_, (String, String, String, bool)>(
165        r#"
166        SELECT DISTINCT
167            v_ns.nspname AS view_schema,
168            v.relname AS view_name,
169            src_attr.attname AS source_column_name,
170            src_attr.attnotnull AS source_not_null
171        FROM pg_class v
172        JOIN pg_namespace v_ns ON v_ns.oid = v.relnamespace
173        JOIN pg_rewrite rw ON rw.ev_class = v.oid
174        JOIN pg_depend d ON d.objid = rw.oid
175            AND d.classid = 'pg_rewrite'::regclass
176            AND d.refobjsubid > 0
177            AND d.deptype = 'n'
178        JOIN pg_attribute src_attr ON src_attr.attrelid = d.refobjid
179            AND src_attr.attnum = d.refobjsubid
180            AND NOT src_attr.attisdropped
181        WHERE v_ns.nspname = ANY($1)
182          AND v.relkind = 'v'
183        "#,
184    )
185    .bind(schemas)
186    .fetch_all(pool)
187    .await?;
188
189    Ok(rows
190        .into_iter()
191        .map(
192            |(view_schema, view_name, source_column_name, source_not_null)| {
193                ViewColumnNullability {
194                    view_schema,
195                    view_name,
196                    source_column_name,
197                    source_not_null,
198                }
199            },
200        )
201        .collect())
202}
203
204fn resolve_view_nullability(
205    views: &mut [TableInfo],
206    nullability_info: &[ViewColumnNullability],
207) {
208    // Build lookup: (view_schema, view_name, column_name) -> Vec<is_not_null>
209    let mut lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
210    for info in nullability_info {
211        lookup
212            .entry((&info.view_schema, &info.view_name, &info.source_column_name))
213            .or_default()
214            .push(info.source_not_null);
215    }
216
217    for view in views.iter_mut() {
218        for col in view.columns.iter_mut() {
219            if let Some(not_null_flags) = lookup.get(&(
220                view.schema_name.as_str(),
221                view.name.as_str(),
222                col.name.as_str(),
223            )) {
224                // Only mark as non-nullable if ALL source columns are NOT NULL
225                if !not_null_flags.is_empty() && not_null_flags.iter().all(|&nn| nn) {
226                    col.is_nullable = false;
227                }
228            }
229        }
230    }
231}
232
233async fn fetch_enums(pool: &PgPool, schemas: &[String]) -> Result<Vec<EnumInfo>> {
234    let rows = sqlx::query_as::<_, (String, String, String)>(
235        r#"
236        SELECT
237            n.nspname AS schema_name,
238            t.typname AS enum_name,
239            e.enumlabel AS variant
240        FROM pg_catalog.pg_type t
241        JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid
242        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
243        WHERE n.nspname = ANY($1)
244        ORDER BY n.nspname, t.typname, e.enumsortorder
245        "#,
246    )
247    .bind(schemas)
248    .fetch_all(pool)
249    .await?;
250
251    let mut enums: Vec<EnumInfo> = Vec::new();
252    let mut current_key: Option<(String, String)> = None;
253
254    for (schema, name, variant) in rows {
255        let key = (schema.clone(), name.clone());
256        if current_key.as_ref() != Some(&key) {
257            current_key = Some(key);
258            enums.push(EnumInfo {
259                schema_name: schema,
260                name,
261                variants: Vec::new(),
262                default_variant: None,
263            });
264        }
265        enums.last_mut().unwrap().variants.push(variant);
266    }
267
268    Ok(enums)
269}
270
271async fn fetch_composite_types(
272    pool: &PgPool,
273    schemas: &[String],
274) -> Result<Vec<CompositeTypeInfo>> {
275    let rows = sqlx::query_as::<_, (String, String, String, String, String, i32)>(
276        r#"
277        SELECT
278            n.nspname AS schema_name,
279            t.typname AS type_name,
280            a.attname AS field_name,
281            COALESCE(ft.typname, '') AS field_type,
282            CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
283            a.attnum AS ordinal
284        FROM pg_catalog.pg_type t
285        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
286        JOIN pg_catalog.pg_class c ON c.oid = t.typrelid
287        JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid AND a.attnum > 0 AND NOT a.attisdropped
288        JOIN pg_catalog.pg_type ft ON ft.oid = a.atttypid
289        WHERE t.typtype = 'c'
290            AND n.nspname = ANY($1)
291            AND NOT EXISTS (
292                SELECT 1 FROM information_schema.tables it
293                WHERE it.table_schema = n.nspname AND it.table_name = t.typname
294            )
295        ORDER BY n.nspname, t.typname, a.attnum
296        "#,
297    )
298    .bind(schemas)
299    .fetch_all(pool)
300    .await?;
301
302    let mut composites: Vec<CompositeTypeInfo> = Vec::new();
303    let mut current_key: Option<(String, String)> = None;
304
305    for (schema, type_name, field_name, field_type, nullable, ordinal) in rows {
306        let key = (schema.clone(), type_name.clone());
307        if current_key.as_ref() != Some(&key) {
308            current_key = Some(key);
309            composites.push(CompositeTypeInfo {
310                schema_name: schema.clone(),
311                name: type_name,
312                fields: Vec::new(),
313            });
314        }
315        composites.last_mut().unwrap().fields.push(ColumnInfo {
316            name: field_name,
317            data_type: field_type.clone(),
318            udt_name: field_type,
319            is_nullable: nullable == "YES",
320            is_primary_key: false,
321            ordinal_position: ordinal,
322            schema_name: schema,
323            column_default: None,
324        });
325    }
326
327    Ok(composites)
328}
329
330async fn fetch_domains(pool: &PgPool, schemas: &[String]) -> Result<Vec<DomainInfo>> {
331    let rows = sqlx::query_as::<_, (String, String, String)>(
332        r#"
333        SELECT
334            n.nspname AS schema_name,
335            t.typname AS domain_name,
336            bt.typname AS base_type
337        FROM pg_catalog.pg_type t
338        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
339        JOIN pg_catalog.pg_type bt ON bt.oid = t.typbasetype
340        WHERE t.typtype = 'd'
341            AND n.nspname = ANY($1)
342        ORDER BY n.nspname, t.typname
343        "#,
344    )
345    .bind(schemas)
346    .fetch_all(pool)
347    .await?;
348
349    Ok(rows
350        .into_iter()
351        .map(|(schema, name, base_type)| DomainInfo {
352            schema_name: schema,
353            name,
354            base_type,
355        })
356        .collect())
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
364        TableInfo {
365            schema_name: schema.to_string(),
366            name: name.to_string(),
367            columns: columns
368                .into_iter()
369                .enumerate()
370                .map(|(i, col)| ColumnInfo {
371                    name: col.to_string(),
372                    data_type: "text".to_string(),
373                    udt_name: "text".to_string(),
374                    is_nullable: true,
375                    is_primary_key: false,
376                    ordinal_position: i as i32,
377                    schema_name: schema.to_string(),
378                    column_default: None,
379                })
380                .collect(),
381        }
382    }
383
384    fn make_nullability(
385        view_schema: &str,
386        view_name: &str,
387        source_column: &str,
388        not_null: bool,
389    ) -> ViewColumnNullability {
390        ViewColumnNullability {
391            view_schema: view_schema.to_string(),
392            view_name: view_name.to_string(),
393            source_column_name: source_column.to_string(),
394            source_not_null: not_null,
395        }
396    }
397
398    #[test]
399    fn test_resolve_not_null_column() {
400        let mut views = vec![make_view("public", "my_view", vec!["id", "name"])];
401        let info = vec![
402            make_nullability("public", "my_view", "id", true),
403            make_nullability("public", "my_view", "name", true),
404        ];
405        resolve_view_nullability(&mut views, &info);
406        assert!(!views[0].columns[0].is_nullable);
407        assert!(!views[0].columns[1].is_nullable);
408    }
409
410    #[test]
411    fn test_resolve_mixed_sources() {
412        let mut views = vec![make_view("public", "my_view", vec!["id"])];
413        let info = vec![
414            make_nullability("public", "my_view", "id", true),
415            make_nullability("public", "my_view", "id", false),
416        ];
417        resolve_view_nullability(&mut views, &info);
418        assert!(views[0].columns[0].is_nullable);
419    }
420
421    #[test]
422    fn test_resolve_no_match_stays_nullable() {
423        let mut views = vec![make_view("public", "my_view", vec!["computed_col"])];
424        let info = vec![make_nullability("public", "my_view", "id", true)];
425        resolve_view_nullability(&mut views, &info);
426        assert!(views[0].columns[0].is_nullable);
427    }
428
429    #[test]
430    fn test_resolve_empty_info() {
431        let mut views = vec![make_view("public", "my_view", vec!["id"])];
432        resolve_view_nullability(&mut views, &[]);
433        assert!(views[0].columns[0].is_nullable);
434    }
435
436    #[test]
437    fn test_resolve_cross_schema() {
438        let mut views = vec![
439            make_view("public", "v1", vec!["id"]),
440            make_view("auth", "v2", vec!["id"]),
441        ];
442        let info = vec![
443            make_nullability("public", "v1", "id", true),
444            make_nullability("auth", "v2", "id", false),
445        ];
446        resolve_view_nullability(&mut views, &info);
447        assert!(!views[0].columns[0].is_nullable);
448        assert!(views[1].columns[0].is_nullable);
449    }
450}