chryso_core/ast/
mod.rs

1#[derive(Debug, Clone)]
2pub enum Statement {
3    With(WithStatement),
4    Select(SelectStatement),
5    SetOp {
6        left: Box<Statement>,
7        op: SetOperator,
8        right: Box<Statement>,
9    },
10    Explain(Box<Statement>),
11    CreateTable(CreateTableStatement),
12    DropTable(DropTableStatement),
13    Truncate(TruncateStatement),
14    Analyze(AnalyzeStatement),
15    Insert(InsertStatement),
16    Update(UpdateStatement),
17    Delete(DeleteStatement),
18}
19
20#[derive(Debug, Clone)]
21pub struct CreateTableStatement {
22    pub name: String,
23    pub if_not_exists: bool,
24    pub columns: Vec<ColumnDef>,
25}
26
27#[derive(Debug, Clone)]
28pub struct ColumnDef {
29    pub name: String,
30    pub data_type: String,
31}
32
33#[derive(Debug, Clone)]
34pub struct DropTableStatement {
35    pub name: String,
36    pub if_exists: bool,
37}
38
39#[derive(Debug, Clone)]
40pub struct TruncateStatement {
41    pub table: String,
42}
43
44#[derive(Debug, Clone)]
45pub struct AnalyzeStatement {
46    pub table: String,
47}
48
49#[derive(Debug, Clone)]
50pub struct InsertStatement {
51    pub table: String,
52    pub columns: Vec<String>,
53    pub source: InsertSource,
54    pub returning: Vec<SelectItem>,
55}
56
57#[derive(Debug, Clone)]
58pub enum InsertSource {
59    Values(Vec<Vec<Expr>>),
60    Query(Box<Statement>),
61    DefaultValues,
62}
63
64#[derive(Debug, Clone)]
65pub struct UpdateStatement {
66    pub table: String,
67    pub assignments: Vec<Assignment>,
68    pub selection: Option<Expr>,
69    pub returning: Vec<SelectItem>,
70}
71
72#[derive(Debug, Clone)]
73pub struct Assignment {
74    pub column: String,
75    pub value: Expr,
76}
77
78#[derive(Debug, Clone)]
79pub struct DeleteStatement {
80    pub table: String,
81    pub selection: Option<Expr>,
82    pub returning: Vec<SelectItem>,
83}
84
85#[derive(Debug, Clone)]
86pub struct SelectStatement {
87    pub distinct: bool,
88    pub distinct_on: Vec<Expr>,
89    pub projection: Vec<SelectItem>,
90    pub from: Option<TableRef>,
91    pub selection: Option<Expr>,
92    pub group_by: Vec<Expr>,
93    pub having: Option<Expr>,
94    pub order_by: Vec<OrderByExpr>,
95    pub limit: Option<u64>,
96    pub offset: Option<u64>,
97}
98
99#[derive(Debug, Clone)]
100pub struct WithStatement {
101    pub ctes: Vec<Cte>,
102    pub recursive: bool,
103    pub statement: Box<Statement>,
104}
105
106#[derive(Debug, Clone)]
107pub struct Cte {
108    pub name: String,
109    pub columns: Vec<String>,
110    pub query: Box<Statement>,
111}
112
113#[derive(Debug, Clone, Copy)]
114pub enum SetOperator {
115    Union,
116    UnionAll,
117    Intersect,
118    IntersectAll,
119    Except,
120    ExceptAll,
121}
122
123#[derive(Debug, Clone)]
124pub struct TableRef {
125    pub factor: TableFactor,
126    pub alias: Option<String>,
127    pub column_aliases: Vec<String>,
128    pub joins: Vec<Join>,
129}
130
131#[derive(Debug, Clone)]
132pub enum TableFactor {
133    Table { name: String },
134    Derived { query: Box<Statement> },
135}
136
137#[derive(Debug, Clone)]
138pub struct Join {
139    pub join_type: JoinType,
140    pub right: TableRef,
141    pub on: Expr,
142}
143
144#[derive(Debug, Clone, Copy)]
145pub enum JoinType {
146    Inner,
147    Left,
148    Right,
149    Full,
150}
151
152#[derive(Debug, Clone)]
153pub struct SelectItem {
154    pub expr: Expr,
155    pub alias: Option<String>,
156}
157
158#[derive(Debug, Clone)]
159pub struct OrderByExpr {
160    pub expr: Expr,
161    pub asc: bool,
162    pub nulls_first: Option<bool>,
163}
164
165#[derive(Debug, Clone)]
166pub enum Expr {
167    Identifier(String),
168    Literal(Literal),
169    BinaryOp {
170        left: Box<Expr>,
171        op: BinaryOperator,
172        right: Box<Expr>,
173    },
174    IsNull {
175        expr: Box<Expr>,
176        negated: bool,
177    },
178    UnaryOp {
179        op: UnaryOperator,
180        expr: Box<Expr>,
181    },
182    FunctionCall {
183        name: String,
184        args: Vec<Expr>,
185    },
186    WindowFunction {
187        function: Box<Expr>,
188        spec: WindowSpec,
189    },
190    Subquery(Box<SelectStatement>),
191    Exists(Box<SelectStatement>),
192    InSubquery {
193        expr: Box<Expr>,
194        subquery: Box<SelectStatement>,
195    },
196    Case {
197        operand: Option<Box<Expr>>,
198        when_then: Vec<(Expr, Expr)>,
199        else_expr: Option<Box<Expr>>,
200    },
201    Wildcard,
202}
203
204#[derive(Debug, Clone)]
205pub enum Literal {
206    String(String),
207    Number(f64),
208    Bool(bool),
209}
210
211#[derive(Debug, Clone, Copy, PartialEq)]
212pub enum BinaryOperator {
213    Eq,
214    NotEq,
215    Lt,
216    LtEq,
217    Gt,
218    GtEq,
219    And,
220    Or,
221    Add,
222    Sub,
223    Mul,
224    Div,
225}
226
227#[derive(Debug, Clone, Copy, PartialEq)]
228pub enum UnaryOperator {
229    Not,
230    Neg,
231}
232
233#[derive(Debug, Clone)]
234pub struct WindowSpec {
235    pub partition_by: Vec<Expr>,
236    pub order_by: Vec<OrderByExpr>,
237}
238
239impl Expr {
240    pub fn to_sql(&self) -> String {
241        match self {
242            Expr::Identifier(name) => name.clone(),
243            Expr::Literal(Literal::String(value)) => format!("'{}'", value),
244            Expr::Literal(Literal::Number(value)) => value.to_string(),
245            Expr::Literal(Literal::Bool(value)) => {
246                if *value { "true".to_string() } else { "false".to_string() }
247            }
248            Expr::BinaryOp { left, op, right } => {
249                let op_str = match op {
250                    BinaryOperator::Eq => "=",
251                    BinaryOperator::NotEq => "!=",
252                    BinaryOperator::Lt => "<",
253                    BinaryOperator::LtEq => "<=",
254                    BinaryOperator::Gt => ">",
255                    BinaryOperator::GtEq => ">=",
256                    BinaryOperator::And => "and",
257                    BinaryOperator::Or => "or",
258                    BinaryOperator::Add => "+",
259                    BinaryOperator::Sub => "-",
260                    BinaryOperator::Mul => "*",
261                    BinaryOperator::Div => "/",
262                };
263                format!("{} {} {}", left.to_sql(), op_str, right.to_sql())
264            }
265            Expr::IsNull { expr, negated } => {
266                if *negated {
267                    format!("{} is not null", expr.to_sql())
268                } else {
269                    format!("{} is null", expr.to_sql())
270                }
271            }
272            Expr::UnaryOp { op, expr } => match op {
273                UnaryOperator::Not => format!("not {}", expr.to_sql()),
274                UnaryOperator::Neg => format!("-{}", expr.to_sql()),
275            },
276            Expr::FunctionCall { name, args } => {
277                let args_sql = args.iter().map(|arg| arg.to_sql()).collect::<Vec<_>>();
278                format!("{}({})", name, args_sql.join(", "))
279            }
280            Expr::WindowFunction { function, spec } => {
281                let mut clauses = Vec::new();
282                if !spec.partition_by.is_empty() {
283                    let partition = spec
284                        .partition_by
285                        .iter()
286                        .map(|expr| expr.to_sql())
287                        .collect::<Vec<_>>()
288                        .join(", ");
289                    clauses.push(format!("partition by {partition}"));
290                }
291                if !spec.order_by.is_empty() {
292                    let order = spec
293                        .order_by
294                        .iter()
295                        .map(|item| {
296                            let dir = if item.asc { "asc" } else { "desc" };
297                            let mut rendered = format!("{} {dir}", item.expr.to_sql());
298                            if let Some(nulls_first) = item.nulls_first {
299                                if nulls_first {
300                                    rendered.push_str(" nulls first");
301                                } else {
302                                    rendered.push_str(" nulls last");
303                                }
304                            }
305                            rendered
306                        })
307                        .collect::<Vec<_>>()
308                        .join(", ");
309                    clauses.push(format!("order by {order}"));
310                }
311                format!("{} over ({})", function.to_sql(), clauses.join(" "))
312            }
313            Expr::Exists(select) => format!("exists ({})", select_to_sql(select)),
314            Expr::InSubquery { expr, subquery } => {
315                format!("{} in ({})", expr.to_sql(), select_to_sql(subquery))
316            }
317            Expr::Subquery(select) => format!("({})", select_to_sql(select)),
318            Expr::Case {
319                operand,
320                when_then,
321                else_expr,
322            } => {
323                let mut output = String::from("case");
324                if let Some(expr) = operand {
325                    output.push(' ');
326                    output.push_str(&expr.to_sql());
327                }
328                for (when_expr, then_expr) in when_then {
329                    output.push_str(" when ");
330                    output.push_str(&when_expr.to_sql());
331                    output.push_str(" then ");
332                    output.push_str(&then_expr.to_sql());
333                }
334                if let Some(expr) = else_expr {
335                    output.push_str(" else ");
336                    output.push_str(&expr.to_sql());
337                }
338                output.push_str(" end");
339                output
340            }
341            Expr::Wildcard => "*".to_string(),
342        }
343    }
344
345    pub fn structural_eq(&self, other: &Expr) -> bool {
346        const FLOAT_EPSILON: f64 = 1e-9;
347        match (self, other) {
348            (Expr::Identifier(left), Expr::Identifier(right)) => left == right,
349            (Expr::Literal(left), Expr::Literal(right)) => match (left, right) {
350                (Literal::String(left), Literal::String(right)) => left == right,
351                (Literal::Number(left), Literal::Number(right)) => {
352                    if left.is_nan() || right.is_nan() {
353                        false
354                    } else if left.is_infinite() || right.is_infinite() {
355                        left == right
356                    } else {
357                        (left - right).abs() <= FLOAT_EPSILON
358                    }
359                }
360                (Literal::Bool(left), Literal::Bool(right)) => left == right,
361                _ => false,
362            },
363            (Expr::UnaryOp { op: left_op, expr: left }, Expr::UnaryOp { op: right_op, expr: right }) => {
364                left_op == right_op && left.structural_eq(right)
365            }
366            (
367                Expr::BinaryOp { left: left_lhs, op: left_op, right: left_rhs },
368                Expr::BinaryOp { left: right_lhs, op: right_op, right: right_rhs },
369            ) => left_op == right_op
370                && left_lhs.structural_eq(right_lhs)
371                && left_rhs.structural_eq(right_rhs),
372            (
373                Expr::IsNull { expr: left, negated: left_negated },
374                Expr::IsNull { expr: right, negated: right_negated },
375            ) => left_negated == right_negated && left.structural_eq(right),
376            (
377                Expr::FunctionCall { name: left_name, args: left_args },
378                Expr::FunctionCall { name: right_name, args: right_args },
379            ) => left_name == right_name
380                && left_args.len() == right_args.len()
381                && left_args
382                    .iter()
383                    .zip(right_args.iter())
384                    .all(|(left, right)| left.structural_eq(right)),
385            (
386                Expr::WindowFunction { function: left_func, spec: left_spec },
387                Expr::WindowFunction { function: right_func, spec: right_spec },
388            ) => left_func.structural_eq(right_func)
389                && left_spec.partition_by.len() == right_spec.partition_by.len()
390                && left_spec
391                    .partition_by
392                    .iter()
393                    .zip(right_spec.partition_by.iter())
394                    .all(|(left, right)| left.structural_eq(right))
395                && left_spec.order_by.len() == right_spec.order_by.len()
396                && left_spec
397                    .order_by
398                    .iter()
399                    .zip(right_spec.order_by.iter())
400                    .all(|(left, right)| {
401                        left.asc == right.asc
402                            && left.nulls_first == right.nulls_first
403                            && left.expr.structural_eq(&right.expr)
404                    }),
405            (Expr::Subquery(left), Expr::Subquery(right)) => select_to_sql(left) == select_to_sql(right),
406            (Expr::Exists(left), Expr::Exists(right)) => select_to_sql(left) == select_to_sql(right),
407            (
408                Expr::InSubquery { expr: left_expr, subquery: left_subquery },
409                Expr::InSubquery { expr: right_expr, subquery: right_subquery },
410            ) => left_expr.structural_eq(right_expr)
411                && select_to_sql(left_subquery) == select_to_sql(right_subquery),
412            (
413                Expr::Case { operand: left_operand, when_then: left_when_then, else_expr: left_else },
414                Expr::Case { operand: right_operand, when_then: right_when_then, else_expr: right_else },
415            ) => left_operand
416                .as_ref()
417                .zip(right_operand.as_ref())
418                .map(|(left, right)| left.structural_eq(right))
419                .unwrap_or(left_operand.is_none() && right_operand.is_none())
420                && left_when_then.len() == right_when_then.len()
421                && left_when_then.iter().zip(right_when_then.iter()).all(|(left, right)| {
422                    left.0.structural_eq(&right.0) && left.1.structural_eq(&right.1)
423                })
424                && left_else
425                    .as_ref()
426                    .zip(right_else.as_ref())
427                    .map(|(left, right)| left.structural_eq(right))
428                    .unwrap_or(left_else.is_none() && right_else.is_none()),
429            (Expr::Wildcard, Expr::Wildcard) => true,
430            _ => false,
431        }
432    }
433
434    pub fn normalize(&self) -> Expr {
435        let normalized = match self {
436            Expr::BinaryOp { left, op, right } => {
437                let left_norm = left.normalize();
438                let right_norm = right.normalize();
439                if matches!(op, BinaryOperator::And | BinaryOperator::Or) {
440                    if left_norm.to_sql() > right_norm.to_sql() {
441                        return Expr::BinaryOp {
442                            left: Box::new(right_norm),
443                            op: *op,
444                            right: Box::new(left_norm),
445                        };
446                    }
447                }
448                Expr::BinaryOp {
449                    left: Box::new(left_norm),
450                    op: *op,
451                    right: Box::new(right_norm),
452                }
453            }
454            Expr::IsNull { expr, negated } => Expr::IsNull {
455                expr: Box::new(expr.normalize()),
456                negated: *negated,
457            },
458            Expr::UnaryOp { op, expr } => Expr::UnaryOp {
459                op: *op,
460                expr: Box::new(expr.normalize()),
461            },
462            Expr::FunctionCall { name, args } => Expr::FunctionCall {
463                name: name.clone(),
464                args: args.iter().map(|arg| arg.normalize()).collect(),
465            },
466            Expr::Case {
467                operand,
468                when_then,
469                else_expr,
470            } => Expr::Case {
471                operand: operand.as_ref().map(|expr| Box::new(expr.normalize())),
472                when_then: when_then
473                    .iter()
474                    .map(|(when_expr, then_expr)| (when_expr.normalize(), then_expr.normalize()))
475                    .collect(),
476                else_expr: else_expr.as_ref().map(|expr| Box::new(expr.normalize())),
477            },
478            Expr::WindowFunction { function, spec } => Expr::WindowFunction {
479                function: Box::new(function.normalize()),
480                spec: spec.clone(),
481            },
482            Expr::Exists(select) => Expr::Exists(Box::new(normalize_select_inner(select))),
483            Expr::InSubquery { expr, subquery } => Expr::InSubquery {
484                expr: Box::new(expr.normalize()),
485                subquery: Box::new(normalize_select_inner(subquery)),
486            },
487            Expr::Subquery(select) => Expr::Subquery(Box::new(normalize_select_inner(select))),
488            other => other.clone(),
489        };
490        rewrite_strong_expr(normalized)
491    }
492}
493
494fn rewrite_strong_expr(expr: Expr) -> Expr {
495    match expr {
496        Expr::UnaryOp {
497            op: UnaryOperator::Not,
498            expr,
499        } => match *expr {
500            Expr::Literal(Literal::Bool(value)) => Expr::Literal(Literal::Bool(!value)),
501            Expr::UnaryOp {
502                op: UnaryOperator::Not,
503                expr,
504            } => *expr,
505            Expr::IsNull { expr, negated } => Expr::IsNull {
506                expr,
507                negated: !negated,
508            },
509            Expr::BinaryOp { left, op, right } => match op {
510                BinaryOperator::Eq => Expr::BinaryOp {
511                    left,
512                    op: BinaryOperator::NotEq,
513                    right,
514                },
515                BinaryOperator::NotEq => Expr::BinaryOp {
516                    left,
517                    op: BinaryOperator::Eq,
518                    right,
519                },
520                BinaryOperator::Lt => Expr::BinaryOp {
521                    left,
522                    op: BinaryOperator::GtEq,
523                    right,
524                },
525                BinaryOperator::LtEq => Expr::BinaryOp {
526                    left,
527                    op: BinaryOperator::Gt,
528                    right,
529                },
530                BinaryOperator::Gt => Expr::BinaryOp {
531                    left,
532                    op: BinaryOperator::LtEq,
533                    right,
534                },
535                BinaryOperator::GtEq => Expr::BinaryOp {
536                    left,
537                    op: BinaryOperator::Lt,
538                    right,
539                },
540                BinaryOperator::And => Expr::BinaryOp {
541                    left: Box::new(Expr::UnaryOp {
542                        op: UnaryOperator::Not,
543                        expr: left,
544                    }),
545                    op: BinaryOperator::Or,
546                    right: Box::new(Expr::UnaryOp {
547                        op: UnaryOperator::Not,
548                        expr: right,
549                    }),
550                },
551                BinaryOperator::Or => Expr::BinaryOp {
552                    left: Box::new(Expr::UnaryOp {
553                        op: UnaryOperator::Not,
554                        expr: left,
555                    }),
556                    op: BinaryOperator::And,
557                    right: Box::new(Expr::UnaryOp {
558                        op: UnaryOperator::Not,
559                        expr: right,
560                    }),
561                },
562                _ => Expr::UnaryOp {
563                    op: UnaryOperator::Not,
564                    expr: Box::new(Expr::BinaryOp { left, op, right }),
565                },
566            },
567            other => Expr::UnaryOp {
568                op: UnaryOperator::Not,
569                expr: Box::new(other),
570            },
571        },
572        Expr::UnaryOp {
573            op: UnaryOperator::Neg,
574            expr,
575        } => match *expr {
576            Expr::Literal(Literal::Number(value)) => {
577                Expr::Literal(Literal::Number(-value))
578            }
579            other => Expr::UnaryOp {
580                op: UnaryOperator::Neg,
581                expr: Box::new(other),
582            },
583        },
584        Expr::BinaryOp { left, op, right } => {
585            if matches!(op, BinaryOperator::And | BinaryOperator::Or) && left.structural_eq(&right) {
586                return *left;
587            }
588            let same_expr = left.structural_eq(&right);
589            if same_expr {
590                return match op {
591                    BinaryOperator::Eq | BinaryOperator::LtEq | BinaryOperator::GtEq => {
592                        Expr::Literal(Literal::Bool(true))
593                    }
594                    BinaryOperator::NotEq | BinaryOperator::Lt | BinaryOperator::Gt => {
595                        Expr::Literal(Literal::Bool(false))
596                    }
597                    _ => Expr::BinaryOp { left, op, right },
598                };
599            }
600            match (*left, op, *right) {
601            (Expr::Literal(Literal::Bool(a)), BinaryOperator::And, Expr::Literal(Literal::Bool(b))) => {
602                Expr::Literal(Literal::Bool(a && b))
603            }
604            (Expr::Literal(Literal::Bool(a)), BinaryOperator::Or, Expr::Literal(Literal::Bool(b))) => {
605                Expr::Literal(Literal::Bool(a || b))
606            }
607            (Expr::Literal(Literal::Bool(a)), BinaryOperator::And, other) => {
608                if a { other } else { Expr::Literal(Literal::Bool(false)) }
609            }
610            (other, BinaryOperator::And, Expr::Literal(Literal::Bool(b))) => {
611                if b { other } else { Expr::Literal(Literal::Bool(false)) }
612            }
613            (Expr::Literal(Literal::Bool(a)), BinaryOperator::Or, other) => {
614                if a { Expr::Literal(Literal::Bool(true)) } else { other }
615            }
616            (other, BinaryOperator::Or, Expr::Literal(Literal::Bool(b))) => {
617                if b { Expr::Literal(Literal::Bool(true)) } else { other }
618            }
619            (Expr::Literal(Literal::Number(a)), BinaryOperator::Eq, Expr::Literal(Literal::Number(b))) => {
620                Expr::Literal(Literal::Bool(a == b))
621            }
622            (Expr::Literal(Literal::Number(a)), BinaryOperator::NotEq, Expr::Literal(Literal::Number(b))) => {
623                Expr::Literal(Literal::Bool(a != b))
624            }
625            (Expr::Literal(Literal::Number(a)), BinaryOperator::Lt, Expr::Literal(Literal::Number(b))) => {
626                Expr::Literal(Literal::Bool(a < b))
627            }
628            (Expr::Literal(Literal::Number(a)), BinaryOperator::LtEq, Expr::Literal(Literal::Number(b))) => {
629                Expr::Literal(Literal::Bool(a <= b))
630            }
631            (Expr::Literal(Literal::Number(a)), BinaryOperator::Gt, Expr::Literal(Literal::Number(b))) => {
632                Expr::Literal(Literal::Bool(a > b))
633            }
634            (Expr::Literal(Literal::Number(a)), BinaryOperator::GtEq, Expr::Literal(Literal::Number(b))) => {
635                Expr::Literal(Literal::Bool(a >= b))
636            }
637            (Expr::Literal(Literal::String(a)), BinaryOperator::Eq, Expr::Literal(Literal::String(b))) => {
638                Expr::Literal(Literal::Bool(a == b))
639            }
640            (Expr::Literal(Literal::String(a)), BinaryOperator::NotEq, Expr::Literal(Literal::String(b))) => {
641                Expr::Literal(Literal::Bool(a != b))
642            }
643            (Expr::Literal(Literal::Bool(a)), BinaryOperator::Eq, Expr::Literal(Literal::Bool(b))) => {
644                Expr::Literal(Literal::Bool(a == b))
645            }
646            (Expr::Literal(Literal::Bool(a)), BinaryOperator::NotEq, Expr::Literal(Literal::Bool(b))) => {
647                Expr::Literal(Literal::Bool(a != b))
648            }
649            (Expr::Literal(Literal::Number(a)), BinaryOperator::Add, Expr::Literal(Literal::Number(b))) => {
650                Expr::Literal(Literal::Number(a + b))
651            }
652            (Expr::Literal(Literal::Number(a)), BinaryOperator::Sub, Expr::Literal(Literal::Number(b))) => {
653                Expr::Literal(Literal::Number(a - b))
654            }
655            (Expr::Literal(Literal::Number(a)), BinaryOperator::Mul, Expr::Literal(Literal::Number(b))) => {
656                Expr::Literal(Literal::Number(a * b))
657            }
658            (Expr::Literal(Literal::Number(a)), BinaryOperator::Div, Expr::Literal(Literal::Number(b))) => {
659                if b == 0.0 {
660                    Expr::BinaryOp {
661                        left: Box::new(Expr::Literal(Literal::Number(a))),
662                        op: BinaryOperator::Div,
663                        right: Box::new(Expr::Literal(Literal::Number(b))),
664                    }
665                } else {
666                    Expr::Literal(Literal::Number(a / b))
667                }
668            }
669            (left, op, right) => Expr::BinaryOp {
670                left: Box::new(left),
671                op,
672                right: Box::new(right),
673            },
674            }
675        }
676        other => other,
677    }
678}
679
680pub fn normalize_statement(statement: &Statement) -> Statement {
681    match statement {
682        Statement::With(stmt) => Statement::With(WithStatement {
683            ctes: stmt
684                .ctes
685                .iter()
686                .map(|cte| Cte {
687                    name: cte.name.clone(),
688                    columns: cte.columns.clone(),
689                    query: Box::new(normalize_statement(&cte.query)),
690                })
691                .collect(),
692            recursive: stmt.recursive,
693            statement: Box::new(normalize_statement(&stmt.statement)),
694        }),
695        Statement::Select(select) => Statement::Select(normalize_select(select)),
696        Statement::SetOp { left, op, right } => Statement::SetOp {
697            left: Box::new(normalize_statement(left)),
698            op: *op,
699            right: Box::new(normalize_statement(right)),
700        },
701        Statement::Explain(inner) => Statement::Explain(Box::new(normalize_statement(inner))),
702        Statement::CreateTable(stmt) => Statement::CreateTable(stmt.clone()),
703        Statement::DropTable(stmt) => Statement::DropTable(stmt.clone()),
704        Statement::Truncate(stmt) => Statement::Truncate(stmt.clone()),
705        Statement::Analyze(stmt) => Statement::Analyze(stmt.clone()),
706        Statement::Insert(stmt) => Statement::Insert(InsertStatement {
707            table: stmt.table.clone(),
708            columns: stmt.columns.clone(),
709            source: match &stmt.source {
710                InsertSource::Values(values) => InsertSource::Values(
711                    values
712                        .iter()
713                        .map(|row| row.iter().map(|expr| expr.normalize()).collect())
714                        .collect(),
715                ),
716                InsertSource::Query(statement) => {
717                    InsertSource::Query(Box::new(normalize_statement(statement)))
718                }
719                InsertSource::DefaultValues => InsertSource::DefaultValues,
720            },
721            returning: stmt
722                .returning
723                .iter()
724                .map(|item| SelectItem {
725                    expr: item.expr.normalize(),
726                    alias: item.alias.clone(),
727                })
728                .collect(),
729        }),
730        Statement::Update(stmt) => Statement::Update(UpdateStatement {
731            table: stmt.table.clone(),
732            assignments: stmt
733                .assignments
734                .iter()
735                .map(|assign| Assignment {
736                    column: assign.column.clone(),
737                    value: assign.value.normalize(),
738                })
739                .collect(),
740            selection: stmt.selection.as_ref().map(|expr| expr.normalize()),
741            returning: stmt
742                .returning
743                .iter()
744                .map(|item| SelectItem {
745                    expr: item.expr.normalize(),
746                    alias: item.alias.clone(),
747                })
748                .collect(),
749        }),
750        Statement::Delete(stmt) => Statement::Delete(DeleteStatement {
751            table: stmt.table.clone(),
752            selection: stmt.selection.as_ref().map(|expr| expr.normalize()),
753            returning: stmt
754                .returning
755                .iter()
756                .map(|item| SelectItem {
757                    expr: item.expr.normalize(),
758                    alias: item.alias.clone(),
759                })
760                .collect(),
761        }),
762    }
763}
764
765fn normalize_select(select: &SelectStatement) -> SelectStatement {
766    SelectStatement {
767        distinct: select.distinct,
768        distinct_on: select.distinct_on.iter().map(|expr| expr.normalize()).collect(),
769        projection: select
770            .projection
771            .iter()
772            .map(|item| SelectItem {
773                expr: item.expr.normalize(),
774                alias: item.alias.clone(),
775            })
776            .collect(),
777        from: select.from.as_ref().map(normalize_table_ref),
778        selection: select.selection.as_ref().map(|expr| expr.normalize()),
779        group_by: select.group_by.iter().map(|expr| expr.normalize()).collect(),
780        having: select.having.as_ref().map(|expr| expr.normalize()),
781        order_by: select
782            .order_by
783            .iter()
784            .map(|order| OrderByExpr {
785                expr: order.expr.normalize(),
786                asc: order.asc,
787                nulls_first: order.nulls_first,
788            })
789            .collect(),
790        limit: select.limit,
791        offset: select.offset,
792    }
793}
794
795fn normalize_select_inner(select: &SelectStatement) -> SelectStatement {
796    normalize_select(select)
797}
798
799fn normalize_table_ref(table: &TableRef) -> TableRef {
800    TableRef {
801        factor: match &table.factor {
802            TableFactor::Table { name } => TableFactor::Table { name: name.clone() },
803            TableFactor::Derived { query } => {
804                TableFactor::Derived { query: Box::new(normalize_statement(query)) }
805            }
806        },
807        alias: table.alias.clone(),
808        column_aliases: table.column_aliases.clone(),
809        joins: table
810            .joins
811            .iter()
812            .map(|join| Join {
813                join_type: join.join_type,
814                right: normalize_table_ref(&join.right),
815                on: join.on.normalize(),
816            })
817            .collect(),
818    }
819}
820
821fn select_to_sql(select: &SelectStatement) -> String {
822    let mut output = String::from("select ");
823    if select.distinct {
824        output.push_str("distinct ");
825        if !select.distinct_on.is_empty() {
826            let distinct_on = select
827                .distinct_on
828                .iter()
829                .map(|expr| expr.to_sql())
830                .collect::<Vec<_>>()
831                .join(", ");
832            output.push_str("on (");
833            output.push_str(&distinct_on);
834            output.push_str(") ");
835        }
836    }
837    let projection = select
838        .projection
839        .iter()
840        .map(|item| item.expr.to_sql())
841        .collect::<Vec<_>>()
842        .join(", ");
843    output.push_str(&projection);
844    if let Some(from) = &select.from {
845        output.push_str(" from ");
846        output.push_str(&table_ref_to_sql(from));
847    }
848    if let Some(selection) = &select.selection {
849        output.push_str(" where ");
850        output.push_str(&selection.to_sql());
851    }
852    if !select.group_by.is_empty() {
853        let group_by = select
854            .group_by
855            .iter()
856            .map(|expr| expr.to_sql())
857            .collect::<Vec<_>>()
858            .join(", ");
859        output.push_str(" group by ");
860        output.push_str(&group_by);
861    }
862    if let Some(having) = &select.having {
863        output.push_str(" having ");
864        output.push_str(&having.to_sql());
865    }
866    if !select.order_by.is_empty() {
867        let order_by = select
868            .order_by
869            .iter()
870            .map(|item| {
871                let mut rendered = item.expr.to_sql();
872                rendered.push(' ');
873                rendered.push_str(if item.asc { "asc" } else { "desc" });
874                if let Some(nulls_first) = item.nulls_first {
875                    rendered.push_str(" nulls ");
876                    rendered.push_str(if nulls_first { "first" } else { "last" });
877                }
878                rendered
879            })
880            .collect::<Vec<_>>()
881            .join(", ");
882        output.push_str(" order by ");
883        output.push_str(&order_by);
884    }
885    if let Some(limit) = select.limit {
886        output.push_str(" limit ");
887        output.push_str(&limit.to_string());
888    }
889    if let Some(offset) = select.offset {
890        output.push_str(" offset ");
891        output.push_str(&offset.to_string());
892    }
893    output
894}
895
896fn table_ref_to_sql(table: &TableRef) -> String {
897    let mut output = match &table.factor {
898        TableFactor::Table { name } => name.clone(),
899        TableFactor::Derived { query } => format!("({})", statement_to_sql(query)),
900    };
901    if let Some(alias) = &table.alias {
902        output.push_str(" as ");
903        output.push_str(alias);
904        if !table.column_aliases.is_empty() {
905            output.push_str(" (");
906            output.push_str(&table.column_aliases.join(", "));
907            output.push(')');
908        }
909    }
910    for join in &table.joins {
911        let join_type = match join.join_type {
912            JoinType::Inner => "join",
913            JoinType::Left => "left join",
914            JoinType::Right => "right join",
915            JoinType::Full => "full join",
916        };
917        output.push(' ');
918        output.push_str(join_type);
919        output.push(' ');
920        output.push_str(&table_ref_to_sql(&join.right));
921        output.push_str(" on ");
922        output.push_str(&join.on.to_sql());
923    }
924    output
925}
926
927pub fn statement_to_sql(statement: &Statement) -> String {
928    match statement {
929        Statement::Select(select) => select_to_sql(select),
930        Statement::SetOp { left, op, right } => {
931            let left_sql = statement_to_sql(left);
932            let right_sql = statement_to_sql(right);
933            let op_str = match op {
934                SetOperator::Union => "union",
935                SetOperator::UnionAll => "union all",
936                SetOperator::Intersect => "intersect",
937                SetOperator::IntersectAll => "intersect all",
938                SetOperator::Except => "except",
939                SetOperator::ExceptAll => "except all",
940            };
941            format!("{left_sql} {op_str} {right_sql}")
942        }
943        Statement::With(with_stmt) => {
944            let ctes = with_stmt
945                .ctes
946                .iter()
947                .map(|cte| {
948                    let mut name = cte.name.clone();
949                    if !cte.columns.is_empty() {
950                        let cols = cte.columns.join(", ");
951                        name.push_str(" (");
952                        name.push_str(&cols);
953                        name.push(')');
954                    }
955                    format!("{name} as ({})", statement_to_sql(&cte.query))
956                })
957                .collect::<Vec<_>>()
958                .join(", ");
959            let keyword = if with_stmt.recursive {
960                "with recursive"
961            } else {
962                "with"
963            };
964            format!("{keyword} {ctes} {}", statement_to_sql(&with_stmt.statement))
965        }
966        Statement::Explain(inner) => format!("explain {}", statement_to_sql(inner)),
967        Statement::CreateTable(stmt) => {
968            if stmt.if_not_exists {
969                if stmt.columns.is_empty() {
970                    format!("create table if not exists {}", stmt.name)
971                } else {
972                    let columns = stmt
973                        .columns
974                        .iter()
975                        .map(|col| format!("{} {}", col.name, col.data_type))
976                        .collect::<Vec<_>>()
977                        .join(", ");
978                    format!("create table if not exists {} ({})", stmt.name, columns)
979                }
980            } else {
981                if stmt.columns.is_empty() {
982                    format!("create table {}", stmt.name)
983                } else {
984                    let columns = stmt
985                        .columns
986                        .iter()
987                        .map(|col| format!("{} {}", col.name, col.data_type))
988                        .collect::<Vec<_>>()
989                        .join(", ");
990                    format!("create table {} ({})", stmt.name, columns)
991                }
992            }
993        }
994        Statement::DropTable(stmt) => {
995            if stmt.if_exists {
996                format!("drop table if exists {}", stmt.name)
997            } else {
998                format!("drop table {}", stmt.name)
999            }
1000        }
1001        Statement::Truncate(stmt) => format!("truncate table {}", stmt.table),
1002        Statement::Analyze(stmt) => format!("analyze {}", stmt.table),
1003        Statement::Insert(stmt) => {
1004            let mut output = format!("insert into {}", stmt.table);
1005            if !stmt.columns.is_empty() {
1006                output.push_str(" (");
1007                output.push_str(&stmt.columns.join(", "));
1008                output.push(')');
1009            }
1010            match &stmt.source {
1011                InsertSource::DefaultValues => {
1012                    output.push_str(" default values");
1013                }
1014                InsertSource::Values(values) => {
1015                    let rows = values
1016                        .iter()
1017                        .map(|row| {
1018                            let values = row
1019                                .iter()
1020                                .map(|expr| expr.to_sql())
1021                                .collect::<Vec<_>>()
1022                                .join(", ");
1023                            format!("({values})")
1024                        })
1025                        .collect::<Vec<_>>()
1026                        .join(", ");
1027                    output.push_str(" values ");
1028                    output.push_str(&rows);
1029                }
1030                InsertSource::Query(statement) => {
1031                    output.push(' ');
1032                    output.push_str(&statement_to_sql(statement));
1033                }
1034            }
1035            if !stmt.returning.is_empty() {
1036                let returning = stmt
1037                    .returning
1038                    .iter()
1039                    .map(|item| item.expr.to_sql())
1040                    .collect::<Vec<_>>()
1041                    .join(", ");
1042                output.push_str(" returning ");
1043                output.push_str(&returning);
1044            }
1045            output
1046        }
1047        Statement::Update(stmt) => {
1048            let mut output = format!("update {} set ", stmt.table);
1049            let assignments = stmt
1050                .assignments
1051                .iter()
1052                .map(|assign| format!("{} = {}", assign.column, assign.value.to_sql()))
1053                .collect::<Vec<_>>()
1054                .join(", ");
1055            output.push_str(&assignments);
1056            if let Some(selection) = &stmt.selection {
1057                output.push_str(" where ");
1058                output.push_str(&selection.to_sql());
1059            }
1060            if !stmt.returning.is_empty() {
1061                let returning = stmt
1062                    .returning
1063                    .iter()
1064                    .map(|item| item.expr.to_sql())
1065                    .collect::<Vec<_>>()
1066                    .join(", ");
1067                output.push_str(" returning ");
1068                output.push_str(&returning);
1069            }
1070            output
1071        }
1072        Statement::Delete(stmt) => {
1073            let mut output = format!("delete from {}", stmt.table);
1074            if let Some(selection) = &stmt.selection {
1075                output.push_str(" where ");
1076                output.push_str(&selection.to_sql());
1077            }
1078            if !stmt.returning.is_empty() {
1079                let returning = stmt
1080                    .returning
1081                    .iter()
1082                    .map(|item| item.expr.to_sql())
1083                    .collect::<Vec<_>>()
1084                    .join(", ");
1085                output.push_str(" returning ");
1086                output.push_str(&returning);
1087            }
1088            output
1089        }
1090    }
1091}
1092
1093#[cfg(test)]
1094mod tests {
1095    use super::{BinaryOperator, Expr, Literal, UnaryOperator};
1096
1097    #[test]
1098    fn normalize_commutative_predicate() {
1099        let expr = Expr::BinaryOp {
1100            left: Box::new(Expr::Identifier("b".to_string())),
1101            op: BinaryOperator::And,
1102            right: Box::new(Expr::Identifier("a".to_string())),
1103        };
1104        let normalized = expr.normalize();
1105        let Expr::BinaryOp { left, right, .. } = normalized else {
1106            panic!("expected binary op");
1107        };
1108        assert!(matches!(left.as_ref(), Expr::Identifier(name) if name == "a"));
1109        assert!(matches!(right.as_ref(), Expr::Identifier(name) if name == "b"));
1110    }
1111
1112    #[test]
1113    fn normalize_not_comparison() {
1114        let expr = Expr::UnaryOp {
1115            op: UnaryOperator::Not,
1116            expr: Box::new(Expr::BinaryOp {
1117                left: Box::new(Expr::Identifier("a".to_string())),
1118                op: BinaryOperator::Eq,
1119                right: Box::new(Expr::Identifier("b".to_string())),
1120            }),
1121        };
1122        let normalized = expr.normalize();
1123        let Expr::BinaryOp { op, .. } = normalized else {
1124            panic!("expected binary op");
1125        };
1126        assert!(matches!(op, BinaryOperator::NotEq));
1127    }
1128
1129    #[test]
1130    fn normalize_double_not() {
1131        let expr = Expr::UnaryOp {
1132            op: UnaryOperator::Not,
1133            expr: Box::new(Expr::UnaryOp {
1134                op: UnaryOperator::Not,
1135                expr: Box::new(Expr::Identifier("flag".to_string())),
1136            }),
1137        };
1138        let normalized = expr.normalize();
1139        assert!(matches!(normalized, Expr::Identifier(_)));
1140    }
1141
1142    #[test]
1143    fn normalize_constant_fold() {
1144        let expr = Expr::BinaryOp {
1145            left: Box::new(Expr::Literal(Literal::Number(2.0))),
1146            op: BinaryOperator::Mul,
1147            right: Box::new(Expr::Literal(Literal::Number(4.0))),
1148        };
1149        let normalized = expr.normalize();
1150        match normalized {
1151            Expr::Literal(Literal::Number(value)) => assert_eq!(value, 8.0),
1152            other => panic!("expected literal, got {other:?}"),
1153        }
1154    }
1155
1156    #[test]
1157    fn normalize_boolean_identities() {
1158        let expr = Expr::BinaryOp {
1159            left: Box::new(Expr::Identifier("flag".to_string())),
1160            op: BinaryOperator::And,
1161            right: Box::new(Expr::Literal(Literal::Bool(true))),
1162        };
1163        let normalized = expr.normalize();
1164        assert!(matches!(normalized, Expr::Identifier(_)));
1165    }
1166
1167    #[test]
1168    fn normalize_duplicate_and_or() {
1169        let expr = Expr::BinaryOp {
1170            left: Box::new(Expr::Identifier("a".to_string())),
1171            op: BinaryOperator::Or,
1172            right: Box::new(Expr::Identifier("a".to_string())),
1173        };
1174        let normalized = expr.normalize();
1175        assert!(matches!(normalized, Expr::Identifier(name) if name == "a"));
1176        let expr = Expr::BinaryOp {
1177            left: Box::new(Expr::Identifier("a".to_string())),
1178            op: BinaryOperator::And,
1179            right: Box::new(Expr::Identifier("a".to_string())),
1180        };
1181        let normalized = expr.normalize();
1182        assert!(matches!(normalized, Expr::Identifier(name) if name == "a"));
1183    }
1184
1185    #[test]
1186    fn normalize_self_comparison() {
1187        let expr = Expr::BinaryOp {
1188            left: Box::new(Expr::Identifier("a".to_string())),
1189            op: BinaryOperator::Eq,
1190            right: Box::new(Expr::Identifier("a".to_string())),
1191        };
1192        let normalized = expr.normalize();
1193        assert!(matches!(normalized, Expr::Literal(Literal::Bool(true))));
1194        let expr = Expr::BinaryOp {
1195            left: Box::new(Expr::Identifier("a".to_string())),
1196            op: BinaryOperator::NotEq,
1197            right: Box::new(Expr::Identifier("a".to_string())),
1198        };
1199        let normalized = expr.normalize();
1200        assert!(matches!(normalized, Expr::Literal(Literal::Bool(false))));
1201        let expr = Expr::BinaryOp {
1202            left: Box::new(Expr::Identifier("a".to_string())),
1203            op: BinaryOperator::Lt,
1204            right: Box::new(Expr::Identifier("a".to_string())),
1205        };
1206        let normalized = expr.normalize();
1207        assert!(matches!(normalized, Expr::Literal(Literal::Bool(false))));
1208        let expr = Expr::BinaryOp {
1209            left: Box::new(Expr::Identifier("a".to_string())),
1210            op: BinaryOperator::LtEq,
1211            right: Box::new(Expr::Identifier("a".to_string())),
1212        };
1213        let normalized = expr.normalize();
1214        assert!(matches!(normalized, Expr::Literal(Literal::Bool(true))));
1215    }
1216
1217    #[test]
1218    fn normalize_not_literal() {
1219        let expr = Expr::UnaryOp {
1220            op: UnaryOperator::Not,
1221            expr: Box::new(Expr::Literal(Literal::Bool(true))),
1222        };
1223        let normalized = expr.normalize();
1224        assert!(matches!(normalized, Expr::Literal(Literal::Bool(false))));
1225    }
1226}