1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::MySqlPool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9 pool: &MySqlPool,
10 schemas: &[String],
11 include_views: bool,
12) -> Result<SchemaInfo> {
13 let tables = fetch_tables(pool, schemas).await?;
14 let mut views = if include_views {
15 fetch_views(pool, schemas).await?
16 } else {
17 Vec::new()
18 };
19
20 if !views.is_empty() {
21 let sources = fetch_view_column_sources(pool, schemas).await?;
22 resolve_view_nullability(&mut views, &sources, &tables);
23 resolve_view_primary_keys(&mut views, &sources, &tables);
24 }
25
26 let enums = extract_enums(&tables);
27
28 Ok(SchemaInfo {
29 tables,
30 views,
31 enums,
32 composite_types: Vec::new(),
33 domains: Vec::new(),
34 })
35}
36
37async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
38 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
40 let query = format!(
41 r#"
42 SELECT
43 c.TABLE_SCHEMA,
44 c.TABLE_NAME,
45 c.COLUMN_NAME,
46 c.DATA_TYPE,
47 c.COLUMN_TYPE,
48 c.IS_NULLABLE,
49 c.ORDINAL_POSITION,
50 c.COLUMN_KEY
51 FROM information_schema.COLUMNS c
52 JOIN information_schema.TABLES t
53 ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
54 AND t.TABLE_NAME = c.TABLE_NAME
55 AND t.TABLE_TYPE = 'BASE TABLE'
56 WHERE c.TABLE_SCHEMA IN ({})
57 ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
58 "#,
59 placeholders.join(",")
60 );
61
62 let mut q = sqlx::query_as::<
63 _,
64 (
65 Vec<u8>,
66 Vec<u8>,
67 Vec<u8>,
68 Vec<u8>,
69 Vec<u8>,
70 Vec<u8>,
71 u32,
72 Vec<u8>,
73 ),
74 >(&query);
75 for schema in schemas {
76 q = q.bind(schema);
77 }
78 let rows = q.fetch_all(pool).await?;
79
80 let mut tables: Vec<TableInfo> = Vec::new();
81 let mut current_key: Option<(String, String)> = None;
82
83 for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
84 let schema = utf8_field(schema, "TABLE_SCHEMA")?;
85 let table = utf8_field(table, "TABLE_NAME")?;
86 let col_name = utf8_field(col_name, "COLUMN_NAME")?;
87 let data_type = utf8_field(data_type, "DATA_TYPE")?;
88 let column_type = utf8_field(column_type, "COLUMN_TYPE")?;
89 let nullable = utf8_field(nullable, "IS_NULLABLE")?;
90 let column_key = utf8_field(column_key, "COLUMN_KEY")?;
91
92 let key = (schema.clone(), table.clone());
93 if current_key.as_ref() != Some(&key) {
94 current_key = Some(key);
95 tables.push(TableInfo {
96 schema_name: schema.clone(),
97 name: table.clone(),
98 columns: Vec::new(),
99 });
100 }
101 let last = tables.last_mut().ok_or_else(|| {
102 crate::error::Error::Config(
103 "Internal sqlx-gen bug: tables vector empty after push".to_string(),
104 )
105 })?;
106 last.columns.push(ColumnInfo {
107 name: col_name,
108 data_type,
109 udt_name: column_type,
110 udt_schema: None,
111 is_nullable: nullable == "YES",
112 is_primary_key: column_key == "PRI",
113 ordinal_position: ordinal as i32,
114 schema_name: schema,
115 column_default: None,
116 });
117 }
118
119 Ok(tables)
120}
121
122fn utf8_field(bytes: Vec<u8>, field: &str) -> Result<String> {
125 String::from_utf8(bytes).map_err(|_| {
126 crate::error::Error::Config(format!(
127 "Database returned non-UTF8 bytes for MySQL information_schema field '{}'. \
128 sqlx-gen requires UTF-8 metadata.",
129 field
130 ))
131 })
132}
133
134async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
135 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
136 let query = format!(
137 r#"
138 SELECT
139 c.TABLE_SCHEMA,
140 c.TABLE_NAME,
141 c.COLUMN_NAME,
142 c.DATA_TYPE,
143 c.COLUMN_TYPE,
144 c.IS_NULLABLE,
145 c.ORDINAL_POSITION
146 FROM information_schema.COLUMNS c
147 JOIN information_schema.TABLES t
148 ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
149 AND t.TABLE_NAME = c.TABLE_NAME
150 AND t.TABLE_TYPE = 'VIEW'
151 WHERE c.TABLE_SCHEMA IN ({})
152 ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
153 "#,
154 placeholders.join(",")
155 );
156
157 let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
158 for schema in schemas {
159 q = q.bind(schema);
160 }
161 let rows = q.fetch_all(pool).await?;
162
163 let mut views: Vec<TableInfo> = Vec::new();
164 let mut current_key: Option<(String, String)> = None;
165
166 for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
167 let key = (schema.clone(), table.clone());
168 if current_key.as_ref() != Some(&key) {
169 current_key = Some(key);
170 views.push(TableInfo {
171 schema_name: schema.clone(),
172 name: table.clone(),
173 columns: Vec::new(),
174 });
175 }
176 let last = views.last_mut().ok_or_else(|| {
177 crate::error::Error::Config(
178 "Internal sqlx-gen bug: views vector empty after push".to_string(),
179 )
180 })?;
181 last.columns.push(ColumnInfo {
182 name: col_name,
183 data_type,
184 udt_name: column_type,
185 udt_schema: None,
186 is_nullable: nullable == "YES",
187 is_primary_key: false,
188 ordinal_position: ordinal as i32,
189 schema_name: schema,
190 column_default: None,
191 });
192 }
193
194 Ok(views)
195}
196
197struct ViewColumnSource {
198 view_schema: String,
199 view_name: String,
200 table_schema: String,
201 table_name: String,
202 column_name: String,
203}
204
205async fn fetch_view_column_sources(
206 pool: &MySqlPool,
207 schemas: &[String],
208) -> Result<Vec<ViewColumnSource>> {
209 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
210 let query = format!(
211 r#"
212 SELECT
213 vcu.VIEW_SCHEMA,
214 vcu.VIEW_NAME,
215 vcu.TABLE_SCHEMA,
216 vcu.TABLE_NAME,
217 vcu.COLUMN_NAME
218 FROM INFORMATION_SCHEMA.VIEW_COLUMN_USAGE vcu
219 WHERE vcu.VIEW_SCHEMA IN ({})
220 "#,
221 placeholders.join(",")
222 );
223
224 let mut q = sqlx::query_as::<_, (String, String, String, String, String)>(&query);
225 for schema in schemas {
226 q = q.bind(schema);
227 }
228
229 match q.fetch_all(pool).await {
230 Ok(rows) => Ok(rows
231 .into_iter()
232 .map(
233 |(view_schema, view_name, table_schema, table_name, column_name)| {
234 ViewColumnSource {
235 view_schema,
236 view_name,
237 table_schema,
238 table_name,
239 column_name,
240 }
241 },
242 )
243 .collect()),
244 Err(_) => {
245 Ok(Vec::new())
247 }
248 }
249}
250
251fn resolve_view_nullability(
252 views: &mut [TableInfo],
253 sources: &[ViewColumnSource],
254 tables: &[TableInfo],
255) {
256 let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
258 for table in tables {
259 for col in &table.columns {
260 table_lookup.insert(
261 (&table.schema_name, &table.name, &col.name),
262 col.is_nullable,
263 );
264 }
265 }
266
267 let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
269 for src in sources {
270 if let Some(&is_nullable) = table_lookup.get(&(
271 src.table_schema.as_str(),
272 src.table_name.as_str(),
273 src.column_name.as_str(),
274 )) {
275 view_lookup
276 .entry((&src.view_schema, &src.view_name, &src.column_name))
277 .or_default()
278 .push(is_nullable);
279 }
280 }
281
282 for view in views.iter_mut() {
283 for col in view.columns.iter_mut() {
284 if let Some(nullable_flags) = view_lookup.get(&(
285 view.schema_name.as_str(),
286 view.name.as_str(),
287 col.name.as_str(),
288 )) {
289 if !nullable_flags.is_empty() && nullable_flags.iter().all(|&n| !n) {
291 col.is_nullable = false;
292 }
293 }
294 }
295 }
296}
297
298fn resolve_view_primary_keys(
299 views: &mut [TableInfo],
300 sources: &[ViewColumnSource],
301 tables: &[TableInfo],
302) {
303 let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
305 for table in tables {
306 for col in &table.columns {
307 table_lookup.insert(
308 (&table.schema_name, &table.name, &col.name),
309 col.is_primary_key,
310 );
311 }
312 }
313
314 let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
316 for src in sources {
317 if let Some(&is_pk) = table_lookup.get(&(
318 src.table_schema.as_str(),
319 src.table_name.as_str(),
320 src.column_name.as_str(),
321 )) {
322 view_lookup
323 .entry((&src.view_schema, &src.view_name, &src.column_name))
324 .or_default()
325 .push(is_pk);
326 }
327 }
328
329 for view in views.iter_mut() {
330 for col in view.columns.iter_mut() {
331 if let Some(pk_flags) = view_lookup.get(&(
332 view.schema_name.as_str(),
333 view.name.as_str(),
334 col.name.as_str(),
335 )) {
336 if !pk_flags.is_empty() && pk_flags.iter().all(|&pk| pk) {
338 col.is_primary_key = true;
339 }
340 }
341 }
342 }
343}
344
345fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
349 let mut enums = Vec::new();
350
351 for table in tables {
352 for col in &table.columns {
353 if col.udt_name.starts_with("enum(") {
354 let variants = parse_enum_variants(&col.udt_name);
355 if !variants.is_empty() {
356 let enum_name = format!("{}_{}", table.name, col.name);
357 enums.push(EnumInfo {
358 schema_name: table.schema_name.clone(),
359 name: enum_name,
360 variants,
361 default_variant: None,
362 });
363 }
364 }
365 }
366 }
367
368 enums
369}
370
371fn parse_enum_variants(column_type: &str) -> Vec<String> {
372 let inner = column_type
374 .strip_prefix("enum(")
375 .and_then(|s| s.strip_suffix(')'));
376 match inner {
377 Some(s) => s
378 .split(',')
379 .map(|v| v.trim().trim_matches('\'').to_string())
380 .filter(|v| !v.is_empty())
381 .collect(),
382 None => Vec::new(),
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
391 TableInfo {
392 schema_name: "test_db".to_string(),
393 name: name.to_string(),
394 columns,
395 }
396 }
397
398 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
399 ColumnInfo {
400 name: name.to_string(),
401 data_type: "varchar".to_string(),
402 udt_name: udt_name.to_string(),
403 is_nullable: false,
404 is_primary_key: false,
405 ordinal_position: 0,
406 schema_name: "test_db".to_string(),
407 udt_schema: None,
408 column_default: None,
409 }
410 }
411
412 #[test]
415 fn test_parse_simple() {
416 assert_eq!(
417 parse_enum_variants("enum('a','b','c')"),
418 vec!["a", "b", "c"]
419 );
420 }
421
422 #[test]
423 fn test_parse_single_variant() {
424 assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
425 }
426
427 #[test]
428 fn test_parse_with_spaces() {
429 assert_eq!(parse_enum_variants("enum( 'a' , 'b' )"), vec!["a", "b"]);
430 }
431
432 #[test]
433 fn test_parse_empty_parens() {
434 let result = parse_enum_variants("enum()");
435 assert!(result.is_empty());
436 }
437
438 #[test]
439 fn test_parse_varchar_not_enum() {
440 let result = parse_enum_variants("varchar(255)");
441 assert!(result.is_empty());
442 }
443
444 #[test]
445 fn test_parse_int_not_enum() {
446 let result = parse_enum_variants("int");
447 assert!(result.is_empty());
448 }
449
450 #[test]
451 fn test_parse_with_spaces_in_value() {
452 assert_eq!(
453 parse_enum_variants("enum('with space','no')"),
454 vec!["with space", "no"]
455 );
456 }
457
458 #[test]
459 fn test_parse_empty_variant_filtered() {
460 let result = parse_enum_variants("enum('a','','c')");
461 assert_eq!(result, vec!["a", "c"]);
462 }
463
464 #[test]
465 fn test_parse_uppercase_enum_not_matched() {
466 let result = parse_enum_variants("ENUM('a','b')");
468 assert!(result.is_empty());
469 }
470
471 #[test]
474 fn test_extract_from_enum_column() {
475 let tables = vec![make_table(
476 "users",
477 vec![make_col("status", "enum('active','inactive')")],
478 )];
479 let enums = extract_enums(&tables);
480 assert_eq!(enums.len(), 1);
481 assert_eq!(enums[0].variants, vec!["active", "inactive"]);
482 }
483
484 #[test]
485 fn test_extract_enum_name_format() {
486 let tables = vec![make_table("users", vec![make_col("status", "enum('a')")])];
487 let enums = extract_enums(&tables);
488 assert_eq!(enums[0].name, "users_status");
489 }
490
491 #[test]
492 fn test_extract_no_enums() {
493 let tables = vec![make_table(
494 "users",
495 vec![make_col("id", "int"), make_col("name", "varchar(255)")],
496 )];
497 let enums = extract_enums(&tables);
498 assert!(enums.is_empty());
499 }
500
501 #[test]
502 fn test_extract_two_enum_columns_same_table() {
503 let tables = vec![make_table(
504 "users",
505 vec![
506 make_col("status", "enum('active','inactive')"),
507 make_col("role", "enum('admin','user')"),
508 ],
509 )];
510 let enums = extract_enums(&tables);
511 assert_eq!(enums.len(), 2);
512 assert_eq!(enums[0].name, "users_status");
513 assert_eq!(enums[1].name, "users_role");
514 }
515
516 #[test]
517 fn test_extract_enums_from_multiple_tables() {
518 let tables = vec![
519 make_table("users", vec![make_col("status", "enum('a')")]),
520 make_table("posts", vec![make_col("state", "enum('b')")]),
521 ];
522 let enums = extract_enums(&tables);
523 assert_eq!(enums.len(), 2);
524 }
525
526 #[test]
527 fn test_extract_non_enum_column_ignored() {
528 let tables = vec![make_table(
529 "users",
530 vec![make_col("id", "int(11)"), make_col("status", "enum('a')")],
531 )];
532 let enums = extract_enums(&tables);
533 assert_eq!(enums.len(), 1);
534 }
535
536 fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
539 TableInfo {
540 schema_name: schema.to_string(),
541 name: name.to_string(),
542 columns: columns
543 .into_iter()
544 .enumerate()
545 .map(|(i, col)| ColumnInfo {
546 name: col.to_string(),
547 data_type: "varchar".to_string(),
548 udt_name: "varchar(255)".to_string(),
549 is_nullable: true,
550 is_primary_key: false,
551 ordinal_position: i as i32,
552 schema_name: schema.to_string(),
553 udt_schema: None,
554 column_default: None,
555 })
556 .collect(),
557 }
558 }
559
560 fn make_table_with_nullability(
561 schema: &str,
562 name: &str,
563 columns: Vec<(&str, bool)>,
564 ) -> TableInfo {
565 TableInfo {
566 schema_name: schema.to_string(),
567 name: name.to_string(),
568 columns: columns
569 .into_iter()
570 .enumerate()
571 .map(|(i, (col, nullable))| ColumnInfo {
572 name: col.to_string(),
573 data_type: "varchar".to_string(),
574 udt_name: "varchar(255)".to_string(),
575 is_nullable: nullable,
576 is_primary_key: false,
577 ordinal_position: i as i32,
578 schema_name: schema.to_string(),
579 udt_schema: None,
580 column_default: None,
581 })
582 .collect(),
583 }
584 }
585
586 fn make_source(
587 view_schema: &str,
588 view_name: &str,
589 table_schema: &str,
590 table_name: &str,
591 column_name: &str,
592 ) -> ViewColumnSource {
593 ViewColumnSource {
594 view_schema: view_schema.to_string(),
595 view_name: view_name.to_string(),
596 table_schema: table_schema.to_string(),
597 table_name: table_name.to_string(),
598 column_name: column_name.to_string(),
599 }
600 }
601
602 #[test]
603 fn test_resolve_not_null_column() {
604 let tables = vec![make_table_with_nullability(
605 "db",
606 "users",
607 vec![("id", false), ("name", false)],
608 )];
609 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
610 let sources = vec![
611 make_source("db", "my_view", "db", "users", "id"),
612 make_source("db", "my_view", "db", "users", "name"),
613 ];
614 resolve_view_nullability(&mut views, &sources, &tables);
615 assert!(!views[0].columns[0].is_nullable);
616 assert!(!views[0].columns[1].is_nullable);
617 }
618
619 #[test]
620 fn test_resolve_nullable_source() {
621 let tables = vec![make_table_with_nullability(
622 "db",
623 "users",
624 vec![("id", false), ("name", true)],
625 )];
626 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
627 let sources = vec![
628 make_source("db", "my_view", "db", "users", "id"),
629 make_source("db", "my_view", "db", "users", "name"),
630 ];
631 resolve_view_nullability(&mut views, &sources, &tables);
632 assert!(!views[0].columns[0].is_nullable);
633 assert!(views[0].columns[1].is_nullable);
634 }
635
636 #[test]
637 fn test_resolve_no_match_stays_nullable() {
638 let tables = vec![make_table_with_nullability(
639 "db",
640 "users",
641 vec![("id", false)],
642 )];
643 let mut views = vec![make_view("db", "my_view", vec!["computed"])];
644 let sources = vec![];
645 resolve_view_nullability(&mut views, &sources, &tables);
646 assert!(views[0].columns[0].is_nullable);
647 }
648
649 #[test]
650 fn test_resolve_empty_sources() {
651 let tables = vec![];
652 let mut views = vec![make_view("db", "my_view", vec!["id"])];
653 resolve_view_nullability(&mut views, &[], &tables);
654 assert!(views[0].columns[0].is_nullable);
655 }
656
657 fn make_table_with_pk(schema: &str, name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
660 TableInfo {
661 schema_name: schema.to_string(),
662 name: name.to_string(),
663 columns: columns
664 .into_iter()
665 .enumerate()
666 .map(|(i, (col, is_pk))| ColumnInfo {
667 name: col.to_string(),
668 data_type: "varchar".to_string(),
669 udt_name: "varchar(255)".to_string(),
670 is_nullable: false,
671 is_primary_key: is_pk,
672 ordinal_position: i as i32,
673 schema_name: schema.to_string(),
674 udt_schema: None,
675 column_default: None,
676 })
677 .collect(),
678 }
679 }
680
681 #[test]
682 fn test_resolve_pk_column() {
683 let tables = vec![make_table_with_pk(
684 "db",
685 "users",
686 vec![("id", true), ("name", false)],
687 )];
688 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
689 let sources = vec![
690 make_source("db", "my_view", "db", "users", "id"),
691 make_source("db", "my_view", "db", "users", "name"),
692 ];
693 resolve_view_primary_keys(&mut views, &sources, &tables);
694 assert!(views[0].columns[0].is_primary_key);
695 assert!(!views[0].columns[1].is_primary_key);
696 }
697
698 #[test]
699 fn test_resolve_pk_no_sources() {
700 let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
701 let mut views = vec![make_view("db", "my_view", vec!["id"])];
702 resolve_view_primary_keys(&mut views, &[], &tables);
703 assert!(!views[0].columns[0].is_primary_key);
704 }
705
706 #[test]
707 fn test_resolve_pk_no_match() {
708 let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
709 let mut views = vec![make_view("db", "my_view", vec!["computed"])];
710 let sources = vec![];
711 resolve_view_primary_keys(&mut views, &sources, &tables);
712 assert!(!views[0].columns[0].is_primary_key);
713 }
714}