restq/
ast.rs

1pub mod ddl;
2pub mod dml;
3mod expr;
4mod operator;
5pub mod parser;
6mod table;
7mod value;
8
9use crate::Error;
10pub use ddl::{AlterTable, DropTable, Foreign, TableDef};
11pub use dml::{BulkDelete, BulkUpdate, Delete, Insert, Update};
12pub use expr::{BinaryOperation, Expr, ExprRename};
13pub use operator::Operator;
14use serde::{Deserialize, Serialize};
15use sqlparser::ast as sql;
16use std::fmt;
17pub use table::{FromTable, JoinType, TableError, TableLookup, TableName};
18pub use value::Value;
19
20#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
21pub enum Statement {
22    Select(Select),
23    Insert(Insert),
24    Update(Update),
25    BulkUpdate(BulkUpdate),
26    Delete(Delete),
27    BulkDelete(BulkDelete),
28    Create(TableDef),
29    DropTable(DropTable),
30    AlterTable(AlterTable),
31}
32
33#[derive(Debug, PartialEq, Default, Clone, Serialize, Deserialize)]
34pub struct Select {
35    pub from_table: FromTable,
36    pub filter: Option<Expr>,
37    pub group_by: Option<Vec<Expr>>,
38    pub having: Option<Expr>,
39    pub projection: Option<Vec<ExprRename>>, // column selection
40    pub order_by: Option<Vec<Order>>,
41    pub range: Option<Range>,
42}
43
44#[derive(
45    Debug,
46    PartialEq,
47    Default,
48    Clone,
49    PartialOrd,
50    Hash,
51    Eq,
52    Ord,
53    Serialize,
54    Deserialize,
55)]
56pub struct ColumnName {
57    pub name: String,
58}
59
60#[derive(Debug, PartialEq, Default, Clone, Serialize, Deserialize)]
61pub struct Function {
62    pub name: String,
63    pub params: Vec<Expr>,
64}
65
66#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
67pub struct Order {
68    pub expr: Expr,
69    pub direction: Option<Direction>,
70}
71
72#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
73pub enum Direction {
74    Asc,
75    Desc,
76}
77
78#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
79pub enum Range {
80    Page(Page),
81    Limit(Limit),
82}
83
84impl Range {
85    pub(crate) fn limit(&self) -> i64 {
86        match self {
87            Range::Page(page) => page.page_size,
88            Range::Limit(limit) => limit.limit,
89        }
90    }
91
92    pub(crate) fn offset(&self) -> Option<i64> {
93        match self {
94            Range::Page(page) => Some((page.page - 1) * page.page_size),
95            Range::Limit(limit) => limit.offset,
96        }
97    }
98}
99
100#[derive(Debug, PartialEq, Default, Clone, Serialize, Deserialize)]
101pub struct Page {
102    pub page: i64,
103    pub page_size: i64,
104}
105
106#[derive(Debug, PartialEq, Default, Clone, Serialize, Deserialize)]
107pub struct Limit {
108    pub limit: i64,
109    pub offset: Option<i64>,
110}
111
112impl Statement {
113    pub fn into_sql_statement(
114        &self,
115        table_lookup: Option<&TableLookup>,
116    ) -> Result<sql::Statement, Error> {
117        match self {
118            Statement::Select(select) => {
119                select.into_sql_statement(table_lookup)
120            }
121            Statement::Insert(insert) => {
122                insert.into_sql_statement(table_lookup)
123            }
124            Statement::Update(update) => update.into_sql_statement(),
125            Statement::Delete(delete) => delete.into_sql_statement(),
126            Statement::BulkUpdate(_update) => todo!(),
127            Statement::BulkDelete(_delete) => todo!(),
128            Statement::Create(create) => {
129                Ok(create.into_sql_statement(table_lookup)?)
130            }
131            Statement::DropTable(drop_table) => {
132                Ok(drop_table.into_sql_statement()?)
133            }
134            Statement::AlterTable(alter_table) => {
135                let mut statements =
136                    alter_table.into_sql_statements(table_lookup)?;
137                if statements.len() == 1 {
138                    Ok(statements.remove(0))
139                } else {
140                    Err(Error::MoreThanOneStatement)
141                }
142            }
143        }
144    }
145}
146
147impl Into<Statement> for Select {
148    fn into(self) -> Statement {
149        Statement::Select(self)
150    }
151}
152
153impl Select {
154    pub fn set_page(&mut self, page: i64, page_size: i64) {
155        self.range = Some(Range::Page(Page { page, page_size }));
156    }
157
158    pub fn get_page(&self) -> Option<i64> {
159        if let Some(Range::Page(page)) = &self.range {
160            Some(page.page)
161        } else {
162            None
163        }
164    }
165
166    pub fn get_page_size(&self) -> Option<i64> {
167        if let Some(Range::Page(page)) = &self.range {
168            Some(page.page_size)
169        } else {
170            None
171        }
172    }
173
174    pub fn add_simple_filter(
175        &mut self,
176        column: ColumnName,
177        operator: Operator,
178        search_key: &str,
179    ) {
180        let simple_filter = Expr::BinaryOperation(Box::new(BinaryOperation {
181            left: Expr::Column(column),
182            operator,
183            right: Expr::Value(Value::String(search_key.to_string())),
184        }));
185
186        //TODO: need to deal with existing filters
187        self.filter = Some(simple_filter);
188    }
189
190    pub fn into_sql_select(
191        &self,
192        table_lookup: Option<&TableLookup>,
193    ) -> Result<sql::Select, Error> {
194        let select = sql::Select {
195            distinct: None,
196            projection: if let Some(projection) = self.projection.as_ref() {
197                projection
198                    .iter()
199                    .map(|proj| {
200                        if let Some(rename) = &proj.rename {
201                            sql::SelectItem::ExprWithAlias {
202                                expr: Into::into(&proj.expr),
203                                alias: sql::Ident::new(rename),
204                            }
205                        } else {
206                            sql::SelectItem::UnnamedExpr(Into::into(&proj.expr))
207                        }
208                    })
209                    .collect::<Vec<_>>()
210            } else {
211                vec![sql::SelectItem::Wildcard(
212                    sql::WildcardAdditionalOptions::default(),
213                )]
214            },
215            from: vec![self.from_table.into_table_with_joins(table_lookup)?],
216            selection: self.filter.as_ref().map(|expr| Into::into(expr)),
217            group_by: sql::GroupByExpr::Expressions(match &self.group_by {
218                Some(group_by) => {
219                    group_by.iter().map(|expr| Into::into(expr)).collect()
220                }
221                None => vec![],
222            }),
223            having: self.having.as_ref().map(|expr| Into::into(expr)),
224            cluster_by: vec![],
225            distribute_by: vec![],
226            sort_by: vec![],
227            lateral_views: vec![],
228            named_window: vec![],
229            qualify: None,
230            value_table_mode: None,
231            into: None,
232            top: None,
233        };
234        Ok(select)
235    }
236
237    /// convert the restq ast representation into sql-ast ast representation
238    pub fn into_sql_query(
239        &self,
240        table_lookup: Option<&TableLookup>,
241    ) -> Result<sql::Query, Error> {
242        let query = sql::Query {
243            with: None,
244            body: Box::new(sql::SetExpr::Select(Box::new(
245                self.into_sql_select(table_lookup)?,
246            ))),
247            order_by: match &self.order_by {
248                Some(order_by) => {
249                    order_by.iter().map(|expr| Into::into(expr)).collect()
250                }
251                None => vec![],
252            },
253            limit: self.range.as_ref().map(|range| {
254                sql::Expr::Value(sql::Value::Number(
255                    range.limit().to_string(),
256                    false,
257                ))
258            }),
259            offset: match &self.range {
260                Some(range) => range.offset().map(|offset| sql::Offset {
261                    value: sql::Expr::Value(sql::Value::Number(
262                        offset.to_string(),
263                        false,
264                    )),
265                    rows: sql::OffsetRows::None,
266                }),
267                None => None,
268            },
269            limit_by: vec![],
270            fetch: None,
271            locks: vec![],
272            for_clause: None,
273        };
274
275        Ok(query)
276    }
277
278    /// convert this ast into a sql-ast Statement representation
279    pub fn into_sql_statement(
280        &self,
281        table_lookup: Option<&TableLookup>,
282    ) -> Result<sql::Statement, Error> {
283        Ok(sql::Statement::Query(Box::new(
284            self.into_sql_query(table_lookup)?,
285        )))
286    }
287}
288
289/// This converts the restq ast into a string for use in the url
290impl fmt::Display for Select {
291    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292        self.from_table.fmt(f)?;
293        if let Some(projection) = &self.projection {
294            write!(f, "{{")?;
295            for (i, exprr) in projection.iter().enumerate() {
296                if i > 0 {
297                    write!(f, ",")?;
298                }
299                exprr.fmt(f)?;
300            }
301            write!(f, "}}")?;
302        }
303
304        if let Some(filter) = &self.filter {
305            write!(f, "?")?;
306            filter.fmt(f)?;
307        }
308
309        if let Some(group_by) = &self.group_by {
310            write!(f, "&group_by=")?;
311            for (i, expr) in group_by.iter().enumerate() {
312                if i > 0 {
313                    write!(f, ",")?;
314                }
315                expr.fmt(f)?;
316            }
317        }
318
319        if let Some(having) = &self.having {
320            write!(f, "&having=")?;
321            having.fmt(f)?;
322        }
323        if let Some(order_by) = &self.order_by {
324            write!(f, "&order_by=")?;
325            for (i, ord) in order_by.iter().enumerate() {
326                if i > 0 {
327                    write!(f, ",")?;
328                }
329                ord.fmt(f)?;
330            }
331        }
332        if let Some(range) = &self.range {
333            write!(f, "&")?;
334            range.fmt(f)?;
335        }
336
337        Ok(())
338    }
339}
340
341impl Default for Direction {
342    fn default() -> Self {
343        Direction::Asc
344    }
345}
346
347impl Into<sql::Function> for &Function {
348    fn into(self) -> sql::Function {
349        sql::Function {
350            name: sql::ObjectName(vec![sql::Ident::new(&self.name)]),
351            args: self
352                .params
353                .iter()
354                .map(|expr| {
355                    sql::FunctionArg::Unnamed(sql::FunctionArgExpr::Expr(
356                        Into::into(expr),
357                    ))
358                })
359                .collect(),
360            over: None,
361            filter: None,
362            null_treatment: None,
363            distinct: false,
364            special: false,
365            order_by: vec![],
366        }
367    }
368}
369
370impl fmt::Display for Function {
371    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
372        write!(f, "{}(", self.name)?;
373        for (i, param) in self.params.iter().enumerate() {
374            if i > 0 {
375                write!(f, ",")?;
376            }
377            write!(f, "{}", param)?;
378        }
379        write!(f, ")")
380    }
381}
382
383impl Into<sql::Ident> for &ColumnName {
384    fn into(self) -> sql::Ident {
385        sql::Ident::new(&self.name)
386    }
387}
388
389impl fmt::Display for ColumnName {
390    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
391        write!(f, "{}", self.name)
392    }
393}
394
395impl Into<sql::OrderByExpr> for &Order {
396    fn into(self) -> sql::OrderByExpr {
397        sql::OrderByExpr {
398            expr: Into::into(&self.expr),
399            asc: self.direction.as_ref().map(|direction| match direction {
400                Direction::Asc => true,
401                Direction::Desc => false,
402            }),
403            nulls_first: None,
404        }
405    }
406}
407
408impl fmt::Display for Order {
409    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
410        self.expr.fmt(f)?;
411        if let Some(direction) = &self.direction {
412            write!(f, ".")?;
413            direction.fmt(f)?;
414        }
415        Ok(())
416    }
417}
418
419impl fmt::Display for Direction {
420    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
421        match self {
422            Direction::Asc => write!(f, "asc"),
423            Direction::Desc => write!(f, "desc"),
424        }
425    }
426}
427
428impl fmt::Display for Range {
429    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
430        match self {
431            Range::Page(page) => page.fmt(f),
432            Range::Limit(limit) => limit.fmt(f),
433        }
434    }
435}
436
437impl fmt::Display for Page {
438    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
439        write!(f, "page={}&page_size={}", self.page, self.page_size)
440    }
441}
442
443impl fmt::Display for Limit {
444    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
445        write!(f, "limit={}", self.limit)?;
446        if let Some(offset) = &self.offset {
447            write!(f, "&offset={}", offset)?;
448        }
449        Ok(())
450    }
451}