1#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7 Add,
8 Sub,
9 Mul,
10 Div,
11 Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16 Literal(SqlValue),
17 Column(String),
18 BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19 UnaryMinus(Box<Expr>),
20 Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21 DateAdd { date: Box<Expr>, days: Box<Expr> },
22 DateDiff { left: Box<Expr>, right: Box<Expr> },
23 CurrentDate,
24 CurrentTimestamp,
25 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26}
27
28impl Expr {
29 pub fn as_column(&self) -> Option<&str> {
30 if let Expr::Column(name) = self { Some(name) } else { None }
31 }
32
33 pub fn display_name(&self) -> String {
34 match self {
35 Expr::Literal(SqlValue::Int(n)) => n.to_string(),
36 Expr::Literal(SqlValue::Float(f)) => f.to_string(),
37 Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
38 Expr::Literal(SqlValue::Null) => "NULL".to_string(),
39 Expr::Literal(SqlValue::List(_)) => "list".to_string(),
40 Expr::Column(name) => name.clone(),
41 Expr::BinaryOp { left, op, right } => {
42 let op_str = match op {
43 ArithOp::Add => "+",
44 ArithOp::Sub => "-",
45 ArithOp::Mul => "*",
46 ArithOp::Div => "/",
47 ArithOp::Mod => "%",
48 };
49 format!("{} {} {}", left.display_name(), op_str, right.display_name())
50 }
51 Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
52 Expr::Case { .. } => "CASE".to_string(),
53 Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
54 Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
55 Expr::CurrentDate => "CURRENT_DATE".to_string(),
56 Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
57 Expr::Aggregate { func, arg, .. } => {
58 let func_name = match func {
59 AggFunc::Count => "COUNT",
60 AggFunc::Sum => "SUM",
61 AggFunc::Avg => "AVG",
62 AggFunc::Min => "MIN",
63 AggFunc::Max => "MAX",
64 };
65 format!("{}({})", func_name, arg)
66 }
67 }
68 }
69
70 pub fn contains_aggregate(&self) -> bool {
71 match self {
72 Expr::Aggregate { .. } => true,
73 Expr::BinaryOp { left, right, .. } => {
74 left.contains_aggregate() || right.contains_aggregate()
75 }
76 Expr::UnaryMinus(inner) => inner.contains_aggregate(),
77 Expr::Case { whens, else_expr } => {
78 whens.iter().any(|(_, e)| e.contains_aggregate())
79 || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
80 }
81 _ => false,
82 }
83 }
84}
85
86#[derive(Debug, Clone, PartialEq)]
87pub struct OrderSpec {
88 pub column: String,
89 pub expr: Option<Expr>,
90 pub descending: bool,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub enum CmpOp {
95 Eq,
96 Ne,
97 Lt,
98 Gt,
99 Le,
100 Ge,
101 Like,
102 NotLike,
103 In,
104 IsNull,
105 IsNotNull,
106}
107
108#[derive(Debug, Clone, PartialEq)]
109pub enum BoolOpKind {
110 And,
111 Or,
112}
113
114#[derive(Debug, Clone, PartialEq)]
115pub struct Comparison {
116 pub column: String,
117 pub op: CmpOp,
118 pub value: Option<SqlValue>,
119 pub left_expr: Option<Expr>,
120 pub right_expr: Option<Expr>,
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub struct BoolOp {
125 pub op: BoolOpKind,
126 pub left: Box<WhereClause>,
127 pub right: Box<WhereClause>,
128}
129
130#[derive(Debug, Clone, PartialEq)]
131pub enum WhereClause {
132 Comparison(Comparison),
133 BoolOp(BoolOp),
134}
135
136#[derive(Debug, Clone, PartialEq)]
137pub enum SqlValue {
138 String(String),
139 Int(i64),
140 Float(f64),
141 Null,
142 List(Vec<SqlValue>),
143}
144
145#[derive(Debug, Clone, PartialEq)]
146pub enum JoinType {
147 Inner,
148 Left,
149}
150
151#[derive(Debug, Clone, PartialEq)]
152pub struct JoinClause {
153 pub join_type: JoinType,
154 pub table: String,
155 pub alias: Option<String>,
156 pub condition: WhereClause,
157}
158
159#[derive(Debug, Clone, PartialEq)]
160pub enum AggFunc {
161 Count,
162 Sum,
163 Avg,
164 Min,
165 Max,
166}
167
168#[derive(Debug, Clone, PartialEq)]
169pub enum SelectExpr {
170 Column(String),
171 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
172 Expr { expr: Expr, alias: Option<String> },
173}
174
175impl SelectExpr {
176 pub fn output_name(&self) -> String {
177 match self {
178 SelectExpr::Column(name) => name.clone(),
179 SelectExpr::Aggregate { func, arg, alias, .. } => {
180 if let Some(a) = alias {
181 a.clone()
182 } else {
183 let func_name = match func {
184 AggFunc::Count => "COUNT",
185 AggFunc::Sum => "SUM",
186 AggFunc::Avg => "AVG",
187 AggFunc::Min => "MIN",
188 AggFunc::Max => "MAX",
189 };
190 format!("{}({})", func_name, arg)
191 }
192 }
193 SelectExpr::Expr { expr, alias } => {
194 alias.clone().unwrap_or_else(|| expr.display_name())
195 }
196 }
197 }
198
199 pub fn is_aggregate(&self) -> bool {
200 match self {
201 SelectExpr::Aggregate { .. } => true,
202 SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
203 _ => false,
204 }
205 }
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub struct CteClause {
210 pub name: String,
211 pub query: Box<SelectQuery>,
212}
213
214#[derive(Debug, Clone, PartialEq)]
215pub struct SelectQuery {
216 pub columns: ColumnList,
217 pub table: String,
218 pub table_alias: Option<String>,
219 pub subquery: Option<Box<SelectQuery>>,
220 pub joins: Vec<JoinClause>,
221 pub where_clause: Option<WhereClause>,
222 pub group_by: Option<Vec<String>>,
223 pub having: Option<WhereClause>,
224 pub order_by: Option<Vec<OrderSpec>>,
225 pub limit: Option<i64>,
226 pub ctes: Vec<CteClause>,
227}
228
229#[derive(Debug, Clone, PartialEq)]
230pub enum ColumnList {
231 All,
232 Named(Vec<SelectExpr>),
233}
234
235#[derive(Debug, Clone, PartialEq)]
236pub struct InsertQuery {
237 pub table: String,
238 pub columns: Vec<String>,
239 pub values: Vec<SqlValue>,
240}
241
242#[derive(Debug, Clone, PartialEq)]
243pub struct UpdateQuery {
244 pub table: String,
245 pub assignments: Vec<(String, SqlValue)>,
246 pub where_clause: Option<WhereClause>,
247}
248
249#[derive(Debug, Clone, PartialEq)]
250pub enum DeleteMode {
251 Default,
252 Cascade,
253 Restrict,
254}
255
256#[derive(Debug, Clone, PartialEq)]
257pub struct DeleteQuery {
258 pub table: String,
259 pub where_clause: Option<WhereClause>,
260 pub mode: DeleteMode,
261}
262
263#[derive(Debug, Clone, PartialEq)]
264pub struct AlterRenameFieldQuery {
265 pub table: String,
266 pub old_name: String,
267 pub new_name: String,
268}
269
270#[derive(Debug, Clone, PartialEq)]
271pub struct AlterDropFieldQuery {
272 pub table: String,
273 pub field_name: String,
274}
275
276#[derive(Debug, Clone, PartialEq)]
277pub struct AlterMergeFieldsQuery {
278 pub table: String,
279 pub sources: Vec<String>,
280 pub into: String,
281}
282
283#[derive(Debug, Clone, PartialEq)]
284pub struct CreateViewQuery {
285 pub view_name: String,
286 pub columns: Option<Vec<String>>,
287 pub query: Box<SelectQuery>,
288}
289
290#[derive(Debug, Clone, PartialEq)]
291pub struct DropViewQuery {
292 pub view_name: String,
293}
294
295#[derive(Debug, Clone, PartialEq)]
296pub enum Statement {
297 Select(SelectQuery),
298 Insert(InsertQuery),
299 Update(UpdateQuery),
300 Delete(DeleteQuery),
301 AlterRename(AlterRenameFieldQuery),
302 AlterDrop(AlterDropFieldQuery),
303 AlterMerge(AlterMergeFieldsQuery),
304 CreateView(CreateViewQuery),
305 DropView(DropViewQuery),
306}
307
308impl Statement {
309 pub fn table_name(&self) -> &str {
310 match self {
311 Statement::Select(q) => &q.table,
312 Statement::Insert(q) => &q.table,
313 Statement::Update(q) => &q.table,
314 Statement::Delete(q) => &q.table,
315 Statement::AlterRename(q) => &q.table,
316 Statement::AlterDrop(q) => &q.table,
317 Statement::AlterMerge(q) => &q.table,
318 Statement::CreateView(q) => &q.view_name,
319 Statement::DropView(q) => &q.view_name,
320 }
321 }
322}
323
324pub fn where_clause_to_sql(clause: &WhereClause) -> String {
325 match clause {
326 WhereClause::BoolOp(bop) => {
327 let left = where_clause_to_sql(&bop.left);
328 let right = where_clause_to_sql(&bop.right);
329 let op = match bop.op {
330 BoolOpKind::And => "AND",
331 BoolOpKind::Or => "OR",
332 };
333 format!("{} {} {}", left, op, right)
334 }
335 WhereClause::Comparison(cmp) => {
336 let op_str = match cmp.op {
337 CmpOp::Eq => "=",
338 CmpOp::Ne => "!=",
339 CmpOp::Lt => "<",
340 CmpOp::Gt => ">",
341 CmpOp::Le => "<=",
342 CmpOp::Ge => ">=",
343 CmpOp::Like => "LIKE",
344 CmpOp::NotLike => "NOT LIKE",
345 CmpOp::In => "IN",
346 CmpOp::IsNull => "IS NULL",
347 CmpOp::IsNotNull => "IS NOT NULL",
348 };
349 if matches!(cmp.op, CmpOp::IsNull | CmpOp::IsNotNull) {
350 if let Some(ref expr) = cmp.left_expr {
351 return format!("{} {}", expr.display_name(), op_str);
352 }
353 return format!("{} {}", cmp.column, op_str);
354 }
355 if let (Some(ref left), Some(ref right)) = (&cmp.left_expr, &cmp.right_expr) {
356 return format!("{} {} {}", left.display_name(), op_str, right.display_name());
357 }
358 match &cmp.value {
359 Some(SqlValue::String(s)) => format!("{} {} '{}'", cmp.column, op_str, s),
360 Some(SqlValue::Int(n)) => format!("{} {} {}", cmp.column, op_str, n),
361 Some(SqlValue::Float(f)) => format!("{} {} {}", cmp.column, op_str, f),
362 Some(SqlValue::Null) => format!("{} {} NULL", cmp.column, op_str),
363 Some(SqlValue::List(items)) => {
364 let vals: Vec<String> = items.iter().map(|v| match v {
365 SqlValue::String(s) => format!("'{}'", s),
366 SqlValue::Int(n) => n.to_string(),
367 SqlValue::Float(f) => f.to_string(),
368 _ => "NULL".to_string(),
369 }).collect();
370 format!("{} {} ({})", cmp.column, op_str, vals.join(", "))
371 }
372 None => format!("{} {}", cmp.column, op_str),
373 }
374 }
375 }
376}