Skip to main content

alopex_sql/parser/
dml.rs

1use crate::ast::span::Spanned;
2use crate::ast::{
3    Assignment, Delete, Expr, Insert, LITERAL_TABLE, OrderByExpr, Select, SelectItem, TableRef,
4    Update,
5};
6use crate::error::{ParserError, Result};
7use crate::tokenizer::keyword::Keyword;
8use crate::tokenizer::token::{Token, Word};
9
10use super::Parser;
11
12impl<'a> Parser<'a> {
13    pub fn parse_select(&mut self) -> Result<Select> {
14        let start_span = self.expect_keyword("SELECT", Keyword::SELECT)?;
15        let distinct = self.consume_keyword(Keyword::DISTINCT);
16
17        let projection = self.parse_projection_list()?;
18        let mut end_span = projection
19            .last()
20            .map(|item| item.span())
21            .unwrap_or(start_span);
22
23        let from = if self.consume_keyword(Keyword::FROM) {
24            let from = self.parse_table_ref()?;
25            end_span = from.span;
26            from
27        } else {
28            validate_literal_projection(&projection)?;
29            TableRef {
30                name: LITERAL_TABLE.to_string(),
31                alias: None,
32                span: end_span,
33            }
34        };
35
36        let selection = if self.consume_keyword(Keyword::WHERE) {
37            let expr = self.parse_expr()?;
38            end_span = end_span.union(&expr.span());
39            Some(expr)
40        } else {
41            None
42        };
43
44        let group_by = self.parse_group_by()?;
45        if let Some(items) = &group_by
46            && let Some(last) = items.last()
47        {
48            end_span = end_span.union(&last.span());
49        }
50
51        if matches!(
52            self.peek().token,
53            Token::Word(Word {
54                keyword: Keyword::HAVING,
55                ..
56            })
57        ) && group_by.is_none()
58        {
59            let tok = self.peek().clone();
60            return Err(ParserError::ExpectedToken {
61                line: tok.span.start.line,
62                column: tok.span.start.column,
63                expected: "GROUP BY".to_string(),
64                found: "HAVING".to_string(),
65            });
66        }
67
68        let having = self.parse_having()?;
69        if let Some(expr) = &having {
70            end_span = end_span.union(&expr.span());
71        }
72
73        let order_by = if self.consume_keyword(Keyword::ORDER) {
74            self.expect_keyword("BY", Keyword::BY)?;
75            let items = self.parse_order_by()?;
76            if let Some(last) = items.last() {
77                end_span = end_span.union(&last.span);
78            }
79            items
80        } else {
81            Vec::new()
82        };
83
84        let mut limit = None;
85        let mut offset = None;
86        if self.consume_keyword(Keyword::LIMIT) {
87            let lim = self.parse_expr()?;
88            end_span = end_span.union(&lim.span());
89            limit = Some(lim);
90
91            if self.consume_keyword(Keyword::OFFSET) {
92                let off = self.parse_expr()?;
93                end_span = end_span.union(&off.span());
94                offset = Some(off);
95            }
96        }
97
98        let span = start_span.union(&end_span);
99        Ok(Select {
100            distinct,
101            projection,
102            from,
103            selection,
104            group_by,
105            having,
106            order_by,
107            limit,
108            offset,
109            span,
110        })
111    }
112
113    fn parse_projection_list(&mut self) -> Result<Vec<SelectItem>> {
114        let mut items = Vec::new();
115        loop {
116            items.push(self.parse_select_item()?);
117            if matches!(self.peek().token, Token::Comma) {
118                self.advance();
119                continue;
120            }
121            break;
122        }
123        Ok(items)
124    }
125
126    fn parse_select_item(&mut self) -> Result<SelectItem> {
127        let tok = self.peek().clone();
128        if matches!(tok.token, Token::Mul) {
129            self.advance();
130            return Ok(SelectItem::Wildcard { span: tok.span });
131        }
132
133        let expr = self.parse_expr()?;
134        let mut span = expr.span();
135        let mut alias = None;
136
137        if self.consume_keyword(Keyword::AS) {
138            let (name, alias_span) = self.parse_identifier()?;
139            span = span.union(&alias_span);
140            alias = Some(name);
141        } else if let Token::Word(Word {
142            keyword: Keyword::NoKeyword,
143            ..
144        }) = &self.peek().token
145        {
146            let alias_tok = self.advance();
147            if let Token::Word(Word { value, .. }) = alias_tok.token {
148                span = span.union(&alias_tok.span);
149                alias = Some(value);
150            }
151        }
152
153        Ok(SelectItem::Expr { expr, alias, span })
154    }
155
156    fn parse_table_ref(&mut self) -> Result<TableRef> {
157        let (name, name_span) = self.parse_identifier()?;
158        let mut alias = None;
159        let mut span = name_span;
160
161        if self.consume_keyword(Keyword::AS) {
162            let (a, alias_span) = self.parse_identifier()?;
163            alias = Some(a);
164            span = span.union(&alias_span);
165        } else if let Token::Word(Word {
166            keyword: Keyword::NoKeyword,
167            ..
168        }) = &self.peek().token
169        {
170            let alias_tok = self.advance();
171            if let Token::Word(Word { value, .. }) = alias_tok.token {
172                span = span.union(&alias_tok.span);
173                alias = Some(value);
174            }
175        }
176
177        Ok(TableRef { name, alias, span })
178    }
179
180    fn parse_order_by(&mut self) -> Result<Vec<OrderByExpr>> {
181        let mut items = Vec::new();
182        loop {
183            let expr = self.parse_expr()?;
184            let mut span = expr.span();
185            let mut asc = None;
186            let mut nulls_first = None;
187
188            if let Token::Word(Word { keyword, .. }) = &self.peek().token {
189                match keyword {
190                    Keyword::ASC => {
191                        let s = self.advance().span;
192                        span = span.union(&s);
193                        asc = Some(true);
194                    }
195                    Keyword::DESC => {
196                        let s = self.advance().span;
197                        span = span.union(&s);
198                        asc = Some(false);
199                    }
200                    _ => {}
201                }
202            }
203
204            if let Token::Word(Word {
205                keyword: Keyword::NULLS,
206                ..
207            }) = &self.peek().token
208            {
209                let nulls_tok = self.advance();
210                let dir_tok = self.expect_token("FIRST or LAST", |t| {
211                    matches!(
212                        t,
213                        Token::Word(Word {
214                            keyword: Keyword::FIRST | Keyword::LAST,
215                            ..
216                        })
217                    )
218                })?;
219                nulls_first = Some(matches!(
220                    dir_tok.token,
221                    Token::Word(Word {
222                        keyword: Keyword::FIRST,
223                        ..
224                    })
225                ));
226                span = span.union(&nulls_tok.span).union(&dir_tok.span);
227            }
228
229            items.push(OrderByExpr {
230                expr,
231                asc,
232                nulls_first,
233                span,
234            });
235
236            if matches!(self.peek().token, Token::Comma) {
237                self.advance();
238                continue;
239            }
240            break;
241        }
242
243        Ok(items)
244    }
245
246    fn parse_group_by(&mut self) -> Result<Option<Vec<Expr>>> {
247        if !self.consume_keyword(Keyword::GROUP) {
248            return Ok(None);
249        }
250
251        self.expect_keyword("BY", Keyword::BY)?;
252        let mut items = Vec::new();
253        loop {
254            items.push(self.parse_expr()?);
255            if matches!(self.peek().token, Token::Comma) {
256                self.advance();
257                continue;
258            }
259            break;
260        }
261
262        Ok(Some(items))
263    }
264
265    fn parse_having(&mut self) -> Result<Option<Expr>> {
266        if !self.consume_keyword(Keyword::HAVING) {
267            return Ok(None);
268        }
269
270        let expr = self.parse_expr()?;
271        Ok(Some(expr))
272    }
273
274    pub fn parse_insert(&mut self) -> Result<Insert> {
275        let start_span = self.expect_keyword("INSERT", Keyword::INSERT)?;
276        self.expect_keyword("INTO", Keyword::INTO)?;
277        let (table, table_span) = self.parse_identifier()?;
278        let mut end_span = table_span;
279        let mut columns = None;
280
281        if matches!(self.peek().token, Token::LParen) {
282            self.advance();
283            let mut cols = Vec::new();
284            loop {
285                let (col, col_span) = self.parse_identifier()?;
286                end_span = end_span.union(&col_span);
287                cols.push(col);
288                if matches!(self.peek().token, Token::Comma) {
289                    self.advance();
290                    continue;
291                }
292                break;
293            }
294            let close = self
295                .expect_token("')'", |t| matches!(t, Token::RParen))?
296                .span;
297            end_span = end_span.union(&close);
298            columns = Some(cols);
299        }
300
301        self.expect_keyword("VALUES", Keyword::VALUES)?;
302        let mut values = Vec::new();
303        loop {
304            self.expect_token("'('", |t| matches!(t, Token::LParen))?;
305            let mut row = Vec::new();
306            row.push(self.parse_expr()?);
307            while matches!(self.peek().token, Token::Comma) {
308                self.advance();
309                row.push(self.parse_expr()?);
310            }
311            let row_end = self
312                .expect_token("')'", |t| matches!(t, Token::RParen))?
313                .span;
314            end_span = end_span.union(&row_end);
315            values.push(row);
316
317            if matches!(self.peek().token, Token::Comma) {
318                self.advance();
319                continue;
320            }
321            break;
322        }
323
324        let span = start_span.union(&end_span);
325        Ok(Insert {
326            table,
327            columns,
328            values,
329            span,
330        })
331    }
332
333    pub fn parse_update(&mut self) -> Result<Update> {
334        let start_span = self.expect_keyword("UPDATE", Keyword::UPDATE)?;
335        let (table, table_span) = self.parse_identifier()?;
336        self.expect_keyword("SET", Keyword::SET)?;
337
338        let mut assignments = Vec::new();
339        loop {
340            let (column, col_span) = self.parse_identifier()?;
341            self.expect_token("'='", |t| matches!(t, Token::Eq))?;
342            let value = self.parse_expr()?;
343            let span = col_span.union(&value.span());
344            assignments.push(Assignment {
345                column,
346                value,
347                span,
348            });
349
350            if matches!(self.peek().token, Token::Comma) {
351                self.advance();
352                continue;
353            }
354            break;
355        }
356
357        let mut end_span = assignments.last().map(|a| a.span).unwrap_or(table_span);
358
359        let selection = if self.consume_keyword(Keyword::WHERE) {
360            let expr = self.parse_expr()?;
361            end_span = end_span.union(&expr.span());
362            Some(expr)
363        } else {
364            None
365        };
366
367        let span = start_span.union(&end_span);
368        Ok(Update {
369            table,
370            assignments,
371            selection,
372            span,
373        })
374    }
375
376    pub fn parse_delete(&mut self) -> Result<Delete> {
377        let start_span = self.expect_keyword("DELETE", Keyword::DELETE)?;
378        self.expect_keyword("FROM", Keyword::FROM)?;
379        let (table, table_span) = self.parse_identifier()?;
380        let mut end_span = table_span;
381
382        let selection = if self.consume_keyword(Keyword::WHERE) {
383            let expr = self.parse_expr()?;
384            end_span = end_span.union(&expr.span());
385            Some(expr)
386        } else {
387            None
388        };
389
390        let span = start_span.union(&end_span);
391        Ok(Delete {
392            table,
393            selection,
394            span,
395        })
396    }
397}
398
399fn validate_literal_projection(items: &[SelectItem]) -> Result<()> {
400    for item in items {
401        match item {
402            SelectItem::Wildcard { span } => {
403                return Err(ParserError::UnexpectedToken {
404                    line: span.start.line,
405                    column: span.start.column,
406                    expected: "literal expression".to_string(),
407                    found: "*".to_string(),
408                });
409            }
410            SelectItem::Expr { expr, .. } => {
411                if expr_contains_column_ref(expr) {
412                    return Err(ParserError::UnexpectedToken {
413                        line: expr.span.start.line,
414                        column: expr.span.start.column,
415                        expected: "literal expression".to_string(),
416                        found: "column reference".to_string(),
417                    });
418                }
419            }
420        }
421    }
422    Ok(())
423}
424
425fn expr_contains_column_ref(expr: &crate::ast::Expr) -> bool {
426    use crate::ast::expr::ExprKind;
427    match &expr.kind {
428        ExprKind::ColumnRef { .. } => true,
429        ExprKind::BinaryOp { left, right, .. } => {
430            expr_contains_column_ref(left) || expr_contains_column_ref(right)
431        }
432        ExprKind::UnaryOp { operand, .. } => expr_contains_column_ref(operand),
433        ExprKind::FunctionCall { args, .. } => args.iter().any(expr_contains_column_ref),
434        ExprKind::Between {
435            expr, low, high, ..
436        } => {
437            expr_contains_column_ref(expr)
438                || expr_contains_column_ref(low)
439                || expr_contains_column_ref(high)
440        }
441        ExprKind::Like {
442            expr,
443            pattern,
444            escape,
445            ..
446        } => {
447            expr_contains_column_ref(expr)
448                || expr_contains_column_ref(pattern)
449                || escape
450                    .as_ref()
451                    .is_some_and(|expr| expr_contains_column_ref(expr))
452        }
453        ExprKind::InList { expr, list, .. } => {
454            expr_contains_column_ref(expr) || list.iter().any(expr_contains_column_ref)
455        }
456        ExprKind::IsNull { expr, .. } => expr_contains_column_ref(expr),
457        ExprKind::Literal(_) | ExprKind::VectorLiteral(_) => false,
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use crate::{AlopexDialect, Parser, ParserError, Tokenizer};
464
465    fn parse_select(sql: &str) -> crate::ast::dml::Select {
466        let dialect = AlopexDialect;
467        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
468        let mut parser = Parser::new(&dialect, tokens);
469        parser.parse_select().unwrap()
470    }
471
472    fn parse_select_err(sql: &str) -> ParserError {
473        let dialect = AlopexDialect;
474        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
475        let mut parser = Parser::new(&dialect, tokens);
476        parser.parse_select().unwrap_err()
477    }
478
479    #[test]
480    fn parse_group_by_single_column() {
481        let select = parse_select("SELECT id, COUNT(*) FROM users GROUP BY id");
482        assert_eq!(select.group_by.as_ref().map(Vec::len), Some(1));
483        assert!(select.having.is_none());
484    }
485
486    #[test]
487    fn parse_group_by_multi_column() {
488        let select = parse_select("SELECT id, name FROM users GROUP BY id, name");
489        assert_eq!(select.group_by.as_ref().map(Vec::len), Some(2));
490    }
491
492    #[test]
493    fn parse_having_with_aggregate_condition() {
494        let select = parse_select("SELECT id, COUNT(*) FROM users GROUP BY id HAVING COUNT(*) > 5");
495        assert_eq!(select.group_by.as_ref().map(Vec::len), Some(1));
496        assert!(select.having.is_some());
497    }
498
499    #[test]
500    fn parse_group_by_without_having() {
501        let select = parse_select("SELECT id FROM users GROUP BY id");
502        assert!(select.having.is_none());
503    }
504
505    #[test]
506    fn parse_group_by_with_order_by() {
507        let select = parse_select("SELECT id, COUNT(*) FROM users GROUP BY id ORDER BY id DESC");
508        assert_eq!(select.group_by.as_ref().map(Vec::len), Some(1));
509        assert_eq!(select.order_by.len(), 1);
510    }
511
512    #[test]
513    fn parse_group_by_requires_expression() {
514        let err = parse_select_err("SELECT id FROM users GROUP BY");
515        match err {
516            ParserError::UnexpectedToken { expected, .. } => {
517                assert_eq!(expected, "expression");
518            }
519            other => panic!("unexpected error {:?}", other),
520        }
521    }
522
523    #[test]
524    fn parse_having_without_group_by() {
525        let err = parse_select_err("SELECT COUNT(*) FROM users HAVING COUNT(*) > 1");
526        match err {
527            ParserError::ExpectedToken {
528                expected, found, ..
529            } => {
530                assert_eq!(expected, "GROUP BY");
531                assert_eq!(found, "HAVING");
532            }
533            other => panic!("unexpected error {:?}", other),
534        }
535    }
536}