Skip to main content

aegis_query/
parser.rs

1//! Aegis Parser - SQL Parser
2//!
3//! Wraps sqlparser-rs to parse SQL statements and convert them to the
4//! internal Aegis AST representation. Handles dialect differences and
5//! provides friendly error messages.
6//!
7//! Key Features:
8//! - Full ANSI SQL support via sqlparser-rs
9//! - Conversion to typed Aegis AST
10//! - Detailed parse error reporting
11//! - Support for multiple SQL dialects
12//!
13//! @version 0.1.0
14//! @author AutomataNexus Development Team
15
16use crate::ast::*;
17use aegis_common::{AegisError, DataType, Result};
18use sqlparser::ast as sp;
19use sqlparser::dialect::GenericDialect;
20use sqlparser::parser::Parser as SqlParser;
21
22// =============================================================================
23// Parser
24// =============================================================================
25
26/// SQL parser wrapping sqlparser-rs.
27pub struct Parser {
28    dialect: GenericDialect,
29}
30
31impl Parser {
32    pub fn new() -> Self {
33        Self {
34            dialect: GenericDialect {},
35        }
36    }
37
38    /// Parse a SQL string into statements.
39    pub fn parse(&self, sql: &str) -> Result<Vec<Statement>> {
40        let ast = SqlParser::parse_sql(&self.dialect, sql)
41            .map_err(|e| AegisError::Parse(e.to_string()))?;
42
43        ast.into_iter()
44            .map(|stmt| self.convert_statement(stmt))
45            .collect()
46    }
47
48    /// Parse a single statement.
49    pub fn parse_single(&self, sql: &str) -> Result<Statement> {
50        let statements = self.parse(sql)?;
51        if statements.len() != 1 {
52            return Err(AegisError::Parse(format!(
53                "Expected 1 statement, got {}",
54                statements.len()
55            )));
56        }
57        // Safe to use expect here: we verified statements.len() == 1 above
58        Ok(statements
59            .into_iter()
60            .next()
61            .expect("statements verified to have exactly 1 element"))
62    }
63
64    fn convert_statement(&self, stmt: sp::Statement) -> Result<Statement> {
65        match stmt {
66            sp::Statement::Query(query) => self.convert_query_to_statement(*query),
67            sp::Statement::Insert(insert) => {
68                let insert_stmt = self.convert_insert(insert)?;
69                Ok(Statement::Insert(insert_stmt))
70            }
71            sp::Statement::Update {
72                table,
73                assignments,
74                from: _,
75                selection,
76                returning: _,
77                ..
78            } => {
79                let update_stmt = self.convert_update(table, assignments, selection)?;
80                Ok(Statement::Update(update_stmt))
81            }
82            sp::Statement::Delete(delete) => {
83                let delete_stmt = self.convert_delete(delete)?;
84                Ok(Statement::Delete(delete_stmt))
85            }
86            sp::Statement::CreateTable(create) => {
87                let create_stmt = self.convert_create_table(create)?;
88                Ok(Statement::CreateTable(create_stmt))
89            }
90            sp::Statement::Drop {
91                object_type,
92                if_exists,
93                names,
94                ..
95            } => self.convert_drop(object_type, if_exists, names),
96            sp::Statement::CreateIndex(create) => {
97                let create_stmt = self.convert_create_index(create)?;
98                Ok(Statement::CreateIndex(create_stmt))
99            }
100            sp::Statement::AlterTable {
101                name, operations, ..
102            } => {
103                let alter_stmt = self.convert_alter_table(name, operations)?;
104                Ok(Statement::AlterTable(alter_stmt))
105            }
106            sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
107            sp::Statement::Commit { .. } => Ok(Statement::Commit),
108            sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
109            _ => Err(AegisError::Parse(format!(
110                "Unsupported statement type: {:?}",
111                stmt
112            ))),
113        }
114    }
115
116    /// Convert a query that may be a set operation into a Statement.
117    /// Returns Statement::SetOperation for UNION/INTERSECT/EXCEPT,
118    /// or Statement::Select for plain SELECTs.
119    fn convert_query_to_statement(&self, query: sp::Query) -> Result<Statement> {
120        if matches!(query.body.as_ref(), sp::SetExpr::SetOperation { .. }) {
121            // Destructure the set operation
122            match *query.body {
123                sp::SetExpr::SetOperation {
124                    op,
125                    set_quantifier,
126                    left,
127                    right,
128                } => {
129                    let op_type = match (op, set_quantifier) {
130                        (sp::SetOperator::Union, sp::SetQuantifier::All) => {
131                            SetOperationType::UnionAll
132                        }
133                        (sp::SetOperator::Union, _) => SetOperationType::Union,
134                        (sp::SetOperator::Intersect, _) => SetOperationType::Intersect,
135                        (sp::SetOperator::Except, _) => SetOperationType::Except,
136                    };
137                    let make_query = |body: Box<sp::SetExpr>| sp::Query {
138                        body,
139                        order_by: None,
140                        limit: None,
141                        offset: None,
142                        fetch: None,
143                        with: None,
144                        limit_by: vec![],
145                        for_clause: None,
146                        settings: None,
147                        format_clause: None,
148                        locks: vec![],
149                    };
150                    let left_stmt = self.convert_query_to_statement(make_query(left))?;
151                    let right_stmt = self.convert_query_to_statement(make_query(right))?;
152                    Ok(Statement::SetOperation(SetOperationStatement {
153                        op: op_type,
154                        left: Box::new(left_stmt),
155                        right: Box::new(right_stmt),
156                    }))
157                }
158                _ => unreachable!(),
159            }
160        } else {
161            let select = self.convert_query(query)?;
162            Ok(Statement::Select(select))
163        }
164    }
165
166    fn convert_query(&self, query: sp::Query) -> Result<SelectStatement> {
167        let body = match *query.body {
168            sp::SetExpr::Select(select) => select,
169            _ => return Err(AegisError::Parse("Unsupported query type".to_string())),
170        };
171
172        let columns = body
173            .projection
174            .into_iter()
175            .map(|item| self.convert_select_item(item))
176            .collect::<Result<Vec<_>>>()?;
177
178        let from = if !body.from.is_empty() {
179            Some(self.convert_from(&body.from)?)
180        } else {
181            None
182        };
183
184        let where_clause = body
185            .selection
186            .map(|expr| self.convert_expr(expr))
187            .transpose()?;
188
189        let group_by = match body.group_by {
190            sp::GroupByExpr::Expressions(exprs, _) => exprs
191                .into_iter()
192                .map(|e| self.convert_expr(e))
193                .collect::<Result<Vec<_>>>()?,
194            sp::GroupByExpr::All(_) => Vec::new(),
195        };
196
197        let having = body.having.map(|e| self.convert_expr(e)).transpose()?;
198
199        let order_by = query
200            .order_by
201            .map(|ob| {
202                ob.exprs
203                    .into_iter()
204                    .map(|item| self.convert_order_by_item(item))
205                    .collect::<Result<Vec<_>>>()
206            })
207            .transpose()?
208            .unwrap_or_default();
209
210        let limit = query.limit.map(|e| self.extract_limit(e)).transpose()?;
211
212        let offset = query
213            .offset
214            .map(|o| self.extract_limit(o.value))
215            .transpose()?;
216
217        Ok(SelectStatement {
218            distinct: body.distinct.is_some(),
219            columns,
220            from,
221            where_clause,
222            group_by,
223            having,
224            order_by,
225            limit,
226            offset,
227        })
228    }
229
230    fn convert_select_item(&self, item: sp::SelectItem) -> Result<SelectColumn> {
231        match item {
232            sp::SelectItem::UnnamedExpr(expr) => Ok(SelectColumn::Expression {
233                expr: self.convert_expr(expr)?,
234                alias: None,
235            }),
236            sp::SelectItem::ExprWithAlias { expr, alias } => Ok(SelectColumn::Expression {
237                expr: self.convert_expr(expr)?,
238                alias: Some(alias.value),
239            }),
240            sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
241            sp::SelectItem::QualifiedWildcard(name, _) => {
242                Ok(SelectColumn::TableAllColumns(name.to_string()))
243            }
244        }
245    }
246
247    fn convert_from(&self, from: &[sp::TableWithJoins]) -> Result<FromClause> {
248        let first = from
249            .first()
250            .ok_or_else(|| AegisError::Parse("Empty FROM".to_string()))?;
251
252        let source = self.convert_table_factor(&first.relation)?;
253
254        let mut joins = Vec::new();
255        for join in &first.joins {
256            joins.push(self.convert_join(join)?);
257        }
258
259        Ok(FromClause { source, joins })
260    }
261
262    fn convert_table_factor(&self, factor: &sp::TableFactor) -> Result<TableReference> {
263        match factor {
264            sp::TableFactor::Table { name, alias, .. } => Ok(TableReference::Table {
265                name: name.to_string(),
266                alias: alias.as_ref().map(|a| a.name.value.clone()),
267            }),
268            sp::TableFactor::Derived {
269                subquery, alias, ..
270            } => {
271                let alias_name = alias
272                    .as_ref()
273                    .map(|a| a.name.value.clone())
274                    .ok_or_else(|| AegisError::Parse("Subquery requires alias".to_string()))?;
275                Ok(TableReference::Subquery {
276                    query: Box::new(self.convert_query(*subquery.clone())?),
277                    alias: alias_name,
278                })
279            }
280            _ => Err(AegisError::Parse("Unsupported table factor".to_string())),
281        }
282    }
283
284    fn convert_join(&self, join: &sp::Join) -> Result<JoinClause> {
285        let join_type = match &join.join_operator {
286            sp::JoinOperator::Inner(_) => JoinType::Inner,
287            sp::JoinOperator::LeftOuter(_) => JoinType::Left,
288            sp::JoinOperator::RightOuter(_) => JoinType::Right,
289            sp::JoinOperator::FullOuter(_) => JoinType::Full,
290            sp::JoinOperator::CrossJoin => JoinType::Cross,
291            _ => return Err(AegisError::Parse("Unsupported join type".to_string())),
292        };
293
294        let condition = match &join.join_operator {
295            sp::JoinOperator::Inner(c)
296            | sp::JoinOperator::LeftOuter(c)
297            | sp::JoinOperator::RightOuter(c)
298            | sp::JoinOperator::FullOuter(c) => match c {
299                sp::JoinConstraint::On(expr) => Some(self.convert_expr(expr.clone())?),
300                sp::JoinConstraint::None => None,
301                _ => return Err(AegisError::Parse("Unsupported join constraint".to_string())),
302            },
303            sp::JoinOperator::CrossJoin => None,
304            _ => None,
305        };
306
307        Ok(JoinClause {
308            join_type,
309            table: self.convert_table_factor(&join.relation)?,
310            condition,
311        })
312    }
313
314    fn convert_order_by_item(&self, item: sp::OrderByExpr) -> Result<OrderByItem> {
315        Ok(OrderByItem {
316            expression: self.convert_expr(item.expr)?,
317            ascending: item.asc.unwrap_or(true),
318            nulls_first: item.nulls_first,
319        })
320    }
321
322    fn convert_expr(&self, expr: sp::Expr) -> Result<Expression> {
323        match expr {
324            sp::Expr::Identifier(ident) => Ok(Expression::Column(ColumnRef {
325                table: None,
326                column: ident.value,
327            })),
328            sp::Expr::CompoundIdentifier(idents) => {
329                if idents.len() == 2 {
330                    Ok(Expression::Column(ColumnRef {
331                        table: Some(idents[0].value.clone()),
332                        column: idents[1].value.clone(),
333                    }))
334                } else {
335                    Ok(Expression::Column(ColumnRef {
336                        table: None,
337                        column: idents.last().map(|i| i.value.clone()).unwrap_or_default(),
338                    }))
339                }
340            }
341            sp::Expr::Value(value) => self.convert_value(value),
342            sp::Expr::BinaryOp { left, op, right } => Ok(Expression::BinaryOp {
343                left: Box::new(self.convert_expr(*left)?),
344                op: self.convert_binary_op(op)?,
345                right: Box::new(self.convert_expr(*right)?),
346            }),
347            sp::Expr::UnaryOp { op, expr } => Ok(Expression::UnaryOp {
348                op: self.convert_unary_op(op)?,
349                expr: Box::new(self.convert_expr(*expr)?),
350            }),
351            sp::Expr::Function(func) => {
352                let args = match func.args {
353                    sp::FunctionArguments::List(list) => list
354                        .args
355                        .into_iter()
356                        .filter_map(|arg| match arg {
357                            sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => {
358                                Some(self.convert_expr(e))
359                            }
360                            _ => None,
361                        })
362                        .collect::<Result<Vec<_>>>()?,
363                    _ => Vec::new(),
364                };
365
366                Ok(Expression::Function {
367                    name: func.name.to_string(),
368                    args,
369                    distinct: false,
370                })
371            }
372            sp::Expr::Nested(expr) => self.convert_expr(*expr),
373            sp::Expr::IsNull(expr) => Ok(Expression::IsNull {
374                expr: Box::new(self.convert_expr(*expr)?),
375                negated: false,
376            }),
377            sp::Expr::IsNotNull(expr) => Ok(Expression::IsNull {
378                expr: Box::new(self.convert_expr(*expr)?),
379                negated: true,
380            }),
381            sp::Expr::InList {
382                expr,
383                list,
384                negated,
385            } => Ok(Expression::InList {
386                expr: Box::new(self.convert_expr(*expr)?),
387                list: list
388                    .into_iter()
389                    .map(|e| self.convert_expr(e))
390                    .collect::<Result<Vec<_>>>()?,
391                negated,
392            }),
393            sp::Expr::Between {
394                expr,
395                low,
396                high,
397                negated,
398            } => Ok(Expression::Between {
399                expr: Box::new(self.convert_expr(*expr)?),
400                low: Box::new(self.convert_expr(*low)?),
401                high: Box::new(self.convert_expr(*high)?),
402                negated,
403            }),
404            sp::Expr::Like {
405                expr,
406                pattern,
407                negated,
408                ..
409            } => Ok(Expression::Like {
410                expr: Box::new(self.convert_expr(*expr)?),
411                pattern: Box::new(self.convert_expr(*pattern)?),
412                negated,
413            }),
414            _ => Err(AegisError::Parse(format!(
415                "Unsupported expression: {:?}",
416                expr
417            ))),
418        }
419    }
420
421    fn convert_value(&self, value: sp::Value) -> Result<Expression> {
422        let literal = match value {
423            sp::Value::Null => Literal::Null,
424            sp::Value::Boolean(b) => Literal::Boolean(b),
425            sp::Value::Number(n, _) => {
426                if n.contains('.') {
427                    Literal::Float(
428                        n.parse()
429                            .map_err(|_| AegisError::Parse("Invalid float".to_string()))?,
430                    )
431                } else {
432                    Literal::Integer(
433                        n.parse()
434                            .map_err(|_| AegisError::Parse("Invalid integer".to_string()))?,
435                    )
436                }
437            }
438            sp::Value::SingleQuotedString(s) | sp::Value::DoubleQuotedString(s) => {
439                Literal::String(s)
440            }
441            sp::Value::Placeholder(s) if s.starts_with('$') => {
442                let idx: usize = s[1..]
443                    .parse()
444                    .map_err(|_| AegisError::Parse(format!("Invalid placeholder: {}", s)))?;
445                return Ok(Expression::Placeholder(idx));
446            }
447            _ => return Err(AegisError::Parse("Unsupported literal value".to_string())),
448        };
449        Ok(Expression::Literal(literal))
450    }
451
452    fn convert_binary_op(&self, op: sp::BinaryOperator) -> Result<BinaryOperator> {
453        match op {
454            sp::BinaryOperator::Plus => Ok(BinaryOperator::Add),
455            sp::BinaryOperator::Minus => Ok(BinaryOperator::Subtract),
456            sp::BinaryOperator::Multiply => Ok(BinaryOperator::Multiply),
457            sp::BinaryOperator::Divide => Ok(BinaryOperator::Divide),
458            sp::BinaryOperator::Modulo => Ok(BinaryOperator::Modulo),
459            sp::BinaryOperator::Eq => Ok(BinaryOperator::Equal),
460            sp::BinaryOperator::NotEq => Ok(BinaryOperator::NotEqual),
461            sp::BinaryOperator::Lt => Ok(BinaryOperator::LessThan),
462            sp::BinaryOperator::LtEq => Ok(BinaryOperator::LessThanOrEqual),
463            sp::BinaryOperator::Gt => Ok(BinaryOperator::GreaterThan),
464            sp::BinaryOperator::GtEq => Ok(BinaryOperator::GreaterThanOrEqual),
465            sp::BinaryOperator::And => Ok(BinaryOperator::And),
466            sp::BinaryOperator::Or => Ok(BinaryOperator::Or),
467            sp::BinaryOperator::StringConcat => Ok(BinaryOperator::Concat),
468            _ => Err(AegisError::Parse(format!("Unsupported operator: {:?}", op))),
469        }
470    }
471
472    fn convert_unary_op(&self, op: sp::UnaryOperator) -> Result<UnaryOperator> {
473        match op {
474            sp::UnaryOperator::Not => Ok(UnaryOperator::Not),
475            sp::UnaryOperator::Minus => Ok(UnaryOperator::Negative),
476            sp::UnaryOperator::Plus => Ok(UnaryOperator::Positive),
477            _ => Err(AegisError::Parse(format!(
478                "Unsupported unary operator: {:?}",
479                op
480            ))),
481        }
482    }
483
484    fn convert_insert(&self, insert: sp::Insert) -> Result<InsertStatement> {
485        let table = insert.table_name.to_string();
486
487        let columns = if insert.columns.is_empty() {
488            None
489        } else {
490            Some(insert.columns.into_iter().map(|c| c.value).collect())
491        };
492
493        let source = match insert.source {
494            Some(query) => match *query.body {
495                sp::SetExpr::Values(values) => {
496                    let rows = values
497                        .rows
498                        .into_iter()
499                        .map(|row| {
500                            row.into_iter()
501                                .map(|e| self.convert_expr(e))
502                                .collect::<Result<Vec<_>>>()
503                        })
504                        .collect::<Result<Vec<_>>>()?;
505                    InsertSource::Values(rows)
506                }
507                _ => InsertSource::Query(Box::new(self.convert_query(*query)?)),
508            },
509            None => return Err(AegisError::Parse("INSERT missing values".to_string())),
510        };
511
512        Ok(InsertStatement {
513            table,
514            columns,
515            source,
516        })
517    }
518
519    fn convert_update(
520        &self,
521        table: sp::TableWithJoins,
522        assignments: Vec<sp::Assignment>,
523        selection: Option<sp::Expr>,
524    ) -> Result<UpdateStatement> {
525        let table_name = match &table.relation {
526            sp::TableFactor::Table { name, .. } => name.to_string(),
527            _ => return Err(AegisError::Parse("UPDATE requires table name".to_string())),
528        };
529
530        let assigns = assignments
531            .into_iter()
532            .map(|a| {
533                let column = match a.target {
534                    sp::AssignmentTarget::ColumnName(names) => names
535                        .0
536                        .into_iter()
537                        .map(|i| i.value)
538                        .collect::<Vec<_>>()
539                        .join("."),
540                    sp::AssignmentTarget::Tuple(cols) => cols
541                        .into_iter()
542                        .map(|c| c.to_string())
543                        .collect::<Vec<_>>()
544                        .join(", "),
545                };
546                Ok(Assignment {
547                    column,
548                    value: self.convert_expr(a.value)?,
549                })
550            })
551            .collect::<Result<Vec<_>>>()?;
552
553        let where_clause = selection.map(|e| self.convert_expr(e)).transpose()?;
554
555        Ok(UpdateStatement {
556            table: table_name,
557            assignments: assigns,
558            where_clause,
559        })
560    }
561
562    fn convert_delete(&self, delete: sp::Delete) -> Result<DeleteStatement> {
563        let table = match delete.from {
564            sp::FromTable::WithFromKeyword(tables) => tables
565                .first()
566                .map(|t| match &t.relation {
567                    sp::TableFactor::Table { name, .. } => name.to_string(),
568                    _ => String::new(),
569                })
570                .ok_or_else(|| AegisError::Parse("DELETE missing table".to_string()))?,
571            sp::FromTable::WithoutKeyword(tables) => tables
572                .first()
573                .map(|t| match &t.relation {
574                    sp::TableFactor::Table { name, .. } => name.to_string(),
575                    _ => String::new(),
576                })
577                .ok_or_else(|| AegisError::Parse("DELETE missing table".to_string()))?,
578        };
579
580        let where_clause = delete.selection.map(|e| self.convert_expr(e)).transpose()?;
581
582        Ok(DeleteStatement {
583            table,
584            where_clause,
585        })
586    }
587
588    fn convert_create_table(&self, create: sp::CreateTable) -> Result<CreateTableStatement> {
589        let columns = create
590            .columns
591            .into_iter()
592            .map(|col| {
593                Ok(ColumnDefinition {
594                    name: col.name.value,
595                    data_type: self.convert_data_type(&col.data_type)?,
596                    nullable: !col
597                        .options
598                        .iter()
599                        .any(|o| matches!(o.option, sp::ColumnOption::NotNull)),
600                    default: col
601                        .options
602                        .iter()
603                        .find_map(|o| match &o.option {
604                            sp::ColumnOption::Default(e) => Some(self.convert_expr(e.clone())),
605                            _ => None,
606                        })
607                        .transpose()?,
608                    constraints: Vec::new(),
609                })
610            })
611            .collect::<Result<Vec<_>>>()?;
612
613        Ok(CreateTableStatement {
614            name: create.name.to_string(),
615            columns,
616            constraints: Vec::new(),
617            if_not_exists: create.if_not_exists,
618        })
619    }
620
621    fn convert_drop(
622        &self,
623        object_type: sp::ObjectType,
624        if_exists: bool,
625        names: Vec<sp::ObjectName>,
626    ) -> Result<Statement> {
627        let name = names
628            .first()
629            .map(|n| n.to_string())
630            .ok_or_else(|| AegisError::Parse("DROP missing name".to_string()))?;
631
632        match object_type {
633            sp::ObjectType::Table => {
634                Ok(Statement::DropTable(DropTableStatement { name, if_exists }))
635            }
636            sp::ObjectType::Index => {
637                Ok(Statement::DropIndex(DropIndexStatement { name, if_exists }))
638            }
639            _ => Err(AegisError::Parse(format!(
640                "Unsupported DROP object type: {:?}",
641                object_type
642            ))),
643        }
644    }
645
646    fn convert_create_index(&self, create: sp::CreateIndex) -> Result<CreateIndexStatement> {
647        let name = create
648            .name
649            .map(|n| n.to_string())
650            .ok_or_else(|| AegisError::Parse("CREATE INDEX missing name".to_string()))?;
651
652        let table = create.table_name.to_string();
653
654        let columns = create
655            .columns
656            .into_iter()
657            .map(|c| c.expr.to_string())
658            .collect();
659
660        Ok(CreateIndexStatement {
661            name,
662            table,
663            columns,
664            unique: create.unique,
665            if_not_exists: create.if_not_exists,
666        })
667    }
668
669    fn convert_alter_table(
670        &self,
671        name: sp::ObjectName,
672        operations: Vec<sp::AlterTableOperation>,
673    ) -> Result<AlterTableStatement> {
674        let ops = operations
675            .into_iter()
676            .map(|op| self.convert_alter_operation(op))
677            .collect::<Result<Vec<_>>>()?;
678
679        Ok(AlterTableStatement {
680            name: name.to_string(),
681            operations: ops,
682        })
683    }
684
685    fn convert_alter_operation(&self, op: sp::AlterTableOperation) -> Result<AlterTableOperation> {
686        match op {
687            sp::AlterTableOperation::AddColumn { column_def, .. } => {
688                let col = ColumnDefinition {
689                    name: column_def.name.value.clone(),
690                    data_type: self.convert_data_type(&column_def.data_type)?,
691                    nullable: !column_def
692                        .options
693                        .iter()
694                        .any(|o| matches!(o.option, sp::ColumnOption::NotNull)),
695                    default: column_def
696                        .options
697                        .iter()
698                        .find_map(|o| match &o.option {
699                            sp::ColumnOption::Default(e) => Some(self.convert_expr(e.clone())),
700                            _ => None,
701                        })
702                        .transpose()?,
703                    constraints: Vec::new(),
704                };
705                Ok(AlterTableOperation::AddColumn(col))
706            }
707            sp::AlterTableOperation::DropColumn {
708                column_name,
709                if_exists,
710                ..
711            } => Ok(AlterTableOperation::DropColumn {
712                name: column_name.value,
713                if_exists,
714            }),
715            sp::AlterTableOperation::RenameColumn {
716                old_column_name,
717                new_column_name,
718            } => Ok(AlterTableOperation::RenameColumn {
719                old_name: old_column_name.value,
720                new_name: new_column_name.value,
721            }),
722            sp::AlterTableOperation::RenameTable { table_name } => {
723                Ok(AlterTableOperation::RenameTable {
724                    new_name: table_name.to_string(),
725                })
726            }
727            sp::AlterTableOperation::AlterColumn { column_name, op } => match op {
728                sp::AlterColumnOperation::SetDataType { data_type, .. } => {
729                    Ok(AlterTableOperation::AlterColumn {
730                        name: column_name.value,
731                        data_type: Some(self.convert_data_type(&data_type)?),
732                        set_not_null: None,
733                        set_default: None,
734                    })
735                }
736                sp::AlterColumnOperation::SetNotNull => Ok(AlterTableOperation::AlterColumn {
737                    name: column_name.value,
738                    data_type: None,
739                    set_not_null: Some(true),
740                    set_default: None,
741                }),
742                sp::AlterColumnOperation::DropNotNull => Ok(AlterTableOperation::AlterColumn {
743                    name: column_name.value,
744                    data_type: None,
745                    set_not_null: Some(false),
746                    set_default: None,
747                }),
748                sp::AlterColumnOperation::SetDefault { value } => {
749                    Ok(AlterTableOperation::AlterColumn {
750                        name: column_name.value,
751                        data_type: None,
752                        set_not_null: None,
753                        set_default: Some(Some(self.convert_expr(value)?)),
754                    })
755                }
756                sp::AlterColumnOperation::DropDefault => Ok(AlterTableOperation::AlterColumn {
757                    name: column_name.value,
758                    data_type: None,
759                    set_not_null: None,
760                    set_default: Some(None),
761                }),
762                _ => Err(AegisError::Parse(format!(
763                    "Unsupported ALTER COLUMN operation: {:?}",
764                    op
765                ))),
766            },
767            sp::AlterTableOperation::DropConstraint { name, .. } => {
768                Ok(AlterTableOperation::DropConstraint { name: name.value })
769            }
770            _ => Err(AegisError::Parse(format!(
771                "Unsupported ALTER TABLE operation: {:?}",
772                op
773            ))),
774        }
775    }
776
777    fn convert_data_type(&self, dt: &sp::DataType) -> Result<DataType> {
778        match dt {
779            sp::DataType::Boolean => Ok(DataType::Boolean),
780            sp::DataType::TinyInt(_) => Ok(DataType::TinyInt),
781            sp::DataType::SmallInt(_) => Ok(DataType::SmallInt),
782            sp::DataType::Int(_) | sp::DataType::Integer(_) => Ok(DataType::Integer),
783            sp::DataType::BigInt(_) => Ok(DataType::BigInt),
784            sp::DataType::Real => Ok(DataType::Float),
785            sp::DataType::Float(_) | sp::DataType::Double | sp::DataType::DoublePrecision => {
786                Ok(DataType::Double)
787            }
788            sp::DataType::Decimal(info) | sp::DataType::Numeric(info) => {
789                let (precision, scale) = match info {
790                    sp::ExactNumberInfo::PrecisionAndScale(p, s) => (*p as u8, *s as u8),
791                    sp::ExactNumberInfo::Precision(p) => (*p as u8, 0),
792                    sp::ExactNumberInfo::None => (18, 0),
793                };
794                Ok(DataType::Decimal { precision, scale })
795            }
796            sp::DataType::Char(len) => {
797                let size = len
798                    .as_ref()
799                    .and_then(|l| match l {
800                        sp::CharacterLength::IntegerLength { length, .. } => Some(*length as u16),
801                        sp::CharacterLength::Max => None,
802                    })
803                    .unwrap_or(1);
804                Ok(DataType::Char(size))
805            }
806            sp::DataType::Varchar(len) => {
807                let size = len
808                    .as_ref()
809                    .and_then(|l| match l {
810                        sp::CharacterLength::IntegerLength { length, .. } => Some(*length as u16),
811                        sp::CharacterLength::Max => None,
812                    })
813                    .unwrap_or(255);
814                Ok(DataType::Varchar(size))
815            }
816            sp::DataType::Text => Ok(DataType::Text),
817            sp::DataType::Blob(_) => Ok(DataType::Blob),
818            sp::DataType::Date => Ok(DataType::Date),
819            sp::DataType::Time(..) => Ok(DataType::Time),
820            sp::DataType::Timestamp(..) => Ok(DataType::Timestamp),
821            sp::DataType::JSON => Ok(DataType::Json),
822            sp::DataType::Uuid => Ok(DataType::Uuid),
823            _ => Err(AegisError::Parse(format!(
824                "Unsupported data type: {:?}",
825                dt
826            ))),
827        }
828    }
829
830    fn extract_limit(&self, expr: sp::Expr) -> Result<u64> {
831        match expr {
832            sp::Expr::Value(sp::Value::Number(n, _)) => n
833                .parse()
834                .map_err(|_| AegisError::Parse("Invalid LIMIT value".to_string())),
835            _ => Err(AegisError::Parse("LIMIT must be a number".to_string())),
836        }
837    }
838}
839
840impl Default for Parser {
841    fn default() -> Self {
842        Self::new()
843    }
844}
845
846// =============================================================================
847// Tests
848// =============================================================================
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853
854    #[test]
855    fn test_parse_simple_select() {
856        let parser = Parser::new();
857        let stmt = parser.parse_single("SELECT id, name FROM users").unwrap();
858
859        match stmt {
860            Statement::Select(select) => {
861                assert_eq!(select.columns.len(), 2);
862                assert!(select.from.is_some());
863            }
864            _ => panic!("Expected SELECT statement"),
865        }
866    }
867
868    #[test]
869    fn test_parse_select_with_where() {
870        let parser = Parser::new();
871        let stmt = parser
872            .parse_single("SELECT * FROM users WHERE age > 18")
873            .unwrap();
874
875        match stmt {
876            Statement::Select(select) => {
877                assert!(select.where_clause.is_some());
878            }
879            _ => panic!("Expected SELECT statement"),
880        }
881    }
882
883    #[test]
884    fn test_parse_select_with_join() {
885        let parser = Parser::new();
886        let stmt = parser
887            .parse_single("SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id")
888            .unwrap();
889
890        match stmt {
891            Statement::Select(select) => {
892                let from = select.from.unwrap();
893                assert_eq!(from.joins.len(), 1);
894                assert_eq!(from.joins[0].join_type, JoinType::Inner);
895            }
896            _ => panic!("Expected SELECT statement"),
897        }
898    }
899
900    #[test]
901    fn test_parse_insert() {
902        let parser = Parser::new();
903        let stmt = parser
904            .parse_single("INSERT INTO users (name, age) VALUES ('Alice', 25)")
905            .unwrap();
906
907        match stmt {
908            Statement::Insert(insert) => {
909                assert_eq!(insert.table, "users");
910                assert_eq!(
911                    insert.columns,
912                    Some(vec!["name".to_string(), "age".to_string()])
913                );
914            }
915            _ => panic!("Expected INSERT statement"),
916        }
917    }
918
919    #[test]
920    fn test_parse_update() {
921        let parser = Parser::new();
922        let stmt = parser
923            .parse_single("UPDATE users SET age = 26 WHERE name = 'Alice'")
924            .unwrap();
925
926        match stmt {
927            Statement::Update(update) => {
928                assert_eq!(update.table, "users");
929                assert_eq!(update.assignments.len(), 1);
930                assert!(update.where_clause.is_some());
931            }
932            _ => panic!("Expected UPDATE statement"),
933        }
934    }
935
936    #[test]
937    fn test_parse_delete() {
938        let parser = Parser::new();
939        let stmt = parser
940            .parse_single("DELETE FROM users WHERE age < 18")
941            .unwrap();
942
943        match stmt {
944            Statement::Delete(delete) => {
945                assert_eq!(delete.table, "users");
946                assert!(delete.where_clause.is_some());
947            }
948            _ => panic!("Expected DELETE statement"),
949        }
950    }
951
952    #[test]
953    fn test_parse_create_table() {
954        let parser = Parser::new();
955        let stmt = parser
956            .parse_single(
957                "CREATE TABLE users (
958                    id INTEGER NOT NULL,
959                    name VARCHAR(255),
960                    age INTEGER
961                )",
962            )
963            .unwrap();
964
965        match stmt {
966            Statement::CreateTable(create) => {
967                assert_eq!(create.name, "users");
968                assert_eq!(create.columns.len(), 3);
969                assert!(!create.columns[0].nullable);
970                assert!(create.columns[1].nullable);
971            }
972            _ => panic!("Expected CREATE TABLE statement"),
973        }
974    }
975
976    #[test]
977    fn test_parse_transaction_statements() {
978        let parser = Parser::new();
979
980        assert!(matches!(
981            parser.parse_single("BEGIN").unwrap(),
982            Statement::Begin
983        ));
984        assert!(matches!(
985            parser.parse_single("COMMIT").unwrap(),
986            Statement::Commit
987        ));
988        assert!(matches!(
989            parser.parse_single("ROLLBACK").unwrap(),
990            Statement::Rollback
991        ));
992    }
993
994    #[test]
995    fn test_parse_union() {
996        let parser = Parser::new();
997        let stmt = parser
998            .parse_single("SELECT id FROM users UNION SELECT id FROM orders")
999            .unwrap();
1000
1001        match stmt {
1002            Statement::SetOperation(set_op) => {
1003                assert_eq!(set_op.op, SetOperationType::Union);
1004                assert!(matches!(*set_op.left, Statement::Select(_)));
1005                assert!(matches!(*set_op.right, Statement::Select(_)));
1006            }
1007            _ => panic!("Expected SetOperation statement"),
1008        }
1009    }
1010
1011    #[test]
1012    fn test_parse_union_all() {
1013        let parser = Parser::new();
1014        let stmt = parser
1015            .parse_single("SELECT id FROM users UNION ALL SELECT id FROM orders")
1016            .unwrap();
1017
1018        match stmt {
1019            Statement::SetOperation(set_op) => {
1020                assert_eq!(set_op.op, SetOperationType::UnionAll);
1021            }
1022            _ => panic!("Expected SetOperation statement"),
1023        }
1024    }
1025
1026    #[test]
1027    fn test_parse_intersect() {
1028        let parser = Parser::new();
1029        let stmt = parser
1030            .parse_single("SELECT id FROM users INTERSECT SELECT id FROM orders")
1031            .unwrap();
1032
1033        match stmt {
1034            Statement::SetOperation(set_op) => {
1035                assert_eq!(set_op.op, SetOperationType::Intersect);
1036            }
1037            _ => panic!("Expected SetOperation statement"),
1038        }
1039    }
1040
1041    #[test]
1042    fn test_parse_except() {
1043        let parser = Parser::new();
1044        let stmt = parser
1045            .parse_single("SELECT id FROM users EXCEPT SELECT id FROM orders")
1046            .unwrap();
1047
1048        match stmt {
1049            Statement::SetOperation(set_op) => {
1050                assert_eq!(set_op.op, SetOperationType::Except);
1051            }
1052            _ => panic!("Expected SetOperation statement"),
1053        }
1054    }
1055}