1use crate::{kind::HlslSyntaxKind, language::HlslLanguage};
2use oak_core::{Lexer, LexerCache, LexerState, OakError, lexer::LexOutput, source::Source};
3
4type State<'a, S> = LexerState<'a, S, HlslLanguage>;
5
6pub struct HlslLexer<'config> {
7 _config: &'config HlslLanguage,
8}
9
10impl<'config> Clone for HlslLexer<'config> {
11 fn clone(&self) -> Self {
12 Self { _config: self._config }
13 }
14}
15
16impl<'config> HlslLexer<'config> {
17 pub fn new(config: &'config HlslLanguage) -> Self {
18 Self { _config: config }
19 }
20
21 fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
23 while state.not_at_end() {
24 let safe_point = state.get_position();
25
26 if self.skip_whitespace(state) {
28 continue;
29 }
30
31 if self.lex_newline(state) {
32 continue;
33 }
34
35 if self.lex_comment(state) {
36 continue;
37 }
38
39 if self.lex_preprocessor(state) {
40 continue;
41 }
42
43 if self.lex_string(state) {
44 continue;
45 }
46
47 if self.lex_number(state) {
48 continue;
49 }
50
51 if self.lex_identifier_or_keyword(state) {
52 continue;
53 }
54
55 if self.lex_operator_or_delimiter(state) {
56 continue;
57 }
58
59 let start_pos = state.get_position();
61 if let Some(ch) = state.peek() {
62 state.advance(ch.len_utf8());
63 state.add_token(HlslSyntaxKind::Error, start_pos, state.get_position());
64 }
65
66 state.advance_if_dead_lock(safe_point);
67 }
68
69 Ok(())
70 }
71
72 fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
74 let start_pos = state.get_position();
75
76 while let Some(ch) = state.peek() {
77 if ch == ' ' || ch == '\t' {
78 state.advance(ch.len_utf8());
79 }
80 else {
81 break;
82 }
83 }
84
85 if state.get_position() > start_pos {
86 state.add_token(HlslSyntaxKind::Whitespace, start_pos, state.get_position());
87 true
88 }
89 else {
90 false
91 }
92 }
93
94 fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
96 let start_pos = state.get_position();
97
98 if let Some('\n') = state.peek() {
99 state.advance(1);
100 state.add_token(HlslSyntaxKind::Newline, start_pos, state.get_position());
101 true
102 }
103 else if let Some('\r') = state.peek() {
104 state.advance(1);
105 if let Some('\n') = state.peek() {
106 state.advance(1);
107 }
108 state.add_token(HlslSyntaxKind::Newline, start_pos, state.get_position());
109 true
110 }
111 else {
112 false
113 }
114 }
115
116 fn lex_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
118 let start_pos = state.get_position();
119
120 if let Some('/') = state.peek() {
122 if let Some('/') = state.peek_next_n(1) {
123 state.advance(2);
124 while let Some(ch) = state.peek() {
125 if ch == '\n' || ch == '\r' {
126 break;
127 }
128 state.advance(ch.len_utf8());
129 }
130 state.add_token(HlslSyntaxKind::Comment, start_pos, state.get_position());
131 return true;
132 }
133 }
134
135 if let Some('/') = state.peek() {
137 if let Some('*') = state.peek_next_n(1) {
138 state.advance(2);
139 while state.not_at_end() {
140 if let Some('*') = state.peek() {
141 if let Some('/') = state.peek_next_n(1) {
142 state.advance(2);
143 break;
144 }
145 }
146 if let Some(ch) = state.peek() {
147 state.advance(ch.len_utf8());
148 }
149 }
150 state.add_token(HlslSyntaxKind::Comment, start_pos, state.get_position());
151 return true;
152 }
153 }
154
155 false
156 }
157
158 fn lex_preprocessor<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
160 let start_pos = state.get_position();
161
162 if let Some('#') = state.peek() {
163 state.advance(1);
164
165 while let Some(ch) = state.peek() {
167 if ch == ' ' || ch == '\t' {
168 state.advance(1);
169 }
170 else {
171 break;
172 }
173 }
174
175 let directive_start = state.get_position();
177 while let Some(ch) = state.peek() {
178 if ch.is_ascii_alphabetic() || ch == '_' {
179 state.advance(1);
180 }
181 else {
182 break;
183 }
184 }
185
186 if state.get_position() > directive_start {
187 let directive = state.get_text_in((directive_start..state.get_position()).into()).to_string();
188
189 while let Some(ch) = state.peek() {
191 if ch == '\n' || ch == '\r' {
192 break;
193 }
194 state.advance(ch.len_utf8());
195 }
196
197 let token_kind = match directive.as_str() {
198 "include" => HlslSyntaxKind::Include,
199 "define" => HlslSyntaxKind::Define,
200 "undef" => HlslSyntaxKind::Undef,
201 "if" => HlslSyntaxKind::If_,
202 "ifdef" => HlslSyntaxKind::Ifdef,
203 "ifndef" => HlslSyntaxKind::Ifndef,
204 "else" => HlslSyntaxKind::Else_,
205 "elif" => HlslSyntaxKind::Elif,
206 "endif" => HlslSyntaxKind::Endif,
207 "line" => HlslSyntaxKind::Line,
208 "error" => HlslSyntaxKind::Error,
209 "pragma" => HlslSyntaxKind::Pragma,
210 _ => HlslSyntaxKind::Hash,
211 };
212
213 state.add_token(token_kind, start_pos, state.get_position());
214 return true;
215 }
216 else {
217 state.add_token(HlslSyntaxKind::Hash, start_pos, state.get_position());
219 return true;
220 }
221 }
222
223 false
224 }
225
226 fn lex_string<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
228 let start_pos = state.get_position();
229
230 if let Some('"') = state.peek() {
231 state.advance(1);
232 let mut escaped = false;
233
234 while let Some(ch) = state.peek() {
235 if escaped {
236 escaped = false;
237 }
238 else if ch == '\\' {
239 escaped = true;
240 }
241 else if ch == '"' {
242 state.advance(1);
243 break;
244 }
245 else if ch == '\n' || ch == '\r' {
246 break; }
248 state.advance(ch.len_utf8());
249 }
250
251 state.add_token(HlslSyntaxKind::StringLiteral, start_pos, state.get_position());
252 return true;
253 }
254
255 false
256 }
257
258 fn lex_number<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
260 let start_pos = state.get_position();
261
262 if let Some(ch) = state.peek() {
263 if ch.is_ascii_digit() || (ch == '.' && state.peek_next_n(1).map_or(false, |c| c.is_ascii_digit())) {
264 if ch == '0' && state.peek_next_n(1) == Some('x') {
266 state.advance(2);
267 while let Some(ch) = state.peek() {
268 if ch.is_ascii_hexdigit() {
269 state.advance(1);
270 }
271 else {
272 break;
273 }
274 }
275 }
276 else {
277 while let Some(ch) = state.peek() {
279 if ch.is_ascii_digit() {
280 state.advance(1);
281 }
282 else {
283 break;
284 }
285 }
286
287 if let Some('.') = state.peek() {
289 if state.peek_next_n(1).map_or(false, |c| c.is_ascii_digit()) {
290 state.advance(1);
291 while let Some(ch) = state.peek() {
292 if ch.is_ascii_digit() {
293 state.advance(1);
294 }
295 else {
296 break;
297 }
298 }
299 }
300 }
301
302 if let Some(e_char) = state.peek() {
304 if e_char == 'e' || e_char == 'E' {
305 let saved_pos = state.get_position();
306 state.advance(1);
307
308 if let Some(sign) = state.peek() {
310 if sign == '+' || sign == '-' {
311 state.advance(1);
312 }
313 }
314
315 let exp_start = state.get_position();
317 while let Some(ch) = state.peek() {
318 if ch.is_ascii_digit() {
319 state.advance(1);
320 }
321 else {
322 break;
323 }
324 }
325
326 if state.get_position() == exp_start {
327 state.set_position(saved_pos);
329 }
330 }
331 }
332 }
333
334 if let Some(suffix) = state.peek() {
336 if suffix == 'f' || suffix == 'F' || suffix == 'h' || suffix == 'H' || suffix == 'l' || suffix == 'L' || suffix == 'u' || suffix == 'U' {
337 state.advance(1);
338 }
339 }
340
341 state.add_token(HlslSyntaxKind::NumberLiteral, start_pos, state.get_position());
342 return true;
343 }
344 }
345
346 false
347 }
348
349 fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
351 let start_pos = state.get_position();
352
353 if let Some(ch) = state.peek() {
354 if ch.is_ascii_alphabetic() || ch == '_' {
355 while let Some(ch) = state.peek() {
356 if ch.is_ascii_alphanumeric() || ch == '_' {
357 state.advance(ch.len_utf8());
358 }
359 else {
360 break;
361 }
362 }
363
364 let text = state.get_text_in((start_pos..state.get_position()).into());
365 let token_kind = match text.as_ref() {
366 "bool" => HlslSyntaxKind::Bool,
368 "int" => HlslSyntaxKind::Int,
369 "uint" => HlslSyntaxKind::Uint,
370 "half" => HlslSyntaxKind::Half,
371 "float" => HlslSyntaxKind::Float,
372 "double" => HlslSyntaxKind::Double,
373 "min16float" => HlslSyntaxKind::Min16float,
374 "min10float" => HlslSyntaxKind::Min10float,
375 "min16int" => HlslSyntaxKind::Min16int,
376 "min12int" => HlslSyntaxKind::Min12int,
377 "min16uint" => HlslSyntaxKind::Min16uint,
378
379 "bool2" => HlslSyntaxKind::Bool2,
381 "bool3" => HlslSyntaxKind::Bool3,
382 "bool4" => HlslSyntaxKind::Bool4,
383 "int2" => HlslSyntaxKind::Int2,
384 "int3" => HlslSyntaxKind::Int3,
385 "int4" => HlslSyntaxKind::Int4,
386 "uint2" => HlslSyntaxKind::Uint2,
387 "uint3" => HlslSyntaxKind::Uint3,
388 "uint4" => HlslSyntaxKind::Uint4,
389 "half2" => HlslSyntaxKind::Half2,
390 "half3" => HlslSyntaxKind::Half3,
391 "half4" => HlslSyntaxKind::Half4,
392 "float2" => HlslSyntaxKind::Float2,
393 "float3" => HlslSyntaxKind::Float3,
394 "float4" => HlslSyntaxKind::Float4,
395 "double2" => HlslSyntaxKind::Double2,
396 "double3" => HlslSyntaxKind::Double3,
397 "double4" => HlslSyntaxKind::Double4,
398
399 "float2x2" => HlslSyntaxKind::Float2x2,
401 "float2x3" => HlslSyntaxKind::Float2x3,
402 "float2x4" => HlslSyntaxKind::Float2x4,
403 "float3x2" => HlslSyntaxKind::Float3x2,
404 "float3x3" => HlslSyntaxKind::Float3x3,
405 "float3x4" => HlslSyntaxKind::Float3x4,
406 "float4x2" => HlslSyntaxKind::Float4x2,
407 "float4x3" => HlslSyntaxKind::Float4x3,
408 "float4x4" => HlslSyntaxKind::Float4x4,
409 "double2x2" => HlslSyntaxKind::Double2x2,
410 "double2x3" => HlslSyntaxKind::Double2x3,
411 "double2x4" => HlslSyntaxKind::Double2x4,
412 "double3x2" => HlslSyntaxKind::Double3x2,
413 "double3x3" => HlslSyntaxKind::Double3x3,
414 "double3x4" => HlslSyntaxKind::Double3x4,
415 "double4x2" => HlslSyntaxKind::Double4x2,
416 "double4x3" => HlslSyntaxKind::Double4x3,
417 "double4x4" => HlslSyntaxKind::Double4x4,
418
419 "Texture1D" => HlslSyntaxKind::Texture1D,
421 "Texture1DArray" => HlslSyntaxKind::Texture1DArray,
422 "Texture2D" => HlslSyntaxKind::Texture2D,
423 "Texture2DArray" => HlslSyntaxKind::Texture2DArray,
424 "Texture2DMS" => HlslSyntaxKind::Texture2DMS,
425 "Texture2DMSArray" => HlslSyntaxKind::Texture2DMSArray,
426 "Texture3D" => HlslSyntaxKind::Texture3D,
427 "TextureCube" => HlslSyntaxKind::TextureCube,
428 "TextureCubeArray" => HlslSyntaxKind::TextureCubeArray,
429
430 "sampler" => HlslSyntaxKind::Sampler,
432 "SamplerState" => HlslSyntaxKind::SamplerState,
433 "SamplerComparisonState" => HlslSyntaxKind::SamplerComparisonState,
434
435 "Buffer" => HlslSyntaxKind::Buffer,
437 "StructuredBuffer" => HlslSyntaxKind::StructuredBuffer,
438 "ByteAddressBuffer" => HlslSyntaxKind::ByteAddressBuffer,
439 "RWBuffer" => HlslSyntaxKind::RWBuffer,
440 "RWStructuredBuffer" => HlslSyntaxKind::RWStructuredBuffer,
441 "RWByteAddressBuffer" => HlslSyntaxKind::RWByteAddressBuffer,
442 "AppendStructuredBuffer" => HlslSyntaxKind::AppendStructuredBuffer,
443 "ConsumeStructuredBuffer" => HlslSyntaxKind::ConsumeStructuredBuffer,
444
445 "if" => HlslSyntaxKind::If,
447 "else" => HlslSyntaxKind::Else,
448 "for" => HlslSyntaxKind::For,
449 "while" => HlslSyntaxKind::While,
450 "do" => HlslSyntaxKind::Do,
451 "switch" => HlslSyntaxKind::Switch,
452 "case" => HlslSyntaxKind::Case,
453 "default" => HlslSyntaxKind::Default,
454 "break" => HlslSyntaxKind::Break,
455 "continue" => HlslSyntaxKind::Continue,
456 "return" => HlslSyntaxKind::Return,
457 "discard" => HlslSyntaxKind::Discard,
458
459 "static" => HlslSyntaxKind::Static,
461 "const" => HlslSyntaxKind::Const,
462 "volatile" => HlslSyntaxKind::Volatile,
463 "extern" => HlslSyntaxKind::Extern,
464 "shared" => HlslSyntaxKind::Shared,
465 "groupshared" => HlslSyntaxKind::Groupshared,
466 "uniform" => HlslSyntaxKind::Uniform,
467 "in" => HlslSyntaxKind::In,
468 "out" => HlslSyntaxKind::Out,
469 "inout" => HlslSyntaxKind::Inout,
470 "inline" => HlslSyntaxKind::Inline,
471 "target" => HlslSyntaxKind::Target,
472
473 "register" => HlslSyntaxKind::Register,
475 "packoffset" => HlslSyntaxKind::Packoffset,
476
477 "struct" => HlslSyntaxKind::Struct,
479 "cbuffer" => HlslSyntaxKind::Cbuffer,
480 "tbuffer" => HlslSyntaxKind::Tbuffer,
481 "interface" => HlslSyntaxKind::Interface,
482 "class" => HlslSyntaxKind::Class,
483
484 "true" | "false" => HlslSyntaxKind::BooleanLiteral,
486
487 _ => HlslSyntaxKind::Identifier,
488 };
489
490 state.add_token(token_kind, start_pos, state.get_position());
491 return true;
492 }
493 }
494
495 false
496 }
497
498 fn lex_operator_or_delimiter<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
500 let start_pos = state.get_position();
501
502 if let Some(ch) = state.peek() {
503 let token_kind = match ch {
504 '+' => {
505 state.advance(1);
506 if let Some('=') = state.peek() {
507 state.advance(1);
508 HlslSyntaxKind::PlusAssign
509 }
510 else if let Some('+') = state.peek() {
511 state.advance(1);
512 HlslSyntaxKind::Increment
513 }
514 else {
515 HlslSyntaxKind::Plus
516 }
517 }
518 '-' => {
519 state.advance(1);
520 if let Some('=') = state.peek() {
521 state.advance(1);
522 HlslSyntaxKind::MinusAssign
523 }
524 else if let Some('-') = state.peek() {
525 state.advance(1);
526 HlslSyntaxKind::Decrement
527 }
528 else if let Some('>') = state.peek() {
529 state.advance(1);
530 HlslSyntaxKind::Arrow
531 }
532 else {
533 HlslSyntaxKind::Minus
534 }
535 }
536 '*' => {
537 state.advance(1);
538 if let Some('=') = state.peek() {
539 state.advance(1);
540 HlslSyntaxKind::MultiplyAssign
541 }
542 else {
543 HlslSyntaxKind::Multiply
544 }
545 }
546 '/' => {
547 state.advance(1);
548 if let Some('=') = state.peek() {
549 state.advance(1);
550 HlslSyntaxKind::DivideAssign
551 }
552 else {
553 HlslSyntaxKind::Divide
554 }
555 }
556 '%' => {
557 state.advance(1);
558 if let Some('=') = state.peek() {
559 state.advance(1);
560 HlslSyntaxKind::ModuloAssign
561 }
562 else {
563 HlslSyntaxKind::Modulo
564 }
565 }
566 '=' => {
567 state.advance(1);
568 if let Some('=') = state.peek() {
569 state.advance(1);
570 HlslSyntaxKind::Equal
571 }
572 else {
573 HlslSyntaxKind::Assign
574 }
575 }
576 '!' => {
577 state.advance(1);
578 if let Some('=') = state.peek() {
579 state.advance(1);
580 HlslSyntaxKind::NotEqual
581 }
582 else {
583 HlslSyntaxKind::LogicalNot
584 }
585 }
586 '<' => {
587 state.advance(1);
588 if let Some('=') = state.peek() {
589 state.advance(1);
590 HlslSyntaxKind::LessEqual
591 }
592 else if let Some('<') = state.peek() {
593 state.advance(1);
594 if let Some('=') = state.peek() {
595 state.advance(1);
596 HlslSyntaxKind::LeftShiftAssign
597 }
598 else {
599 HlslSyntaxKind::LeftShift
600 }
601 }
602 else {
603 HlslSyntaxKind::Less
604 }
605 }
606 '>' => {
607 state.advance(1);
608 if let Some('=') = state.peek() {
609 state.advance(1);
610 HlslSyntaxKind::GreaterEqual
611 }
612 else if let Some('>') = state.peek() {
613 state.advance(1);
614 if let Some('=') = state.peek() {
615 state.advance(1);
616 HlslSyntaxKind::RightShiftAssign
617 }
618 else {
619 HlslSyntaxKind::RightShift
620 }
621 }
622 else {
623 HlslSyntaxKind::Greater
624 }
625 }
626 '&' => {
627 state.advance(1);
628 if let Some('&') = state.peek() {
629 state.advance(1);
630 HlslSyntaxKind::LogicalAnd
631 }
632 else if let Some('=') = state.peek() {
633 state.advance(1);
634 HlslSyntaxKind::BitwiseAndAssign
635 }
636 else {
637 HlslSyntaxKind::BitwiseAnd
638 }
639 }
640 '|' => {
641 state.advance(1);
642 if let Some('|') = state.peek() {
643 state.advance(1);
644 HlslSyntaxKind::LogicalOr
645 }
646 else if let Some('=') = state.peek() {
647 state.advance(1);
648 HlslSyntaxKind::BitwiseOrAssign
649 }
650 else {
651 HlslSyntaxKind::BitwiseOr
652 }
653 }
654 '^' => {
655 state.advance(1);
656 if let Some('=') = state.peek() {
657 state.advance(1);
658 HlslSyntaxKind::BitwiseXorAssign
659 }
660 else {
661 HlslSyntaxKind::BitwiseXor
662 }
663 }
664 '~' => {
665 state.advance(1);
666 HlslSyntaxKind::BitwiseNot
667 }
668 '?' => {
669 state.advance(1);
670 HlslSyntaxKind::Conditional
671 }
672 '.' => {
673 state.advance(1);
674 HlslSyntaxKind::Dot
675 }
676 ':' => {
677 state.advance(1);
678 if let Some(':') = state.peek() {
679 state.advance(1);
680 HlslSyntaxKind::DoubleColon
681 }
682 else {
683 HlslSyntaxKind::Colon
684 }
685 }
686 ';' => {
687 state.advance(1);
688 HlslSyntaxKind::Semicolon
689 }
690 ',' => {
691 state.advance(1);
692 HlslSyntaxKind::Comma
693 }
694 '(' => {
695 state.advance(1);
696 HlslSyntaxKind::LeftParen
697 }
698 ')' => {
699 state.advance(1);
700 HlslSyntaxKind::RightParen
701 }
702 '[' => {
703 state.advance(1);
704 HlslSyntaxKind::LeftBracket
705 }
706 ']' => {
707 state.advance(1);
708 HlslSyntaxKind::RightBracket
709 }
710 '{' => {
711 state.advance(1);
712 HlslSyntaxKind::LeftBrace
713 }
714 '}' => {
715 state.advance(1);
716 HlslSyntaxKind::RightBrace
717 }
718 '\\' => {
719 state.advance(1);
720 HlslSyntaxKind::Backslash
721 }
722 _ => return false,
723 };
724
725 state.add_token(token_kind, start_pos, state.get_position());
726 true
727 }
728 else {
729 false
730 }
731 }
732}
733
734impl<'config> Lexer<HlslLanguage> for HlslLexer<'config> {
735 fn lex<'a, S: Source + ?Sized>(&self, source: &S, _edits: &[oak_core::TextEdit], cache: &'a mut impl LexerCache<HlslLanguage>) -> LexOutput<HlslLanguage> {
736 let mut state = State::new(source);
737 let result = self.run(&mut state);
738 if result.is_ok() {
739 state.add_eof();
740 }
741 state.finish_with_cache(result, cache)
742 }
743}