1use logos::Logos;
4
5#[derive(Logos, Debug, PartialEq, Clone)]
7#[logos(skip r"[ \t\r\n\f]+")]
8pub enum Token {
9 #[token("__global__")]
11 Global,
12 #[token("__device__")]
13 Device,
14 #[token("__host__")]
15 Host,
16 #[token("__shared__")]
17 Shared,
18 #[token("__constant__")]
19 Constant,
20 #[token("extern")]
21 Extern,
22 #[token("void")]
23 Void,
24 #[token("int")]
25 Int,
26 #[token("unsigned")]
27 Unsigned,
28 #[token("float")]
29 Float,
30 #[token("double")]
31 Double,
32 #[token("char")]
33 Char,
34 #[token("short")]
35 Short,
36 #[token("long")]
37 Long,
38 #[token("bool")]
39 Bool,
40 #[token("const")]
41 Const,
42 #[token("volatile")]
43 Volatile,
44 #[token("__restrict__")]
45 Restrict,
46 #[token("restrict")]
47 RestrictC,
48 #[token("if")]
49 If,
50 #[token("else")]
51 Else,
52 #[token("for")]
53 For,
54 #[token("while")]
55 While,
56 #[token("do")]
57 Do,
58 #[token("return")]
59 Return,
60 #[token("break")]
61 Break,
62 #[token("continue")]
63 Continue,
64 #[token("struct")]
65 Struct,
66 #[token("typedef")]
67 Typedef,
68 #[token("sizeof")]
69 Sizeof,
70 #[token("register")]
71 Register,
72 #[token("static")]
73 Static,
74 #[token("inline")]
75 Inline,
76 #[token("__inline__")]
77 InlineAlt,
78 #[token("__forceinline__")]
79 ForceInline,
80
81 #[token("threadIdx")]
83 ThreadIdx,
84 #[token("blockIdx")]
85 BlockIdx,
86 #[token("blockDim")]
87 BlockDim,
88 #[token("gridDim")]
89 GridDim,
90 #[token("__syncthreads")]
91 SyncThreads,
92
93 #[token("float2")]
95 Float2,
96 #[token("float3")]
97 Float3,
98 #[token("float4")]
99 Float4,
100 #[token("int2")]
101 Int2,
102 #[token("int3")]
103 Int3,
104 #[token("int4")]
105 Int4,
106 #[token("double2")]
107 Double2,
108 #[token("double3")]
109 Double3,
110 #[token("double4")]
111 Double4,
112 #[token("dim3")]
113 Dim3,
114
115 #[regex(r"0[xX][0-9a-fA-F]+[uUlL]*", |lex| lex.slice().to_string())]
117 HexLiteral(String),
118 #[regex(r"[0-9]+\.[0-9]*([eE][+-]?[0-9]+)?[fF]?", |lex| lex.slice().to_string())]
119 FloatLiteral(String),
120 #[regex(r"\.[0-9]+([eE][+-]?[0-9]+)?[fF]?", |lex| lex.slice().to_string())]
121 FloatLiteralDot(String),
122 #[regex(r"[0-9]+[eE][+-]?[0-9]+[fF]?", |lex| lex.slice().to_string())]
123 FloatLiteralExp(String),
124 #[regex(r"[0-9]+[fF]", |lex| lex.slice().to_string())]
125 FloatLiteralSuffix(String),
126 #[regex(r"[0-9]+[uUlL]*", priority = 2, callback = |lex| lex.slice().to_string())]
127 IntLiteral(String),
128 #[regex(r#""([^"\\]|\\.)*""#, |lex| lex.slice().to_string())]
129 StringLiteral(String),
130 #[regex(r"'([^'\\]|\\.)'", |lex| lex.slice().to_string())]
131 CharLiteral(String),
132
133 #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", priority = 1, callback = |lex| lex.slice().to_string())]
135 Ident(String),
136
137 #[token("+=")]
139 PlusAssign,
140 #[token("-=")]
141 MinusAssign,
142 #[token("*=")]
143 StarAssign,
144 #[token("/=")]
145 SlashAssign,
146 #[token("%=")]
147 PercentAssign,
148 #[token("&=")]
149 AmpAssign,
150 #[token("|=")]
151 PipeAssign,
152 #[token("^=")]
153 CaretAssign,
154 #[token("<<=")]
155 ShlAssign,
156 #[token(">>=")]
157 ShrAssign,
158 #[token("++")]
159 PlusPlus,
160 #[token("--")]
161 MinusMinus,
162 #[token("&&")]
163 AmpAmp,
164 #[token("||")]
165 PipePipe,
166 #[token("==")]
167 EqEq,
168 #[token("!=")]
169 BangEq,
170 #[token("<=")]
171 LtEq,
172 #[token(">=")]
173 GtEq,
174 #[token("<<")]
175 Shl,
176 #[token(">>")]
177 Shr,
178 #[token("->")]
179 Arrow,
180 #[token("+")]
181 Plus,
182 #[token("-")]
183 Minus,
184 #[token("*")]
185 Star,
186 #[token("/")]
187 Slash,
188 #[token("%")]
189 Percent,
190 #[token("&")]
191 Amp,
192 #[token("|")]
193 Pipe,
194 #[token("^")]
195 Caret,
196 #[token("~")]
197 Tilde,
198 #[token("!")]
199 Bang,
200 #[token("=")]
201 Eq,
202 #[token("<")]
203 Lt,
204 #[token(">")]
205 Gt,
206 #[token("?")]
207 Question,
208 #[token(":")]
209 Colon,
210
211 #[token("(")]
213 LParen,
214 #[token(")")]
215 RParen,
216 #[token("{")]
217 LBrace,
218 #[token("}")]
219 RBrace,
220 #[token("[")]
221 LBracket,
222 #[token("]")]
223 RBracket,
224 #[token(";")]
225 Semi,
226 #[token(",")]
227 Comma,
228 #[token(".")]
229 Dot,
230
231 #[regex(r#"#include\s*[<"][^>"\n]+[>"]"#, |lex| lex.slice().to_string())]
233 Include(String),
234 #[regex(r"#define\s+[^\n]+", |lex| lex.slice().to_string())]
235 Define(String),
236 #[regex(r"#(pragma|ifdef|ifndef|endif|if|elif|else|undef|error|warning)[^\n]*", |lex| lex.slice().to_string())]
237 Preprocessor(String),
238 #[regex(r"//[^\n]*")]
239 LineComment,
240 #[regex(r"/\*([^*]|\*[^/])*\*/")]
241 BlockComment,
242}
243
244#[derive(Debug, Clone)]
246pub struct SpannedToken {
247 pub token: Token,
248 pub span: std::ops::Range<usize>,
249 pub text: String,
250}
251
252pub fn tokenize(source: &str) -> Vec<SpannedToken> {
254 let mut tokens = Vec::new();
255 let lex = Token::lexer(source);
256 for (result, span) in lex.spanned() {
257 match result {
258 Ok(tok) => {
259 if matches!(tok, Token::LineComment | Token::BlockComment) {
261 continue;
262 }
263 tokens.push(SpannedToken {
264 token: tok,
265 span: span.clone(),
266 text: source[span].to_string(),
267 });
268 }
269 Err(_) => {
270 }
272 }
273 }
274 tokens
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_basic_tokenize() {
283 let src = "__global__ void vectorAdd(float* a, int n) { }";
284 let tokens = tokenize(src);
285 assert!(tokens.iter().any(|t| matches!(&t.token, Token::Global)));
286 assert!(tokens.iter().any(|t| matches!(&t.token, Token::Void)));
287 }
288
289 #[test]
290 fn test_comments_stripped() {
291 let src = "int x; // comment\n/* block */ float y;";
292 let tokens = tokenize(src);
293 assert!(!tokens.iter().any(|t| matches!(&t.token, Token::LineComment)));
294 assert!(!tokens.iter().any(|t| matches!(&t.token, Token::BlockComment)));
295 }
296}