use logos::Logos;
#[derive(Logos, Debug, PartialEq, Clone)]
#[logos(skip r"[ \t\r\n\f]+")]
pub enum Token {
#[token("__global__")]
Global,
#[token("__device__")]
Device,
#[token("__host__")]
Host,
#[token("__shared__")]
Shared,
#[token("__constant__")]
Constant,
#[token("extern")]
Extern,
#[token("void")]
Void,
#[token("int")]
Int,
#[token("unsigned")]
Unsigned,
#[token("float")]
Float,
#[token("double")]
Double,
#[token("char")]
Char,
#[token("short")]
Short,
#[token("long")]
Long,
#[token("bool")]
Bool,
#[token("const")]
Const,
#[token("volatile")]
Volatile,
#[token("__restrict__")]
Restrict,
#[token("restrict")]
RestrictC,
#[token("if")]
If,
#[token("else")]
Else,
#[token("for")]
For,
#[token("while")]
While,
#[token("do")]
Do,
#[token("return")]
Return,
#[token("break")]
Break,
#[token("continue")]
Continue,
#[token("struct")]
Struct,
#[token("typedef")]
Typedef,
#[token("sizeof")]
Sizeof,
#[token("register")]
Register,
#[token("static")]
Static,
#[token("inline")]
Inline,
#[token("__inline__")]
InlineAlt,
#[token("__forceinline__")]
ForceInline,
#[token("threadIdx")]
ThreadIdx,
#[token("blockIdx")]
BlockIdx,
#[token("blockDim")]
BlockDim,
#[token("gridDim")]
GridDim,
#[token("__syncthreads")]
SyncThreads,
#[token("float2")]
Float2,
#[token("float3")]
Float3,
#[token("float4")]
Float4,
#[token("int2")]
Int2,
#[token("int3")]
Int3,
#[token("int4")]
Int4,
#[token("double2")]
Double2,
#[token("double3")]
Double3,
#[token("double4")]
Double4,
#[token("dim3")]
Dim3,
#[regex(r"0[xX][0-9a-fA-F]+[uUlL]*", |lex| lex.slice().to_string())]
HexLiteral(String),
#[regex(r"[0-9]+\.[0-9]*([eE][+-]?[0-9]+)?[fF]?", |lex| lex.slice().to_string())]
FloatLiteral(String),
#[regex(r"\.[0-9]+([eE][+-]?[0-9]+)?[fF]?", |lex| lex.slice().to_string())]
FloatLiteralDot(String),
#[regex(r"[0-9]+[eE][+-]?[0-9]+[fF]?", |lex| lex.slice().to_string())]
FloatLiteralExp(String),
#[regex(r"[0-9]+[fF]", |lex| lex.slice().to_string())]
FloatLiteralSuffix(String),
#[regex(r"[0-9]+[uUlL]*", priority = 2, callback = |lex| lex.slice().to_string())]
IntLiteral(String),
#[regex(r#""([^"\\]|\\.)*""#, |lex| lex.slice().to_string())]
StringLiteral(String),
#[regex(r"'([^'\\]|\\.)'", |lex| lex.slice().to_string())]
CharLiteral(String),
#[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", priority = 1, callback = |lex| lex.slice().to_string())]
Ident(String),
#[token("+=")]
PlusAssign,
#[token("-=")]
MinusAssign,
#[token("*=")]
StarAssign,
#[token("/=")]
SlashAssign,
#[token("%=")]
PercentAssign,
#[token("&=")]
AmpAssign,
#[token("|=")]
PipeAssign,
#[token("^=")]
CaretAssign,
#[token("<<=")]
ShlAssign,
#[token(">>=")]
ShrAssign,
#[token("++")]
PlusPlus,
#[token("--")]
MinusMinus,
#[token("&&")]
AmpAmp,
#[token("||")]
PipePipe,
#[token("==")]
EqEq,
#[token("!=")]
BangEq,
#[token("<=")]
LtEq,
#[token(">=")]
GtEq,
#[token("<<")]
Shl,
#[token(">>")]
Shr,
#[token("->")]
Arrow,
#[token("+")]
Plus,
#[token("-")]
Minus,
#[token("*")]
Star,
#[token("/")]
Slash,
#[token("%")]
Percent,
#[token("&")]
Amp,
#[token("|")]
Pipe,
#[token("^")]
Caret,
#[token("~")]
Tilde,
#[token("!")]
Bang,
#[token("=")]
Eq,
#[token("<")]
Lt,
#[token(">")]
Gt,
#[token("?")]
Question,
#[token(":")]
Colon,
#[token("(")]
LParen,
#[token(")")]
RParen,
#[token("{")]
LBrace,
#[token("}")]
RBrace,
#[token("[")]
LBracket,
#[token("]")]
RBracket,
#[token(";")]
Semi,
#[token(",")]
Comma,
#[token(".")]
Dot,
#[regex(r#"#include\s*[<"][^>"\n]+[>"]"#, |lex| lex.slice().to_string())]
Include(String),
#[regex(r"#define\s+[^\n]+", |lex| lex.slice().to_string())]
Define(String),
#[regex(r"#(pragma|ifdef|ifndef|endif|if|elif|else|undef|error|warning)[^\n]*", |lex| lex.slice().to_string())]
Preprocessor(String),
#[regex(r"//[^\n]*")]
LineComment,
#[regex(r"/\*([^*]|\*[^/])*\*/")]
BlockComment,
}
#[derive(Debug, Clone)]
pub struct SpannedToken {
pub token: Token,
pub span: std::ops::Range<usize>,
pub text: String,
}
pub fn tokenize(source: &str) -> Vec<SpannedToken> {
let mut tokens = Vec::new();
let lex = Token::lexer(source);
for (result, span) in lex.spanned() {
match result {
Ok(tok) => {
if matches!(tok, Token::LineComment | Token::BlockComment) {
continue;
}
tokens.push(SpannedToken {
token: tok,
span: span.clone(),
text: source[span].to_string(),
});
}
Err(_) => {
}
}
}
tokens
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_tokenize() {
let src = "__global__ void vectorAdd(float* a, int n) { }";
let tokens = tokenize(src);
assert!(tokens.iter().any(|t| matches!(&t.token, Token::Global)));
assert!(tokens.iter().any(|t| matches!(&t.token, Token::Void)));
}
#[test]
fn test_comments_stripped() {
let src = "int x; // comment\n/* block */ float y;";
let tokens = tokenize(src);
assert!(!tokens.iter().any(|t| matches!(&t.token, Token::LineComment)));
assert!(!tokens.iter().any(|t| matches!(&t.token, Token::BlockComment)));
}
}