Skip to main content

sqlx_gen/introspect/
postgres.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::PgPool;
5
6use super::{ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9    pool: &PgPool,
10    schemas: &[String],
11    include_views: bool,
12) -> Result<SchemaInfo> {
13    let tables = fetch_tables(pool, schemas).await?;
14    let mut views = if include_views {
15        fetch_views(pool, schemas).await?
16    } else {
17        Vec::new()
18    };
19
20    if !views.is_empty() {
21        let nullability_info = fetch_view_column_nullability(pool, schemas).await?;
22        resolve_view_nullability(&mut views, &nullability_info);
23
24        let pk_info = fetch_view_column_primary_keys(pool, schemas).await?;
25        resolve_view_primary_keys(&mut views, &pk_info);
26    }
27
28    let enums = fetch_enums(pool, schemas).await?;
29    let composite_types = fetch_composite_types(pool, schemas).await?;
30    let domains = fetch_domains(pool, schemas).await?;
31
32    Ok(SchemaInfo {
33        tables,
34        views,
35        enums,
36        composite_types,
37        domains,
38    })
39}
40
41async fn fetch_tables(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
42    let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, bool, Option<String>)>(
43        r#"
44        SELECT
45            c.table_schema,
46            c.table_name,
47            c.column_name,
48            c.data_type,
49            COALESCE(c.udt_name, c.data_type) as udt_name,
50            c.is_nullable,
51            c.ordinal_position,
52            CASE WHEN kcu.column_name IS NOT NULL THEN true ELSE false END AS is_primary_key,
53            c.column_default
54        FROM information_schema.columns c
55        JOIN information_schema.tables t
56            ON t.table_schema = c.table_schema
57            AND t.table_name = c.table_name
58            AND t.table_type = 'BASE TABLE'
59        LEFT JOIN information_schema.table_constraints tc
60            ON tc.table_schema = c.table_schema
61            AND tc.table_name = c.table_name
62            AND tc.constraint_type = 'PRIMARY KEY'
63        LEFT JOIN information_schema.key_column_usage kcu
64            ON kcu.constraint_name = tc.constraint_name
65            AND kcu.constraint_schema = tc.constraint_schema
66            AND kcu.column_name = c.column_name
67        WHERE c.table_schema = ANY($1)
68        ORDER BY c.table_schema, c.table_name, c.ordinal_position
69        "#,
70    )
71    .bind(schemas)
72    .fetch_all(pool)
73    .await?;
74
75    let mut tables: Vec<TableInfo> = Vec::new();
76    let mut current_key: Option<(String, String)> = None;
77
78    for (schema, table, col_name, data_type, udt_name, nullable, ordinal, is_pk, column_default) in rows {
79        let key = (schema.clone(), table.clone());
80        if current_key.as_ref() != Some(&key) {
81            current_key = Some(key);
82            tables.push(TableInfo {
83                schema_name: schema.clone(),
84                name: table.clone(),
85                columns: Vec::new(),
86            });
87        }
88        tables.last_mut().unwrap().columns.push(ColumnInfo {
89            name: col_name,
90            data_type,
91            udt_name,
92            is_nullable: nullable == "YES",
93            is_primary_key: is_pk,
94            ordinal_position: ordinal,
95            schema_name: schema,
96            column_default,
97        });
98    }
99
100    Ok(tables)
101}
102
103async fn fetch_views(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
104    let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, Option<String>)>(
105        r#"
106        SELECT
107            c.table_schema,
108            c.table_name,
109            c.column_name,
110            c.data_type,
111            COALESCE(c.udt_name, c.data_type) as udt_name,
112            c.is_nullable,
113            c.ordinal_position,
114            c.column_default
115        FROM information_schema.columns c
116        JOIN information_schema.tables t
117            ON t.table_schema = c.table_schema
118            AND t.table_name = c.table_name
119            AND t.table_type = 'VIEW'
120        WHERE c.table_schema = ANY($1)
121        ORDER BY c.table_schema, c.table_name, c.ordinal_position
122        "#,
123    )
124    .bind(schemas)
125    .fetch_all(pool)
126    .await?;
127
128    let mut views: Vec<TableInfo> = Vec::new();
129    let mut current_key: Option<(String, String)> = None;
130
131    for (schema, table, col_name, data_type, udt_name, nullable, ordinal, column_default) in rows {
132        let key = (schema.clone(), table.clone());
133        if current_key.as_ref() != Some(&key) {
134            current_key = Some(key);
135            views.push(TableInfo {
136                schema_name: schema.clone(),
137                name: table.clone(),
138                columns: Vec::new(),
139            });
140        }
141        views.last_mut().unwrap().columns.push(ColumnInfo {
142            name: col_name,
143            data_type,
144            udt_name,
145            is_nullable: nullable == "YES",
146            is_primary_key: false,
147            ordinal_position: ordinal,
148            schema_name: schema,
149            column_default,
150        });
151    }
152
153    Ok(views)
154}
155
156struct ViewColumnNullability {
157    view_schema: String,
158    view_name: String,
159    source_column_name: String,
160    source_not_null: bool,
161}
162
163async fn fetch_view_column_nullability(
164    pool: &PgPool,
165    schemas: &[String],
166) -> Result<Vec<ViewColumnNullability>> {
167    let rows = sqlx::query_as::<_, (String, String, String, bool)>(
168        r#"
169        SELECT DISTINCT
170            v_ns.nspname AS view_schema,
171            v.relname AS view_name,
172            src_attr.attname AS source_column_name,
173            src_attr.attnotnull AS source_not_null
174        FROM pg_class v
175        JOIN pg_namespace v_ns ON v_ns.oid = v.relnamespace
176        JOIN pg_rewrite rw ON rw.ev_class = v.oid
177        JOIN pg_depend d ON d.objid = rw.oid
178            AND d.classid = 'pg_rewrite'::regclass
179            AND d.refobjsubid > 0
180            AND d.deptype = 'n'
181        JOIN pg_attribute src_attr ON src_attr.attrelid = d.refobjid
182            AND src_attr.attnum = d.refobjsubid
183            AND NOT src_attr.attisdropped
184        WHERE v_ns.nspname = ANY($1)
185          AND v.relkind = 'v'
186        "#,
187    )
188    .bind(schemas)
189    .fetch_all(pool)
190    .await?;
191
192    Ok(rows
193        .into_iter()
194        .map(
195            |(view_schema, view_name, source_column_name, source_not_null)| {
196                ViewColumnNullability {
197                    view_schema,
198                    view_name,
199                    source_column_name,
200                    source_not_null,
201                }
202            },
203        )
204        .collect())
205}
206
207fn resolve_view_nullability(
208    views: &mut [TableInfo],
209    nullability_info: &[ViewColumnNullability],
210) {
211    // Build lookup: (view_schema, view_name, column_name) -> Vec<is_not_null>
212    let mut lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
213    for info in nullability_info {
214        lookup
215            .entry((&info.view_schema, &info.view_name, &info.source_column_name))
216            .or_default()
217            .push(info.source_not_null);
218    }
219
220    for view in views.iter_mut() {
221        for col in view.columns.iter_mut() {
222            if let Some(not_null_flags) = lookup.get(&(
223                view.schema_name.as_str(),
224                view.name.as_str(),
225                col.name.as_str(),
226            )) {
227                // Only mark as non-nullable if ALL source columns are NOT NULL
228                if !not_null_flags.is_empty() && not_null_flags.iter().all(|&nn| nn) {
229                    col.is_nullable = false;
230                }
231            }
232        }
233    }
234}
235
236struct ViewColumnPrimaryKey {
237    view_schema: String,
238    view_name: String,
239    source_column_name: String,
240    source_is_pk: bool,
241}
242
243async fn fetch_view_column_primary_keys(
244    pool: &PgPool,
245    schemas: &[String],
246) -> Result<Vec<ViewColumnPrimaryKey>> {
247    let rows = sqlx::query_as::<_, (String, String, String, bool)>(
248        r#"
249        SELECT DISTINCT
250            v_ns.nspname AS view_schema,
251            v.relname AS view_name,
252            src_attr.attname AS source_column_name,
253            COALESCE(
254                EXISTS (
255                    SELECT 1
256                    FROM pg_constraint con
257                    WHERE con.conrelid = src_attr.attrelid
258                      AND con.contype = 'p'
259                      AND src_attr.attnum = ANY(con.conkey)
260                ),
261                false
262            ) AS source_is_pk
263        FROM pg_class v
264        JOIN pg_namespace v_ns ON v_ns.oid = v.relnamespace
265        JOIN pg_rewrite rw ON rw.ev_class = v.oid
266        JOIN pg_depend d ON d.objid = rw.oid
267            AND d.classid = 'pg_rewrite'::regclass
268            AND d.refobjsubid > 0
269            AND d.deptype = 'n'
270        JOIN pg_attribute src_attr ON src_attr.attrelid = d.refobjid
271            AND src_attr.attnum = d.refobjsubid
272            AND NOT src_attr.attisdropped
273        WHERE v_ns.nspname = ANY($1)
274          AND v.relkind = 'v'
275        "#,
276    )
277    .bind(schemas)
278    .fetch_all(pool)
279    .await?;
280
281    Ok(rows
282        .into_iter()
283        .map(
284            |(view_schema, view_name, source_column_name, source_is_pk)| ViewColumnPrimaryKey {
285                view_schema,
286                view_name,
287                source_column_name,
288                source_is_pk,
289            },
290        )
291        .collect())
292}
293
294fn resolve_view_primary_keys(
295    views: &mut [TableInfo],
296    pk_info: &[ViewColumnPrimaryKey],
297) {
298    // Build lookup: (view_schema, view_name, column_name) -> Vec<is_pk>
299    let mut lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
300    for info in pk_info {
301        lookup
302            .entry((&info.view_schema, &info.view_name, &info.source_column_name))
303            .or_default()
304            .push(info.source_is_pk);
305    }
306
307    for view in views.iter_mut() {
308        for col in view.columns.iter_mut() {
309            if let Some(pk_flags) = lookup.get(&(
310                view.schema_name.as_str(),
311                view.name.as_str(),
312                col.name.as_str(),
313            )) {
314                // Only mark as PK if ALL source columns are PKs
315                if !pk_flags.is_empty() && pk_flags.iter().all(|&pk| pk) {
316                    col.is_primary_key = true;
317                }
318            }
319        }
320    }
321}
322
323async fn fetch_enums(pool: &PgPool, schemas: &[String]) -> Result<Vec<EnumInfo>> {
324    let rows = sqlx::query_as::<_, (String, String, String)>(
325        r#"
326        SELECT
327            n.nspname AS schema_name,
328            t.typname AS enum_name,
329            e.enumlabel AS variant
330        FROM pg_catalog.pg_type t
331        JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid
332        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
333        WHERE n.nspname = ANY($1)
334        ORDER BY n.nspname, t.typname, e.enumsortorder
335        "#,
336    )
337    .bind(schemas)
338    .fetch_all(pool)
339    .await?;
340
341    let mut enums: Vec<EnumInfo> = Vec::new();
342    let mut current_key: Option<(String, String)> = None;
343
344    for (schema, name, variant) in rows {
345        let key = (schema.clone(), name.clone());
346        if current_key.as_ref() != Some(&key) {
347            current_key = Some(key);
348            enums.push(EnumInfo {
349                schema_name: schema,
350                name,
351                variants: Vec::new(),
352                default_variant: None,
353            });
354        }
355        enums.last_mut().unwrap().variants.push(variant);
356    }
357
358    Ok(enums)
359}
360
361async fn fetch_composite_types(
362    pool: &PgPool,
363    schemas: &[String],
364) -> Result<Vec<CompositeTypeInfo>> {
365    let rows = sqlx::query_as::<_, (String, String, String, String, String, i32)>(
366        r#"
367        SELECT
368            n.nspname AS schema_name,
369            t.typname AS type_name,
370            a.attname AS field_name,
371            COALESCE(ft.typname, '') AS field_type,
372            CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
373            a.attnum AS ordinal
374        FROM pg_catalog.pg_type t
375        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
376        JOIN pg_catalog.pg_class c ON c.oid = t.typrelid
377        JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid AND a.attnum > 0 AND NOT a.attisdropped
378        JOIN pg_catalog.pg_type ft ON ft.oid = a.atttypid
379        WHERE t.typtype = 'c'
380            AND n.nspname = ANY($1)
381            AND NOT EXISTS (
382                SELECT 1 FROM information_schema.tables it
383                WHERE it.table_schema = n.nspname AND it.table_name = t.typname
384            )
385        ORDER BY n.nspname, t.typname, a.attnum
386        "#,
387    )
388    .bind(schemas)
389    .fetch_all(pool)
390    .await?;
391
392    let mut composites: Vec<CompositeTypeInfo> = Vec::new();
393    let mut current_key: Option<(String, String)> = None;
394
395    for (schema, type_name, field_name, field_type, nullable, ordinal) in rows {
396        let key = (schema.clone(), type_name.clone());
397        if current_key.as_ref() != Some(&key) {
398            current_key = Some(key);
399            composites.push(CompositeTypeInfo {
400                schema_name: schema.clone(),
401                name: type_name,
402                fields: Vec::new(),
403            });
404        }
405        composites.last_mut().unwrap().fields.push(ColumnInfo {
406            name: field_name,
407            data_type: field_type.clone(),
408            udt_name: field_type,
409            is_nullable: nullable == "YES",
410            is_primary_key: false,
411            ordinal_position: ordinal,
412            schema_name: schema,
413            column_default: None,
414        });
415    }
416
417    Ok(composites)
418}
419
420async fn fetch_domains(pool: &PgPool, schemas: &[String]) -> Result<Vec<DomainInfo>> {
421    let rows = sqlx::query_as::<_, (String, String, String)>(
422        r#"
423        SELECT
424            n.nspname AS schema_name,
425            t.typname AS domain_name,
426            bt.typname AS base_type
427        FROM pg_catalog.pg_type t
428        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
429        JOIN pg_catalog.pg_type bt ON bt.oid = t.typbasetype
430        WHERE t.typtype = 'd'
431            AND n.nspname = ANY($1)
432        ORDER BY n.nspname, t.typname
433        "#,
434    )
435    .bind(schemas)
436    .fetch_all(pool)
437    .await?;
438
439    Ok(rows
440        .into_iter()
441        .map(|(schema, name, base_type)| DomainInfo {
442            schema_name: schema,
443            name,
444            base_type,
445        })
446        .collect())
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
454        TableInfo {
455            schema_name: schema.to_string(),
456            name: name.to_string(),
457            columns: columns
458                .into_iter()
459                .enumerate()
460                .map(|(i, col)| ColumnInfo {
461                    name: col.to_string(),
462                    data_type: "text".to_string(),
463                    udt_name: "text".to_string(),
464                    is_nullable: true,
465                    is_primary_key: false,
466                    ordinal_position: i as i32,
467                    schema_name: schema.to_string(),
468                    column_default: None,
469                })
470                .collect(),
471        }
472    }
473
474    fn make_nullability(
475        view_schema: &str,
476        view_name: &str,
477        source_column: &str,
478        not_null: bool,
479    ) -> ViewColumnNullability {
480        ViewColumnNullability {
481            view_schema: view_schema.to_string(),
482            view_name: view_name.to_string(),
483            source_column_name: source_column.to_string(),
484            source_not_null: not_null,
485        }
486    }
487
488    #[test]
489    fn test_resolve_not_null_column() {
490        let mut views = vec![make_view("public", "my_view", vec!["id", "name"])];
491        let info = vec![
492            make_nullability("public", "my_view", "id", true),
493            make_nullability("public", "my_view", "name", true),
494        ];
495        resolve_view_nullability(&mut views, &info);
496        assert!(!views[0].columns[0].is_nullable);
497        assert!(!views[0].columns[1].is_nullable);
498    }
499
500    #[test]
501    fn test_resolve_mixed_sources() {
502        let mut views = vec![make_view("public", "my_view", vec!["id"])];
503        let info = vec![
504            make_nullability("public", "my_view", "id", true),
505            make_nullability("public", "my_view", "id", false),
506        ];
507        resolve_view_nullability(&mut views, &info);
508        assert!(views[0].columns[0].is_nullable);
509    }
510
511    #[test]
512    fn test_resolve_no_match_stays_nullable() {
513        let mut views = vec![make_view("public", "my_view", vec!["computed_col"])];
514        let info = vec![make_nullability("public", "my_view", "id", true)];
515        resolve_view_nullability(&mut views, &info);
516        assert!(views[0].columns[0].is_nullable);
517    }
518
519    #[test]
520    fn test_resolve_empty_info() {
521        let mut views = vec![make_view("public", "my_view", vec!["id"])];
522        resolve_view_nullability(&mut views, &[]);
523        assert!(views[0].columns[0].is_nullable);
524    }
525
526    #[test]
527    fn test_resolve_cross_schema() {
528        let mut views = vec![
529            make_view("public", "v1", vec!["id"]),
530            make_view("auth", "v2", vec!["id"]),
531        ];
532        let info = vec![
533            make_nullability("public", "v1", "id", true),
534            make_nullability("auth", "v2", "id", false),
535        ];
536        resolve_view_nullability(&mut views, &info);
537        assert!(!views[0].columns[0].is_nullable);
538        assert!(views[1].columns[0].is_nullable);
539    }
540
541    // --- resolve_view_primary_keys tests ---
542
543    fn make_pk_info(
544        view_schema: &str,
545        view_name: &str,
546        source_column: &str,
547        is_pk: bool,
548    ) -> ViewColumnPrimaryKey {
549        ViewColumnPrimaryKey {
550            view_schema: view_schema.to_string(),
551            view_name: view_name.to_string(),
552            source_column_name: source_column.to_string(),
553            source_is_pk: is_pk,
554        }
555    }
556
557    #[test]
558    fn test_resolve_pk_column() {
559        let mut views = vec![make_view("public", "my_view", vec!["id", "name"])];
560        let info = vec![
561            make_pk_info("public", "my_view", "id", true),
562            make_pk_info("public", "my_view", "name", false),
563        ];
564        resolve_view_primary_keys(&mut views, &info);
565        assert!(views[0].columns[0].is_primary_key);
566        assert!(!views[0].columns[1].is_primary_key);
567    }
568
569    #[test]
570    fn test_resolve_pk_mixed_sources() {
571        let mut views = vec![make_view("public", "my_view", vec!["id"])];
572        let info = vec![
573            make_pk_info("public", "my_view", "id", true),
574            make_pk_info("public", "my_view", "id", false),
575        ];
576        resolve_view_primary_keys(&mut views, &info);
577        assert!(!views[0].columns[0].is_primary_key);
578    }
579
580    #[test]
581    fn test_resolve_pk_no_match() {
582        let mut views = vec![make_view("public", "my_view", vec!["computed_col"])];
583        let info = vec![make_pk_info("public", "my_view", "id", true)];
584        resolve_view_primary_keys(&mut views, &info);
585        assert!(!views[0].columns[0].is_primary_key);
586    }
587
588    #[test]
589    fn test_resolve_pk_empty_info() {
590        let mut views = vec![make_view("public", "my_view", vec!["id"])];
591        resolve_view_primary_keys(&mut views, &[]);
592        assert!(!views[0].columns[0].is_primary_key);
593    }
594
595    #[test]
596    fn test_resolve_pk_cross_schema() {
597        let mut views = vec![
598            make_view("public", "v1", vec!["id"]),
599            make_view("auth", "v2", vec!["id"]),
600        ];
601        let info = vec![
602            make_pk_info("public", "v1", "id", true),
603            make_pk_info("auth", "v2", "id", false),
604        ];
605        resolve_view_primary_keys(&mut views, &info);
606        assert!(views[0].columns[0].is_primary_key);
607        assert!(!views[1].columns[0].is_primary_key);
608    }
609}