Skip to main content

rdbi_codegen/parser/
schema_parser.rs

1//! SQL schema parser using sqlparser-rs
2
3use sqlparser::ast::{
4    ColumnOption, DataType, EnumMember, Expr, ForeignKeyConstraint, Ident, IndexColumn,
5    IndexConstraint, ObjectName, PrimaryKeyConstraint, Statement, TableConstraint,
6    UniqueConstraint,
7};
8use sqlparser::dialect::MySqlDialect;
9use sqlparser::parser::Parser;
10
11use super::metadata::*;
12use crate::error::Result;
13
14/// Parse a SQL schema string into table metadata
15pub fn parse_schema(sql: &str) -> Result<Vec<TableMetadata>> {
16    let dialect = MySqlDialect {};
17    let statements = Parser::parse_sql(&dialect, sql)?;
18
19    let mut tables = Vec::new();
20
21    for stmt in statements {
22        if let Statement::CreateTable(create_table) = stmt {
23            let table = extract_table_metadata(&create_table)?;
24            tables.push(table);
25        }
26    }
27
28    Ok(tables)
29}
30
31/// Extract table metadata from a CREATE TABLE statement
32fn extract_table_metadata(create: &sqlparser::ast::CreateTable) -> Result<TableMetadata> {
33    let name = extract_table_name(&create.name);
34
35    let mut columns = Vec::new();
36    let mut indexes = Vec::new();
37    let mut foreign_keys = Vec::new();
38    let mut primary_key = None;
39
40    // Extract columns with their options
41    for col_def in &create.columns {
42        let (column, col_pk, col_unique) = extract_column_metadata(col_def)?;
43
44        // Handle column-level PRIMARY KEY
45        if col_pk {
46            primary_key = Some(PrimaryKey {
47                columns: vec![column.name.clone()],
48            });
49        }
50
51        // Handle column-level UNIQUE
52        if col_unique {
53            indexes.push(IndexMetadata {
54                name: format!("{}_unique", column.name),
55                columns: vec![column.name.clone()],
56                unique: true,
57            });
58        }
59
60        columns.push(column);
61    }
62
63    // Extract table-level constraints
64    for constraint in &create.constraints {
65        match constraint {
66            TableConstraint::PrimaryKey(PrimaryKeyConstraint {
67                columns: pk_cols, ..
68            }) => {
69                primary_key = Some(PrimaryKey {
70                    columns: pk_cols
71                        .iter()
72                        .map(extract_ident_from_index_column)
73                        .collect(),
74                });
75                // Mark PK columns as non-nullable
76                for pk_col in pk_cols {
77                    let col_name = extract_ident_from_index_column(pk_col);
78                    if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
79                        col.nullable = false;
80                    }
81                }
82            }
83            TableConstraint::Unique(UniqueConstraint {
84                columns: uniq_cols,
85                name,
86                ..
87            }) => {
88                let idx_name = name.as_ref().map(extract_ident).unwrap_or_else(|| {
89                    let first_col = extract_ident_from_index_column(&uniq_cols[0]);
90                    format!("{}_unique", first_col)
91                });
92                indexes.push(IndexMetadata {
93                    name: idx_name,
94                    columns: uniq_cols
95                        .iter()
96                        .map(extract_ident_from_index_column)
97                        .collect(),
98                    unique: true,
99                });
100            }
101            TableConstraint::Index(IndexConstraint {
102                columns: idx_cols,
103                name,
104                ..
105            }) => {
106                let idx_name = name.as_ref().map(extract_ident).unwrap_or_else(|| {
107                    let first_col = extract_ident_from_index_column(&idx_cols[0]);
108                    format!("idx_{}", first_col)
109                });
110                indexes.push(IndexMetadata {
111                    name: idx_name,
112                    columns: idx_cols
113                        .iter()
114                        .map(extract_ident_from_index_column)
115                        .collect(),
116                    unique: false,
117                });
118            }
119            TableConstraint::ForeignKey(ForeignKeyConstraint {
120                columns,
121                foreign_table,
122                referred_columns,
123                ..
124            }) => {
125                for (col, ref_col) in columns.iter().zip(referred_columns.iter()) {
126                    foreign_keys.push(ForeignKeyMetadata {
127                        column_name: extract_ident(col),
128                        referenced_table: extract_table_name(foreign_table),
129                        referenced_column: extract_ident(ref_col),
130                    });
131                }
132            }
133            _ => {}
134        }
135    }
136
137    Ok(TableMetadata {
138        name,
139        comment: None, // sqlparser doesn't expose table comments directly
140        columns,
141        indexes,
142        foreign_keys,
143        primary_key,
144    })
145}
146
147/// Extract column metadata from a column definition
148fn extract_column_metadata(
149    col_def: &sqlparser::ast::ColumnDef,
150) -> Result<(ColumnMetadata, bool, bool)> {
151    let name = extract_ident(&col_def.name);
152    let data_type = format!("{}", col_def.data_type);
153    let enum_values = extract_enum_values(&col_def.data_type);
154    let is_unsigned = data_type.to_uppercase().contains("UNSIGNED");
155
156    let mut nullable = true; // Default to nullable
157    let mut default_value = None;
158    let mut is_auto_increment = false;
159    let mut col_is_primary = false;
160    let mut col_is_unique = false;
161    let mut comment = None;
162
163    for option in &col_def.options {
164        match &option.option {
165            ColumnOption::NotNull => {
166                nullable = false;
167            }
168            ColumnOption::Null => {
169                nullable = true;
170            }
171            ColumnOption::Default(expr) => {
172                default_value = Some(format!("{}", expr));
173            }
174            ColumnOption::PrimaryKey(_) => {
175                col_is_primary = true;
176                nullable = false;
177            }
178            ColumnOption::Unique(_) => {
179                col_is_unique = true;
180            }
181            ColumnOption::Comment(c) => {
182                comment = Some(c.clone());
183            }
184            ColumnOption::DialectSpecific(tokens) => {
185                // Check for AUTO_INCREMENT in MySQL-specific options
186                let token_str = tokens
187                    .iter()
188                    .map(|t| t.to_string())
189                    .collect::<Vec<_>>()
190                    .join(" ")
191                    .to_uppercase();
192                if token_str.contains("AUTO_INCREMENT") {
193                    is_auto_increment = true;
194                }
195            }
196            _ => {}
197        }
198    }
199
200    let column = ColumnMetadata {
201        name,
202        data_type,
203        nullable,
204        default_value,
205        is_auto_increment,
206        is_unsigned,
207        enum_values,
208        comment,
209    };
210
211    Ok((column, col_is_primary, col_is_unique))
212}
213
214/// Extract enum values from a data type
215fn extract_enum_values(data_type: &DataType) -> Option<Vec<String>> {
216    match data_type {
217        DataType::Enum(members, _) => Some(
218            members
219                .iter()
220                .map(|m| match m {
221                    EnumMember::Name(s) => s.clone(),
222                    EnumMember::NamedValue(s, _) => s.clone(),
223                })
224                .collect(),
225        ),
226        _ => None,
227    }
228}
229
230/// Extract a simple string from an ObjectName
231fn extract_table_name(name: &ObjectName) -> String {
232    name.0
233        .last()
234        .and_then(|part| part.as_ident())
235        .map(|ident| ident.value.clone())
236        .unwrap_or_default()
237}
238
239/// Extract a string from an Ident, removing backticks if present
240fn extract_ident(ident: &Ident) -> String {
241    ident.value.clone()
242}
243
244/// Extract a column name string from an IndexColumn
245fn extract_ident_from_index_column(ic: &IndexColumn) -> String {
246    match &ic.column.expr {
247        Expr::Identifier(ident) => ident.value.clone(),
248        other => format!("{}", other),
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_parse_simple_table() {
258        let sql = r#"
259            CREATE TABLE users (
260                id BIGINT AUTO_INCREMENT PRIMARY KEY,
261                username VARCHAR(255) NOT NULL,
262                email VARCHAR(255) NOT NULL
263            );
264        "#;
265
266        let tables = parse_schema(sql).unwrap();
267        assert_eq!(tables.len(), 1);
268        assert_eq!(tables[0].name, "users");
269        assert_eq!(tables[0].columns.len(), 3);
270        assert!(tables[0].primary_key.is_some());
271    }
272
273    #[test]
274    fn test_parse_table_with_indexes() {
275        let sql = r#"
276            CREATE TABLE posts (
277                id BIGINT AUTO_INCREMENT PRIMARY KEY,
278                user_id BIGINT NOT NULL,
279                title VARCHAR(255) NOT NULL,
280                status ENUM('DRAFT', 'PUBLISHED', 'ARCHIVED') NOT NULL,
281                INDEX idx_user (user_id),
282                UNIQUE INDEX idx_title (title)
283            );
284        "#;
285
286        let tables = parse_schema(sql).unwrap();
287        assert_eq!(tables.len(), 1);
288
289        // Debug: print all indexes
290        for idx in &tables[0].indexes {
291            eprintln!(
292                "Index: {} unique={} cols={:?}",
293                idx.name, idx.unique, idx.columns
294            );
295        }
296
297        // We should have 2 indexes: idx_user (non-unique) and idx_title (unique)
298        assert!(tables[0].indexes.len() >= 2);
299
300        let idx_user = tables[0].indexes.iter().find(|i| i.name == "idx_user");
301        assert!(idx_user.is_some());
302        assert!(!idx_user.unwrap().unique);
303
304        // UNIQUE INDEX is parsed as a Unique constraint, so check by column
305        let title_idx = tables[0]
306            .indexes
307            .iter()
308            .find(|i| i.columns.contains(&"title".to_string()));
309        assert!(title_idx.is_some());
310        assert!(title_idx.unwrap().unique);
311    }
312
313    #[test]
314    fn test_parse_enum_column() {
315        let sql = r#"
316            CREATE TABLE items (
317                id BIGINT PRIMARY KEY,
318                status ENUM('ACTIVE', 'INACTIVE', 'PENDING') NOT NULL
319            );
320        "#;
321
322        let tables = parse_schema(sql).unwrap();
323        let status_col = tables[0]
324            .columns
325            .iter()
326            .find(|c| c.name == "status")
327            .unwrap();
328        assert!(status_col.enum_values.is_some());
329        let values = status_col.enum_values.as_ref().unwrap();
330        assert_eq!(values.len(), 3);
331        assert!(values.contains(&"ACTIVE".to_string()));
332    }
333
334    #[test]
335    fn test_parse_foreign_key() {
336        let sql = r#"
337            CREATE TABLE orders (
338                id BIGINT PRIMARY KEY,
339                user_id BIGINT NOT NULL,
340                FOREIGN KEY (user_id) REFERENCES users(id)
341            );
342        "#;
343
344        let tables = parse_schema(sql).unwrap();
345        assert_eq!(tables[0].foreign_keys.len(), 1);
346        assert_eq!(tables[0].foreign_keys[0].column_name, "user_id");
347        assert_eq!(tables[0].foreign_keys[0].referenced_table, "users");
348        assert_eq!(tables[0].foreign_keys[0].referenced_column, "id");
349    }
350
351    #[test]
352    fn test_parse_composite_primary_key() {
353        let sql = r#"
354            CREATE TABLE order_items (
355                order_id BIGINT NOT NULL,
356                product_id BIGINT NOT NULL,
357                quantity INT NOT NULL,
358                PRIMARY KEY (order_id, product_id)
359            );
360        "#;
361
362        let tables = parse_schema(sql).unwrap();
363        let pk = tables[0].primary_key.as_ref().unwrap();
364        assert!(pk.is_composite());
365        assert_eq!(pk.columns.len(), 2);
366        assert_eq!(pk.columns[0], "order_id");
367        assert_eq!(pk.columns[1], "product_id");
368    }
369}