1use crate::error::ParseError;
4use itertools::Itertools;
5use logos::{Logos, SpannedIter};
6use std::{fmt::Display, num::NonZeroU8, sync::LazyLock};
7use wgsl_types::idents::RESERVED_WORDS;
8
9type Span = std::ops::Range<usize>;
10
11fn maybe_template_end(
12 lex: &mut logos::Lexer<Token>,
13 current: Token,
14 lookahead: Option<Token>,
15) -> Token {
16 if let Some(depth) = lex.extras.template_depths.last() {
17 if lex.extras.depth == *depth {
19 lex.extras.template_depths.pop();
20 if let Some(depth) = lex.extras.template_depths.last() {
23 if lex.extras.depth == *depth && lookahead == Some(Token::SymGreaterThan) {
24 lex.extras.template_depths.pop();
25 lex.extras.lookahead = Some(Token::TemplateArgsEnd);
26 } else {
27 lex.extras.lookahead = lookahead;
28 }
29 } else {
30 lex.extras.lookahead = lookahead;
31 }
32 return Token::TemplateArgsEnd;
33 }
34 }
35
36 current
37}
38
39fn maybe_fail_template(lex: &mut logos::Lexer<Token>) -> bool {
42 if let Some(depth) = lex.extras.template_depths.last() {
43 if lex.extras.depth == *depth {
44 return false;
45 }
46 }
47 true
48}
49
50fn incr_depth(lex: &mut logos::Lexer<Token>) {
51 lex.extras.depth += 1;
52}
53
54fn decr_depth(lex: &mut logos::Lexer<Token>) {
55 lex.extras.depth -= 1;
56}
57
58const DEC_FORMAT: u128 = lexical::NumberFormatBuilder::new().build();
62
63const HEX_FORMAT: u128 = lexical::NumberFormatBuilder::new()
65 .mantissa_radix(16)
66 .base_prefix(NonZeroU8::new(b'x'))
67 .exponent_base(NonZeroU8::new(16))
68 .exponent_radix(NonZeroU8::new(10))
69 .build();
70
71static FLOAT_HEX_OPTIONS: LazyLock<lexical::parse_float_options::Options> = LazyLock::new(|| {
72 lexical::parse_float_options::OptionsBuilder::new()
73 .exponent(b'p')
74 .decimal_point(b'.')
75 .build()
76 .unwrap()
77});
78
79fn parse_dec_abstract_int(lex: &mut logos::Lexer<Token>) -> Option<i64> {
80 let options = &lexical::parse_integer_options::STANDARD;
81 let str = lex.slice();
82 lexical::parse_with_options::<i64, _, DEC_FORMAT>(str, options).ok()
83}
84
85fn parse_hex_abstract_int(lex: &mut logos::Lexer<Token>) -> Option<i64> {
86 let options = &lexical::parse_integer_options::STANDARD;
87 let str = lex.slice();
88 lexical::parse_with_options::<i64, _, HEX_FORMAT>(str, options).ok()
89}
90
91fn parse_dec_i32(lex: &mut logos::Lexer<Token>) -> Option<i32> {
92 let options = &lexical::parse_integer_options::STANDARD;
93 let str = lex.slice();
94 let str = &str[..str.len() - 1];
95 lexical::parse_with_options::<i32, _, DEC_FORMAT>(str, options).ok()
96}
97
98fn parse_hex_i32(lex: &mut logos::Lexer<Token>) -> Option<i32> {
99 let options = &lexical::parse_integer_options::STANDARD;
100 let str = lex.slice();
101 let str = &str[..str.len() - 1];
102 lexical::parse_with_options::<i32, _, HEX_FORMAT>(str, options).ok()
103}
104
105fn parse_dec_u32(lex: &mut logos::Lexer<Token>) -> Option<u32> {
106 let options = &lexical::parse_integer_options::STANDARD;
107 let str = lex.slice();
108 let str = &str[..str.len() - 1];
109 lexical::parse_with_options::<u32, _, DEC_FORMAT>(str, options).ok()
110}
111
112fn parse_hex_u32(lex: &mut logos::Lexer<Token>) -> Option<u32> {
113 let options = &lexical::parse_integer_options::STANDARD;
114 let str = lex.slice();
115 let str = &str[..str.len() - 1];
116 lexical::parse_with_options::<u32, _, HEX_FORMAT>(str, options).ok()
117}
118
119fn parse_dec_abs_float(lex: &mut logos::Lexer<Token>) -> Option<f64> {
120 let options = &lexical::parse_float_options::STANDARD;
121 let str = lex.slice();
122 lexical::parse_with_options::<f64, _, DEC_FORMAT>(str, options).ok()
123}
124
125fn parse_hex_abs_float(lex: &mut logos::Lexer<Token>) -> Option<f64> {
126 let str = lex.slice();
127 lexical::parse_with_options::<f64, _, HEX_FORMAT>(str, &FLOAT_HEX_OPTIONS).ok()
128}
129
130fn parse_dec_f32(lex: &mut logos::Lexer<Token>) -> Option<f32> {
131 let options = &lexical::parse_float_options::STANDARD;
132 let str = lex.slice();
133 let str = &str[..str.len() - 1];
134 lexical::parse_with_options::<f32, _, DEC_FORMAT>(str, options).ok()
135}
136
137fn parse_hex_f32(lex: &mut logos::Lexer<Token>) -> Option<f32> {
138 let str = lex.slice();
139 let options = &lexical::parse_float_options::STANDARD;
141 let str = &str[..str.len() - 1];
142 lexical::parse_with_options::<f32, _, HEX_FORMAT>(str, options).ok()
143}
144
145fn parse_dec_f16(lex: &mut logos::Lexer<Token>) -> Option<f32> {
146 let options = &lexical::parse_float_options::STANDARD;
147 let str = lex.slice();
148 let str = &str[..str.len() - 1];
149 lexical::parse_with_options::<f32, _, DEC_FORMAT>(str, options).ok()
150}
151
152fn parse_hex_f16(lex: &mut logos::Lexer<Token>) -> Option<f32> {
153 let str = lex.slice();
154 let str = &str[..str.len() - 1];
155 lexical::parse_with_options::<f32, _, HEX_FORMAT>(str, &FLOAT_HEX_OPTIONS).ok()
156}
157
158#[cfg(feature = "naga-ext")]
159fn parse_dec_i64(lex: &mut logos::Lexer<Token>) -> Option<i64> {
160 let options = &lexical::parse_integer_options::STANDARD;
161 let str = lex.slice();
162 let str = &str[..str.len() - 2];
163 lexical::parse_with_options::<i64, _, DEC_FORMAT>(str, options).ok()
164}
165
166#[cfg(feature = "naga-ext")]
167fn parse_hex_i64(lex: &mut logos::Lexer<Token>) -> Option<i64> {
168 let options = &lexical::parse_integer_options::STANDARD;
169 let str = lex.slice();
170 let str = &str[..str.len() - 2];
171 lexical::parse_with_options::<i64, _, HEX_FORMAT>(str, options).ok()
172}
173
174#[cfg(feature = "naga-ext")]
175fn parse_dec_u64(lex: &mut logos::Lexer<Token>) -> Option<u64> {
176 let options = &lexical::parse_integer_options::STANDARD;
177 let str = lex.slice();
178 let str = &str[..str.len() - 2];
179 lexical::parse_with_options::<u64, _, DEC_FORMAT>(str, options).ok()
180}
181
182#[cfg(feature = "naga-ext")]
183fn parse_hex_u64(lex: &mut logos::Lexer<Token>) -> Option<u64> {
184 let options = &lexical::parse_integer_options::STANDARD;
185 let str = lex.slice();
186 let str = &str[..str.len() - 2];
187 lexical::parse_with_options::<u64, _, HEX_FORMAT>(str, options).ok()
188}
189
190#[cfg(feature = "naga-ext")]
191fn parse_dec_f64(lex: &mut logos::Lexer<Token>) -> Option<f64> {
192 let options = &lexical::parse_float_options::STANDARD;
193 let str = lex.slice();
194 let str = &str[..str.len() - 2];
195 lexical::parse_with_options::<f64, _, DEC_FORMAT>(str, options).ok()
196}
197
198#[cfg(feature = "naga-ext")]
199fn parse_hex_f64(lex: &mut logos::Lexer<Token>) -> Option<f64> {
200 let str = lex.slice();
201 let options = &lexical::parse_float_options::STANDARD;
203 let str = &str[..str.len() - 2];
204 lexical::parse_with_options::<f64, _, HEX_FORMAT>(str, options).ok()
205}
206
207fn parse_line_comment(lex: &mut logos::Lexer<Token>) {
208 let rem = lex.remainder();
209 let line_end = rem
211 .char_indices()
212 .find(|(_, c)| "\n\u{000B}\u{000C}\r\u{0085}\u{2028}\u{2029}".contains(*c))
213 .map(|(i, _)| i)
214 .unwrap_or(rem.len());
215 lex.bump(line_end);
216}
217
218fn parse_block_comment(lex: &mut logos::Lexer<Token>) {
219 let mut depth = 1;
220 while depth > 0 {
221 let rem = lex.remainder();
222 if rem.is_empty() {
223 break;
224 } else if rem.starts_with("/*") {
225 lex.bump(2);
226 depth += 1;
227 } else if rem.starts_with("*/") {
228 lex.bump(2);
229 depth -= 1;
230 } else {
231 let mut next_char = 1;
232 while !rem.is_char_boundary(next_char) {
233 next_char += 1;
234 }
235 lex.bump(next_char);
236 }
237 }
238}
239
240fn parse_ident(lex: &mut logos::Lexer<Token>) -> Token {
241 let ident = lex.slice().to_string();
242 if RESERVED_WORDS.iter().contains(&ident.as_str()) {
243 Token::ReservedWord(ident)
244 } else {
245 Token::Ident(ident)
246 }
247}
248
249#[derive(Default, Clone, Debug, PartialEq)]
250pub struct LexerState {
251 depth: i32,
252 template_depths: Vec<i32>,
253 lookahead: Option<Token>,
254}
255
256#[derive(Logos, Clone, Debug, PartialEq)]
258#[logos(
259 skip r"[\s\u0085\u200e\u200f\u2028\u2029]+", extras = LexerState,
262 error = ParseError)]
263pub enum Token {
264 #[token("//", parse_line_comment)]
265 LineComment,
266 #[token("/*", parse_block_comment, priority = 2)]
267 BlockComment,
268 #[regex(
272 r#"([_\p{XID_Start}][\p{XID_Continue}]+)|([\p{XID_Start}])"#,
273 parse_ident,
274 priority = 1
275 )]
276 Ignored,
277 #[token("&")]
280 SymAnd,
281 #[token("&&", maybe_fail_template)]
282 SymAndAnd,
283 #[token("->")]
284 SymArrow,
285 #[token("@")]
286 SymAttr,
287 #[token("/")]
288 SymForwardSlash,
289 #[token("!")]
290 SymBang,
291 #[token("[", incr_depth)]
292 SymBracketLeft,
293 #[token("]", decr_depth)]
294 SymBracketRight,
295 #[token("{")]
296 SymBraceLeft,
297 #[token("}")]
298 SymBraceRight,
299 #[token(":")]
300 SymColon,
301 #[token(",")]
302 SymComma,
303 #[token("=")]
304 SymEqual,
305 #[token("==")]
306 SymEqualEqual,
307 #[token("!=")]
308 SymNotEqual,
309 #[token(">", |lex| maybe_template_end(lex, Token::SymGreaterThan, None))]
310 SymGreaterThan,
311 #[token(">=", |lex| maybe_template_end(lex, Token::SymGreaterThanEqual, Some(Token::SymEqual)))]
312 SymGreaterThanEqual,
313 #[token(">>", |lex| maybe_template_end(lex, Token::SymShiftRight, Some(Token::SymGreaterThan)))]
314 SymShiftRight,
315 #[token("<")]
316 SymLessThan,
317 #[token("<=")]
318 SymLessThanEqual,
319 #[token("<<")]
320 SymShiftLeft,
321 #[token("%")]
322 SymModulo,
323 #[token("-")]
324 SymMinus,
325 #[token("--")]
326 SymMinusMinus,
327 #[token(".")]
328 SymPeriod,
329 #[token("+")]
330 SymPlus,
331 #[token("++")]
332 SymPlusPlus,
333 #[token("|")]
334 SymOr,
335 #[token("||", maybe_fail_template)]
336 SymOrOr,
337 #[token("(", incr_depth)]
338 SymParenLeft,
339 #[token(")", decr_depth)]
340 SymParenRight,
341 #[token(";")]
342 SymSemicolon,
343 #[token("*")]
344 SymStar,
345 #[token("~")]
346 SymTilde,
347 #[token("_")]
348 SymUnderscore,
349 #[token("^")]
350 SymXor,
351 #[token("+=")]
352 SymPlusEqual,
353 #[token("-=")]
354 SymMinusEqual,
355 #[token("*=")]
356 SymTimesEqual,
357 #[token("/=")]
358 SymDivisionEqual,
359 #[token("%=")]
360 SymModuloEqual,
361 #[token("&=")]
362 SymAndEqual,
363 #[token("|=")]
364 SymOrEqual,
365 #[token("^=")]
366 SymXorEqual,
367 #[token(">>=", |lex| maybe_template_end(lex, Token::SymShiftRightAssign, Some(Token::SymGreaterThanEqual)))]
368 SymShiftRightAssign,
369 #[token("<<=")]
370 SymShiftLeftAssign,
371
372 #[token("alias")]
375 KwAlias,
376 #[token("break")]
377 KwBreak,
378 #[token("case")]
379 KwCase,
380 #[token("const", priority = 2)]
381 KwConst,
382 #[token("const_assert")]
383 KwConstAssert,
384 #[token("continue")]
385 KwContinue,
386 #[token("continuing")]
387 KwContinuing,
388 #[token("default")]
389 KwDefault,
390 #[token("diagnostic")]
391 KwDiagnostic,
392 #[token("discard")]
393 KwDiscard,
394 #[token("else")]
395 KwElse,
396 #[token("enable")]
397 KwEnable,
398 #[token("false")]
399 KwFalse,
400 #[token("fn")]
401 KwFn,
402 #[token("for")]
403 KwFor,
404 #[token("if")]
405 KwIf,
406 #[token("let")]
407 KwLet,
408 #[token("loop")]
409 KwLoop,
410 #[token("override")]
411 KwOverride,
412 #[token("requires")]
413 KwRequires,
414 #[token("return")]
415 KwReturn,
416 #[token("struct")]
417 KwStruct,
418 #[token("switch")]
419 KwSwitch,
420 #[token("true")]
421 KwTrue,
422 #[token("var")]
423 KwVar,
424 #[token("while")]
425 KwWhile,
426
427 Ident(String),
430 ReservedWord(String),
433
434 #[regex(r#"0|[1-9]\d*"#, parse_dec_abstract_int)]
435 #[regex(r#"0[xX][\da-fA-F]+"#, parse_hex_abstract_int)]
436 AbstractInt(i64),
437 #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?"#, parse_dec_abs_float)]
438 #[regex(r#"\d+[eE][+-]?\d+"#, parse_dec_abs_float)]
439 #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*([pP][+-]?\d+)?"#, parse_hex_abs_float)]
440 #[regex(r#"0[xX]\.[\da-fA-F]+([pP][+-]?\d+)?"#, parse_hex_abs_float)]
441 #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+"#, parse_hex_abs_float)]
442 AbstractFloat(f64),
444 #[regex(r#"(0|[1-9]\d*)i"#, parse_dec_i32)]
445 #[regex(r#"0[xX][\da-fA-F]+i"#, parse_hex_i32)]
446 I32(i32),
448 #[regex(r#"(0|[1-9]\d*)u"#, parse_dec_u32)]
449 #[regex(r#"0[xX][\da-fA-F]+u"#, parse_hex_u32)]
450 U32(u32),
452 #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?f"#, parse_dec_f32)]
453 #[regex(r#"\d+([eE][+-]?\d+)?f"#, parse_dec_f32)]
454 #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+f"#, parse_hex_f32)]
455 #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+f"#, parse_hex_f32)]
456 #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+f"#, parse_hex_f32)]
457 F32(f32),
458 #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?h"#, parse_dec_f16)]
459 #[regex(r#"\d+([eE][+-]?\d+)?h"#, parse_dec_f16)]
460 #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+h"#, parse_hex_f16)]
461 #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+h"#, parse_hex_f16)]
462 #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+h"#, parse_hex_f16)]
463 F16(f32),
464 #[cfg(feature = "naga-ext")]
465 #[regex(r#"(0|[1-9]\d*)li"#, parse_dec_i64)]
466 #[regex(r#"0[xX][\da-fA-F]+li"#, parse_hex_i64)]
467 I64(i64),
469 #[cfg(feature = "naga-ext")]
470 #[regex(r#"(0|[1-9]\d*)lu"#, parse_dec_u64)]
471 #[regex(r#"0[xX][\da-fA-F]+lu"#, parse_hex_u64)]
472 U64(u64),
474 #[cfg(feature = "naga-ext")]
475 #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?lf"#, parse_dec_f64)]
476 #[regex(r#"\d+([eE][+-]?\d+)?lf"#, parse_dec_f64)]
477 #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+lf"#, parse_hex_f64)]
478 #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+lf"#, parse_hex_f64)]
479 #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+lf"#, parse_hex_f64)]
480 F64(f64),
481 TemplateArgsStart,
482 TemplateArgsEnd,
483
484 #[cfg(feature = "imports")]
488 #[token("::")]
489 SymColonColon,
490 #[cfg(feature = "imports")]
491 #[token("self")]
492 KwSelf,
493 #[cfg(feature = "imports")]
494 #[token("super")]
495 KwSuper,
496 #[cfg(feature = "imports")]
497 #[token("package")]
498 KwPackage,
499 #[cfg(feature = "imports")]
500 #[token("as")]
501 KwAs,
502 #[cfg(feature = "imports")]
503 #[token("import")]
504 KwImport,
505}
506
507impl Token {
508 pub fn is_trivia(&self) -> bool {
509 matches!(
510 self,
511 Token::LineComment | Token::BlockComment | Token::Ignored
512 )
513 }
514
515 #[allow(unused)]
516 pub fn is_symbol(&self) -> bool {
517 matches!(
518 self,
519 Token::SymAnd
520 | Token::SymAndAnd
521 | Token::SymArrow
522 | Token::SymAttr
523 | Token::SymForwardSlash
524 | Token::SymBang
525 | Token::SymBracketLeft
526 | Token::SymBracketRight
527 | Token::SymBraceLeft
528 | Token::SymBraceRight
529 | Token::SymColon
530 | Token::SymComma
531 | Token::SymEqual
532 | Token::SymEqualEqual
533 | Token::SymNotEqual
534 | Token::SymGreaterThan
535 | Token::SymGreaterThanEqual
536 | Token::SymShiftRight
537 | Token::SymLessThan
538 | Token::SymLessThanEqual
539 | Token::SymShiftLeft
540 | Token::SymModulo
541 | Token::SymMinus
542 | Token::SymMinusMinus
543 | Token::SymPeriod
544 | Token::SymPlus
545 | Token::SymPlusPlus
546 | Token::SymOr
547 | Token::SymOrOr
548 | Token::SymParenLeft
549 | Token::SymParenRight
550 | Token::SymSemicolon
551 | Token::SymStar
552 | Token::SymTilde
553 | Token::SymUnderscore
554 | Token::SymXor
555 | Token::SymPlusEqual
556 | Token::SymMinusEqual
557 | Token::SymTimesEqual
558 | Token::SymDivisionEqual
559 | Token::SymModuloEqual
560 | Token::SymAndEqual
561 | Token::SymOrEqual
562 | Token::SymXorEqual
563 | Token::SymShiftRightAssign
564 | Token::SymShiftLeftAssign
565 )
566 }
567
568 #[allow(unused)]
569 pub fn is_keyword(&self) -> bool {
570 matches!(
571 self,
572 Token::KwAlias
573 | Token::KwBreak
574 | Token::KwCase
575 | Token::KwConst
576 | Token::KwConstAssert
577 | Token::KwContinue
578 | Token::KwContinuing
579 | Token::KwDefault
580 | Token::KwDiagnostic
581 | Token::KwDiscard
582 | Token::KwElse
583 | Token::KwEnable
584 | Token::KwFalse
585 | Token::KwFn
586 | Token::KwFor
587 | Token::KwIf
588 | Token::KwLet
589 | Token::KwLoop
590 | Token::KwOverride
591 | Token::KwRequires
592 | Token::KwReturn
593 | Token::KwStruct
594 | Token::KwSwitch
595 | Token::KwTrue
596 | Token::KwVar
597 | Token::KwWhile
598 )
599 }
600
601 #[allow(unused)]
602 pub fn is_numeric_literal(&self) -> bool {
603 matches!(
604 self,
605 Token::AbstractInt(_)
606 | Token::AbstractFloat(_)
607 | Token::I32(_)
608 | Token::U32(_)
609 | Token::F32(_)
610 | Token::F16(_)
611 )
612 }
613}
614
615impl Display for Token {
616 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
618 match self {
619 Token::LineComment => f.write_str("// line comment"),
620 Token::BlockComment => f.write_str("/* block comment */"),
621 Token::Ignored => unreachable!(),
622 Token::SymAnd => f.write_str("&"),
623 Token::SymAndAnd => f.write_str("&&"),
624 Token::SymArrow => f.write_str("->"),
625 Token::SymAttr => f.write_str("@"),
626 Token::SymForwardSlash => f.write_str("/"),
627 Token::SymBang => f.write_str("!"),
628 Token::SymBracketLeft => f.write_str("["),
629 Token::SymBracketRight => f.write_str("]"),
630 Token::SymBraceLeft => f.write_str("{"),
631 Token::SymBraceRight => f.write_str("}"),
632 Token::SymColon => f.write_str(":"),
633 Token::SymComma => f.write_str(","),
634 Token::SymEqual => f.write_str("="),
635 Token::SymEqualEqual => f.write_str("=="),
636 Token::SymNotEqual => f.write_str("!="),
637 Token::SymGreaterThan => f.write_str(">"),
638 Token::SymGreaterThanEqual => f.write_str(">="),
639 Token::SymShiftRight => f.write_str(">>"),
640 Token::SymLessThan => f.write_str("<"),
641 Token::SymLessThanEqual => f.write_str("<="),
642 Token::SymShiftLeft => f.write_str("<<"),
643 Token::SymModulo => f.write_str("%"),
644 Token::SymMinus => f.write_str("-"),
645 Token::SymMinusMinus => f.write_str("--"),
646 Token::SymPeriod => f.write_str("."),
647 Token::SymPlus => f.write_str("+"),
648 Token::SymPlusPlus => f.write_str("++"),
649 Token::SymOr => f.write_str("|"),
650 Token::SymOrOr => f.write_str("||"),
651 Token::SymParenLeft => f.write_str("("),
652 Token::SymParenRight => f.write_str(")"),
653 Token::SymSemicolon => f.write_str(";"),
654 Token::SymStar => f.write_str("*"),
655 Token::SymTilde => f.write_str("~"),
656 Token::SymUnderscore => f.write_str("_"),
657 Token::SymXor => f.write_str("^"),
658 Token::SymPlusEqual => f.write_str("+="),
659 Token::SymMinusEqual => f.write_str("-="),
660 Token::SymTimesEqual => f.write_str("*="),
661 Token::SymDivisionEqual => f.write_str("/="),
662 Token::SymModuloEqual => f.write_str("%="),
663 Token::SymAndEqual => f.write_str("&="),
664 Token::SymOrEqual => f.write_str("|="),
665 Token::SymXorEqual => f.write_str("^="),
666 Token::SymShiftRightAssign => f.write_str(">>="),
667 Token::SymShiftLeftAssign => f.write_str("<<="),
668 Token::KwAlias => f.write_str("alias"),
669 Token::KwBreak => f.write_str("break"),
670 Token::KwCase => f.write_str("case"),
671 Token::KwConst => f.write_str("const"),
672 Token::KwConstAssert => f.write_str("const_assert"),
673 Token::KwContinue => f.write_str("continue"),
674 Token::KwContinuing => f.write_str("continuing"),
675 Token::KwDefault => f.write_str("default"),
676 Token::KwDiagnostic => f.write_str("diagnostic"),
677 Token::KwDiscard => f.write_str("discard"),
678 Token::KwElse => f.write_str("else"),
679 Token::KwEnable => f.write_str("enable"),
680 Token::KwFalse => f.write_str("false"),
681 Token::KwFn => f.write_str("fn"),
682 Token::KwFor => f.write_str("for"),
683 Token::KwIf => f.write_str("if"),
684 Token::KwLet => f.write_str("let"),
685 Token::KwLoop => f.write_str("loop"),
686 Token::KwOverride => f.write_str("override"),
687 Token::KwRequires => f.write_str("requires"),
688 Token::KwReturn => f.write_str("return"),
689 Token::KwStruct => f.write_str("struct"),
690 Token::KwSwitch => f.write_str("switch"),
691 Token::KwTrue => f.write_str("true"),
692 Token::KwVar => f.write_str("var"),
693 Token::KwWhile => f.write_str("while"),
694 Token::Ident(s) => write!(f, "identifier `{s}`"),
695 Token::ReservedWord(s) => write!(f, "reserved word `{s}`"),
696 Token::AbstractInt(n) => write!(f, "{n}"),
697 Token::AbstractFloat(n) => write!(f, "{n}"),
698 Token::I32(n) => write!(f, "{n}i"),
699 Token::U32(n) => write!(f, "{n}u"),
700 Token::F32(n) => write!(f, "{n}f"),
701 Token::F16(n) => write!(f, "{n}h"),
702 #[cfg(feature = "naga-ext")]
703 Token::I64(n) => write!(f, "{n}li"),
704 #[cfg(feature = "naga-ext")]
705 Token::U64(n) => write!(f, "{n}lu"),
706 #[cfg(feature = "naga-ext")]
707 Token::F64(n) => write!(f, "{n}lf"),
708 Token::TemplateArgsStart => f.write_str("start of template"),
709 Token::TemplateArgsEnd => f.write_str("end of template"),
710 #[cfg(feature = "imports")]
711 Token::SymColonColon => write!(f, "::"),
712 #[cfg(feature = "imports")]
713 Token::KwSelf => write!(f, "self"),
714 #[cfg(feature = "imports")]
715 Token::KwSuper => write!(f, "super"),
716 #[cfg(feature = "imports")]
717 Token::KwPackage => write!(f, "package"),
718 #[cfg(feature = "imports")]
719 Token::KwAs => write!(f, "as"),
720 #[cfg(feature = "imports")]
721 Token::KwImport => write!(f, "import"),
722 }
723 }
724}
725
726type Spanned<Tok, Loc, ParseError> = Result<(Loc, Tok, Loc), (Loc, ParseError, Loc)>;
727type NextToken = Option<(Result<Token, ParseError>, Span)>;
728
729#[derive(Clone)]
730pub struct Lexer<'s> {
731 source: &'s str,
732 token_stream: SpannedIter<'s, Token>,
733 next_token: NextToken,
734 recognizing_template: bool,
735 opened_templates: u32,
736}
737
738impl<'s> Lexer<'s> {
739 pub fn new(source: &'s str) -> Self {
740 let mut token_stream = Token::lexer_with_extras(source, LexerState::default()).spanned();
741 let next_token =
742 token_stream.find(|(tok, _)| tok.as_ref().is_ok_and(|tok| !tok.is_trivia()));
743
744 Self {
745 source,
746 token_stream,
747 next_token,
748 recognizing_template: false,
749 opened_templates: 0,
750 }
751 }
752
753 fn take_two_tokens(&mut self) -> (NextToken, NextToken) {
754 let mut tok1 = self.next_token.take();
755
756 let lookahead = self.token_stream.extras.lookahead.take();
757 let tok2 = match lookahead {
758 Some(tok) => {
759 let (_, span1) = tok1.as_mut().unwrap(); let span2 = span1.start + 1..span1.end;
761 Some((Ok(tok), span2))
762 }
763 None => self
764 .token_stream
765 .find(|(tok, _)| tok.as_ref().is_ok_and(|tok| !tok.is_trivia())),
766 };
767
768 (tok1, tok2)
769 }
770
771 fn next_token(&mut self) -> NextToken {
772 let (cur, mut next) = self.take_two_tokens();
773
774 let (cur_tok, cur_span) = match cur {
775 Some((Ok(tok), span)) => (tok, span),
776 Some((Err(e), span)) => return Some((Err(e), span)),
777 None => return None,
778 };
779
780 if let Some((Ok(next_tok), next_span)) = &mut next {
781 if (matches!(cur_tok, Token::Ident(_)) || cur_tok.is_keyword())
782 && *next_tok == Token::SymLessThan
783 {
784 let source = &self.source[next_span.start..];
785 if recognize_template_list(source) {
786 *next_tok = Token::TemplateArgsStart;
787 let cur_depth = self.token_stream.extras.depth;
788 self.token_stream.extras.template_depths.push(cur_depth);
789 self.opened_templates += 1;
790 }
791 }
792 }
793
794 if self.recognizing_template && cur_tok == Token::TemplateArgsEnd {
796 self.opened_templates -= 1;
797 if self.opened_templates == 0 {
798 next = None; }
800 }
801
802 self.next_token = next;
803 Some((Ok(cur_tok), cur_span))
804 }
805}
806
807pub fn recognize_template_list(source: &str) -> bool {
819 let mut lexer = Lexer::new(source);
820 match lexer.next_token {
821 Some((Ok(ref mut t), _)) if *t == Token::SymLessThan => *t = Token::TemplateArgsStart,
822 _ => return false,
823 };
824 lexer.recognizing_template = true;
825 lexer.opened_templates = 1;
826 lexer.token_stream.extras.template_depths.push(0);
827 crate::parser::recognize_template_list(lexer).is_ok()
828}
829
830#[test]
831fn test_recognize_template() {
832 assert!(recognize_template_list("<i32,select(2,3,a>b)>"));
834 assert!(!recognize_template_list("<d]>"));
835 assert!(recognize_template_list("<B<<C>"));
836 assert!(recognize_template_list("<B<=C>"));
837 assert!(recognize_template_list("<(B>=C)>"));
838 assert!(recognize_template_list("<(B!=C)>"));
839 assert!(recognize_template_list("<(B==C)>"));
840 assert!(recognize_template_list("<X>"));
842 assert!(recognize_template_list("<X<Y>>"));
843 assert!(recognize_template_list("<X<Y<Z>>>"));
844 assert!(!recognize_template_list(""));
845 assert!(!recognize_template_list(""));
846 assert!(!recognize_template_list("<>"));
847 assert!(!recognize_template_list("<b || c>d"));
848}
849
850pub trait TokenIterator: IntoIterator<Item = Spanned<Token, usize, ParseError>> {}
851
852impl Iterator for Lexer<'_> {
853 type Item = Spanned<Token, usize, ParseError>;
854
855 fn next(&mut self) -> Option<Self::Item> {
856 let tok = self.next_token();
857 tok.map(|(tok, span)| match tok {
858 Ok(tok) => Ok((span.start, tok, span.end)),
859 Err(err) => Err((span.start, err, span.end)),
860 })
861 }
862}
863
864impl TokenIterator for Lexer<'_> {}