Skip to main content

oak_graphql/lexer/
mod.rs

1use crate::{kind::GraphQLSyntaxKind, language::GraphQLLanguage};
2use oak_core::{
3    Lexer, LexerCache, LexerState, OakError, TextEdit,
4    lexer::{CommentConfig, LexOutput, StringConfig, WhitespaceConfig},
5    source::Source,
6};
7use std::sync::LazyLock;
8
9type State<'a, S> = LexerState<'a, S, GraphQLLanguage>;
10
11static GRAPHQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
12static GRAPHQL_COMMENT: LazyLock<CommentConfig> = LazyLock::new(|| CommentConfig { line_marker: "#", block_start: "", block_end: "", nested_blocks: false });
13static GRAPHQL_STRING: LazyLock<StringConfig> = LazyLock::new(|| StringConfig { quotes: &['"'], escape: Some('\\') });
14
15#[derive(Clone, Debug)]
16pub struct GraphQLLexer<'config> {
17    _config: &'config GraphQLLanguage,
18}
19
20impl<'config> Lexer<GraphQLLanguage> for GraphQLLexer<'config> {
21    fn lex<'a, S: Source + ?Sized>(&self, text: &S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<GraphQLLanguage>) -> LexOutput<GraphQLLanguage> {
22        let mut state = LexerState::new(text);
23        let result = self.run(&mut state);
24        if result.is_ok() {
25            state.add_eof();
26        }
27        state.finish_with_cache(result, cache)
28    }
29}
30
31impl<'config> GraphQLLexer<'config> {
32    pub fn new(config: &'config GraphQLLanguage) -> Self {
33        Self { _config: config }
34    }
35
36    fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
37        while state.not_at_end() {
38            let safe_point = state.get_position();
39
40            if self.skip_whitespace(state) {
41                continue;
42            }
43
44            if self.skip_comment(state) {
45                continue;
46            }
47
48            if self.lex_string_literal(state) {
49                continue;
50            }
51
52            if self.lex_number_literal(state) {
53                continue;
54            }
55
56            if self.lex_identifier_or_keyword(state) {
57                continue;
58            }
59
60            if self.lex_operators(state) {
61                continue;
62            }
63
64            if self.lex_single_char_tokens(state) {
65                continue;
66            }
67
68            state.advance_if_dead_lock(safe_point);
69        }
70
71        Ok(())
72    }
73
74    /// 跳过空白字符
75    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
76        GRAPHQL_WHITESPACE.scan(state, GraphQLSyntaxKind::Whitespace)
77    }
78
79    /// 跳过注释
80    fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
81        GRAPHQL_COMMENT.scan(state, GraphQLSyntaxKind::Comment, GraphQLSyntaxKind::Comment)
82    }
83
84    /// 词法分析字符串字面量
85    fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
86        // 普通字符串 "..."
87        if GRAPHQL_STRING.scan(state, GraphQLSyntaxKind::StringLiteral) {
88            return true;
89        }
90
91        // 多行字符串 """..."""
92        if state.starts_with("\"\"\"") {
93            let start = state.get_position();
94            state.advance(3); // 跳过开始的 """
95
96            while state.not_at_end() {
97                if state.starts_with("\"\"\"") {
98                    state.advance(3); // 跳过结束的 """
99                    break;
100                }
101                if let Some(ch) = state.peek() {
102                    state.advance(ch.len_utf8());
103                }
104            }
105
106            let end = state.get_position();
107            state.add_token(GraphQLSyntaxKind::StringLiteral, start, end);
108            return true;
109        }
110
111        false
112    }
113
114    /// 词法分析数字字面量
115    fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
116        let start = state.get_position();
117        let mut has_digits = false;
118        let mut is_float = false;
119
120        // 处理负号
121        if state.starts_with("-") {
122            state.advance(1);
123        }
124
125        // 处理整数部分
126        if state.starts_with("0") {
127            // 单独的 0
128            state.advance(1);
129            has_digits = true;
130        }
131        else {
132            // 非零开头的数字
133            while let Some(ch) = state.peek() {
134                if ch.is_ascii_digit() {
135                    state.advance(ch.len_utf8());
136                    has_digits = true;
137                }
138                else {
139                    break;
140                }
141            }
142        }
143
144        // 处理小数部分
145        if state.starts_with(".") && has_digits {
146            if let Some(next_ch) = state.peek_next_n(1) {
147                if next_ch.is_ascii_digit() {
148                    state.advance(1); // 跳过 .
149                    is_float = true;
150
151                    while let Some(ch) = state.peek() {
152                        if ch.is_ascii_digit() {
153                            state.advance(ch.len_utf8());
154                        }
155                        else {
156                            break;
157                        }
158                    }
159                }
160            }
161        }
162
163        // 处理指数部分
164        if (state.starts_with("e") || state.starts_with("E")) && has_digits {
165            state.advance(1);
166            is_float = true;
167
168            // 处理指数符号
169            if state.starts_with("+") || state.starts_with("-") {
170                state.advance(1);
171            }
172
173            // 处理指数数字
174            let mut exp_digits = false;
175            while let Some(ch) = state.peek() {
176                if ch.is_ascii_digit() {
177                    state.advance(ch.len_utf8());
178                    exp_digits = true;
179                }
180                else {
181                    break;
182                }
183            }
184            if !exp_digits {
185                return false;
186            }
187        }
188
189        if !has_digits {
190            return false;
191        }
192
193        let kind = if is_float { GraphQLSyntaxKind::FloatLiteral } else { GraphQLSyntaxKind::IntLiteral };
194        state.add_token(kind, start, state.get_position());
195        true
196    }
197
198    /// 词法分析标识符或关键字
199    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
200        let start = state.get_position();
201
202        // 标识符必须以字母或下划线开始
203        if let Some(first_ch) = state.peek() {
204            if !first_ch.is_alphabetic() && first_ch != '_' {
205                return false;
206            }
207
208            state.advance(first_ch.len_utf8());
209
210            // 后续字符可以是字母、数字或下划线
211            while let Some(ch) = state.peek() {
212                if ch.is_alphanumeric() || ch == '_' {
213                    state.advance(ch.len_utf8());
214                }
215                else {
216                    break;
217                }
218            }
219
220            let end = state.get_position();
221            let text = state.get_text_in((start..end).into());
222            let kind = self.keyword_or_identifier(&text);
223            state.add_token(kind, start, end);
224            true
225        }
226        else {
227            false
228        }
229    }
230
231    /// 判断是关键字还是标识符
232    fn keyword_or_identifier(&self, text: &str) -> GraphQLSyntaxKind {
233        match text {
234            // 关键字
235            "query" => GraphQLSyntaxKind::QueryKeyword,
236            "mutation" => GraphQLSyntaxKind::MutationKeyword,
237            "subscription" => GraphQLSyntaxKind::SubscriptionKeyword,
238            "fragment" => GraphQLSyntaxKind::FragmentKeyword,
239            "on" => GraphQLSyntaxKind::OnKeyword,
240            "type" => GraphQLSyntaxKind::TypeKeyword,
241            "interface" => GraphQLSyntaxKind::InterfaceKeyword,
242            "union" => GraphQLSyntaxKind::UnionKeyword,
243            "scalar" => GraphQLSyntaxKind::ScalarKeyword,
244            "enum" => GraphQLSyntaxKind::EnumKeyword,
245            "input" => GraphQLSyntaxKind::InputKeyword,
246            "extend" => GraphQLSyntaxKind::ExtendKeyword,
247            "schema" => GraphQLSyntaxKind::SchemaKeyword,
248            "directive" => GraphQLSyntaxKind::DirectiveKeyword,
249            "implements" => GraphQLSyntaxKind::ImplementsKeyword,
250            "repeats" => GraphQLSyntaxKind::RepeatsKeyword,
251
252            // 特殊字面量
253            "true" | "false" => GraphQLSyntaxKind::BooleanLiteral,
254            "null" => GraphQLSyntaxKind::NullLiteral,
255
256            // 默认为名称
257            _ => GraphQLSyntaxKind::Name,
258        }
259    }
260
261    /// 词法分析操作符
262    fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
263        let start = state.get_position();
264
265        // 三字符操作符
266        if state.starts_with("...") {
267            state.advance(3);
268            state.add_token(GraphQLSyntaxKind::Spread, start, state.get_position());
269            return true;
270        }
271
272        false
273    }
274
275    /// 词法分析单字符 token
276    fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
277        if let Some(ch) = state.peek() {
278            let start = state.get_position();
279            let kind = match ch {
280                '(' => Some(GraphQLSyntaxKind::LeftParen),
281                ')' => Some(GraphQLSyntaxKind::RightParen),
282                '[' => Some(GraphQLSyntaxKind::LeftBracket),
283                ']' => Some(GraphQLSyntaxKind::RightBracket),
284                '{' => Some(GraphQLSyntaxKind::LeftBrace),
285                '}' => Some(GraphQLSyntaxKind::RightBrace),
286                ',' => Some(GraphQLSyntaxKind::Comma),
287                ':' => Some(GraphQLSyntaxKind::Colon),
288                ';' => Some(GraphQLSyntaxKind::Semicolon),
289                '|' => Some(GraphQLSyntaxKind::Pipe),
290                '&' => Some(GraphQLSyntaxKind::Ampersand),
291                '=' => Some(GraphQLSyntaxKind::Equals),
292                '!' => Some(GraphQLSyntaxKind::Exclamation),
293                '@' => Some(GraphQLSyntaxKind::At),
294                '$' => Some(GraphQLSyntaxKind::Dollar),
295                _ => None,
296            };
297
298            if let Some(token_kind) = kind {
299                state.advance(ch.len_utf8());
300                let end = state.get_position();
301                state.add_token(token_kind, start, end);
302                true
303            }
304            else {
305                false
306            }
307        }
308        else {
309            false
310        }
311    }
312}