Skip to main content

oak_sql/parser/
mod.rs

1pub mod element_type;
2
3use crate::{SqlElementType, SqlLanguage};
4use oak_core::{
5    GreenNode, OakError, Parser, ParserState, TextEdit, TokenType,
6    parser::{
7        ParseCache, ParseOutput, parse_with_lexer,
8        pratt::{Associativity, Pratt, PrattParser, binary},
9    },
10    source::Source,
11};
12
13/// SQL 解析器
14pub struct SqlParser<'config> {
15    pub(crate) config: &'config SqlLanguage,
16}
17
18type State<'a, S> = ParserState<'a, SqlLanguage, S>;
19
20impl<'config> Pratt<SqlLanguage> for SqlParser<'config> {
21    fn primary<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> &'a GreenNode<'a, SqlLanguage> {
22        use crate::lexer::SqlTokenType::*;
23        let cp = state.checkpoint();
24        match state.peek_kind() {
25            Some(Identifier_) => {
26                state.bump();
27                state.finish_at(cp, SqlElementType::Identifier)
28            }
29            Some(NumberLiteral) | Some(StringLiteral) | Some(BooleanLiteral) | Some(NullLiteral) => {
30                state.bump();
31                state.finish_at(cp, SqlElementType::Expression)
32            }
33            Some(LeftParen) => {
34                state.bump();
35                PrattParser::parse(state, 0, self);
36                state.expect(RightParen).ok();
37                state.finish_at(cp, SqlElementType::Expression)
38            }
39            _ => {
40                state.bump();
41                state.finish_at(cp, SqlElementType::ErrorNode)
42            }
43        }
44    }
45
46    fn prefix<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> &'a GreenNode<'a, SqlLanguage> {
47        self.primary(state)
48    }
49
50    fn infix<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, left: &'a GreenNode<'a, SqlLanguage>, min_precedence: u8) -> Option<&'a GreenNode<'a, SqlLanguage>> {
51        use crate::lexer::SqlTokenType::*;
52        let kind = state.peek_kind()?;
53
54        let (prec, assoc) = match kind {
55            Or => (1, Associativity::Left),
56            And => (2, Associativity::Left),
57            Equal | NotEqual | Less | Greater | LessEqual | GreaterEqual | Like | In | Between | Is => (3, Associativity::Left),
58            Plus | Minus => (10, Associativity::Left),
59            Star | Slash | Percent => (11, Associativity::Left),
60            _ => return None,
61        };
62
63        if prec < min_precedence {
64            return None;
65        }
66
67        Some(binary(state, left, kind, prec, assoc, Expression.into(), |s, p| PrattParser::parse(s, p, self)))
68    }
69}
70
71impl<'config> SqlParser<'config> {
72    pub fn new(config: &'config SqlLanguage) -> Self {
73        Self { config }
74    }
75
76    pub(crate) fn parse_root_internal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<&'a GreenNode<'a, SqlLanguage>, OakError> {
77        let cp = state.checkpoint();
78        while state.not_at_end() {
79            if state.current().map(|t| t.kind.is_ignored()).unwrap_or(false) {
80                state.advance();
81                continue;
82            }
83            self.parse_statement(state)?
84        }
85        Ok(state.finish_at(cp, SqlElementType::Root))
86    }
87
88    fn parse_statement<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
89        use crate::lexer::SqlTokenType::*;
90        match state.peek_kind() {
91            Some(Select) => self.parse_select(state)?,
92            Some(Insert) => self.parse_insert(state)?,
93            Some(Update) => self.parse_update(state)?,
94            Some(Delete) => self.parse_delete(state)?,
95            Some(Create) => self.parse_create(state)?,
96            Some(Drop) => self.parse_drop(state)?,
97            Some(Alter) => self.parse_alter(state)?,
98            _ => {
99                state.advance_until(Semicolon);
100                state.eat(Semicolon);
101            }
102        }
103        Ok(())
104    }
105
106    fn parse_select<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
107        use crate::lexer::SqlTokenType::*;
108        let cp = state.checkpoint();
109        state.expect(Select).ok();
110
111        // Parse Select Items
112        while state.not_at_end() && state.peek_kind() != Some(From) {
113            PrattParser::parse(state, 0, self);
114            state.eat(Comma);
115        }
116
117        if state.eat(From) {
118            state.expect(Identifier_).ok(); // TableName
119
120            // Parse JOIN clauses
121            while let Some(kind) = state.peek_kind() {
122                if matches!(kind, Join | Inner | Left | Right | Full) {
123                    let join_cp = state.checkpoint();
124                    if kind != Join {
125                        state.bump(); // Inner, Left, etc.
126                        state.eat(Outer);
127                    }
128                    state.expect(Join).ok();
129                    state.expect(Identifier_).ok(); // Joined TableName
130                    if state.eat(On) {
131                        PrattParser::parse(state, 0, self); // Join condition
132                    }
133                    state.finish_at(join_cp, SqlElementType::JoinClause);
134                }
135                else {
136                    break;
137                }
138            }
139        }
140
141        if state.eat(Where) {
142            PrattParser::parse(state, 0, self);
143        }
144
145        if state.eat(Group) {
146            let group_cp = state.checkpoint();
147            state.expect(By).ok();
148            while state.not_at_end() {
149                PrattParser::parse(state, 0, self);
150                if !state.eat(Comma) {
151                    break;
152                }
153            }
154            state.finish_at(group_cp, SqlElementType::GroupByClause);
155        }
156
157        if state.eat(Having) {
158            let having_cp = state.checkpoint();
159            PrattParser::parse(state, 0, self);
160            state.finish_at(having_cp, SqlElementType::HavingClause);
161        }
162
163        if state.eat(Order) {
164            let order_cp = state.checkpoint();
165            state.expect(By).ok();
166            while state.not_at_end() {
167                PrattParser::parse(state, 0, self);
168                if state.eat(Asc) || state.eat(Desc) {
169                    // Handled
170                }
171                if !state.eat(Comma) {
172                    break;
173                }
174            }
175            state.finish_at(order_cp, SqlElementType::OrderByClause);
176        }
177
178        if state.eat(Limit) {
179            let limit_cp = state.checkpoint();
180            state.expect(NumberLiteral).ok();
181            if state.eat(Offset) {
182                state.expect(NumberLiteral).ok();
183            }
184            state.finish_at(limit_cp, SqlElementType::LimitClause);
185        }
186        else if state.eat(Offset) {
187            let offset_cp = state.checkpoint();
188            state.expect(NumberLiteral).ok();
189            state.finish_at(offset_cp, SqlElementType::LimitClause);
190        }
191
192        state.eat(Semicolon);
193        state.finish_at(cp, SqlElementType::SelectStatement);
194        Ok(())
195    }
196
197    fn parse_insert<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
198        use crate::lexer::SqlTokenType::*;
199        let cp = state.checkpoint();
200        state.expect(Insert).ok();
201        state.eat(Into);
202        state.expect(Identifier_).ok(); // TableName
203
204        if state.eat(Values) {
205            if state.eat(LeftParen) {
206                while state.not_at_end() && state.peek_kind() != Some(RightParen) {
207                    PrattParser::parse(state, 0, self);
208                    state.eat(Comma);
209                }
210                state.expect(RightParen).ok();
211            }
212        }
213
214        state.eat(Semicolon);
215        state.finish_at(cp, SqlElementType::InsertStatement);
216        Ok(())
217    }
218
219    fn parse_update<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
220        use crate::lexer::SqlTokenType::*;
221        let cp = state.checkpoint();
222        state.expect(Update).ok();
223        state.expect(Identifier_).ok(); // TableName
224
225        if state.eat(Set) {
226            while state.not_at_end() && state.peek_kind() != Some(Where) && state.peek_kind() != Some(Semicolon) {
227                state.expect(Identifier_).ok(); // Column
228                state.expect(Equal).ok();
229                PrattParser::parse(state, 0, self);
230                if !state.eat(Comma) {
231                    break;
232                }
233            }
234        }
235
236        if state.eat(Where) {
237            PrattParser::parse(state, 0, self);
238        }
239
240        state.eat(Semicolon);
241        state.finish_at(cp, SqlElementType::UpdateStatement);
242        Ok(())
243    }
244
245    fn parse_delete<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
246        use crate::lexer::SqlTokenType::*;
247        let cp = state.checkpoint();
248        state.expect(Delete).ok();
249        state.eat(From);
250        state.expect(Identifier_).ok(); // TableName
251
252        if state.eat(Where) {
253            PrattParser::parse(state, 0, self);
254        }
255
256        state.eat(Semicolon);
257        state.finish_at(cp, SqlElementType::DeleteStatement);
258        Ok(())
259    }
260
261    fn parse_create<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
262        use crate::lexer::SqlTokenType::*;
263        let cp = state.checkpoint();
264        state.bump(); // create
265        state.advance_until(Semicolon);
266        state.eat(Semicolon);
267        state.finish_at(cp, SqlElementType::CreateStatement);
268        Ok(())
269    }
270
271    fn parse_drop<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
272        use crate::lexer::SqlTokenType::*;
273        let cp = state.checkpoint();
274        state.bump(); // drop
275        state.advance_until(Semicolon);
276        state.eat(Semicolon);
277        state.finish_at(cp, SqlElementType::DropStatement);
278        Ok(())
279    }
280
281    fn parse_alter<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
282        use crate::lexer::SqlTokenType::*;
283        let cp = state.checkpoint();
284        state.bump(); // alter
285        state.advance_until(Semicolon);
286        state.eat(Semicolon);
287        state.finish_at(cp, SqlElementType::AlterStatement);
288        Ok(())
289    }
290}
291
292impl<'config> Parser<SqlLanguage> for SqlParser<'config> {
293    fn parse<'a, S: Source + ?Sized>(&self, text: &'a S, edits: &[TextEdit], cache: &'a mut impl ParseCache<SqlLanguage>) -> ParseOutput<'a, SqlLanguage> {
294        let lexer = crate::lexer::SqlLexer::new(&self.config);
295        parse_with_lexer(&lexer, text, edits, cache, |state| self.parse_root_internal(state))
296    }
297}