Skip to main content

cuda_rust_wasm/parser/
lexer.rs

1//! CUDA lexer using logos for tokenization
2
3use logos::Logos;
4
5/// Token types for CUDA source code
6#[derive(Logos, Debug, PartialEq, Clone)]
7#[logos(skip r"[ \t\r\n\f]+")]
8pub enum Token {
9    // ── Keywords ──────────────────────────────────────────────
10    #[token("__global__")]
11    Global,
12    #[token("__device__")]
13    Device,
14    #[token("__host__")]
15    Host,
16    #[token("__shared__")]
17    Shared,
18    #[token("__constant__")]
19    Constant,
20    #[token("extern")]
21    Extern,
22    #[token("void")]
23    Void,
24    #[token("int")]
25    Int,
26    #[token("unsigned")]
27    Unsigned,
28    #[token("float")]
29    Float,
30    #[token("double")]
31    Double,
32    #[token("char")]
33    Char,
34    #[token("short")]
35    Short,
36    #[token("long")]
37    Long,
38    #[token("bool")]
39    Bool,
40    #[token("const")]
41    Const,
42    #[token("volatile")]
43    Volatile,
44    #[token("__restrict__")]
45    Restrict,
46    #[token("restrict")]
47    RestrictC,
48    #[token("if")]
49    If,
50    #[token("else")]
51    Else,
52    #[token("for")]
53    For,
54    #[token("while")]
55    While,
56    #[token("do")]
57    Do,
58    #[token("return")]
59    Return,
60    #[token("break")]
61    Break,
62    #[token("continue")]
63    Continue,
64    #[token("struct")]
65    Struct,
66    #[token("typedef")]
67    Typedef,
68    #[token("sizeof")]
69    Sizeof,
70    #[token("register")]
71    Register,
72    #[token("static")]
73    Static,
74    #[token("inline")]
75    Inline,
76    #[token("__inline__")]
77    InlineAlt,
78    #[token("__forceinline__")]
79    ForceInline,
80
81    // ── CUDA builtins ────────────────────────────────────────
82    #[token("threadIdx")]
83    ThreadIdx,
84    #[token("blockIdx")]
85    BlockIdx,
86    #[token("blockDim")]
87    BlockDim,
88    #[token("gridDim")]
89    GridDim,
90    #[token("__syncthreads")]
91    SyncThreads,
92
93    // ── Vector types ─────────────────────────────────────────
94    #[token("float2")]
95    Float2,
96    #[token("float3")]
97    Float3,
98    #[token("float4")]
99    Float4,
100    #[token("int2")]
101    Int2,
102    #[token("int3")]
103    Int3,
104    #[token("int4")]
105    Int4,
106    #[token("double2")]
107    Double2,
108    #[token("double3")]
109    Double3,
110    #[token("double4")]
111    Double4,
112    #[token("dim3")]
113    Dim3,
114
115    // ── Literals ─────────────────────────────────────────────
116    #[regex(r"0[xX][0-9a-fA-F]+[uUlL]*", |lex| lex.slice().to_string())]
117    HexLiteral(String),
118    #[regex(r"[0-9]+\.[0-9]*([eE][+-]?[0-9]+)?[fF]?", |lex| lex.slice().to_string())]
119    FloatLiteral(String),
120    #[regex(r"\.[0-9]+([eE][+-]?[0-9]+)?[fF]?", |lex| lex.slice().to_string())]
121    FloatLiteralDot(String),
122    #[regex(r"[0-9]+[eE][+-]?[0-9]+[fF]?", |lex| lex.slice().to_string())]
123    FloatLiteralExp(String),
124    #[regex(r"[0-9]+[fF]", |lex| lex.slice().to_string())]
125    FloatLiteralSuffix(String),
126    #[regex(r"[0-9]+[uUlL]*", priority = 2, callback = |lex| lex.slice().to_string())]
127    IntLiteral(String),
128    #[regex(r#""([^"\\]|\\.)*""#, |lex| lex.slice().to_string())]
129    StringLiteral(String),
130    #[regex(r"'([^'\\]|\\.)'", |lex| lex.slice().to_string())]
131    CharLiteral(String),
132
133    // ── Identifiers ──────────────────────────────────────────
134    #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", priority = 1, callback = |lex| lex.slice().to_string())]
135    Ident(String),
136
137    // ── Operators ────────────────────────────────────────────
138    #[token("+=")]
139    PlusAssign,
140    #[token("-=")]
141    MinusAssign,
142    #[token("*=")]
143    StarAssign,
144    #[token("/=")]
145    SlashAssign,
146    #[token("%=")]
147    PercentAssign,
148    #[token("&=")]
149    AmpAssign,
150    #[token("|=")]
151    PipeAssign,
152    #[token("^=")]
153    CaretAssign,
154    #[token("<<=")]
155    ShlAssign,
156    #[token(">>=")]
157    ShrAssign,
158    #[token("++")]
159    PlusPlus,
160    #[token("--")]
161    MinusMinus,
162    #[token("&&")]
163    AmpAmp,
164    #[token("||")]
165    PipePipe,
166    #[token("==")]
167    EqEq,
168    #[token("!=")]
169    BangEq,
170    #[token("<=")]
171    LtEq,
172    #[token(">=")]
173    GtEq,
174    #[token("<<")]
175    Shl,
176    #[token(">>")]
177    Shr,
178    #[token("->")]
179    Arrow,
180    #[token("+")]
181    Plus,
182    #[token("-")]
183    Minus,
184    #[token("*")]
185    Star,
186    #[token("/")]
187    Slash,
188    #[token("%")]
189    Percent,
190    #[token("&")]
191    Amp,
192    #[token("|")]
193    Pipe,
194    #[token("^")]
195    Caret,
196    #[token("~")]
197    Tilde,
198    #[token("!")]
199    Bang,
200    #[token("=")]
201    Eq,
202    #[token("<")]
203    Lt,
204    #[token(">")]
205    Gt,
206    #[token("?")]
207    Question,
208    #[token(":")]
209    Colon,
210
211    // ── Delimiters ───────────────────────────────────────────
212    #[token("(")]
213    LParen,
214    #[token(")")]
215    RParen,
216    #[token("{")]
217    LBrace,
218    #[token("}")]
219    RBrace,
220    #[token("[")]
221    LBracket,
222    #[token("]")]
223    RBracket,
224    #[token(";")]
225    Semi,
226    #[token(",")]
227    Comma,
228    #[token(".")]
229    Dot,
230
231    // ── Preprocessor & comments ──────────────────────────────
232    #[regex(r#"#include\s*[<"][^>"\n]+[>"]"#, |lex| lex.slice().to_string())]
233    Include(String),
234    #[regex(r"#define\s+[^\n]+", |lex| lex.slice().to_string())]
235    Define(String),
236    #[regex(r"#(pragma|ifdef|ifndef|endif|if|elif|else|undef|error|warning)[^\n]*", |lex| lex.slice().to_string())]
237    Preprocessor(String),
238    #[regex(r"//[^\n]*")]
239    LineComment,
240    #[regex(r"/\*([^*]|\*[^/])*\*/")]
241    BlockComment,
242}
243
244/// A token with its span information
245#[derive(Debug, Clone)]
246pub struct SpannedToken {
247    pub token: Token,
248    pub span: std::ops::Range<usize>,
249    pub text: String,
250}
251
252/// Tokenize CUDA source code, stripping comments and returning spanned tokens
253pub fn tokenize(source: &str) -> Vec<SpannedToken> {
254    let mut tokens = Vec::new();
255    let lex = Token::lexer(source);
256    for (result, span) in lex.spanned() {
257        match result {
258            Ok(tok) => {
259                // Skip comments
260                if matches!(tok, Token::LineComment | Token::BlockComment) {
261                    continue;
262                }
263                tokens.push(SpannedToken {
264                    token: tok,
265                    span: span.clone(),
266                    text: source[span].to_string(),
267                });
268            }
269            Err(_) => {
270                // Skip unrecognized bytes
271            }
272        }
273    }
274    tokens
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_basic_tokenize() {
283        let src = "__global__ void vectorAdd(float* a, int n) { }";
284        let tokens = tokenize(src);
285        assert!(tokens.iter().any(|t| matches!(&t.token, Token::Global)));
286        assert!(tokens.iter().any(|t| matches!(&t.token, Token::Void)));
287    }
288
289    #[test]
290    fn test_comments_stripped() {
291        let src = "int x; // comment\n/* block */ float y;";
292        let tokens = tokenize(src);
293        assert!(!tokens.iter().any(|t| matches!(&t.token, Token::LineComment)));
294        assert!(!tokens.iter().any(|t| matches!(&t.token, Token::BlockComment)));
295    }
296}