Skip to main content

oak_hlsl/lexer/
mod.rs

1use crate::{kind::HlslSyntaxKind, language::HlslLanguage};
2use oak_core::{Lexer, LexerCache, LexerState, OakError, lexer::LexOutput, source::Source};
3
4type State<'a, S> = LexerState<'a, S, HlslLanguage>;
5
6pub struct HlslLexer<'config> {
7    _config: &'config HlslLanguage,
8}
9
10impl<'config> Clone for HlslLexer<'config> {
11    fn clone(&self) -> Self {
12        Self { _config: self._config }
13    }
14}
15
16impl<'config> HlslLexer<'config> {
17    pub fn new(config: &'config HlslLanguage) -> Self {
18        Self { _config: config }
19    }
20
21    /// 主要的词法分析循环
22    fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
23        while state.not_at_end() {
24            let safe_point = state.get_position();
25
26            // 尝试各种词法规则
27            if self.skip_whitespace(state) {
28                continue;
29            }
30
31            if self.lex_newline(state) {
32                continue;
33            }
34
35            if self.lex_comment(state) {
36                continue;
37            }
38
39            if self.lex_preprocessor(state) {
40                continue;
41            }
42
43            if self.lex_string(state) {
44                continue;
45            }
46
47            if self.lex_number(state) {
48                continue;
49            }
50
51            if self.lex_identifier_or_keyword(state) {
52                continue;
53            }
54
55            if self.lex_operator_or_delimiter(state) {
56                continue;
57            }
58
59            // 如果所有规则都不匹配,跳过当前字符并标记为错误
60            let start_pos = state.get_position();
61            if let Some(ch) = state.peek() {
62                state.advance(ch.len_utf8());
63                state.add_token(HlslSyntaxKind::Error, start_pos, state.get_position());
64            }
65
66            state.advance_if_dead_lock(safe_point);
67        }
68
69        Ok(())
70    }
71
72    /// 跳过空白字符
73    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
74        let start_pos = state.get_position();
75
76        while let Some(ch) = state.peek() {
77            if ch == ' ' || ch == '\t' {
78                state.advance(ch.len_utf8());
79            }
80            else {
81                break;
82            }
83        }
84
85        if state.get_position() > start_pos {
86            state.add_token(HlslSyntaxKind::Whitespace, start_pos, state.get_position());
87            true
88        }
89        else {
90            false
91        }
92    }
93
94    /// 处理换行
95    fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
96        let start_pos = state.get_position();
97
98        if let Some('\n') = state.peek() {
99            state.advance(1);
100            state.add_token(HlslSyntaxKind::Newline, start_pos, state.get_position());
101            true
102        }
103        else if let Some('\r') = state.peek() {
104            state.advance(1);
105            if let Some('\n') = state.peek() {
106                state.advance(1);
107            }
108            state.add_token(HlslSyntaxKind::Newline, start_pos, state.get_position());
109            true
110        }
111        else {
112            false
113        }
114    }
115
116    /// 处理注释
117    fn lex_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
118        let start_pos = state.get_position();
119
120        // 单行注释 //
121        if let Some('/') = state.peek() {
122            if let Some('/') = state.peek_next_n(1) {
123                state.advance(2);
124                while let Some(ch) = state.peek() {
125                    if ch == '\n' || ch == '\r' {
126                        break;
127                    }
128                    state.advance(ch.len_utf8());
129                }
130                state.add_token(HlslSyntaxKind::Comment, start_pos, state.get_position());
131                return true;
132            }
133        }
134
135        // 多行注释 /* ... */
136        if let Some('/') = state.peek() {
137            if let Some('*') = state.peek_next_n(1) {
138                state.advance(2);
139                while state.not_at_end() {
140                    if let Some('*') = state.peek() {
141                        if let Some('/') = state.peek_next_n(1) {
142                            state.advance(2);
143                            break;
144                        }
145                    }
146                    if let Some(ch) = state.peek() {
147                        state.advance(ch.len_utf8());
148                    }
149                }
150                state.add_token(HlslSyntaxKind::Comment, start_pos, state.get_position());
151                return true;
152            }
153        }
154
155        false
156    }
157
158    /// 处理预处理器指令
159    fn lex_preprocessor<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
160        let start_pos = state.get_position();
161
162        if let Some('#') = state.peek() {
163            state.advance(1);
164
165            // 跳过空白
166            while let Some(ch) = state.peek() {
167                if ch == ' ' || ch == '\t' {
168                    state.advance(1);
169                }
170                else {
171                    break;
172                }
173            }
174
175            // 读取指令名称
176            let directive_start = state.get_position();
177            while let Some(ch) = state.peek() {
178                if ch.is_ascii_alphabetic() || ch == '_' {
179                    state.advance(1);
180                }
181                else {
182                    break;
183                }
184            }
185
186            if state.get_position() > directive_start {
187                let directive = state.get_text_in((directive_start..state.get_position()).into()).to_string();
188
189                // 读取指令的其余部分直到行尾
190                while let Some(ch) = state.peek() {
191                    if ch == '\n' || ch == '\r' {
192                        break;
193                    }
194                    state.advance(ch.len_utf8());
195                }
196
197                let token_kind = match directive.as_str() {
198                    "include" => HlslSyntaxKind::Include,
199                    "define" => HlslSyntaxKind::Define,
200                    "undef" => HlslSyntaxKind::Undef,
201                    "if" => HlslSyntaxKind::If_,
202                    "ifdef" => HlslSyntaxKind::Ifdef,
203                    "ifndef" => HlslSyntaxKind::Ifndef,
204                    "else" => HlslSyntaxKind::Else_,
205                    "elif" => HlslSyntaxKind::Elif,
206                    "endif" => HlslSyntaxKind::Endif,
207                    "line" => HlslSyntaxKind::Line,
208                    "error" => HlslSyntaxKind::Error,
209                    "pragma" => HlslSyntaxKind::Pragma,
210                    _ => HlslSyntaxKind::Hash,
211                };
212
213                state.add_token(token_kind, start_pos, state.get_position());
214                return true;
215            }
216            else {
217                // 只是一个 # 符号
218                state.add_token(HlslSyntaxKind::Hash, start_pos, state.get_position());
219                return true;
220            }
221        }
222
223        false
224    }
225
226    /// 处理字符串字面量
227    fn lex_string<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
228        let start_pos = state.get_position();
229
230        if let Some('"') = state.peek() {
231            state.advance(1);
232            let mut escaped = false;
233
234            while let Some(ch) = state.peek() {
235                if escaped {
236                    escaped = false;
237                }
238                else if ch == '\\' {
239                    escaped = true;
240                }
241                else if ch == '"' {
242                    state.advance(1);
243                    break;
244                }
245                else if ch == '\n' || ch == '\r' {
246                    break; // 字符串不能跨行
247                }
248                state.advance(ch.len_utf8());
249            }
250
251            state.add_token(HlslSyntaxKind::StringLiteral, start_pos, state.get_position());
252            return true;
253        }
254
255        false
256    }
257
258    /// 处理数字字面量
259    fn lex_number<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
260        let start_pos = state.get_position();
261
262        if let Some(ch) = state.peek() {
263            if ch.is_ascii_digit() || (ch == '.' && state.peek_next_n(1).map_or(false, |c| c.is_ascii_digit())) {
264                // 处理十六进制数
265                if ch == '0' && state.peek_next_n(1) == Some('x') {
266                    state.advance(2);
267                    while let Some(ch) = state.peek() {
268                        if ch.is_ascii_hexdigit() {
269                            state.advance(1);
270                        }
271                        else {
272                            break;
273                        }
274                    }
275                }
276                else {
277                    // 整数部分
278                    while let Some(ch) = state.peek() {
279                        if ch.is_ascii_digit() {
280                            state.advance(1);
281                        }
282                        else {
283                            break;
284                        }
285                    }
286
287                    // 小数点和小数部分
288                    if let Some('.') = state.peek() {
289                        if state.peek_next_n(1).map_or(false, |c| c.is_ascii_digit()) {
290                            state.advance(1);
291                            while let Some(ch) = state.peek() {
292                                if ch.is_ascii_digit() {
293                                    state.advance(1);
294                                }
295                                else {
296                                    break;
297                                }
298                            }
299                        }
300                    }
301
302                    // 指数部分
303                    if let Some(e_char) = state.peek() {
304                        if e_char == 'e' || e_char == 'E' {
305                            let saved_pos = state.get_position();
306                            state.advance(1);
307
308                            // 可选的符号
309                            if let Some(sign) = state.peek() {
310                                if sign == '+' || sign == '-' {
311                                    state.advance(1);
312                                }
313                            }
314
315                            // 指数数字
316                            let exp_start = state.get_position();
317                            while let Some(ch) = state.peek() {
318                                if ch.is_ascii_digit() {
319                                    state.advance(1);
320                                }
321                                else {
322                                    break;
323                                }
324                            }
325
326                            if state.get_position() == exp_start {
327                                // 没有有效的指数,回退
328                                state.set_position(saved_pos);
329                            }
330                        }
331                    }
332                }
333
334                // 处理后缀 (f, h, l, u 等)
335                if let Some(suffix) = state.peek() {
336                    if suffix == 'f' || suffix == 'F' || suffix == 'h' || suffix == 'H' || suffix == 'l' || suffix == 'L' || suffix == 'u' || suffix == 'U' {
337                        state.advance(1);
338                    }
339                }
340
341                state.add_token(HlslSyntaxKind::NumberLiteral, start_pos, state.get_position());
342                return true;
343            }
344        }
345
346        false
347    }
348
349    /// 处理标识符和关键字
350    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
351        let start_pos = state.get_position();
352
353        if let Some(ch) = state.peek() {
354            if ch.is_ascii_alphabetic() || ch == '_' {
355                while let Some(ch) = state.peek() {
356                    if ch.is_ascii_alphanumeric() || ch == '_' {
357                        state.advance(ch.len_utf8());
358                    }
359                    else {
360                        break;
361                    }
362                }
363
364                let text = state.get_text_in((start_pos..state.get_position()).into());
365                let token_kind = match text.as_ref() {
366                    // 基本数据类型
367                    "bool" => HlslSyntaxKind::Bool,
368                    "int" => HlslSyntaxKind::Int,
369                    "uint" => HlslSyntaxKind::Uint,
370                    "half" => HlslSyntaxKind::Half,
371                    "float" => HlslSyntaxKind::Float,
372                    "double" => HlslSyntaxKind::Double,
373                    "min16float" => HlslSyntaxKind::Min16float,
374                    "min10float" => HlslSyntaxKind::Min10float,
375                    "min16int" => HlslSyntaxKind::Min16int,
376                    "min12int" => HlslSyntaxKind::Min12int,
377                    "min16uint" => HlslSyntaxKind::Min16uint,
378
379                    // 向量类型
380                    "bool2" => HlslSyntaxKind::Bool2,
381                    "bool3" => HlslSyntaxKind::Bool3,
382                    "bool4" => HlslSyntaxKind::Bool4,
383                    "int2" => HlslSyntaxKind::Int2,
384                    "int3" => HlslSyntaxKind::Int3,
385                    "int4" => HlslSyntaxKind::Int4,
386                    "uint2" => HlslSyntaxKind::Uint2,
387                    "uint3" => HlslSyntaxKind::Uint3,
388                    "uint4" => HlslSyntaxKind::Uint4,
389                    "half2" => HlslSyntaxKind::Half2,
390                    "half3" => HlslSyntaxKind::Half3,
391                    "half4" => HlslSyntaxKind::Half4,
392                    "float2" => HlslSyntaxKind::Float2,
393                    "float3" => HlslSyntaxKind::Float3,
394                    "float4" => HlslSyntaxKind::Float4,
395                    "double2" => HlslSyntaxKind::Double2,
396                    "double3" => HlslSyntaxKind::Double3,
397                    "double4" => HlslSyntaxKind::Double4,
398
399                    // 矩阵类型
400                    "float2x2" => HlslSyntaxKind::Float2x2,
401                    "float2x3" => HlslSyntaxKind::Float2x3,
402                    "float2x4" => HlslSyntaxKind::Float2x4,
403                    "float3x2" => HlslSyntaxKind::Float3x2,
404                    "float3x3" => HlslSyntaxKind::Float3x3,
405                    "float3x4" => HlslSyntaxKind::Float3x4,
406                    "float4x2" => HlslSyntaxKind::Float4x2,
407                    "float4x3" => HlslSyntaxKind::Float4x3,
408                    "float4x4" => HlslSyntaxKind::Float4x4,
409                    "double2x2" => HlslSyntaxKind::Double2x2,
410                    "double2x3" => HlslSyntaxKind::Double2x3,
411                    "double2x4" => HlslSyntaxKind::Double2x4,
412                    "double3x2" => HlslSyntaxKind::Double3x2,
413                    "double3x3" => HlslSyntaxKind::Double3x3,
414                    "double3x4" => HlslSyntaxKind::Double3x4,
415                    "double4x2" => HlslSyntaxKind::Double4x2,
416                    "double4x3" => HlslSyntaxKind::Double4x3,
417                    "double4x4" => HlslSyntaxKind::Double4x4,
418
419                    // 纹理类型
420                    "Texture1D" => HlslSyntaxKind::Texture1D,
421                    "Texture1DArray" => HlslSyntaxKind::Texture1DArray,
422                    "Texture2D" => HlslSyntaxKind::Texture2D,
423                    "Texture2DArray" => HlslSyntaxKind::Texture2DArray,
424                    "Texture2DMS" => HlslSyntaxKind::Texture2DMS,
425                    "Texture2DMSArray" => HlslSyntaxKind::Texture2DMSArray,
426                    "Texture3D" => HlslSyntaxKind::Texture3D,
427                    "TextureCube" => HlslSyntaxKind::TextureCube,
428                    "TextureCubeArray" => HlslSyntaxKind::TextureCubeArray,
429
430                    // 采样器类型
431                    "sampler" => HlslSyntaxKind::Sampler,
432                    "SamplerState" => HlslSyntaxKind::SamplerState,
433                    "SamplerComparisonState" => HlslSyntaxKind::SamplerComparisonState,
434
435                    // 缓冲区类型
436                    "Buffer" => HlslSyntaxKind::Buffer,
437                    "StructuredBuffer" => HlslSyntaxKind::StructuredBuffer,
438                    "ByteAddressBuffer" => HlslSyntaxKind::ByteAddressBuffer,
439                    "RWBuffer" => HlslSyntaxKind::RWBuffer,
440                    "RWStructuredBuffer" => HlslSyntaxKind::RWStructuredBuffer,
441                    "RWByteAddressBuffer" => HlslSyntaxKind::RWByteAddressBuffer,
442                    "AppendStructuredBuffer" => HlslSyntaxKind::AppendStructuredBuffer,
443                    "ConsumeStructuredBuffer" => HlslSyntaxKind::ConsumeStructuredBuffer,
444
445                    // 控制流关键字
446                    "if" => HlslSyntaxKind::If,
447                    "else" => HlslSyntaxKind::Else,
448                    "for" => HlslSyntaxKind::For,
449                    "while" => HlslSyntaxKind::While,
450                    "do" => HlslSyntaxKind::Do,
451                    "switch" => HlslSyntaxKind::Switch,
452                    "case" => HlslSyntaxKind::Case,
453                    "default" => HlslSyntaxKind::Default,
454                    "break" => HlslSyntaxKind::Break,
455                    "continue" => HlslSyntaxKind::Continue,
456                    "return" => HlslSyntaxKind::Return,
457                    "discard" => HlslSyntaxKind::Discard,
458
459                    // 函数和变量修饰符
460                    "static" => HlslSyntaxKind::Static,
461                    "const" => HlslSyntaxKind::Const,
462                    "volatile" => HlslSyntaxKind::Volatile,
463                    "extern" => HlslSyntaxKind::Extern,
464                    "shared" => HlslSyntaxKind::Shared,
465                    "groupshared" => HlslSyntaxKind::Groupshared,
466                    "uniform" => HlslSyntaxKind::Uniform,
467                    "in" => HlslSyntaxKind::In,
468                    "out" => HlslSyntaxKind::Out,
469                    "inout" => HlslSyntaxKind::Inout,
470                    "inline" => HlslSyntaxKind::Inline,
471                    "target" => HlslSyntaxKind::Target,
472
473                    // 语义修饰符
474                    "register" => HlslSyntaxKind::Register,
475                    "packoffset" => HlslSyntaxKind::Packoffset,
476
477                    // 着色器类型
478                    "struct" => HlslSyntaxKind::Struct,
479                    "cbuffer" => HlslSyntaxKind::Cbuffer,
480                    "tbuffer" => HlslSyntaxKind::Tbuffer,
481                    "interface" => HlslSyntaxKind::Interface,
482                    "class" => HlslSyntaxKind::Class,
483
484                    // 布尔字面量
485                    "true" | "false" => HlslSyntaxKind::BooleanLiteral,
486
487                    _ => HlslSyntaxKind::Identifier,
488                };
489
490                state.add_token(token_kind, start_pos, state.get_position());
491                return true;
492            }
493        }
494
495        false
496    }
497
498    /// 处理运算符和分隔符
499    fn lex_operator_or_delimiter<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
500        let start_pos = state.get_position();
501
502        if let Some(ch) = state.peek() {
503            let token_kind = match ch {
504                '+' => {
505                    state.advance(1);
506                    if let Some('=') = state.peek() {
507                        state.advance(1);
508                        HlslSyntaxKind::PlusAssign
509                    }
510                    else if let Some('+') = state.peek() {
511                        state.advance(1);
512                        HlslSyntaxKind::Increment
513                    }
514                    else {
515                        HlslSyntaxKind::Plus
516                    }
517                }
518                '-' => {
519                    state.advance(1);
520                    if let Some('=') = state.peek() {
521                        state.advance(1);
522                        HlslSyntaxKind::MinusAssign
523                    }
524                    else if let Some('-') = state.peek() {
525                        state.advance(1);
526                        HlslSyntaxKind::Decrement
527                    }
528                    else if let Some('>') = state.peek() {
529                        state.advance(1);
530                        HlslSyntaxKind::Arrow
531                    }
532                    else {
533                        HlslSyntaxKind::Minus
534                    }
535                }
536                '*' => {
537                    state.advance(1);
538                    if let Some('=') = state.peek() {
539                        state.advance(1);
540                        HlslSyntaxKind::MultiplyAssign
541                    }
542                    else {
543                        HlslSyntaxKind::Multiply
544                    }
545                }
546                '/' => {
547                    state.advance(1);
548                    if let Some('=') = state.peek() {
549                        state.advance(1);
550                        HlslSyntaxKind::DivideAssign
551                    }
552                    else {
553                        HlslSyntaxKind::Divide
554                    }
555                }
556                '%' => {
557                    state.advance(1);
558                    if let Some('=') = state.peek() {
559                        state.advance(1);
560                        HlslSyntaxKind::ModuloAssign
561                    }
562                    else {
563                        HlslSyntaxKind::Modulo
564                    }
565                }
566                '=' => {
567                    state.advance(1);
568                    if let Some('=') = state.peek() {
569                        state.advance(1);
570                        HlslSyntaxKind::Equal
571                    }
572                    else {
573                        HlslSyntaxKind::Assign
574                    }
575                }
576                '!' => {
577                    state.advance(1);
578                    if let Some('=') = state.peek() {
579                        state.advance(1);
580                        HlslSyntaxKind::NotEqual
581                    }
582                    else {
583                        HlslSyntaxKind::LogicalNot
584                    }
585                }
586                '<' => {
587                    state.advance(1);
588                    if let Some('=') = state.peek() {
589                        state.advance(1);
590                        HlslSyntaxKind::LessEqual
591                    }
592                    else if let Some('<') = state.peek() {
593                        state.advance(1);
594                        if let Some('=') = state.peek() {
595                            state.advance(1);
596                            HlslSyntaxKind::LeftShiftAssign
597                        }
598                        else {
599                            HlslSyntaxKind::LeftShift
600                        }
601                    }
602                    else {
603                        HlslSyntaxKind::Less
604                    }
605                }
606                '>' => {
607                    state.advance(1);
608                    if let Some('=') = state.peek() {
609                        state.advance(1);
610                        HlslSyntaxKind::GreaterEqual
611                    }
612                    else if let Some('>') = state.peek() {
613                        state.advance(1);
614                        if let Some('=') = state.peek() {
615                            state.advance(1);
616                            HlslSyntaxKind::RightShiftAssign
617                        }
618                        else {
619                            HlslSyntaxKind::RightShift
620                        }
621                    }
622                    else {
623                        HlslSyntaxKind::Greater
624                    }
625                }
626                '&' => {
627                    state.advance(1);
628                    if let Some('&') = state.peek() {
629                        state.advance(1);
630                        HlslSyntaxKind::LogicalAnd
631                    }
632                    else if let Some('=') = state.peek() {
633                        state.advance(1);
634                        HlslSyntaxKind::BitwiseAndAssign
635                    }
636                    else {
637                        HlslSyntaxKind::BitwiseAnd
638                    }
639                }
640                '|' => {
641                    state.advance(1);
642                    if let Some('|') = state.peek() {
643                        state.advance(1);
644                        HlslSyntaxKind::LogicalOr
645                    }
646                    else if let Some('=') = state.peek() {
647                        state.advance(1);
648                        HlslSyntaxKind::BitwiseOrAssign
649                    }
650                    else {
651                        HlslSyntaxKind::BitwiseOr
652                    }
653                }
654                '^' => {
655                    state.advance(1);
656                    if let Some('=') = state.peek() {
657                        state.advance(1);
658                        HlslSyntaxKind::BitwiseXorAssign
659                    }
660                    else {
661                        HlslSyntaxKind::BitwiseXor
662                    }
663                }
664                '~' => {
665                    state.advance(1);
666                    HlslSyntaxKind::BitwiseNot
667                }
668                '?' => {
669                    state.advance(1);
670                    HlslSyntaxKind::Conditional
671                }
672                '.' => {
673                    state.advance(1);
674                    HlslSyntaxKind::Dot
675                }
676                ':' => {
677                    state.advance(1);
678                    if let Some(':') = state.peek() {
679                        state.advance(1);
680                        HlslSyntaxKind::DoubleColon
681                    }
682                    else {
683                        HlslSyntaxKind::Colon
684                    }
685                }
686                ';' => {
687                    state.advance(1);
688                    HlslSyntaxKind::Semicolon
689                }
690                ',' => {
691                    state.advance(1);
692                    HlslSyntaxKind::Comma
693                }
694                '(' => {
695                    state.advance(1);
696                    HlslSyntaxKind::LeftParen
697                }
698                ')' => {
699                    state.advance(1);
700                    HlslSyntaxKind::RightParen
701                }
702                '[' => {
703                    state.advance(1);
704                    HlslSyntaxKind::LeftBracket
705                }
706                ']' => {
707                    state.advance(1);
708                    HlslSyntaxKind::RightBracket
709                }
710                '{' => {
711                    state.advance(1);
712                    HlslSyntaxKind::LeftBrace
713                }
714                '}' => {
715                    state.advance(1);
716                    HlslSyntaxKind::RightBrace
717                }
718                '\\' => {
719                    state.advance(1);
720                    HlslSyntaxKind::Backslash
721                }
722                _ => return false,
723            };
724
725            state.add_token(token_kind, start_pos, state.get_position());
726            true
727        }
728        else {
729            false
730        }
731    }
732}
733
734impl<'config> Lexer<HlslLanguage> for HlslLexer<'config> {
735    fn lex<'a, S: Source + ?Sized>(&self, source: &S, _edits: &[oak_core::TextEdit], cache: &'a mut impl LexerCache<HlslLanguage>) -> LexOutput<HlslLanguage> {
736        let mut state = State::new(source);
737        let result = self.run(&mut state);
738        if result.is_ok() {
739            state.add_eof();
740        }
741        state.finish_with_cache(result, cache)
742    }
743}