wgsl_parse/
lexer.rs

1//! Prefer using [`crate::parse_str`]. You shouldn't need to manipulate the lexer.
2
3use crate::error::ParseError;
4use itertools::Itertools;
5use logos::{Logos, SpannedIter};
6use std::{fmt::Display, num::NonZeroU8, sync::LazyLock};
7use wgsl_types::idents::RESERVED_WORDS;
8
9type Span = std::ops::Range<usize>;
10
11fn maybe_template_end(
12    lex: &mut logos::Lexer<Token>,
13    current: Token,
14    lookahead: Option<Token>,
15) -> Token {
16    if let Some(depth) = lex.extras.template_depths.last() {
17        // if found a ">" on the same nesting level as the opening "<", it is a template end.
18        if lex.extras.depth == *depth {
19            lex.extras.template_depths.pop();
20            // if lookahead is GreaterThan, we may have a second closing template.
21            // note that >>= can never be (TemplateEnd, TemplateEnd, Equal).
22            if let Some(depth) = lex.extras.template_depths.last() {
23                if lex.extras.depth == *depth && lookahead == Some(Token::SymGreaterThan) {
24                    lex.extras.template_depths.pop();
25                    lex.extras.lookahead = Some(Token::TemplateArgsEnd);
26                } else {
27                    lex.extras.lookahead = lookahead;
28                }
29            } else {
30                lex.extras.lookahead = lookahead;
31            }
32            return Token::TemplateArgsEnd;
33        }
34    }
35
36    current
37}
38
39// operators && and || have lower precedence than < and >.
40// therefore, this is not a template: a < b || c > d
41fn maybe_fail_template(lex: &mut logos::Lexer<Token>) -> bool {
42    if let Some(depth) = lex.extras.template_depths.last() {
43        if lex.extras.depth == *depth {
44            return false;
45        }
46    }
47    true
48}
49
50fn incr_depth(lex: &mut logos::Lexer<Token>) {
51    lex.extras.depth += 1;
52}
53
54fn decr_depth(lex: &mut logos::Lexer<Token>) {
55    lex.extras.depth -= 1;
56}
57
58// TODO: get rid of crate `lexical`
59
60// don't have to be super strict, the lexer regex already did the heavy lifting
61const DEC_FORMAT: u128 = lexical::NumberFormatBuilder::new().build();
62
63// don't have to be super strict, the lexer regex already did the heavy lifting
64const HEX_FORMAT: u128 = lexical::NumberFormatBuilder::new()
65    .mantissa_radix(16)
66    .base_prefix(NonZeroU8::new(b'x'))
67    .exponent_base(NonZeroU8::new(16))
68    .exponent_radix(NonZeroU8::new(10))
69    .build();
70
71static FLOAT_HEX_OPTIONS: LazyLock<lexical::parse_float_options::Options> = LazyLock::new(|| {
72    lexical::parse_float_options::OptionsBuilder::new()
73        .exponent(b'p')
74        .decimal_point(b'.')
75        .build()
76        .unwrap()
77});
78
79fn parse_dec_abstract_int(lex: &mut logos::Lexer<Token>) -> Option<i64> {
80    let options = &lexical::parse_integer_options::STANDARD;
81    let str = lex.slice();
82    lexical::parse_with_options::<i64, _, DEC_FORMAT>(str, options).ok()
83}
84
85fn parse_hex_abstract_int(lex: &mut logos::Lexer<Token>) -> Option<i64> {
86    let options = &lexical::parse_integer_options::STANDARD;
87    let str = lex.slice();
88    lexical::parse_with_options::<i64, _, HEX_FORMAT>(str, options).ok()
89}
90
91fn parse_dec_i32(lex: &mut logos::Lexer<Token>) -> Option<i32> {
92    let options = &lexical::parse_integer_options::STANDARD;
93    let str = lex.slice();
94    let str = &str[..str.len() - 1];
95    lexical::parse_with_options::<i32, _, DEC_FORMAT>(str, options).ok()
96}
97
98fn parse_hex_i32(lex: &mut logos::Lexer<Token>) -> Option<i32> {
99    let options = &lexical::parse_integer_options::STANDARD;
100    let str = lex.slice();
101    let str = &str[..str.len() - 1];
102    lexical::parse_with_options::<i32, _, HEX_FORMAT>(str, options).ok()
103}
104
105fn parse_dec_u32(lex: &mut logos::Lexer<Token>) -> Option<u32> {
106    let options = &lexical::parse_integer_options::STANDARD;
107    let str = lex.slice();
108    let str = &str[..str.len() - 1];
109    lexical::parse_with_options::<u32, _, DEC_FORMAT>(str, options).ok()
110}
111
112fn parse_hex_u32(lex: &mut logos::Lexer<Token>) -> Option<u32> {
113    let options = &lexical::parse_integer_options::STANDARD;
114    let str = lex.slice();
115    let str = &str[..str.len() - 1];
116    lexical::parse_with_options::<u32, _, HEX_FORMAT>(str, options).ok()
117}
118
119fn parse_dec_abs_float(lex: &mut logos::Lexer<Token>) -> Option<f64> {
120    let options = &lexical::parse_float_options::STANDARD;
121    let str = lex.slice();
122    lexical::parse_with_options::<f64, _, DEC_FORMAT>(str, options).ok()
123}
124
125fn parse_hex_abs_float(lex: &mut logos::Lexer<Token>) -> Option<f64> {
126    let str = lex.slice();
127    lexical::parse_with_options::<f64, _, HEX_FORMAT>(str, &FLOAT_HEX_OPTIONS).ok()
128}
129
130fn parse_dec_f32(lex: &mut logos::Lexer<Token>) -> Option<f32> {
131    let options = &lexical::parse_float_options::STANDARD;
132    let str = lex.slice();
133    let str = &str[..str.len() - 1];
134    lexical::parse_with_options::<f32, _, DEC_FORMAT>(str, options).ok()
135}
136
137fn parse_hex_f32(lex: &mut logos::Lexer<Token>) -> Option<f32> {
138    let str = lex.slice();
139    // TODO
140    let options = &lexical::parse_float_options::STANDARD;
141    let str = &str[..str.len() - 1];
142    lexical::parse_with_options::<f32, _, HEX_FORMAT>(str, options).ok()
143}
144
145fn parse_dec_f16(lex: &mut logos::Lexer<Token>) -> Option<f32> {
146    let options = &lexical::parse_float_options::STANDARD;
147    let str = lex.slice();
148    let str = &str[..str.len() - 1];
149    lexical::parse_with_options::<f32, _, DEC_FORMAT>(str, options).ok()
150}
151
152fn parse_hex_f16(lex: &mut logos::Lexer<Token>) -> Option<f32> {
153    let str = lex.slice();
154    let str = &str[..str.len() - 1];
155    lexical::parse_with_options::<f32, _, HEX_FORMAT>(str, &FLOAT_HEX_OPTIONS).ok()
156}
157
158#[cfg(feature = "naga-ext")]
159fn parse_dec_i64(lex: &mut logos::Lexer<Token>) -> Option<i64> {
160    let options = &lexical::parse_integer_options::STANDARD;
161    let str = lex.slice();
162    let str = &str[..str.len() - 2];
163    lexical::parse_with_options::<i64, _, DEC_FORMAT>(str, options).ok()
164}
165
166#[cfg(feature = "naga-ext")]
167fn parse_hex_i64(lex: &mut logos::Lexer<Token>) -> Option<i64> {
168    let options = &lexical::parse_integer_options::STANDARD;
169    let str = lex.slice();
170    let str = &str[..str.len() - 2];
171    lexical::parse_with_options::<i64, _, HEX_FORMAT>(str, options).ok()
172}
173
174#[cfg(feature = "naga-ext")]
175fn parse_dec_u64(lex: &mut logos::Lexer<Token>) -> Option<u64> {
176    let options = &lexical::parse_integer_options::STANDARD;
177    let str = lex.slice();
178    let str = &str[..str.len() - 2];
179    lexical::parse_with_options::<u64, _, DEC_FORMAT>(str, options).ok()
180}
181
182#[cfg(feature = "naga-ext")]
183fn parse_hex_u64(lex: &mut logos::Lexer<Token>) -> Option<u64> {
184    let options = &lexical::parse_integer_options::STANDARD;
185    let str = lex.slice();
186    let str = &str[..str.len() - 2];
187    lexical::parse_with_options::<u64, _, HEX_FORMAT>(str, options).ok()
188}
189
190#[cfg(feature = "naga-ext")]
191fn parse_dec_f64(lex: &mut logos::Lexer<Token>) -> Option<f64> {
192    let options = &lexical::parse_float_options::STANDARD;
193    let str = lex.slice();
194    let str = &str[..str.len() - 2];
195    lexical::parse_with_options::<f64, _, DEC_FORMAT>(str, options).ok()
196}
197
198#[cfg(feature = "naga-ext")]
199fn parse_hex_f64(lex: &mut logos::Lexer<Token>) -> Option<f64> {
200    let str = lex.slice();
201    // TODO
202    let options = &lexical::parse_float_options::STANDARD;
203    let str = &str[..str.len() - 2];
204    lexical::parse_with_options::<f64, _, HEX_FORMAT>(str, options).ok()
205}
206
207fn parse_line_comment(lex: &mut logos::Lexer<Token>) {
208    let rem = lex.remainder();
209    // see blankspace and line breaks: https://www.w3.org/TR/WGSL/#blankspace-and-line-breaks
210    let line_end = rem
211        .char_indices()
212        .find(|(_, c)| "\n\u{000B}\u{000C}\r\u{0085}\u{2028}\u{2029}".contains(*c))
213        .map(|(i, _)| i)
214        .unwrap_or(rem.len());
215    lex.bump(line_end);
216}
217
218fn parse_block_comment(lex: &mut logos::Lexer<Token>) {
219    let mut depth = 1;
220    while depth > 0 {
221        let rem = lex.remainder();
222        if rem.is_empty() {
223            break;
224        } else if rem.starts_with("/*") {
225            lex.bump(2);
226            depth += 1;
227        } else if rem.starts_with("*/") {
228            lex.bump(2);
229            depth -= 1;
230        } else {
231            let mut next_char = 1;
232            while !rem.is_char_boundary(next_char) {
233                next_char += 1;
234            }
235            lex.bump(next_char);
236        }
237    }
238}
239
240fn parse_ident(lex: &mut logos::Lexer<Token>) -> Token {
241    let ident = lex.slice().to_string();
242    if RESERVED_WORDS.iter().contains(&ident.as_str()) {
243        Token::ReservedWord(ident)
244    } else {
245        Token::Ident(ident)
246    }
247}
248
249#[derive(Default, Clone, Debug, PartialEq)]
250pub struct LexerState {
251    depth: i32,
252    template_depths: Vec<i32>,
253    lookahead: Option<Token>,
254}
255
256// following the spec at this date: https://www.w3.org/TR/2024/WD-WGSL-20240731/
257#[derive(Logos, Clone, Debug, PartialEq)]
258#[logos(
259    // see blankspace and line breaks: https://www.w3.org/TR/WGSL/#blankspace-and-line-breaks
260    skip r"[\s\u0085\u200e\u200f\u2028\u2029]+", // blankspace
261    extras = LexerState,
262    error = ParseError)]
263pub enum Token {
264    #[token("//", parse_line_comment)]
265    LineComment,
266    #[token("/*", parse_block_comment, priority = 2)]
267    BlockComment,
268    // the parse_ident function can return either Token::Ident or Token::ReservedWord.
269    // Token::Ignored variant is never produced.
270    // It serves as a placeholder for running parse_ident.
271    #[regex(
272        r#"([_\p{XID_Start}][\p{XID_Continue}]+)|([\p{XID_Start}])"#,
273        parse_ident,
274        priority = 1
275    )]
276    Ignored,
277    // syntactic tokens
278    // https://www.w3.org/TR/WGSL/#syntactic-tokens
279    #[token("&")]
280    SymAnd,
281    #[token("&&", maybe_fail_template)]
282    SymAndAnd,
283    #[token("->")]
284    SymArrow,
285    #[token("@")]
286    SymAttr,
287    #[token("/")]
288    SymForwardSlash,
289    #[token("!")]
290    SymBang,
291    #[token("[", incr_depth)]
292    SymBracketLeft,
293    #[token("]", decr_depth)]
294    SymBracketRight,
295    #[token("{")]
296    SymBraceLeft,
297    #[token("}")]
298    SymBraceRight,
299    #[token(":")]
300    SymColon,
301    #[token(",")]
302    SymComma,
303    #[token("=")]
304    SymEqual,
305    #[token("==")]
306    SymEqualEqual,
307    #[token("!=")]
308    SymNotEqual,
309    #[token(">", |lex| maybe_template_end(lex, Token::SymGreaterThan, None))]
310    SymGreaterThan,
311    #[token(">=", |lex| maybe_template_end(lex, Token::SymGreaterThanEqual, Some(Token::SymEqual)))]
312    SymGreaterThanEqual,
313    #[token(">>", |lex| maybe_template_end(lex, Token::SymShiftRight, Some(Token::SymGreaterThan)))]
314    SymShiftRight,
315    #[token("<")]
316    SymLessThan,
317    #[token("<=")]
318    SymLessThanEqual,
319    #[token("<<")]
320    SymShiftLeft,
321    #[token("%")]
322    SymModulo,
323    #[token("-")]
324    SymMinus,
325    #[token("--")]
326    SymMinusMinus,
327    #[token(".")]
328    SymPeriod,
329    #[token("+")]
330    SymPlus,
331    #[token("++")]
332    SymPlusPlus,
333    #[token("|")]
334    SymOr,
335    #[token("||", maybe_fail_template)]
336    SymOrOr,
337    #[token("(", incr_depth)]
338    SymParenLeft,
339    #[token(")", decr_depth)]
340    SymParenRight,
341    #[token(";")]
342    SymSemicolon,
343    #[token("*")]
344    SymStar,
345    #[token("~")]
346    SymTilde,
347    #[token("_")]
348    SymUnderscore,
349    #[token("^")]
350    SymXor,
351    #[token("+=")]
352    SymPlusEqual,
353    #[token("-=")]
354    SymMinusEqual,
355    #[token("*=")]
356    SymTimesEqual,
357    #[token("/=")]
358    SymDivisionEqual,
359    #[token("%=")]
360    SymModuloEqual,
361    #[token("&=")]
362    SymAndEqual,
363    #[token("|=")]
364    SymOrEqual,
365    #[token("^=")]
366    SymXorEqual,
367    #[token(">>=", |lex| maybe_template_end(lex, Token::SymShiftRightAssign, Some(Token::SymGreaterThanEqual)))]
368    SymShiftRightAssign,
369    #[token("<<=")]
370    SymShiftLeftAssign,
371
372    // keywords
373    // https://www.w3.org/TR/WGSL/#keyword-summary
374    #[token("alias")]
375    KwAlias,
376    #[token("break")]
377    KwBreak,
378    #[token("case")]
379    KwCase,
380    #[token("const", priority = 2)]
381    KwConst,
382    #[token("const_assert")]
383    KwConstAssert,
384    #[token("continue")]
385    KwContinue,
386    #[token("continuing")]
387    KwContinuing,
388    #[token("default")]
389    KwDefault,
390    #[token("diagnostic")]
391    KwDiagnostic,
392    #[token("discard")]
393    KwDiscard,
394    #[token("else")]
395    KwElse,
396    #[token("enable")]
397    KwEnable,
398    #[token("false")]
399    KwFalse,
400    #[token("fn")]
401    KwFn,
402    #[token("for")]
403    KwFor,
404    #[token("if")]
405    KwIf,
406    #[token("let")]
407    KwLet,
408    #[token("loop")]
409    KwLoop,
410    #[token("override")]
411    KwOverride,
412    #[token("requires")]
413    KwRequires,
414    #[token("return")]
415    KwReturn,
416    #[token("struct")]
417    KwStruct,
418    #[token("switch")]
419    KwSwitch,
420    #[token("true")]
421    KwTrue,
422    #[token("var")]
423    KwVar,
424    #[token("while")]
425    KwWhile,
426
427    // Idents and ReservedWord tokens are parsed on the Ignored variant, because of a current
428    // limitation of logos. See logos#295.
429    Ident(String),
430    // variant produced by parse_ident for reserved words.
431    // Reserved words can be used in context-dependent words, e.g. attribute names and module names.
432    ReservedWord(String),
433
434    #[regex(r#"0|[1-9]\d*"#, parse_dec_abstract_int)]
435    #[regex(r#"0[xX][\da-fA-F]+"#, parse_hex_abstract_int)]
436    AbstractInt(i64),
437    #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?"#, parse_dec_abs_float)]
438    #[regex(r#"\d+[eE][+-]?\d+"#, parse_dec_abs_float)]
439    #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*([pP][+-]?\d+)?"#, parse_hex_abs_float)]
440    #[regex(r#"0[xX]\.[\da-fA-F]+([pP][+-]?\d+)?"#, parse_hex_abs_float)]
441    #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+"#, parse_hex_abs_float)]
442    // hex
443    AbstractFloat(f64),
444    #[regex(r#"(0|[1-9]\d*)i"#, parse_dec_i32)]
445    #[regex(r#"0[xX][\da-fA-F]+i"#, parse_hex_i32)]
446    // hex
447    I32(i32),
448    #[regex(r#"(0|[1-9]\d*)u"#, parse_dec_u32)]
449    #[regex(r#"0[xX][\da-fA-F]+u"#, parse_hex_u32)]
450    // hex
451    U32(u32),
452    #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?f"#, parse_dec_f32)]
453    #[regex(r#"\d+([eE][+-]?\d+)?f"#, parse_dec_f32)]
454    #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+f"#, parse_hex_f32)]
455    #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+f"#, parse_hex_f32)]
456    #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+f"#, parse_hex_f32)]
457    F32(f32),
458    #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?h"#, parse_dec_f16)]
459    #[regex(r#"\d+([eE][+-]?\d+)?h"#, parse_dec_f16)]
460    #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+h"#, parse_hex_f16)]
461    #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+h"#, parse_hex_f16)]
462    #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+h"#, parse_hex_f16)]
463    F16(f32),
464    #[cfg(feature = "naga-ext")]
465    #[regex(r#"(0|[1-9]\d*)li"#, parse_dec_i64)]
466    #[regex(r#"0[xX][\da-fA-F]+li"#, parse_hex_i64)]
467    // hex
468    I64(i64),
469    #[cfg(feature = "naga-ext")]
470    #[regex(r#"(0|[1-9]\d*)lu"#, parse_dec_u64)]
471    #[regex(r#"0[xX][\da-fA-F]+lu"#, parse_hex_u64)]
472    // hex
473    U64(u64),
474    #[cfg(feature = "naga-ext")]
475    #[regex(r#"(\d+\.\d*|\.\d+)([eE][+-]?\d+)?lf"#, parse_dec_f64)]
476    #[regex(r#"\d+([eE][+-]?\d+)?lf"#, parse_dec_f64)]
477    #[regex(r#"0[xX][\da-fA-F]+\.[\da-fA-F]*[pP][+-]?\d+lf"#, parse_hex_f64)]
478    #[regex(r#"0[xX]\.[\da-fA-F]+[pP][+-]?\d+lf"#, parse_hex_f64)]
479    #[regex(r#"0[xX][\da-fA-F]+[pP][+-]?\d+lf"#, parse_hex_f64)]
480    F64(f64),
481    TemplateArgsStart,
482    TemplateArgsEnd,
483
484    // extension: wesl-imports
485    // https://github.com/wgsl-tooling-wg/wesl-spec/blob/imports-update/Imports.md
486    // date: 2025-01-18, hash: 2db8e7f681087db6bdcd4a254963deb5c0159775
487    #[cfg(feature = "imports")]
488    #[token("::")]
489    SymColonColon,
490    #[cfg(feature = "imports")]
491    #[token("self")]
492    KwSelf,
493    #[cfg(feature = "imports")]
494    #[token("super")]
495    KwSuper,
496    #[cfg(feature = "imports")]
497    #[token("package")]
498    KwPackage,
499    #[cfg(feature = "imports")]
500    #[token("as")]
501    KwAs,
502    #[cfg(feature = "imports")]
503    #[token("import")]
504    KwImport,
505}
506
507impl Token {
508    pub fn is_trivia(&self) -> bool {
509        matches!(
510            self,
511            Token::LineComment | Token::BlockComment | Token::Ignored
512        )
513    }
514
515    #[allow(unused)]
516    pub fn is_symbol(&self) -> bool {
517        matches!(
518            self,
519            Token::SymAnd
520                | Token::SymAndAnd
521                | Token::SymArrow
522                | Token::SymAttr
523                | Token::SymForwardSlash
524                | Token::SymBang
525                | Token::SymBracketLeft
526                | Token::SymBracketRight
527                | Token::SymBraceLeft
528                | Token::SymBraceRight
529                | Token::SymColon
530                | Token::SymComma
531                | Token::SymEqual
532                | Token::SymEqualEqual
533                | Token::SymNotEqual
534                | Token::SymGreaterThan
535                | Token::SymGreaterThanEqual
536                | Token::SymShiftRight
537                | Token::SymLessThan
538                | Token::SymLessThanEqual
539                | Token::SymShiftLeft
540                | Token::SymModulo
541                | Token::SymMinus
542                | Token::SymMinusMinus
543                | Token::SymPeriod
544                | Token::SymPlus
545                | Token::SymPlusPlus
546                | Token::SymOr
547                | Token::SymOrOr
548                | Token::SymParenLeft
549                | Token::SymParenRight
550                | Token::SymSemicolon
551                | Token::SymStar
552                | Token::SymTilde
553                | Token::SymUnderscore
554                | Token::SymXor
555                | Token::SymPlusEqual
556                | Token::SymMinusEqual
557                | Token::SymTimesEqual
558                | Token::SymDivisionEqual
559                | Token::SymModuloEqual
560                | Token::SymAndEqual
561                | Token::SymOrEqual
562                | Token::SymXorEqual
563                | Token::SymShiftRightAssign
564                | Token::SymShiftLeftAssign
565        )
566    }
567
568    #[allow(unused)]
569    pub fn is_keyword(&self) -> bool {
570        matches!(
571            self,
572            Token::KwAlias
573                | Token::KwBreak
574                | Token::KwCase
575                | Token::KwConst
576                | Token::KwConstAssert
577                | Token::KwContinue
578                | Token::KwContinuing
579                | Token::KwDefault
580                | Token::KwDiagnostic
581                | Token::KwDiscard
582                | Token::KwElse
583                | Token::KwEnable
584                | Token::KwFalse
585                | Token::KwFn
586                | Token::KwFor
587                | Token::KwIf
588                | Token::KwLet
589                | Token::KwLoop
590                | Token::KwOverride
591                | Token::KwRequires
592                | Token::KwReturn
593                | Token::KwStruct
594                | Token::KwSwitch
595                | Token::KwTrue
596                | Token::KwVar
597                | Token::KwWhile
598        )
599    }
600
601    #[allow(unused)]
602    pub fn is_numeric_literal(&self) -> bool {
603        matches!(
604            self,
605            Token::AbstractInt(_)
606                | Token::AbstractFloat(_)
607                | Token::I32(_)
608                | Token::U32(_)
609                | Token::F32(_)
610                | Token::F16(_)
611        )
612    }
613}
614
615impl Display for Token {
616    /// This display implementation is used for error messages.
617    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
618        match self {
619            Token::LineComment => f.write_str("// line comment"),
620            Token::BlockComment => f.write_str("/* block comment */"),
621            Token::Ignored => unreachable!(),
622            Token::SymAnd => f.write_str("&"),
623            Token::SymAndAnd => f.write_str("&&"),
624            Token::SymArrow => f.write_str("->"),
625            Token::SymAttr => f.write_str("@"),
626            Token::SymForwardSlash => f.write_str("/"),
627            Token::SymBang => f.write_str("!"),
628            Token::SymBracketLeft => f.write_str("["),
629            Token::SymBracketRight => f.write_str("]"),
630            Token::SymBraceLeft => f.write_str("{"),
631            Token::SymBraceRight => f.write_str("}"),
632            Token::SymColon => f.write_str(":"),
633            Token::SymComma => f.write_str(","),
634            Token::SymEqual => f.write_str("="),
635            Token::SymEqualEqual => f.write_str("=="),
636            Token::SymNotEqual => f.write_str("!="),
637            Token::SymGreaterThan => f.write_str(">"),
638            Token::SymGreaterThanEqual => f.write_str(">="),
639            Token::SymShiftRight => f.write_str(">>"),
640            Token::SymLessThan => f.write_str("<"),
641            Token::SymLessThanEqual => f.write_str("<="),
642            Token::SymShiftLeft => f.write_str("<<"),
643            Token::SymModulo => f.write_str("%"),
644            Token::SymMinus => f.write_str("-"),
645            Token::SymMinusMinus => f.write_str("--"),
646            Token::SymPeriod => f.write_str("."),
647            Token::SymPlus => f.write_str("+"),
648            Token::SymPlusPlus => f.write_str("++"),
649            Token::SymOr => f.write_str("|"),
650            Token::SymOrOr => f.write_str("||"),
651            Token::SymParenLeft => f.write_str("("),
652            Token::SymParenRight => f.write_str(")"),
653            Token::SymSemicolon => f.write_str(";"),
654            Token::SymStar => f.write_str("*"),
655            Token::SymTilde => f.write_str("~"),
656            Token::SymUnderscore => f.write_str("_"),
657            Token::SymXor => f.write_str("^"),
658            Token::SymPlusEqual => f.write_str("+="),
659            Token::SymMinusEqual => f.write_str("-="),
660            Token::SymTimesEqual => f.write_str("*="),
661            Token::SymDivisionEqual => f.write_str("/="),
662            Token::SymModuloEqual => f.write_str("%="),
663            Token::SymAndEqual => f.write_str("&="),
664            Token::SymOrEqual => f.write_str("|="),
665            Token::SymXorEqual => f.write_str("^="),
666            Token::SymShiftRightAssign => f.write_str(">>="),
667            Token::SymShiftLeftAssign => f.write_str("<<="),
668            Token::KwAlias => f.write_str("alias"),
669            Token::KwBreak => f.write_str("break"),
670            Token::KwCase => f.write_str("case"),
671            Token::KwConst => f.write_str("const"),
672            Token::KwConstAssert => f.write_str("const_assert"),
673            Token::KwContinue => f.write_str("continue"),
674            Token::KwContinuing => f.write_str("continuing"),
675            Token::KwDefault => f.write_str("default"),
676            Token::KwDiagnostic => f.write_str("diagnostic"),
677            Token::KwDiscard => f.write_str("discard"),
678            Token::KwElse => f.write_str("else"),
679            Token::KwEnable => f.write_str("enable"),
680            Token::KwFalse => f.write_str("false"),
681            Token::KwFn => f.write_str("fn"),
682            Token::KwFor => f.write_str("for"),
683            Token::KwIf => f.write_str("if"),
684            Token::KwLet => f.write_str("let"),
685            Token::KwLoop => f.write_str("loop"),
686            Token::KwOverride => f.write_str("override"),
687            Token::KwRequires => f.write_str("requires"),
688            Token::KwReturn => f.write_str("return"),
689            Token::KwStruct => f.write_str("struct"),
690            Token::KwSwitch => f.write_str("switch"),
691            Token::KwTrue => f.write_str("true"),
692            Token::KwVar => f.write_str("var"),
693            Token::KwWhile => f.write_str("while"),
694            Token::Ident(s) => write!(f, "identifier `{s}`"),
695            Token::ReservedWord(s) => write!(f, "reserved word `{s}`"),
696            Token::AbstractInt(n) => write!(f, "{n}"),
697            Token::AbstractFloat(n) => write!(f, "{n}"),
698            Token::I32(n) => write!(f, "{n}i"),
699            Token::U32(n) => write!(f, "{n}u"),
700            Token::F32(n) => write!(f, "{n}f"),
701            Token::F16(n) => write!(f, "{n}h"),
702            #[cfg(feature = "naga-ext")]
703            Token::I64(n) => write!(f, "{n}li"),
704            #[cfg(feature = "naga-ext")]
705            Token::U64(n) => write!(f, "{n}lu"),
706            #[cfg(feature = "naga-ext")]
707            Token::F64(n) => write!(f, "{n}lf"),
708            Token::TemplateArgsStart => f.write_str("start of template"),
709            Token::TemplateArgsEnd => f.write_str("end of template"),
710            #[cfg(feature = "imports")]
711            Token::SymColonColon => write!(f, "::"),
712            #[cfg(feature = "imports")]
713            Token::KwSelf => write!(f, "self"),
714            #[cfg(feature = "imports")]
715            Token::KwSuper => write!(f, "super"),
716            #[cfg(feature = "imports")]
717            Token::KwPackage => write!(f, "package"),
718            #[cfg(feature = "imports")]
719            Token::KwAs => write!(f, "as"),
720            #[cfg(feature = "imports")]
721            Token::KwImport => write!(f, "import"),
722        }
723    }
724}
725
726type Spanned<Tok, Loc, ParseError> = Result<(Loc, Tok, Loc), (Loc, ParseError, Loc)>;
727type NextToken = Option<(Result<Token, ParseError>, Span)>;
728
729#[derive(Clone)]
730pub struct Lexer<'s> {
731    source: &'s str,
732    token_stream: SpannedIter<'s, Token>,
733    next_token: NextToken,
734    recognizing_template: bool,
735    opened_templates: u32,
736}
737
738impl<'s> Lexer<'s> {
739    pub fn new(source: &'s str) -> Self {
740        let mut token_stream = Token::lexer_with_extras(source, LexerState::default()).spanned();
741        let next_token =
742            token_stream.find(|(tok, _)| tok.as_ref().is_ok_and(|tok| !tok.is_trivia()));
743
744        Self {
745            source,
746            token_stream,
747            next_token,
748            recognizing_template: false,
749            opened_templates: 0,
750        }
751    }
752
753    fn take_two_tokens(&mut self) -> (NextToken, NextToken) {
754        let mut tok1 = self.next_token.take();
755
756        let lookahead = self.token_stream.extras.lookahead.take();
757        let tok2 = match lookahead {
758            Some(tok) => {
759                let (_, span1) = tok1.as_mut().unwrap(); // safety: lookahead implies lexer looked at a `<` token
760                let span2 = span1.start + 1..span1.end;
761                Some((Ok(tok), span2))
762            }
763            None => self
764                .token_stream
765                .find(|(tok, _)| tok.as_ref().is_ok_and(|tok| !tok.is_trivia())),
766        };
767
768        (tok1, tok2)
769    }
770
771    fn next_token(&mut self) -> NextToken {
772        let (cur, mut next) = self.take_two_tokens();
773
774        let (cur_tok, cur_span) = match cur {
775            Some((Ok(tok), span)) => (tok, span),
776            Some((Err(e), span)) => return Some((Err(e), span)),
777            None => return None,
778        };
779
780        if let Some((Ok(next_tok), next_span)) = &mut next {
781            if (matches!(cur_tok, Token::Ident(_)) || cur_tok.is_keyword())
782                && *next_tok == Token::SymLessThan
783            {
784                let source = &self.source[next_span.start..];
785                if recognize_template_list(source) {
786                    *next_tok = Token::TemplateArgsStart;
787                    let cur_depth = self.token_stream.extras.depth;
788                    self.token_stream.extras.template_depths.push(cur_depth);
789                    self.opened_templates += 1;
790                }
791            }
792        }
793
794        // if we finished recognition of a template
795        if self.recognizing_template && cur_tok == Token::TemplateArgsEnd {
796            self.opened_templates -= 1;
797            if self.opened_templates == 0 {
798                next = None; // push eof after end of template
799            }
800        }
801
802        self.next_token = next;
803        Some((Ok(cur_tok), cur_span))
804    }
805}
806
807/// Returns `true` if the source starts with a valid template list.
808///
809/// ## Specification
810///
811/// [3.9. Template Lists](https://www.w3.org/TR/WGSL/#template-lists-sec)
812///
813/// Contrary to the specification [template list discovery algorithm], this function also
814/// checks that the template is syntactically valid (syntax: [*template_list*]).
815///
816/// [template list discovery algorigthm]: https://www.w3.org/TR/WGSL/#template-list-discovery
817/// [*template_list*]: https://www.w3.org/TR/WGSL/#syntax-template_list
818pub fn recognize_template_list(source: &str) -> bool {
819    let mut lexer = Lexer::new(source);
820    match lexer.next_token {
821        Some((Ok(ref mut t), _)) if *t == Token::SymLessThan => *t = Token::TemplateArgsStart,
822        _ => return false,
823    };
824    lexer.recognizing_template = true;
825    lexer.opened_templates = 1;
826    lexer.token_stream.extras.template_depths.push(0);
827    crate::parser::recognize_template_list(lexer).is_ok()
828}
829
830#[test]
831fn test_recognize_template() {
832    // cases from the WGSL spec
833    assert!(recognize_template_list("<i32,select(2,3,a>b)>"));
834    assert!(!recognize_template_list("<d]>"));
835    assert!(recognize_template_list("<B<<C>"));
836    assert!(recognize_template_list("<B<=C>"));
837    assert!(recognize_template_list("<(B>=C)>"));
838    assert!(recognize_template_list("<(B!=C)>"));
839    assert!(recognize_template_list("<(B==C)>"));
840    // more cases
841    assert!(recognize_template_list("<X>"));
842    assert!(recognize_template_list("<X<Y>>"));
843    assert!(recognize_template_list("<X<Y<Z>>>"));
844    assert!(!recognize_template_list(""));
845    assert!(!recognize_template_list(""));
846    assert!(!recognize_template_list("<>"));
847    assert!(!recognize_template_list("<b || c>d"));
848}
849
850pub trait TokenIterator: IntoIterator<Item = Spanned<Token, usize, ParseError>> {}
851
852impl Iterator for Lexer<'_> {
853    type Item = Spanned<Token, usize, ParseError>;
854
855    fn next(&mut self) -> Option<Self::Item> {
856        let tok = self.next_token();
857        tok.map(|(tok, span)| match tok {
858            Ok(tok) => Ok((span.start, tok, span.end)),
859            Err(err) => Err((span.start, err, span.end)),
860        })
861    }
862}
863
864impl TokenIterator for Lexer<'_> {}