Skip to main content

oak_sql/parser/
mod.rs

1use crate::{SqlLanguage, kind::SqlSyntaxKind};
2use oak_core::{
3    GreenNode, OakError, Parser, ParserState, TextEdit, TokenType,
4    parser::{
5        ParseCache, ParseOutput, parse_with_lexer,
6        pratt::{Associativity, Pratt, PrattParser, binary},
7    },
8    source::Source,
9};
10
11/// SQL 解析器
12pub struct SqlParser<'config> {
13    pub(crate) config: &'config SqlLanguage,
14}
15
16type State<'a, S> = ParserState<'a, SqlLanguage, S>;
17
18impl<'config> Pratt<SqlLanguage> for SqlParser<'config> {
19    fn primary<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> &'a GreenNode<'a, SqlLanguage> {
20        use crate::kind::SqlSyntaxKind::*;
21        let cp = state.checkpoint();
22        match state.peek_kind() {
23            Some(Identifier_) => {
24                state.bump();
25                state.finish_at(cp, Identifier.into())
26            }
27            Some(NumberLiteral) | Some(StringLiteral) | Some(BooleanLiteral) | Some(NullLiteral) => {
28                state.bump();
29                state.finish_at(cp, Expression.into())
30            }
31            Some(LeftParen) => {
32                state.bump();
33                PrattParser::parse(state, 0, self);
34                state.expect(RightParen).ok();
35                state.finish_at(cp, Expression.into())
36            }
37            _ => {
38                state.bump();
39                state.finish_at(cp, ErrorNode.into())
40            }
41        }
42    }
43
44    fn prefix<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> &'a GreenNode<'a, SqlLanguage> {
45        self.primary(state)
46    }
47
48    fn infix<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, left: &'a GreenNode<'a, SqlLanguage>, min_precedence: u8) -> Option<&'a GreenNode<'a, SqlLanguage>> {
49        use crate::kind::SqlSyntaxKind::*;
50        let kind = state.peek_kind()?;
51
52        let (prec, assoc) = match kind {
53            Or => (1, Associativity::Left),
54            And => (2, Associativity::Left),
55            Equal | NotEqual | Less | Greater | LessEqual | GreaterEqual | Like | In | Between | Is => (3, Associativity::Left),
56            Plus | Minus => (10, Associativity::Left),
57            Star | Slash | Percent => (11, Associativity::Left),
58            _ => return None,
59        };
60
61        if prec < min_precedence {
62            return None;
63        }
64
65        Some(binary(state, left, kind, prec, assoc, Expression.into(), |s, p| PrattParser::parse(s, p, self)))
66    }
67}
68
69impl<'config> SqlParser<'config> {
70    pub fn new(config: &'config SqlLanguage) -> Self {
71        Self { config }
72    }
73
74    pub(crate) fn parse_root_internal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<&'a GreenNode<'a, SqlLanguage>, OakError> {
75        let cp = state.checkpoint();
76        while state.not_at_end() {
77            if state.current().map(|t| t.kind.is_ignored()).unwrap_or(false) {
78                state.advance();
79                continue;
80            }
81            self.parse_statement(state)?;
82        }
83        Ok(state.finish_at(cp, SqlSyntaxKind::Root.into()))
84    }
85
86    fn parse_statement<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
87        use crate::kind::SqlSyntaxKind::*;
88        match state.peek_kind() {
89            Some(Select) => self.parse_select(state)?,
90            Some(Insert) => self.parse_insert(state)?,
91            Some(Update) => self.parse_update(state)?,
92            Some(Delete) => self.parse_delete(state)?,
93            Some(Create) => self.parse_create(state)?,
94            Some(Drop) => self.parse_drop(state)?,
95            Some(Alter) => self.parse_alter(state)?,
96            _ => {
97                state.advance_until(Semicolon);
98                state.eat(Semicolon);
99            }
100        }
101        Ok(())
102    }
103
104    fn parse_select<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
105        use crate::kind::SqlSyntaxKind::*;
106        let cp = state.checkpoint();
107        state.expect(Select).ok();
108
109        // Parse Select Items
110        while state.not_at_end() && state.peek_kind() != Some(From) {
111            PrattParser::parse(state, 0, self);
112            state.eat(Comma);
113        }
114
115        if state.eat(From) {
116            state.expect(Identifier_).ok(); // TableName
117
118            // Parse JOIN clauses
119            while let Some(kind) = state.peek_kind() {
120                if matches!(kind, Join | Inner | Left | Right | Full) {
121                    let join_cp = state.checkpoint();
122                    if kind != Join {
123                        state.bump(); // Inner, Left, etc.
124                        state.eat(Outer);
125                    }
126                    state.expect(Join).ok();
127                    state.expect(Identifier_).ok(); // Joined TableName
128                    if state.eat(On) {
129                        PrattParser::parse(state, 0, self); // Join condition
130                    }
131                    state.finish_at(join_cp, JoinClause.into());
132                }
133                else {
134                    break;
135                }
136            }
137        }
138
139        if state.eat(Where) {
140            PrattParser::parse(state, 0, self);
141        }
142
143        if state.eat(Group) {
144            let group_cp = state.checkpoint();
145            state.expect(By).ok();
146            while state.not_at_end() {
147                PrattParser::parse(state, 0, self);
148                if !state.eat(Comma) {
149                    break;
150                }
151            }
152            state.finish_at(group_cp, GroupByClause.into());
153        }
154
155        if state.eat(Having) {
156            let having_cp = state.checkpoint();
157            PrattParser::parse(state, 0, self);
158            state.finish_at(having_cp, HavingClause.into());
159        }
160
161        if state.eat(Order) {
162            let order_cp = state.checkpoint();
163            state.expect(By).ok();
164            while state.not_at_end() {
165                PrattParser::parse(state, 0, self);
166                if state.eat(Asc) || state.eat(Desc) {
167                    // Handled
168                }
169                if !state.eat(Comma) {
170                    break;
171                }
172            }
173            state.finish_at(order_cp, OrderByClause.into());
174        }
175
176        if state.eat(Limit) {
177            let limit_cp = state.checkpoint();
178            state.expect(NumberLiteral).ok();
179            if state.eat(Offset) {
180                state.expect(NumberLiteral).ok();
181            }
182            state.finish_at(limit_cp, LimitClause.into());
183        }
184        else if state.eat(Offset) {
185            let offset_cp = state.checkpoint();
186            state.expect(NumberLiteral).ok();
187            state.finish_at(offset_cp, LimitClause.into());
188        }
189
190        state.eat(Semicolon);
191        state.finish_at(cp, SelectStatement.into());
192        Ok(())
193    }
194
195    fn parse_insert<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
196        use crate::kind::SqlSyntaxKind::*;
197        let cp = state.checkpoint();
198        state.expect(Insert).ok();
199        state.eat(Into);
200        state.expect(Identifier_).ok(); // TableName
201
202        if state.eat(Values) {
203            if state.eat(LeftParen) {
204                while state.not_at_end() && state.peek_kind() != Some(RightParen) {
205                    PrattParser::parse(state, 0, self);
206                    state.eat(Comma);
207                }
208                state.expect(RightParen).ok();
209            }
210        }
211
212        state.eat(Semicolon);
213        state.finish_at(cp, InsertStatement.into());
214        Ok(())
215    }
216
217    fn parse_update<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
218        use crate::kind::SqlSyntaxKind::*;
219        let cp = state.checkpoint();
220        state.expect(Update).ok();
221        state.expect(Identifier_).ok(); // TableName
222
223        if state.eat(Set) {
224            while state.not_at_end() && state.peek_kind() != Some(Where) && state.peek_kind() != Some(Semicolon) {
225                state.expect(Identifier_).ok(); // Column
226                state.expect(Equal).ok();
227                PrattParser::parse(state, 0, self);
228                if !state.eat(Comma) {
229                    break;
230                }
231            }
232        }
233
234        if state.eat(Where) {
235            PrattParser::parse(state, 0, self);
236        }
237
238        state.eat(Semicolon);
239        state.finish_at(cp, UpdateStatement.into());
240        Ok(())
241    }
242
243    fn parse_delete<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
244        use crate::kind::SqlSyntaxKind::*;
245        let cp = state.checkpoint();
246        state.expect(Delete).ok();
247        state.eat(From);
248        state.expect(Identifier_).ok(); // TableName
249
250        if state.eat(Where) {
251            PrattParser::parse(state, 0, self);
252        }
253
254        state.eat(Semicolon);
255        state.finish_at(cp, DeleteStatement.into());
256        Ok(())
257    }
258
259    fn parse_create<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
260        let cp = state.checkpoint();
261        state.bump(); // create
262        state.advance_until(crate::kind::SqlSyntaxKind::Semicolon);
263        state.eat(crate::kind::SqlSyntaxKind::Semicolon);
264        state.finish_at(cp, crate::kind::SqlSyntaxKind::CreateStatement.into());
265        Ok(())
266    }
267
268    fn parse_drop<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
269        let cp = state.checkpoint();
270        state.bump(); // drop
271        state.advance_until(crate::kind::SqlSyntaxKind::Semicolon);
272        state.eat(crate::kind::SqlSyntaxKind::Semicolon);
273        state.finish_at(cp, crate::kind::SqlSyntaxKind::DropStatement.into());
274        Ok(())
275    }
276
277    fn parse_alter<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
278        let cp = state.checkpoint();
279        state.bump(); // alter
280        state.advance_until(crate::kind::SqlSyntaxKind::Semicolon);
281        state.eat(crate::kind::SqlSyntaxKind::Semicolon);
282        state.finish_at(cp, crate::kind::SqlSyntaxKind::AlterStatement.into());
283        Ok(())
284    }
285}
286
287impl<'config> Parser<SqlLanguage> for SqlParser<'config> {
288    fn parse<'a, S: Source + ?Sized>(&self, text: &'a S, edits: &[TextEdit], cache: &'a mut impl ParseCache<SqlLanguage>) -> ParseOutput<'a, SqlLanguage> {
289        let lexer = crate::lexer::SqlLexer::new(&self.config);
290        parse_with_lexer(&lexer, text, edits, cache, |state| self.parse_root_internal(state))
291    }
292}