Skip to main content

oak_graphql/lexer/
mod.rs

1#![doc = include_str!("readme.md")]
2pub mod token_type;
3
4use crate::{language::GraphQLLanguage, lexer::token_type::GraphQLTokenType};
5use oak_core::{
6    Lexer, LexerCache, LexerState, OakError, TextEdit,
7    lexer::{CommentConfig, LexOutput, StringConfig, WhitespaceConfig},
8    source::Source,
9};
10use std::sync::LazyLock;
11
12type State<'a, S> = LexerState<'a, S, GraphQLLanguage>;
13
14static GRAPHQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
15static GRAPHQL_COMMENT: LazyLock<CommentConfig> = LazyLock::new(|| CommentConfig { line_marker: "#", block_start: "", block_end: "", nested_blocks: false });
16static GRAPHQL_STRING: LazyLock<StringConfig> = LazyLock::new(|| StringConfig { quotes: &['"'], escape: Some('\\') });
17
18#[derive(Clone, Debug)]
19pub struct GraphQLLexer<'config> {
20    _config: &'config GraphQLLanguage,
21}
22
23impl<'config> Lexer<GraphQLLanguage> for GraphQLLexer<'config> {
24    fn lex<'a, S: Source + ?Sized>(&self, text: &S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<GraphQLLanguage>) -> LexOutput<GraphQLLanguage> {
25        let mut state = LexerState::new(text);
26        let result = self.run(&mut state);
27        if result.is_ok() {
28            state.add_eof();
29        }
30        state.finish_with_cache(result, cache)
31    }
32}
33
34impl<'config> GraphQLLexer<'config> {
35    pub fn new(config: &'config GraphQLLanguage) -> Self {
36        Self { _config: config }
37    }
38
39    fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
40        while state.not_at_end() {
41            let safe_point = state.get_position();
42
43            if self.skip_whitespace(state) {
44                continue;
45            }
46
47            if self.skip_comment(state) {
48                continue;
49            }
50
51            if self.lex_string_literal(state) {
52                continue;
53            }
54
55            if self.lex_number_literal(state) {
56                continue;
57            }
58
59            if self.lex_identifier_or_keyword(state) {
60                continue;
61            }
62
63            if self.lex_operators(state) {
64                continue;
65            }
66
67            if self.lex_single_char_tokens(state) {
68                continue;
69            }
70
71            state.advance_if_dead_lock(safe_point);
72        }
73
74        Ok(())
75    }
76
77    /// 跳过空白字符
78    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
79        GRAPHQL_WHITESPACE.scan(state, GraphQLTokenType::Whitespace)
80    }
81
82    /// 跳过注释
83    fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
84        GRAPHQL_COMMENT.scan(state, GraphQLTokenType::Comment, GraphQLTokenType::Comment)
85    }
86
87    /// 词法分析字符串字面量
88    fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
89        // 普通字符串 "..."
90        if GRAPHQL_STRING.scan(state, GraphQLTokenType::StringLiteral) {
91            return true;
92        }
93
94        // 多行字符串 """..."""
95        if state.starts_with("\"\"\"") {
96            let start = state.get_position();
97            state.advance(3); // 跳过开始的 """
98
99            while state.not_at_end() {
100                if state.starts_with("\"\"\"") {
101                    state.advance(3); // 跳过结束的 """
102                    break;
103                }
104                if let Some(ch) = state.peek() {
105                    state.advance(ch.len_utf8());
106                }
107            }
108
109            let end = state.get_position();
110            state.add_token(GraphQLTokenType::StringLiteral, start, end);
111            return true;
112        }
113
114        false
115    }
116
117    /// 词法分析数字字面量
118    fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
119        let start = state.get_position();
120        let mut has_digits = false;
121        let mut is_float = false;
122
123        // 处理负号
124        if state.starts_with("-") {
125            state.advance(1);
126        }
127
128        // 处理整数部分
129        if state.starts_with("0") {
130            // 单独的 0
131            state.advance(1);
132            has_digits = true;
133        }
134        else {
135            // 非零开头的数字
136            while let Some(ch) = state.peek() {
137                if ch.is_ascii_digit() {
138                    state.advance(ch.len_utf8());
139                    has_digits = true;
140                }
141                else {
142                    break;
143                }
144            }
145        }
146
147        // 处理小数部分
148        if state.starts_with(".") && has_digits {
149            if let Some(next_ch) = state.peek_next_n(1) {
150                if next_ch.is_ascii_digit() {
151                    state.advance(1); // 跳过 .
152                    is_float = true;
153
154                    while let Some(ch) = state.peek() {
155                        if ch.is_ascii_digit() {
156                            state.advance(ch.len_utf8());
157                        }
158                        else {
159                            break;
160                        }
161                    }
162                }
163            }
164        }
165
166        // 处理指数部分
167        if (state.starts_with("e") || state.starts_with("E")) && has_digits {
168            state.advance(1);
169            is_float = true;
170
171            // 处理指数符号
172            if state.starts_with("+") || state.starts_with("-") {
173                state.advance(1);
174            }
175
176            // 处理指数数字
177            let mut exp_digits = false;
178            while let Some(ch) = state.peek() {
179                if ch.is_ascii_digit() {
180                    state.advance(ch.len_utf8());
181                    exp_digits = true;
182                }
183                else {
184                    break;
185                }
186            }
187            if !exp_digits {
188                return false;
189            }
190        }
191
192        if !has_digits {
193            return false;
194        }
195
196        let kind = if is_float { GraphQLTokenType::FloatLiteral } else { GraphQLTokenType::IntLiteral };
197        state.add_token(kind, start, state.get_position());
198        true
199    }
200
201    /// 词法分析标识符或关键字
202    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
203        let start = state.get_position();
204
205        // 标识符必须以字母或下划线开始
206        if let Some(first_ch) = state.peek() {
207            if !first_ch.is_alphabetic() && first_ch != '_' {
208                return false;
209            }
210
211            state.advance(first_ch.len_utf8());
212
213            // 后续字符可以是字母、数字或下划线
214            while let Some(ch) = state.peek() {
215                if ch.is_alphanumeric() || ch == '_' {
216                    state.advance(ch.len_utf8());
217                }
218                else {
219                    break;
220                }
221            }
222
223            let end = state.get_position();
224            let text = state.get_text_in((start..end).into());
225            let kind = self.keyword_or_identifier(&text);
226            state.add_token(kind, start, end);
227            true
228        }
229        else {
230            false
231        }
232    }
233
234    /// 判断是关键字还是标识符
235    fn keyword_or_identifier(&self, text: &str) -> GraphQLTokenType {
236        match text {
237            // 关键字
238            "query" => GraphQLTokenType::QueryKeyword,
239            "mutation" => GraphQLTokenType::MutationKeyword,
240            "subscription" => GraphQLTokenType::SubscriptionKeyword,
241            "fragment" => GraphQLTokenType::FragmentKeyword,
242            "on" => GraphQLTokenType::OnKeyword,
243            "type" => GraphQLTokenType::TypeKeyword,
244            "interface" => GraphQLTokenType::InterfaceKeyword,
245            "union" => GraphQLTokenType::UnionKeyword,
246            "scalar" => GraphQLTokenType::ScalarKeyword,
247            "enum" => GraphQLTokenType::EnumKeyword,
248            "input" => GraphQLTokenType::InputKeyword,
249            "extend" => GraphQLTokenType::ExtendKeyword,
250            "schema" => GraphQLTokenType::SchemaKeyword,
251            "directive" => GraphQLTokenType::DirectiveKeyword,
252            "implements" => GraphQLTokenType::ImplementsKeyword,
253            "repeats" => GraphQLTokenType::RepeatsKeyword,
254
255            // 特殊字面量
256            "true" | "false" => GraphQLTokenType::BooleanLiteral,
257            "null" => GraphQLTokenType::NullLiteral,
258
259            // 默认为名称
260            _ => GraphQLTokenType::Name,
261        }
262    }
263
264    /// 词法分析操作符
265    fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
266        let start = state.get_position();
267
268        // 三字符操作符
269        if state.starts_with("...") {
270            state.advance(3);
271            state.add_token(GraphQLTokenType::Spread, start, state.get_position());
272            return true;
273        }
274
275        false
276    }
277
278    /// 词法分析单字符 token
279    fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
280        if let Some(ch) = state.peek() {
281            let start = state.get_position();
282            let kind = match ch {
283                '(' => Some(GraphQLTokenType::LeftParen),
284                ')' => Some(GraphQLTokenType::RightParen),
285                '[' => Some(GraphQLTokenType::LeftBracket),
286                ']' => Some(GraphQLTokenType::RightBracket),
287                '{' => Some(GraphQLTokenType::LeftBrace),
288                '}' => Some(GraphQLTokenType::RightBrace),
289                ',' => Some(GraphQLTokenType::Comma),
290                ':' => Some(GraphQLTokenType::Colon),
291                ';' => Some(GraphQLTokenType::Semicolon),
292                '|' => Some(GraphQLTokenType::Pipe),
293                '&' => Some(GraphQLTokenType::Ampersand),
294                '=' => Some(GraphQLTokenType::Equals),
295                '!' => Some(GraphQLTokenType::Exclamation),
296                '@' => Some(GraphQLTokenType::At),
297                '$' => Some(GraphQLTokenType::Dollar),
298                _ => None,
299            };
300
301            if let Some(token_kind) = kind {
302                state.advance(ch.len_utf8());
303                let end = state.get_position();
304                state.add_token(token_kind, start, end);
305                true
306            }
307            else {
308                false
309            }
310        }
311        else {
312            false
313        }
314    }
315}