Skip to main content

rook_parser/
ast.rs

1/// Abstract Syntax Tree (AST) structures for SQL statements
2use std::fmt;
3
4/// Represents a complete SQL statement
5#[derive(Debug, Clone, PartialEq)]
6pub enum Statement {
7    Select(SelectStatement),
8    Insert(InsertStatement),
9    Update(UpdateStatement),
10    Delete(DeleteStatement),
11    Create(CreateStatement),
12    Drop(DropStatement),
13    Alter(AlterStatement),
14}
15
16impl fmt::Display for Statement {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        match self {
19            Statement::Select(s) => write!(f, "SELECT: {}", s),
20            Statement::Insert(s) => write!(f, "INSERT: {}", s),
21            Statement::Update(s) => write!(f, "UPDATE: {}", s),
22            Statement::Delete(s) => write!(f, "DELETE: {}", s),
23            Statement::Create(s) => write!(f, "CREATE: {}", s),
24            Statement::Drop(s) => write!(f, "DROP: {}", s),
25            Statement::Alter(s) => write!(f, "ALTER: {}", s),
26        }
27    }
28}
29
30/// SELECT statement
31#[derive(Debug, Clone, PartialEq)]
32pub struct SelectStatement {
33    pub distinct: bool,
34    pub select_list: Vec<SelectItem>,
35    pub from_clause: Option<FromClause>,
36    pub where_clause: Option<Expression>,
37    pub group_by_clause: Option<Vec<Expression>>,
38    pub having_clause: Option<Expression>,
39    pub order_by_clause: Option<Vec<OrderByItem>>,
40    pub limit_clause: Option<LimitClause>,
41}
42
43impl fmt::Display for SelectStatement {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        write!(f, "SELECT{}", if self.distinct { " DISTINCT" } else { "" })?;
46        write!(f, " [{}]", self.select_list.iter()
47            .map(|s| s.to_string())
48            .collect::<Vec<_>>()
49            .join(", "))?;
50        
51        if let Some(from) = &self.from_clause {
52            write!(f, " FROM {}", from)?;
53        }
54        if let Some(where_expr) = &self.where_clause {
55            write!(f, " WHERE {}", where_expr)?;
56        }
57        if let Some(group_by) = &self.group_by_clause {
58            write!(f, " GROUP BY [{}]", group_by.iter()
59                .map(|e| e.to_string())
60                .collect::<Vec<_>>()
61                .join(", "))?;
62        }
63        if let Some(having) = &self.having_clause {
64            write!(f, " HAVING {}", having)?;
65        }
66        if let Some(order_by) = &self.order_by_clause {
67            write!(f, " ORDER BY [{}]", order_by.iter()
68                .map(|o| o.to_string())
69                .collect::<Vec<_>>()
70                .join(", "))?;
71        }
72        if let Some(limit) = &self.limit_clause {
73            write!(f, " {}", limit)?;
74        }
75        Ok(())
76    }
77}
78
79/// SELECT list item
80#[derive(Debug, Clone, PartialEq)]
81pub enum SelectItem {
82    /// All columns (*)
83    AllColumns,
84    /// Column name with optional alias
85    Column {
86        expr: Expression,
87        alias: Option<String>,
88    },
89}
90
91impl fmt::Display for SelectItem {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        match self {
94            SelectItem::AllColumns => write!(f, "*"),
95            SelectItem::Column { expr, alias } => {
96                if let Some(a) = alias {
97                    write!(f, "{} AS {}", expr, a)
98                } else {
99                    write!(f, "{}", expr)
100                }
101            }
102        }
103    }
104}
105
106/// FROM clause with table and joins
107#[derive(Debug, Clone, PartialEq)]
108pub struct FromClause {
109    pub table: TableReference,
110    pub joins: Vec<Join>,
111}
112
113impl fmt::Display for FromClause {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        write!(f, "{}", self.table)?;
116        for join in &self.joins {
117            write!(f, " {}", join)?;
118        }
119        Ok(())
120    }
121}
122
123/// Table reference with optional alias
124#[derive(Debug, Clone, PartialEq)]
125pub struct TableReference {
126    pub name: String,
127    pub alias: Option<String>,
128}
129
130impl fmt::Display for TableReference {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        write!(f, "{}", self.name)?;
133        if let Some(alias) = &self.alias {
134            write!(f, " {}", alias)?;
135        }
136        Ok(())
137    }
138}
139
140/// JOIN clause
141#[derive(Debug, Clone, PartialEq)]
142pub enum JoinType {
143    Inner,
144    Left,
145    Right,
146    Full,
147    Cross,
148}
149
150impl fmt::Display for JoinType {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        match self {
153            JoinType::Inner => write!(f, "INNER"),
154            JoinType::Left => write!(f, "LEFT"),
155            JoinType::Right => write!(f, "RIGHT"),
156            JoinType::Full => write!(f, "FULL"),
157            JoinType::Cross => write!(f, "CROSS"),
158        }
159    }
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub struct Join {
164    pub join_type: JoinType,
165    pub table: TableReference,
166    pub on_condition: Option<Expression>,
167}
168
169impl fmt::Display for Join {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        write!(f, "{} JOIN {}", self.join_type, self.table)?;
172        if let Some(cond) = &self.on_condition {
173            write!(f, " ON {}", cond)?;
174        }
175        Ok(())
176    }
177}
178
179/// ORDER BY item
180#[derive(Debug, Clone, PartialEq)]
181pub struct OrderByItem {
182    pub expr: Expression,
183    pub direction: SortDirection,
184}
185
186impl fmt::Display for OrderByItem {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        write!(f, "{} {}", self.expr, self.direction)
189    }
190}
191
192#[derive(Debug, Clone, PartialEq)]
193pub enum SortDirection {
194    Asc,
195    Desc,
196}
197
198impl fmt::Display for SortDirection {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        match self {
201            SortDirection::Asc => write!(f, "ASC"),
202            SortDirection::Desc => write!(f, "DESC"),
203        }
204    }
205}
206
207/// LIMIT clause
208#[derive(Debug, Clone, PartialEq)]
209pub struct LimitClause {
210    pub limit: i64,
211    pub offset: Option<i64>,
212}
213
214impl fmt::Display for LimitClause {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        write!(f, "LIMIT {}", self.limit)?;
217        if let Some(offset) = self.offset {
218            write!(f, " OFFSET {}", offset)?;
219        }
220        Ok(())
221    }
222}
223
224/// Expressions
225#[derive(Debug, Clone, PartialEq)]
226pub enum Expression {
227    // Literals
228    Number(String),
229    String(String),
230    
231    // Compound expressions
232    BinaryOp {
233        left: Box<Expression>,
234        op: BinaryOperator,
235        right: Box<Expression>,
236    },
237    UnaryOp {
238        op: UnaryOperator,
239        expr: Box<Expression>,
240    },
241    FunctionCall {
242        name: String,
243        args: Vec<Expression>,
244    },
245    Column {
246        table: Option<String>,
247        name: String,
248    },
249    /// (expression)
250    Parenthesized(Box<Expression>),
251    
252    // Special
253    Star,
254    Null,
255}
256
257impl fmt::Display for Expression {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        match self {
260            Expression::Number(n) => write!(f, "{}", n),
261            Expression::String(s) => write!(f, "'{}'", s),
262            Expression::Star => write!(f, "*"),
263            Expression::Null => write!(f, "NULL"),
264            Expression::BinaryOp { left, op, right } => {
265                write!(f, "({} {} {})", left, op, right)
266            }
267            Expression::UnaryOp { op, expr } => {
268                write!(f, "{} {}", op, expr)
269            }
270            Expression::FunctionCall { name, args } => {
271                write!(f, "{}({})", name, args.iter()
272                    .map(|a| a.to_string())
273                    .collect::<Vec<_>>()
274                    .join(", "))
275            }
276            Expression::Column { table, name } => {
277                if let Some(t) = table {
278                    write!(f, "{}.{}", t, name)
279                } else {
280                    write!(f, "{}", name)
281                }
282            }
283            Expression::Parenthesized(expr) => {
284                write!(f, "({})", expr)
285            }
286        }
287    }
288}
289
290#[derive(Debug, Clone, PartialEq)]
291pub enum BinaryOperator {
292    // Comparison
293    Equal,
294    NotEqual,
295    LessThan,
296    LessThanOrEqual,
297    GreaterThan,
298    GreaterThanOrEqual,
299    
300    // Logical
301    And,
302    Or,
303    
304    // Arithmetic
305    Plus,
306    Minus,
307    Multiply,
308    Divide,
309    Modulo,
310    
311    // String
312    Like,
313    In,
314    Between,
315    Is,
316}
317
318impl fmt::Display for BinaryOperator {
319    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320        match self {
321            BinaryOperator::Equal => write!(f, "="),
322            BinaryOperator::NotEqual => write!(f, "!="),
323            BinaryOperator::LessThan => write!(f, "<"),
324            BinaryOperator::LessThanOrEqual => write!(f, "<="),
325            BinaryOperator::GreaterThan => write!(f, ">"),
326            BinaryOperator::GreaterThanOrEqual => write!(f, ">="),
327            BinaryOperator::And => write!(f, "AND"),
328            BinaryOperator::Or => write!(f, "OR"),
329            BinaryOperator::Plus => write!(f, "+"),
330            BinaryOperator::Minus => write!(f, "-"),
331            BinaryOperator::Multiply => write!(f, "*"),
332            BinaryOperator::Divide => write!(f, "/"),
333            BinaryOperator::Modulo => write!(f, "%"),
334            BinaryOperator::Like => write!(f, "LIKE"),
335            BinaryOperator::In => write!(f, "IN"),
336            BinaryOperator::Between => write!(f, "BETWEEN"),
337            BinaryOperator::Is => write!(f, "IS"),
338        }
339    }
340}
341
342#[derive(Debug, Clone, PartialEq)]
343pub enum UnaryOperator {
344    Not,
345    Minus,
346    Plus,
347}
348
349impl fmt::Display for UnaryOperator {
350    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351        match self {
352            UnaryOperator::Not => write!(f, "NOT"),
353            UnaryOperator::Minus => write!(f, "-"),
354            UnaryOperator::Plus => write!(f, "+"),
355        }
356    }
357}
358
359/// INSERT statement
360#[derive(Debug, Clone, PartialEq)]
361pub struct InsertStatement {
362    pub table: String,
363    pub columns: Option<Vec<String>>,
364    pub values: Vec<Vec<Expression>>,
365}
366
367impl fmt::Display for InsertStatement {
368    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369        write!(f, "INTO {} ", self.table)?;
370        if let Some(cols) = &self.columns {
371            write!(f, "({})", cols.join(", "))?;
372        }
373        write!(f, " VALUES")?;
374        for (idx, row) in self.values.iter().enumerate() {
375            if idx > 0 {
376                write!(f, ",")?;
377            }
378            write!(f, " ({})", row.iter()
379                .map(|v| v.to_string())
380                .collect::<Vec<_>>()
381                .join(", "))?;
382        }
383        Ok(())
384    }
385}
386
387/// UPDATE statement
388#[derive(Debug, Clone, PartialEq)]
389pub struct UpdateStatement {
390    pub table: String,
391    pub assignments: Vec<(String, Expression)>,
392    pub where_clause: Option<Expression>,
393}
394
395impl fmt::Display for UpdateStatement {
396    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
397        write!(f, "{} SET ", self.table)?;
398        write!(f, "{}", self.assignments.iter()
399            .map(|(col, expr)| format!("{} = {}", col, expr))
400            .collect::<Vec<_>>()
401            .join(", "))?;
402        if let Some(where_expr) = &self.where_clause {
403            write!(f, " WHERE {}", where_expr)?;
404        }
405        Ok(())
406    }
407}
408
409/// DELETE statement
410#[derive(Debug, Clone, PartialEq)]
411pub struct DeleteStatement {
412    pub table: String,
413    pub where_clause: Option<Expression>,
414}
415
416impl fmt::Display for DeleteStatement {
417    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
418        write!(f, "FROM {} ", self.table)?;
419        if let Some(where_expr) = &self.where_clause {
420            write!(f, "WHERE {}", where_expr)?;
421        }
422        Ok(())
423    }
424}
425
426/// CREATE TABLE statement
427#[derive(Debug, Clone, PartialEq)]
428pub struct CreateStatement {
429    pub name: String,
430    pub columns: Vec<ColumnDefinition>,
431}
432
433impl fmt::Display for CreateStatement {
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        write!(f, "TABLE {} (", self.name)?;
436        write!(f, "{})", self.columns.iter()
437            .map(|c| c.to_string())
438            .collect::<Vec<_>>()
439            .join(", "))?;
440        Ok(())
441    }
442}
443
444#[derive(Debug, Clone, PartialEq)]
445pub struct ColumnDefinition {
446    pub name: String,
447    pub data_type: String,
448    pub constraints: Vec<String>,
449}
450
451impl fmt::Display for ColumnDefinition {
452    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453        write!(f, "{} {}", self.name, self.data_type)?;
454        if !self.constraints.is_empty() {
455            write!(f, " {}", self.constraints.join(" "))?;
456        }
457        Ok(())
458    }
459}
460
461/// DROP statement
462#[derive(Debug, Clone, PartialEq)]
463pub struct DropStatement {
464    pub object_type: String, // TABLE, DATABASE, etc.
465    pub name: String,
466    pub if_exists: bool,
467}
468
469impl fmt::Display for DropStatement {
470    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471        write!(f, "{} {} {}", 
472            self.object_type, 
473            if self.if_exists { "IF EXISTS" } else { "" },
474            self.name)?;
475        Ok(())
476    }
477}
478
479/// ALTER statement
480#[derive(Debug, Clone, PartialEq)]
481pub struct AlterStatement {
482    pub table: String,
483    pub action: AlterAction,
484}
485
486impl fmt::Display for AlterStatement {
487    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
488        write!(f, "TABLE {} {}", self.table, self.action)
489    }
490}
491
492#[derive(Debug, Clone, PartialEq)]
493pub enum AlterAction {
494    Add(ColumnDefinition),
495    Drop(String),
496    Rename { old_name: String, new_name: String },
497}
498
499impl fmt::Display for AlterAction {
500    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
501        match self {
502            AlterAction::Add(col) => write!(f, "ADD COLUMN {}", col),
503            AlterAction::Drop(name) => write!(f, "DROP COLUMN {}", name),
504            AlterAction::Rename { old_name, new_name } => {
505                write!(f, "RENAME COLUMN {} TO {}", old_name, new_name)
506            }
507        }
508    }
509}