Skip to main content

sqlx_gen/introspect/
mysql.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::MySqlPool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9    pool: &MySqlPool,
10    schemas: &[String],
11    include_views: bool,
12) -> Result<SchemaInfo> {
13    let tables = fetch_tables(pool, schemas).await?;
14    let mut views = if include_views {
15        fetch_views(pool, schemas).await?
16    } else {
17        Vec::new()
18    };
19
20    if !views.is_empty() {
21        let sources = fetch_view_column_sources(pool, schemas).await?;
22        resolve_view_nullability(&mut views, &sources, &tables);
23        resolve_view_primary_keys(&mut views, &sources, &tables);
24    }
25
26    let enums = extract_enums(&tables);
27
28    Ok(SchemaInfo {
29        tables,
30        views,
31        enums,
32        composite_types: Vec::new(),
33        domains: Vec::new(),
34    })
35}
36
37async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
38    // MySQL doesn't support binding arrays directly, so we build placeholders
39    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
40    let query = format!(
41        r#"
42        SELECT
43            c.TABLE_SCHEMA,
44            c.TABLE_NAME,
45            c.COLUMN_NAME,
46            c.DATA_TYPE,
47            c.COLUMN_TYPE,
48            c.IS_NULLABLE,
49            c.ORDINAL_POSITION,
50            c.COLUMN_KEY
51        FROM information_schema.COLUMNS c
52        JOIN information_schema.TABLES t
53            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
54            AND t.TABLE_NAME = c.TABLE_NAME
55            AND t.TABLE_TYPE = 'BASE TABLE'
56        WHERE c.TABLE_SCHEMA IN ({})
57        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
58        "#,
59        placeholders.join(",")
60    );
61
62    let mut q = sqlx::query_as::<_, (Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, u32, Vec<u8>)>(&query);
63    for schema in schemas {
64        q = q.bind(schema);
65    }
66    let rows = q.fetch_all(pool).await?;
67
68    let mut tables: Vec<TableInfo> = Vec::new();
69    let mut current_key: Option<(String, String)> = None;
70
71    for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
72        let schema = String::from_utf8(schema).expect("Could not convert schema name from UTF8 bytes");
73        let table = String::from_utf8(table).expect("Could not convert schema name from UTF8 bytes");
74        let col_name = String::from_utf8(col_name).expect("Could not convert col_name name from UTF8 bytes");
75        let data_type = String::from_utf8(data_type).expect("Could not convert data_type name from UTF8 bytes");
76        let column_type = String::from_utf8(column_type).expect("Could not convert column_type name from UTF8 bytes");
77        let nullable = String::from_utf8(nullable).expect("Could not convert nullable name from UTF8 bytes");
78        let column_key = String::from_utf8(column_key).expect("Could not convert column_key name from UTF8 bytes");
79
80        let key = (schema.clone(), table.clone());
81        if current_key.as_ref() != Some(&key) {
82            current_key = Some(key);
83            tables.push(TableInfo {
84                schema_name: schema.clone(),
85                name: table.clone(),
86                columns: Vec::new(),
87            });
88        }
89        tables.last_mut().unwrap().columns.push(ColumnInfo {
90            name: col_name,
91            data_type,
92            udt_name: column_type,
93            is_nullable: nullable == "YES",
94            is_primary_key: column_key == "PRI",
95            ordinal_position: ordinal as i32,
96            schema_name: schema,
97            column_default: None,
98        });
99    }
100
101    Ok(tables)
102}
103
104async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
105    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
106    let query = format!(
107        r#"
108        SELECT
109            c.TABLE_SCHEMA,
110            c.TABLE_NAME,
111            c.COLUMN_NAME,
112            c.DATA_TYPE,
113            c.COLUMN_TYPE,
114            c.IS_NULLABLE,
115            c.ORDINAL_POSITION
116        FROM information_schema.COLUMNS c
117        JOIN information_schema.TABLES t
118            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
119            AND t.TABLE_NAME = c.TABLE_NAME
120            AND t.TABLE_TYPE = 'VIEW'
121        WHERE c.TABLE_SCHEMA IN ({})
122        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
123        "#,
124        placeholders.join(",")
125    );
126
127    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
128    for schema in schemas {
129        q = q.bind(schema);
130    }
131    let rows = q.fetch_all(pool).await?;
132
133    let mut views: Vec<TableInfo> = Vec::new();
134    let mut current_key: Option<(String, String)> = None;
135
136    for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
137        let key = (schema.clone(), table.clone());
138        if current_key.as_ref() != Some(&key) {
139            current_key = Some(key);
140            views.push(TableInfo {
141                schema_name: schema.clone(),
142                name: table.clone(),
143                columns: Vec::new(),
144            });
145        }
146        views.last_mut().unwrap().columns.push(ColumnInfo {
147            name: col_name,
148            data_type,
149            udt_name: column_type,
150            is_nullable: nullable == "YES",
151            is_primary_key: false,
152            ordinal_position: ordinal as i32,
153            schema_name: schema,
154            column_default: None,
155        });
156    }
157
158    Ok(views)
159}
160
161struct ViewColumnSource {
162    view_schema: String,
163    view_name: String,
164    table_schema: String,
165    table_name: String,
166    column_name: String,
167}
168
169async fn fetch_view_column_sources(
170    pool: &MySqlPool,
171    schemas: &[String],
172) -> Result<Vec<ViewColumnSource>> {
173    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
174    let query = format!(
175        r#"
176        SELECT
177            vcu.VIEW_SCHEMA,
178            vcu.VIEW_NAME,
179            vcu.TABLE_SCHEMA,
180            vcu.TABLE_NAME,
181            vcu.COLUMN_NAME
182        FROM INFORMATION_SCHEMA.VIEW_COLUMN_USAGE vcu
183        WHERE vcu.VIEW_SCHEMA IN ({})
184        "#,
185        placeholders.join(",")
186    );
187
188    let mut q = sqlx::query_as::<_, (String, String, String, String, String)>(&query);
189    for schema in schemas {
190        q = q.bind(schema);
191    }
192
193    match q.fetch_all(pool).await {
194        Ok(rows) => Ok(rows
195            .into_iter()
196            .map(
197                |(view_schema, view_name, table_schema, table_name, column_name)| {
198                    ViewColumnSource {
199                        view_schema,
200                        view_name,
201                        table_schema,
202                        table_name,
203                        column_name,
204                    }
205                },
206            )
207            .collect()),
208        Err(_) => {
209            // VIEW_COLUMN_USAGE may not exist on older MySQL versions
210            Ok(Vec::new())
211        }
212    }
213}
214
215fn resolve_view_nullability(
216    views: &mut [TableInfo],
217    sources: &[ViewColumnSource],
218    tables: &[TableInfo],
219) {
220    // Build table column lookup: (schema, table, column) -> is_nullable
221    let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
222    for table in tables {
223        for col in &table.columns {
224            table_lookup.insert(
225                (&table.schema_name, &table.name, &col.name),
226                col.is_nullable,
227            );
228        }
229    }
230
231    // Build view column source lookup: (view_schema, view_name, column_name) -> Vec<is_nullable>
232    let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
233    for src in sources {
234        if let Some(&is_nullable) =
235            table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
236        {
237            view_lookup
238                .entry((&src.view_schema, &src.view_name, &src.column_name))
239                .or_default()
240                .push(is_nullable);
241        }
242    }
243
244    for view in views.iter_mut() {
245        for col in view.columns.iter_mut() {
246            if let Some(nullable_flags) = view_lookup.get(&(
247                view.schema_name.as_str(),
248                view.name.as_str(),
249                col.name.as_str(),
250            )) {
251                // Only mark as non-nullable if ALL sources are NOT nullable
252                if !nullable_flags.is_empty() && nullable_flags.iter().all(|&n| !n) {
253                    col.is_nullable = false;
254                }
255            }
256        }
257    }
258}
259
260fn resolve_view_primary_keys(
261    views: &mut [TableInfo],
262    sources: &[ViewColumnSource],
263    tables: &[TableInfo],
264) {
265    // Build table column lookup: (schema, table, column) -> is_primary_key
266    let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
267    for table in tables {
268        for col in &table.columns {
269            table_lookup.insert(
270                (&table.schema_name, &table.name, &col.name),
271                col.is_primary_key,
272            );
273        }
274    }
275
276    // Build view column source lookup: (view_schema, view_name, column_name) -> Vec<is_pk>
277    let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
278    for src in sources {
279        if let Some(&is_pk) =
280            table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
281        {
282            view_lookup
283                .entry((&src.view_schema, &src.view_name, &src.column_name))
284                .or_default()
285                .push(is_pk);
286        }
287    }
288
289    for view in views.iter_mut() {
290        for col in view.columns.iter_mut() {
291            if let Some(pk_flags) = view_lookup.get(&(
292                view.schema_name.as_str(),
293                view.name.as_str(),
294                col.name.as_str(),
295            )) {
296                // Only mark as PK if ALL sources are PKs
297                if !pk_flags.is_empty() && pk_flags.iter().all(|&pk| pk) {
298                    col.is_primary_key = true;
299                }
300            }
301        }
302    }
303}
304
305/// Extract inline ENUMs from column types.
306/// MySQL ENUM('a','b','c') in COLUMN_TYPE gets extracted to an EnumInfo
307/// keyed by table_name + column_name.
308fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
309    let mut enums = Vec::new();
310
311    for table in tables {
312        for col in &table.columns {
313            if col.udt_name.starts_with("enum(") {
314                let variants = parse_enum_variants(&col.udt_name);
315                if !variants.is_empty() {
316                    let enum_name = format!("{}_{}", table.name, col.name);
317                    enums.push(EnumInfo {
318                        schema_name: table.schema_name.clone(),
319                        name: enum_name,
320                        variants,
321                        default_variant: None,
322                    });
323                }
324            }
325        }
326    }
327
328    enums
329}
330
331fn parse_enum_variants(column_type: &str) -> Vec<String> {
332    // Parse "enum('a','b','c')" → ["a", "b", "c"]
333    let inner = column_type
334        .strip_prefix("enum(")
335        .and_then(|s| s.strip_suffix(')'));
336    match inner {
337        Some(s) => s
338            .split(',')
339            .map(|v| v.trim().trim_matches('\'').to_string())
340            .filter(|v| !v.is_empty())
341            .collect(),
342        None => Vec::new(),
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
351        TableInfo {
352            schema_name: "test_db".to_string(),
353            name: name.to_string(),
354            columns,
355        }
356    }
357
358    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
359        ColumnInfo {
360            name: name.to_string(),
361            data_type: "varchar".to_string(),
362            udt_name: udt_name.to_string(),
363            is_nullable: false,
364            is_primary_key: false,
365            ordinal_position: 0,
366            schema_name: "test_db".to_string(),
367            column_default: None,
368        }
369    }
370
371    // ========== parse_enum_variants ==========
372
373    #[test]
374    fn test_parse_simple() {
375        assert_eq!(
376            parse_enum_variants("enum('a','b','c')"),
377            vec!["a", "b", "c"]
378        );
379    }
380
381    #[test]
382    fn test_parse_single_variant() {
383        assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
384    }
385
386    #[test]
387    fn test_parse_with_spaces() {
388        assert_eq!(
389            parse_enum_variants("enum( 'a' , 'b' )"),
390            vec!["a", "b"]
391        );
392    }
393
394    #[test]
395    fn test_parse_empty_parens() {
396        let result = parse_enum_variants("enum()");
397        assert!(result.is_empty());
398    }
399
400    #[test]
401    fn test_parse_varchar_not_enum() {
402        let result = parse_enum_variants("varchar(255)");
403        assert!(result.is_empty());
404    }
405
406    #[test]
407    fn test_parse_int_not_enum() {
408        let result = parse_enum_variants("int");
409        assert!(result.is_empty());
410    }
411
412    #[test]
413    fn test_parse_with_spaces_in_value() {
414        assert_eq!(
415            parse_enum_variants("enum('with space','no')"),
416            vec!["with space", "no"]
417        );
418    }
419
420    #[test]
421    fn test_parse_empty_variant_filtered() {
422        let result = parse_enum_variants("enum('a','','c')");
423        assert_eq!(result, vec!["a", "c"]);
424    }
425
426    #[test]
427    fn test_parse_uppercase_enum_not_matched() {
428        // "ENUM(" doesn't match "enum(" prefix
429        let result = parse_enum_variants("ENUM('a','b')");
430        assert!(result.is_empty());
431    }
432
433    // ========== extract_enums ==========
434
435    #[test]
436    fn test_extract_from_enum_column() {
437        let tables = vec![make_table(
438            "users",
439            vec![make_col("status", "enum('active','inactive')")],
440        )];
441        let enums = extract_enums(&tables);
442        assert_eq!(enums.len(), 1);
443        assert_eq!(enums[0].variants, vec!["active", "inactive"]);
444    }
445
446    #[test]
447    fn test_extract_enum_name_format() {
448        let tables = vec![make_table(
449            "users",
450            vec![make_col("status", "enum('a')")],
451        )];
452        let enums = extract_enums(&tables);
453        assert_eq!(enums[0].name, "users_status");
454    }
455
456    #[test]
457    fn test_extract_no_enums() {
458        let tables = vec![make_table(
459            "users",
460            vec![make_col("id", "int"), make_col("name", "varchar(255)")],
461        )];
462        let enums = extract_enums(&tables);
463        assert!(enums.is_empty());
464    }
465
466    #[test]
467    fn test_extract_two_enum_columns_same_table() {
468        let tables = vec![make_table(
469            "users",
470            vec![
471                make_col("status", "enum('active','inactive')"),
472                make_col("role", "enum('admin','user')"),
473            ],
474        )];
475        let enums = extract_enums(&tables);
476        assert_eq!(enums.len(), 2);
477        assert_eq!(enums[0].name, "users_status");
478        assert_eq!(enums[1].name, "users_role");
479    }
480
481    #[test]
482    fn test_extract_enums_from_multiple_tables() {
483        let tables = vec![
484            make_table("users", vec![make_col("status", "enum('a')")]),
485            make_table("posts", vec![make_col("state", "enum('b')")]),
486        ];
487        let enums = extract_enums(&tables);
488        assert_eq!(enums.len(), 2);
489    }
490
491    #[test]
492    fn test_extract_non_enum_column_ignored() {
493        let tables = vec![make_table(
494            "users",
495            vec![
496                make_col("id", "int(11)"),
497                make_col("status", "enum('a')"),
498            ],
499        )];
500        let enums = extract_enums(&tables);
501        assert_eq!(enums.len(), 1);
502    }
503
504    // ========== resolve_view_nullability ==========
505
506    fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
507        TableInfo {
508            schema_name: schema.to_string(),
509            name: name.to_string(),
510            columns: columns
511                .into_iter()
512                .enumerate()
513                .map(|(i, col)| ColumnInfo {
514                    name: col.to_string(),
515                    data_type: "varchar".to_string(),
516                    udt_name: "varchar(255)".to_string(),
517                    is_nullable: true,
518                    is_primary_key: false,
519                    ordinal_position: i as i32,
520                    schema_name: schema.to_string(),
521                    column_default: None,
522                })
523                .collect(),
524        }
525    }
526
527    fn make_table_with_nullability(
528        schema: &str,
529        name: &str,
530        columns: Vec<(&str, bool)>,
531    ) -> TableInfo {
532        TableInfo {
533            schema_name: schema.to_string(),
534            name: name.to_string(),
535            columns: columns
536                .into_iter()
537                .enumerate()
538                .map(|(i, (col, nullable))| ColumnInfo {
539                    name: col.to_string(),
540                    data_type: "varchar".to_string(),
541                    udt_name: "varchar(255)".to_string(),
542                    is_nullable: nullable,
543                    is_primary_key: false,
544                    ordinal_position: i as i32,
545                    schema_name: schema.to_string(),
546                    column_default: None,
547                })
548                .collect(),
549        }
550    }
551
552    fn make_source(
553        view_schema: &str,
554        view_name: &str,
555        table_schema: &str,
556        table_name: &str,
557        column_name: &str,
558    ) -> ViewColumnSource {
559        ViewColumnSource {
560            view_schema: view_schema.to_string(),
561            view_name: view_name.to_string(),
562            table_schema: table_schema.to_string(),
563            table_name: table_name.to_string(),
564            column_name: column_name.to_string(),
565        }
566    }
567
568    #[test]
569    fn test_resolve_not_null_column() {
570        let tables = vec![make_table_with_nullability(
571            "db",
572            "users",
573            vec![("id", false), ("name", false)],
574        )];
575        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
576        let sources = vec![
577            make_source("db", "my_view", "db", "users", "id"),
578            make_source("db", "my_view", "db", "users", "name"),
579        ];
580        resolve_view_nullability(&mut views, &sources, &tables);
581        assert!(!views[0].columns[0].is_nullable);
582        assert!(!views[0].columns[1].is_nullable);
583    }
584
585    #[test]
586    fn test_resolve_nullable_source() {
587        let tables = vec![make_table_with_nullability(
588            "db",
589            "users",
590            vec![("id", false), ("name", true)],
591        )];
592        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
593        let sources = vec![
594            make_source("db", "my_view", "db", "users", "id"),
595            make_source("db", "my_view", "db", "users", "name"),
596        ];
597        resolve_view_nullability(&mut views, &sources, &tables);
598        assert!(!views[0].columns[0].is_nullable);
599        assert!(views[0].columns[1].is_nullable);
600    }
601
602    #[test]
603    fn test_resolve_no_match_stays_nullable() {
604        let tables = vec![make_table_with_nullability(
605            "db",
606            "users",
607            vec![("id", false)],
608        )];
609        let mut views = vec![make_view("db", "my_view", vec!["computed"])];
610        let sources = vec![];
611        resolve_view_nullability(&mut views, &sources, &tables);
612        assert!(views[0].columns[0].is_nullable);
613    }
614
615    #[test]
616    fn test_resolve_empty_sources() {
617        let tables = vec![];
618        let mut views = vec![make_view("db", "my_view", vec!["id"])];
619        resolve_view_nullability(&mut views, &[], &tables);
620        assert!(views[0].columns[0].is_nullable);
621    }
622
623    // ========== resolve_view_primary_keys ==========
624
625    fn make_table_with_pk(
626        schema: &str,
627        name: &str,
628        columns: Vec<(&str, bool)>,
629    ) -> TableInfo {
630        TableInfo {
631            schema_name: schema.to_string(),
632            name: name.to_string(),
633            columns: columns
634                .into_iter()
635                .enumerate()
636                .map(|(i, (col, is_pk))| ColumnInfo {
637                    name: col.to_string(),
638                    data_type: "varchar".to_string(),
639                    udt_name: "varchar(255)".to_string(),
640                    is_nullable: false,
641                    is_primary_key: is_pk,
642                    ordinal_position: i as i32,
643                    schema_name: schema.to_string(),
644                    column_default: None,
645                })
646                .collect(),
647        }
648    }
649
650    #[test]
651    fn test_resolve_pk_column() {
652        let tables = vec![make_table_with_pk("db", "users", vec![("id", true), ("name", false)])];
653        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
654        let sources = vec![
655            make_source("db", "my_view", "db", "users", "id"),
656            make_source("db", "my_view", "db", "users", "name"),
657        ];
658        resolve_view_primary_keys(&mut views, &sources, &tables);
659        assert!(views[0].columns[0].is_primary_key);
660        assert!(!views[0].columns[1].is_primary_key);
661    }
662
663    #[test]
664    fn test_resolve_pk_no_sources() {
665        let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
666        let mut views = vec![make_view("db", "my_view", vec!["id"])];
667        resolve_view_primary_keys(&mut views, &[], &tables);
668        assert!(!views[0].columns[0].is_primary_key);
669    }
670
671    #[test]
672    fn test_resolve_pk_no_match() {
673        let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
674        let mut views = vec![make_view("db", "my_view", vec!["computed"])];
675        let sources = vec![];
676        resolve_view_primary_keys(&mut views, &sources, &tables);
677        assert!(!views[0].columns[0].is_primary_key);
678    }
679}