1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::PgPool;
5
6use super::{ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9 pool: &PgPool,
10 schemas: &[String],
11 include_views: bool,
12) -> Result<SchemaInfo> {
13 let tables = fetch_tables(pool, schemas).await?;
14 let mut views = if include_views {
15 fetch_views(pool, schemas).await?
16 } else {
17 Vec::new()
18 };
19
20 if !views.is_empty() {
21 let nullability_info = fetch_view_column_nullability(pool, schemas).await?;
22 resolve_view_nullability(&mut views, &nullability_info);
23
24 let pk_info = fetch_view_column_primary_keys(pool, schemas).await?;
25 resolve_view_primary_keys(&mut views, &pk_info);
26 }
27
28 let enums = fetch_enums(pool, schemas).await?;
29 let composite_types = fetch_composite_types(pool, schemas).await?;
30 let domains = fetch_domains(pool, schemas).await?;
31
32 Ok(SchemaInfo {
33 tables,
34 views,
35 enums,
36 composite_types,
37 domains,
38 })
39}
40
41async fn fetch_tables(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
42 let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, bool, Option<String>)>(
43 r#"
44 SELECT
45 c.table_schema,
46 c.table_name,
47 c.column_name,
48 c.data_type,
49 COALESCE(c.udt_name, c.data_type) as udt_name,
50 c.is_nullable,
51 c.ordinal_position,
52 CASE WHEN kcu.column_name IS NOT NULL THEN true ELSE false END AS is_primary_key,
53 c.column_default
54 FROM information_schema.columns c
55 JOIN information_schema.tables t
56 ON t.table_schema = c.table_schema
57 AND t.table_name = c.table_name
58 AND t.table_type = 'BASE TABLE'
59 LEFT JOIN information_schema.table_constraints tc
60 ON tc.table_schema = c.table_schema
61 AND tc.table_name = c.table_name
62 AND tc.constraint_type = 'PRIMARY KEY'
63 LEFT JOIN information_schema.key_column_usage kcu
64 ON kcu.constraint_name = tc.constraint_name
65 AND kcu.constraint_schema = tc.constraint_schema
66 AND kcu.column_name = c.column_name
67 WHERE c.table_schema = ANY($1)
68 ORDER BY c.table_schema, c.table_name, c.ordinal_position
69 "#,
70 )
71 .bind(schemas)
72 .fetch_all(pool)
73 .await?;
74
75 let mut tables: Vec<TableInfo> = Vec::new();
76 let mut current_key: Option<(String, String)> = None;
77
78 for (schema, table, col_name, data_type, udt_name, nullable, ordinal, is_pk, column_default) in rows {
79 let key = (schema.clone(), table.clone());
80 if current_key.as_ref() != Some(&key) {
81 current_key = Some(key);
82 tables.push(TableInfo {
83 schema_name: schema.clone(),
84 name: table.clone(),
85 columns: Vec::new(),
86 });
87 }
88 tables.last_mut().unwrap().columns.push(ColumnInfo {
89 name: col_name,
90 data_type,
91 udt_name,
92 is_nullable: nullable == "YES",
93 is_primary_key: is_pk,
94 ordinal_position: ordinal,
95 schema_name: schema,
96 column_default,
97 });
98 }
99
100 Ok(tables)
101}
102
103async fn fetch_views(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
104 let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, Option<String>)>(
105 r#"
106 SELECT
107 c.table_schema,
108 c.table_name,
109 c.column_name,
110 c.data_type,
111 COALESCE(c.udt_name, c.data_type) as udt_name,
112 c.is_nullable,
113 c.ordinal_position,
114 c.column_default
115 FROM information_schema.columns c
116 JOIN information_schema.tables t
117 ON t.table_schema = c.table_schema
118 AND t.table_name = c.table_name
119 AND t.table_type = 'VIEW'
120 WHERE c.table_schema = ANY($1)
121 ORDER BY c.table_schema, c.table_name, c.ordinal_position
122 "#,
123 )
124 .bind(schemas)
125 .fetch_all(pool)
126 .await?;
127
128 let mut views: Vec<TableInfo> = Vec::new();
129 let mut current_key: Option<(String, String)> = None;
130
131 for (schema, table, col_name, data_type, udt_name, nullable, ordinal, column_default) in rows {
132 let key = (schema.clone(), table.clone());
133 if current_key.as_ref() != Some(&key) {
134 current_key = Some(key);
135 views.push(TableInfo {
136 schema_name: schema.clone(),
137 name: table.clone(),
138 columns: Vec::new(),
139 });
140 }
141 views.last_mut().unwrap().columns.push(ColumnInfo {
142 name: col_name,
143 data_type,
144 udt_name,
145 is_nullable: nullable == "YES",
146 is_primary_key: false,
147 ordinal_position: ordinal,
148 schema_name: schema,
149 column_default,
150 });
151 }
152
153 Ok(views)
154}
155
156struct ViewColumnNullability {
157 view_schema: String,
158 view_name: String,
159 source_column_name: String,
160 source_not_null: bool,
161}
162
163async fn fetch_view_column_nullability(
164 pool: &PgPool,
165 schemas: &[String],
166) -> Result<Vec<ViewColumnNullability>> {
167 let rows = sqlx::query_as::<_, (String, String, String, bool)>(
168 r#"
169 SELECT DISTINCT
170 v_ns.nspname AS view_schema,
171 v.relname AS view_name,
172 src_attr.attname AS source_column_name,
173 src_attr.attnotnull AS source_not_null
174 FROM pg_class v
175 JOIN pg_namespace v_ns ON v_ns.oid = v.relnamespace
176 JOIN pg_rewrite rw ON rw.ev_class = v.oid
177 JOIN pg_depend d ON d.objid = rw.oid
178 AND d.classid = 'pg_rewrite'::regclass
179 AND d.refobjsubid > 0
180 AND d.deptype = 'n'
181 JOIN pg_attribute src_attr ON src_attr.attrelid = d.refobjid
182 AND src_attr.attnum = d.refobjsubid
183 AND NOT src_attr.attisdropped
184 WHERE v_ns.nspname = ANY($1)
185 AND v.relkind = 'v'
186 "#,
187 )
188 .bind(schemas)
189 .fetch_all(pool)
190 .await?;
191
192 Ok(rows
193 .into_iter()
194 .map(
195 |(view_schema, view_name, source_column_name, source_not_null)| {
196 ViewColumnNullability {
197 view_schema,
198 view_name,
199 source_column_name,
200 source_not_null,
201 }
202 },
203 )
204 .collect())
205}
206
207fn resolve_view_nullability(
208 views: &mut [TableInfo],
209 nullability_info: &[ViewColumnNullability],
210) {
211 let mut lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
213 for info in nullability_info {
214 lookup
215 .entry((&info.view_schema, &info.view_name, &info.source_column_name))
216 .or_default()
217 .push(info.source_not_null);
218 }
219
220 for view in views.iter_mut() {
221 for col in view.columns.iter_mut() {
222 if let Some(not_null_flags) = lookup.get(&(
223 view.schema_name.as_str(),
224 view.name.as_str(),
225 col.name.as_str(),
226 )) {
227 if !not_null_flags.is_empty() && not_null_flags.iter().all(|&nn| nn) {
229 col.is_nullable = false;
230 }
231 }
232 }
233 }
234}
235
236struct ViewColumnPrimaryKey {
237 view_schema: String,
238 view_name: String,
239 source_column_name: String,
240 source_is_pk: bool,
241}
242
243async fn fetch_view_column_primary_keys(
244 pool: &PgPool,
245 schemas: &[String],
246) -> Result<Vec<ViewColumnPrimaryKey>> {
247 let rows = sqlx::query_as::<_, (String, String, String, bool)>(
248 r#"
249 SELECT DISTINCT
250 v_ns.nspname AS view_schema,
251 v.relname AS view_name,
252 src_attr.attname AS source_column_name,
253 COALESCE(
254 EXISTS (
255 SELECT 1
256 FROM pg_constraint con
257 WHERE con.conrelid = src_attr.attrelid
258 AND con.contype = 'p'
259 AND src_attr.attnum = ANY(con.conkey)
260 ),
261 false
262 ) AS source_is_pk
263 FROM pg_class v
264 JOIN pg_namespace v_ns ON v_ns.oid = v.relnamespace
265 JOIN pg_rewrite rw ON rw.ev_class = v.oid
266 JOIN pg_depend d ON d.objid = rw.oid
267 AND d.classid = 'pg_rewrite'::regclass
268 AND d.refobjsubid > 0
269 AND d.deptype = 'n'
270 JOIN pg_attribute src_attr ON src_attr.attrelid = d.refobjid
271 AND src_attr.attnum = d.refobjsubid
272 AND NOT src_attr.attisdropped
273 WHERE v_ns.nspname = ANY($1)
274 AND v.relkind = 'v'
275 "#,
276 )
277 .bind(schemas)
278 .fetch_all(pool)
279 .await?;
280
281 Ok(rows
282 .into_iter()
283 .map(
284 |(view_schema, view_name, source_column_name, source_is_pk)| ViewColumnPrimaryKey {
285 view_schema,
286 view_name,
287 source_column_name,
288 source_is_pk,
289 },
290 )
291 .collect())
292}
293
294fn resolve_view_primary_keys(
295 views: &mut [TableInfo],
296 pk_info: &[ViewColumnPrimaryKey],
297) {
298 let mut lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
300 for info in pk_info {
301 lookup
302 .entry((&info.view_schema, &info.view_name, &info.source_column_name))
303 .or_default()
304 .push(info.source_is_pk);
305 }
306
307 for view in views.iter_mut() {
308 for col in view.columns.iter_mut() {
309 if let Some(pk_flags) = lookup.get(&(
310 view.schema_name.as_str(),
311 view.name.as_str(),
312 col.name.as_str(),
313 )) {
314 if !pk_flags.is_empty() && pk_flags.iter().all(|&pk| pk) {
316 col.is_primary_key = true;
317 }
318 }
319 }
320 }
321}
322
323async fn fetch_enums(pool: &PgPool, schemas: &[String]) -> Result<Vec<EnumInfo>> {
324 let rows = sqlx::query_as::<_, (String, String, String)>(
325 r#"
326 SELECT
327 n.nspname AS schema_name,
328 t.typname AS enum_name,
329 e.enumlabel AS variant
330 FROM pg_catalog.pg_type t
331 JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid
332 JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
333 WHERE n.nspname = ANY($1)
334 ORDER BY n.nspname, t.typname, e.enumsortorder
335 "#,
336 )
337 .bind(schemas)
338 .fetch_all(pool)
339 .await?;
340
341 let mut enums: Vec<EnumInfo> = Vec::new();
342 let mut current_key: Option<(String, String)> = None;
343
344 for (schema, name, variant) in rows {
345 let key = (schema.clone(), name.clone());
346 if current_key.as_ref() != Some(&key) {
347 current_key = Some(key);
348 enums.push(EnumInfo {
349 schema_name: schema,
350 name,
351 variants: Vec::new(),
352 default_variant: None,
353 });
354 }
355 enums.last_mut().unwrap().variants.push(variant);
356 }
357
358 Ok(enums)
359}
360
361async fn fetch_composite_types(
362 pool: &PgPool,
363 schemas: &[String],
364) -> Result<Vec<CompositeTypeInfo>> {
365 let rows = sqlx::query_as::<_, (String, String, String, String, String, i32)>(
366 r#"
367 SELECT
368 n.nspname AS schema_name,
369 t.typname AS type_name,
370 a.attname AS field_name,
371 COALESCE(ft.typname, '') AS field_type,
372 CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
373 a.attnum AS ordinal
374 FROM pg_catalog.pg_type t
375 JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
376 JOIN pg_catalog.pg_class c ON c.oid = t.typrelid
377 JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid AND a.attnum > 0 AND NOT a.attisdropped
378 JOIN pg_catalog.pg_type ft ON ft.oid = a.atttypid
379 WHERE t.typtype = 'c'
380 AND n.nspname = ANY($1)
381 AND NOT EXISTS (
382 SELECT 1 FROM information_schema.tables it
383 WHERE it.table_schema = n.nspname AND it.table_name = t.typname
384 )
385 ORDER BY n.nspname, t.typname, a.attnum
386 "#,
387 )
388 .bind(schemas)
389 .fetch_all(pool)
390 .await?;
391
392 let mut composites: Vec<CompositeTypeInfo> = Vec::new();
393 let mut current_key: Option<(String, String)> = None;
394
395 for (schema, type_name, field_name, field_type, nullable, ordinal) in rows {
396 let key = (schema.clone(), type_name.clone());
397 if current_key.as_ref() != Some(&key) {
398 current_key = Some(key);
399 composites.push(CompositeTypeInfo {
400 schema_name: schema.clone(),
401 name: type_name,
402 fields: Vec::new(),
403 });
404 }
405 composites.last_mut().unwrap().fields.push(ColumnInfo {
406 name: field_name,
407 data_type: field_type.clone(),
408 udt_name: field_type,
409 is_nullable: nullable == "YES",
410 is_primary_key: false,
411 ordinal_position: ordinal,
412 schema_name: schema,
413 column_default: None,
414 });
415 }
416
417 Ok(composites)
418}
419
420async fn fetch_domains(pool: &PgPool, schemas: &[String]) -> Result<Vec<DomainInfo>> {
421 let rows = sqlx::query_as::<_, (String, String, String)>(
422 r#"
423 SELECT
424 n.nspname AS schema_name,
425 t.typname AS domain_name,
426 bt.typname AS base_type
427 FROM pg_catalog.pg_type t
428 JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
429 JOIN pg_catalog.pg_type bt ON bt.oid = t.typbasetype
430 WHERE t.typtype = 'd'
431 AND n.nspname = ANY($1)
432 ORDER BY n.nspname, t.typname
433 "#,
434 )
435 .bind(schemas)
436 .fetch_all(pool)
437 .await?;
438
439 Ok(rows
440 .into_iter()
441 .map(|(schema, name, base_type)| DomainInfo {
442 schema_name: schema,
443 name,
444 base_type,
445 })
446 .collect())
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
454 TableInfo {
455 schema_name: schema.to_string(),
456 name: name.to_string(),
457 columns: columns
458 .into_iter()
459 .enumerate()
460 .map(|(i, col)| ColumnInfo {
461 name: col.to_string(),
462 data_type: "text".to_string(),
463 udt_name: "text".to_string(),
464 is_nullable: true,
465 is_primary_key: false,
466 ordinal_position: i as i32,
467 schema_name: schema.to_string(),
468 column_default: None,
469 })
470 .collect(),
471 }
472 }
473
474 fn make_nullability(
475 view_schema: &str,
476 view_name: &str,
477 source_column: &str,
478 not_null: bool,
479 ) -> ViewColumnNullability {
480 ViewColumnNullability {
481 view_schema: view_schema.to_string(),
482 view_name: view_name.to_string(),
483 source_column_name: source_column.to_string(),
484 source_not_null: not_null,
485 }
486 }
487
488 #[test]
489 fn test_resolve_not_null_column() {
490 let mut views = vec![make_view("public", "my_view", vec!["id", "name"])];
491 let info = vec![
492 make_nullability("public", "my_view", "id", true),
493 make_nullability("public", "my_view", "name", true),
494 ];
495 resolve_view_nullability(&mut views, &info);
496 assert!(!views[0].columns[0].is_nullable);
497 assert!(!views[0].columns[1].is_nullable);
498 }
499
500 #[test]
501 fn test_resolve_mixed_sources() {
502 let mut views = vec![make_view("public", "my_view", vec!["id"])];
503 let info = vec![
504 make_nullability("public", "my_view", "id", true),
505 make_nullability("public", "my_view", "id", false),
506 ];
507 resolve_view_nullability(&mut views, &info);
508 assert!(views[0].columns[0].is_nullable);
509 }
510
511 #[test]
512 fn test_resolve_no_match_stays_nullable() {
513 let mut views = vec![make_view("public", "my_view", vec!["computed_col"])];
514 let info = vec![make_nullability("public", "my_view", "id", true)];
515 resolve_view_nullability(&mut views, &info);
516 assert!(views[0].columns[0].is_nullable);
517 }
518
519 #[test]
520 fn test_resolve_empty_info() {
521 let mut views = vec![make_view("public", "my_view", vec!["id"])];
522 resolve_view_nullability(&mut views, &[]);
523 assert!(views[0].columns[0].is_nullable);
524 }
525
526 #[test]
527 fn test_resolve_cross_schema() {
528 let mut views = vec![
529 make_view("public", "v1", vec!["id"]),
530 make_view("auth", "v2", vec!["id"]),
531 ];
532 let info = vec![
533 make_nullability("public", "v1", "id", true),
534 make_nullability("auth", "v2", "id", false),
535 ];
536 resolve_view_nullability(&mut views, &info);
537 assert!(!views[0].columns[0].is_nullable);
538 assert!(views[1].columns[0].is_nullable);
539 }
540
541 fn make_pk_info(
544 view_schema: &str,
545 view_name: &str,
546 source_column: &str,
547 is_pk: bool,
548 ) -> ViewColumnPrimaryKey {
549 ViewColumnPrimaryKey {
550 view_schema: view_schema.to_string(),
551 view_name: view_name.to_string(),
552 source_column_name: source_column.to_string(),
553 source_is_pk: is_pk,
554 }
555 }
556
557 #[test]
558 fn test_resolve_pk_column() {
559 let mut views = vec![make_view("public", "my_view", vec!["id", "name"])];
560 let info = vec![
561 make_pk_info("public", "my_view", "id", true),
562 make_pk_info("public", "my_view", "name", false),
563 ];
564 resolve_view_primary_keys(&mut views, &info);
565 assert!(views[0].columns[0].is_primary_key);
566 assert!(!views[0].columns[1].is_primary_key);
567 }
568
569 #[test]
570 fn test_resolve_pk_mixed_sources() {
571 let mut views = vec![make_view("public", "my_view", vec!["id"])];
572 let info = vec![
573 make_pk_info("public", "my_view", "id", true),
574 make_pk_info("public", "my_view", "id", false),
575 ];
576 resolve_view_primary_keys(&mut views, &info);
577 assert!(!views[0].columns[0].is_primary_key);
578 }
579
580 #[test]
581 fn test_resolve_pk_no_match() {
582 let mut views = vec![make_view("public", "my_view", vec!["computed_col"])];
583 let info = vec![make_pk_info("public", "my_view", "id", true)];
584 resolve_view_primary_keys(&mut views, &info);
585 assert!(!views[0].columns[0].is_primary_key);
586 }
587
588 #[test]
589 fn test_resolve_pk_empty_info() {
590 let mut views = vec![make_view("public", "my_view", vec!["id"])];
591 resolve_view_primary_keys(&mut views, &[]);
592 assert!(!views[0].columns[0].is_primary_key);
593 }
594
595 #[test]
596 fn test_resolve_pk_cross_schema() {
597 let mut views = vec![
598 make_view("public", "v1", vec!["id"]),
599 make_view("auth", "v2", vec!["id"]),
600 ];
601 let info = vec![
602 make_pk_info("public", "v1", "id", true),
603 make_pk_info("auth", "v2", "id", false),
604 ];
605 resolve_view_primary_keys(&mut views, &info);
606 assert!(views[0].columns[0].is_primary_key);
607 assert!(!views[1].columns[0].is_primary_key);
608 }
609}