Skip to main content

sqlx_gen/introspect/
mysql.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::MySqlPool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9    pool: &MySqlPool,
10    schemas: &[String],
11    include_views: bool,
12) -> Result<SchemaInfo> {
13    let tables = fetch_tables(pool, schemas).await?;
14    let mut views = if include_views {
15        fetch_views(pool, schemas).await?
16    } else {
17        Vec::new()
18    };
19
20    if !views.is_empty() {
21        let sources = fetch_view_column_sources(pool, schemas).await?;
22        resolve_view_nullability(&mut views, &sources, &tables);
23        resolve_view_primary_keys(&mut views, &sources, &tables);
24    }
25
26    let enums = extract_enums(&tables);
27
28    Ok(SchemaInfo {
29        tables,
30        views,
31        enums,
32        composite_types: Vec::new(),
33        domains: Vec::new(),
34    })
35}
36
37async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
38    // MySQL doesn't support binding arrays directly, so we build placeholders
39    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
40    let query = format!(
41        r#"
42        SELECT
43            c.TABLE_SCHEMA,
44            c.TABLE_NAME,
45            c.COLUMN_NAME,
46            c.DATA_TYPE,
47            c.COLUMN_TYPE,
48            c.IS_NULLABLE,
49            c.ORDINAL_POSITION,
50            c.COLUMN_KEY
51        FROM information_schema.COLUMNS c
52        JOIN information_schema.TABLES t
53            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
54            AND t.TABLE_NAME = c.TABLE_NAME
55            AND t.TABLE_TYPE = 'BASE TABLE'
56        WHERE c.TABLE_SCHEMA IN ({})
57        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
58        "#,
59        placeholders.join(",")
60    );
61
62    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32, String)>(&query);
63    for schema in schemas {
64        q = q.bind(schema);
65    }
66    let rows = q.fetch_all(pool).await?;
67
68    let mut tables: Vec<TableInfo> = Vec::new();
69    let mut current_key: Option<(String, String)> = None;
70
71    for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
72        let key = (schema.clone(), table.clone());
73        if current_key.as_ref() != Some(&key) {
74            current_key = Some(key);
75            tables.push(TableInfo {
76                schema_name: schema.clone(),
77                name: table.clone(),
78                columns: Vec::new(),
79            });
80        }
81        tables.last_mut().unwrap().columns.push(ColumnInfo {
82            name: col_name,
83            data_type,
84            udt_name: column_type,
85            is_nullable: nullable == "YES",
86            is_primary_key: column_key == "PRI",
87            ordinal_position: ordinal as i32,
88            schema_name: schema,
89            column_default: None,
90        });
91    }
92
93    Ok(tables)
94}
95
96async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
97    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
98    let query = format!(
99        r#"
100        SELECT
101            c.TABLE_SCHEMA,
102            c.TABLE_NAME,
103            c.COLUMN_NAME,
104            c.DATA_TYPE,
105            c.COLUMN_TYPE,
106            c.IS_NULLABLE,
107            c.ORDINAL_POSITION
108        FROM information_schema.COLUMNS c
109        JOIN information_schema.TABLES t
110            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
111            AND t.TABLE_NAME = c.TABLE_NAME
112            AND t.TABLE_TYPE = 'VIEW'
113        WHERE c.TABLE_SCHEMA IN ({})
114        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
115        "#,
116        placeholders.join(",")
117    );
118
119    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
120    for schema in schemas {
121        q = q.bind(schema);
122    }
123    let rows = q.fetch_all(pool).await?;
124
125    let mut views: Vec<TableInfo> = Vec::new();
126    let mut current_key: Option<(String, String)> = None;
127
128    for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
129        let key = (schema.clone(), table.clone());
130        if current_key.as_ref() != Some(&key) {
131            current_key = Some(key);
132            views.push(TableInfo {
133                schema_name: schema.clone(),
134                name: table.clone(),
135                columns: Vec::new(),
136            });
137        }
138        views.last_mut().unwrap().columns.push(ColumnInfo {
139            name: col_name,
140            data_type,
141            udt_name: column_type,
142            is_nullable: nullable == "YES",
143            is_primary_key: false,
144            ordinal_position: ordinal as i32,
145            schema_name: schema,
146            column_default: None,
147        });
148    }
149
150    Ok(views)
151}
152
153struct ViewColumnSource {
154    view_schema: String,
155    view_name: String,
156    table_schema: String,
157    table_name: String,
158    column_name: String,
159}
160
161async fn fetch_view_column_sources(
162    pool: &MySqlPool,
163    schemas: &[String],
164) -> Result<Vec<ViewColumnSource>> {
165    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
166    let query = format!(
167        r#"
168        SELECT
169            vcu.VIEW_SCHEMA,
170            vcu.VIEW_NAME,
171            vcu.TABLE_SCHEMA,
172            vcu.TABLE_NAME,
173            vcu.COLUMN_NAME
174        FROM INFORMATION_SCHEMA.VIEW_COLUMN_USAGE vcu
175        WHERE vcu.VIEW_SCHEMA IN ({})
176        "#,
177        placeholders.join(",")
178    );
179
180    let mut q = sqlx::query_as::<_, (String, String, String, String, String)>(&query);
181    for schema in schemas {
182        q = q.bind(schema);
183    }
184
185    match q.fetch_all(pool).await {
186        Ok(rows) => Ok(rows
187            .into_iter()
188            .map(
189                |(view_schema, view_name, table_schema, table_name, column_name)| {
190                    ViewColumnSource {
191                        view_schema,
192                        view_name,
193                        table_schema,
194                        table_name,
195                        column_name,
196                    }
197                },
198            )
199            .collect()),
200        Err(_) => {
201            // VIEW_COLUMN_USAGE may not exist on older MySQL versions
202            Ok(Vec::new())
203        }
204    }
205}
206
207fn resolve_view_nullability(
208    views: &mut [TableInfo],
209    sources: &[ViewColumnSource],
210    tables: &[TableInfo],
211) {
212    // Build table column lookup: (schema, table, column) -> is_nullable
213    let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
214    for table in tables {
215        for col in &table.columns {
216            table_lookup.insert(
217                (&table.schema_name, &table.name, &col.name),
218                col.is_nullable,
219            );
220        }
221    }
222
223    // Build view column source lookup: (view_schema, view_name, column_name) -> Vec<is_nullable>
224    let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
225    for src in sources {
226        if let Some(&is_nullable) =
227            table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
228        {
229            view_lookup
230                .entry((&src.view_schema, &src.view_name, &src.column_name))
231                .or_default()
232                .push(is_nullable);
233        }
234    }
235
236    for view in views.iter_mut() {
237        for col in view.columns.iter_mut() {
238            if let Some(nullable_flags) = view_lookup.get(&(
239                view.schema_name.as_str(),
240                view.name.as_str(),
241                col.name.as_str(),
242            )) {
243                // Only mark as non-nullable if ALL sources are NOT nullable
244                if !nullable_flags.is_empty() && nullable_flags.iter().all(|&n| !n) {
245                    col.is_nullable = false;
246                }
247            }
248        }
249    }
250}
251
252fn resolve_view_primary_keys(
253    views: &mut [TableInfo],
254    sources: &[ViewColumnSource],
255    tables: &[TableInfo],
256) {
257    // Build table column lookup: (schema, table, column) -> is_primary_key
258    let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
259    for table in tables {
260        for col in &table.columns {
261            table_lookup.insert(
262                (&table.schema_name, &table.name, &col.name),
263                col.is_primary_key,
264            );
265        }
266    }
267
268    // Build view column source lookup: (view_schema, view_name, column_name) -> Vec<is_pk>
269    let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
270    for src in sources {
271        if let Some(&is_pk) =
272            table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
273        {
274            view_lookup
275                .entry((&src.view_schema, &src.view_name, &src.column_name))
276                .or_default()
277                .push(is_pk);
278        }
279    }
280
281    for view in views.iter_mut() {
282        for col in view.columns.iter_mut() {
283            if let Some(pk_flags) = view_lookup.get(&(
284                view.schema_name.as_str(),
285                view.name.as_str(),
286                col.name.as_str(),
287            )) {
288                // Only mark as PK if ALL sources are PKs
289                if !pk_flags.is_empty() && pk_flags.iter().all(|&pk| pk) {
290                    col.is_primary_key = true;
291                }
292            }
293        }
294    }
295}
296
297/// Extract inline ENUMs from column types.
298/// MySQL ENUM('a','b','c') in COLUMN_TYPE gets extracted to an EnumInfo
299/// keyed by table_name + column_name.
300fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
301    let mut enums = Vec::new();
302
303    for table in tables {
304        for col in &table.columns {
305            if col.udt_name.starts_with("enum(") {
306                let variants = parse_enum_variants(&col.udt_name);
307                if !variants.is_empty() {
308                    let enum_name = format!("{}_{}", table.name, col.name);
309                    enums.push(EnumInfo {
310                        schema_name: table.schema_name.clone(),
311                        name: enum_name,
312                        variants,
313                        default_variant: None,
314                    });
315                }
316            }
317        }
318    }
319
320    enums
321}
322
323fn parse_enum_variants(column_type: &str) -> Vec<String> {
324    // Parse "enum('a','b','c')" → ["a", "b", "c"]
325    let inner = column_type
326        .strip_prefix("enum(")
327        .and_then(|s| s.strip_suffix(')'));
328    match inner {
329        Some(s) => s
330            .split(',')
331            .map(|v| v.trim().trim_matches('\'').to_string())
332            .filter(|v| !v.is_empty())
333            .collect(),
334        None => Vec::new(),
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
343        TableInfo {
344            schema_name: "test_db".to_string(),
345            name: name.to_string(),
346            columns,
347        }
348    }
349
350    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
351        ColumnInfo {
352            name: name.to_string(),
353            data_type: "varchar".to_string(),
354            udt_name: udt_name.to_string(),
355            is_nullable: false,
356            is_primary_key: false,
357            ordinal_position: 0,
358            schema_name: "test_db".to_string(),
359            column_default: None,
360        }
361    }
362
363    // ========== parse_enum_variants ==========
364
365    #[test]
366    fn test_parse_simple() {
367        assert_eq!(
368            parse_enum_variants("enum('a','b','c')"),
369            vec!["a", "b", "c"]
370        );
371    }
372
373    #[test]
374    fn test_parse_single_variant() {
375        assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
376    }
377
378    #[test]
379    fn test_parse_with_spaces() {
380        assert_eq!(
381            parse_enum_variants("enum( 'a' , 'b' )"),
382            vec!["a", "b"]
383        );
384    }
385
386    #[test]
387    fn test_parse_empty_parens() {
388        let result = parse_enum_variants("enum()");
389        assert!(result.is_empty());
390    }
391
392    #[test]
393    fn test_parse_varchar_not_enum() {
394        let result = parse_enum_variants("varchar(255)");
395        assert!(result.is_empty());
396    }
397
398    #[test]
399    fn test_parse_int_not_enum() {
400        let result = parse_enum_variants("int");
401        assert!(result.is_empty());
402    }
403
404    #[test]
405    fn test_parse_with_spaces_in_value() {
406        assert_eq!(
407            parse_enum_variants("enum('with space','no')"),
408            vec!["with space", "no"]
409        );
410    }
411
412    #[test]
413    fn test_parse_empty_variant_filtered() {
414        let result = parse_enum_variants("enum('a','','c')");
415        assert_eq!(result, vec!["a", "c"]);
416    }
417
418    #[test]
419    fn test_parse_uppercase_enum_not_matched() {
420        // "ENUM(" doesn't match "enum(" prefix
421        let result = parse_enum_variants("ENUM('a','b')");
422        assert!(result.is_empty());
423    }
424
425    // ========== extract_enums ==========
426
427    #[test]
428    fn test_extract_from_enum_column() {
429        let tables = vec![make_table(
430            "users",
431            vec![make_col("status", "enum('active','inactive')")],
432        )];
433        let enums = extract_enums(&tables);
434        assert_eq!(enums.len(), 1);
435        assert_eq!(enums[0].variants, vec!["active", "inactive"]);
436    }
437
438    #[test]
439    fn test_extract_enum_name_format() {
440        let tables = vec![make_table(
441            "users",
442            vec![make_col("status", "enum('a')")],
443        )];
444        let enums = extract_enums(&tables);
445        assert_eq!(enums[0].name, "users_status");
446    }
447
448    #[test]
449    fn test_extract_no_enums() {
450        let tables = vec![make_table(
451            "users",
452            vec![make_col("id", "int"), make_col("name", "varchar(255)")],
453        )];
454        let enums = extract_enums(&tables);
455        assert!(enums.is_empty());
456    }
457
458    #[test]
459    fn test_extract_two_enum_columns_same_table() {
460        let tables = vec![make_table(
461            "users",
462            vec![
463                make_col("status", "enum('active','inactive')"),
464                make_col("role", "enum('admin','user')"),
465            ],
466        )];
467        let enums = extract_enums(&tables);
468        assert_eq!(enums.len(), 2);
469        assert_eq!(enums[0].name, "users_status");
470        assert_eq!(enums[1].name, "users_role");
471    }
472
473    #[test]
474    fn test_extract_enums_from_multiple_tables() {
475        let tables = vec![
476            make_table("users", vec![make_col("status", "enum('a')")]),
477            make_table("posts", vec![make_col("state", "enum('b')")]),
478        ];
479        let enums = extract_enums(&tables);
480        assert_eq!(enums.len(), 2);
481    }
482
483    #[test]
484    fn test_extract_non_enum_column_ignored() {
485        let tables = vec![make_table(
486            "users",
487            vec![
488                make_col("id", "int(11)"),
489                make_col("status", "enum('a')"),
490            ],
491        )];
492        let enums = extract_enums(&tables);
493        assert_eq!(enums.len(), 1);
494    }
495
496    // ========== resolve_view_nullability ==========
497
498    fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
499        TableInfo {
500            schema_name: schema.to_string(),
501            name: name.to_string(),
502            columns: columns
503                .into_iter()
504                .enumerate()
505                .map(|(i, col)| ColumnInfo {
506                    name: col.to_string(),
507                    data_type: "varchar".to_string(),
508                    udt_name: "varchar(255)".to_string(),
509                    is_nullable: true,
510                    is_primary_key: false,
511                    ordinal_position: i as i32,
512                    schema_name: schema.to_string(),
513                    column_default: None,
514                })
515                .collect(),
516        }
517    }
518
519    fn make_table_with_nullability(
520        schema: &str,
521        name: &str,
522        columns: Vec<(&str, bool)>,
523    ) -> TableInfo {
524        TableInfo {
525            schema_name: schema.to_string(),
526            name: name.to_string(),
527            columns: columns
528                .into_iter()
529                .enumerate()
530                .map(|(i, (col, nullable))| ColumnInfo {
531                    name: col.to_string(),
532                    data_type: "varchar".to_string(),
533                    udt_name: "varchar(255)".to_string(),
534                    is_nullable: nullable,
535                    is_primary_key: false,
536                    ordinal_position: i as i32,
537                    schema_name: schema.to_string(),
538                    column_default: None,
539                })
540                .collect(),
541        }
542    }
543
544    fn make_source(
545        view_schema: &str,
546        view_name: &str,
547        table_schema: &str,
548        table_name: &str,
549        column_name: &str,
550    ) -> ViewColumnSource {
551        ViewColumnSource {
552            view_schema: view_schema.to_string(),
553            view_name: view_name.to_string(),
554            table_schema: table_schema.to_string(),
555            table_name: table_name.to_string(),
556            column_name: column_name.to_string(),
557        }
558    }
559
560    #[test]
561    fn test_resolve_not_null_column() {
562        let tables = vec![make_table_with_nullability(
563            "db",
564            "users",
565            vec![("id", false), ("name", false)],
566        )];
567        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
568        let sources = vec![
569            make_source("db", "my_view", "db", "users", "id"),
570            make_source("db", "my_view", "db", "users", "name"),
571        ];
572        resolve_view_nullability(&mut views, &sources, &tables);
573        assert!(!views[0].columns[0].is_nullable);
574        assert!(!views[0].columns[1].is_nullable);
575    }
576
577    #[test]
578    fn test_resolve_nullable_source() {
579        let tables = vec![make_table_with_nullability(
580            "db",
581            "users",
582            vec![("id", false), ("name", true)],
583        )];
584        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
585        let sources = vec![
586            make_source("db", "my_view", "db", "users", "id"),
587            make_source("db", "my_view", "db", "users", "name"),
588        ];
589        resolve_view_nullability(&mut views, &sources, &tables);
590        assert!(!views[0].columns[0].is_nullable);
591        assert!(views[0].columns[1].is_nullable);
592    }
593
594    #[test]
595    fn test_resolve_no_match_stays_nullable() {
596        let tables = vec![make_table_with_nullability(
597            "db",
598            "users",
599            vec![("id", false)],
600        )];
601        let mut views = vec![make_view("db", "my_view", vec!["computed"])];
602        let sources = vec![];
603        resolve_view_nullability(&mut views, &sources, &tables);
604        assert!(views[0].columns[0].is_nullable);
605    }
606
607    #[test]
608    fn test_resolve_empty_sources() {
609        let tables = vec![];
610        let mut views = vec![make_view("db", "my_view", vec!["id"])];
611        resolve_view_nullability(&mut views, &[], &tables);
612        assert!(views[0].columns[0].is_nullable);
613    }
614
615    // ========== resolve_view_primary_keys ==========
616
617    fn make_table_with_pk(
618        schema: &str,
619        name: &str,
620        columns: Vec<(&str, bool)>,
621    ) -> TableInfo {
622        TableInfo {
623            schema_name: schema.to_string(),
624            name: name.to_string(),
625            columns: columns
626                .into_iter()
627                .enumerate()
628                .map(|(i, (col, is_pk))| ColumnInfo {
629                    name: col.to_string(),
630                    data_type: "varchar".to_string(),
631                    udt_name: "varchar(255)".to_string(),
632                    is_nullable: false,
633                    is_primary_key: is_pk,
634                    ordinal_position: i as i32,
635                    schema_name: schema.to_string(),
636                    column_default: None,
637                })
638                .collect(),
639        }
640    }
641
642    #[test]
643    fn test_resolve_pk_column() {
644        let tables = vec![make_table_with_pk("db", "users", vec![("id", true), ("name", false)])];
645        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
646        let sources = vec![
647            make_source("db", "my_view", "db", "users", "id"),
648            make_source("db", "my_view", "db", "users", "name"),
649        ];
650        resolve_view_primary_keys(&mut views, &sources, &tables);
651        assert!(views[0].columns[0].is_primary_key);
652        assert!(!views[0].columns[1].is_primary_key);
653    }
654
655    #[test]
656    fn test_resolve_pk_no_sources() {
657        let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
658        let mut views = vec![make_view("db", "my_view", vec!["id"])];
659        resolve_view_primary_keys(&mut views, &[], &tables);
660        assert!(!views[0].columns[0].is_primary_key);
661    }
662
663    #[test]
664    fn test_resolve_pk_no_match() {
665        let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
666        let mut views = vec![make_view("db", "my_view", vec!["computed"])];
667        let sources = vec![];
668        resolve_view_primary_keys(&mut views, &sources, &tables);
669        assert!(!views[0].columns[0].is_primary_key);
670    }
671}