cuda_rust_wasm/parser/
ast.rs

1//! Abstract Syntax Tree definitions for CUDA
2
3use serde::{Deserialize, Serialize};
4
5/// Root AST node
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Ast {
8    pub items: Vec<Item>,
9}
10
11/// Top-level items in CUDA code
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum Item {
14    /// Kernel function definition
15    Kernel(KernelDef),
16    /// Device function
17    DeviceFunction(FunctionDef),
18    /// Host function
19    HostFunction(FunctionDef),
20    /// Global variable
21    GlobalVar(GlobalVar),
22    /// Type definition
23    TypeDef(TypeDef),
24    /// Include directive
25    Include(String),
26}
27
28/// CUDA kernel definition
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct KernelDef {
31    pub name: String,
32    pub params: Vec<Parameter>,
33    pub body: Block,
34    pub attributes: Vec<KernelAttribute>,
35}
36
37/// Kernel attributes (launch bounds, etc.)
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum KernelAttribute {
40    LaunchBounds { max_threads: u32, min_blocks: Option<u32> },
41    MaxRegisters(u32),
42}
43
44/// Function definition
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct FunctionDef {
47    pub name: String,
48    pub return_type: Type,
49    pub params: Vec<Parameter>,
50    pub body: Block,
51    pub qualifiers: Vec<FunctionQualifier>,
52}
53
54/// Function qualifiers
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub enum FunctionQualifier {
57    Device,
58    Host,
59    Global,
60    Inline,
61    NoInline,
62}
63
64/// Function parameter
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Parameter {
67    pub name: String,
68    pub ty: Type,
69    pub qualifiers: Vec<ParamQualifier>,
70}
71
72/// Parameter qualifiers
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum ParamQualifier {
75    Const,
76    Restrict,
77    Volatile,
78}
79
80/// CUDA types
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum Type {
83    /// Primitive types
84    Void,
85    Bool,
86    Int(IntType),
87    Float(FloatType),
88    /// Pointer type
89    Pointer(Box<Type>),
90    /// Array type
91    Array(Box<Type>, Option<usize>),
92    /// Vector types (float4, int2, etc.)
93    Vector(VectorType),
94    /// User-defined type
95    Named(String),
96    /// Texture type
97    Texture(TextureType),
98}
99
100/// Integer types
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub enum IntType {
103    I8,
104    I16,
105    I32,
106    I64,
107    U8,
108    U16,
109    U32,
110    U64,
111}
112
113/// Floating-point types
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub enum FloatType {
116    F16,
117    F32,
118    F64,
119}
120
121/// Vector types
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct VectorType {
124    pub element: Box<Type>,
125    pub size: u8, // 1, 2, 3, or 4
126}
127
128/// Texture types
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct TextureType {
131    pub dim: TextureDim,
132    pub element: Box<Type>,
133}
134
135/// Texture dimensions
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub enum TextureDim {
138    Tex1D,
139    Tex2D,
140    Tex3D,
141    TexCube,
142}
143
144/// Statement types
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub enum Statement {
147    /// Variable declaration
148    VarDecl {
149        name: String,
150        ty: Type,
151        init: Option<Expression>,
152        storage: StorageClass,
153    },
154    /// Expression statement
155    Expr(Expression),
156    /// Block statement
157    Block(Block),
158    /// If statement
159    If {
160        condition: Expression,
161        then_branch: Box<Statement>,
162        else_branch: Option<Box<Statement>>,
163    },
164    /// For loop
165    For {
166        init: Option<Box<Statement>>,
167        condition: Option<Expression>,
168        update: Option<Expression>,
169        body: Box<Statement>,
170    },
171    /// While loop
172    While {
173        condition: Expression,
174        body: Box<Statement>,
175    },
176    /// Return statement
177    Return(Option<Expression>),
178    /// Break statement
179    Break,
180    /// Continue statement
181    Continue,
182    /// Synchronization
183    SyncThreads,
184}
185
186/// Storage classes
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub enum StorageClass {
189    Auto,
190    Register,
191    Shared,
192    Global,
193    Constant,
194    Local,
195}
196
197/// Block of statements
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct Block {
200    pub statements: Vec<Statement>,
201}
202
203/// Expression types
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub enum Expression {
206    /// Literal values
207    Literal(Literal),
208    /// Variable reference
209    Var(String),
210    /// Binary operation
211    Binary {
212        op: BinaryOp,
213        left: Box<Expression>,
214        right: Box<Expression>,
215    },
216    /// Unary operation
217    Unary {
218        op: UnaryOp,
219        expr: Box<Expression>,
220    },
221    /// Function call
222    Call {
223        name: String,
224        args: Vec<Expression>,
225    },
226    /// Array access
227    Index {
228        array: Box<Expression>,
229        index: Box<Expression>,
230    },
231    /// Member access
232    Member {
233        object: Box<Expression>,
234        field: String,
235    },
236    /// Cast expression
237    Cast {
238        ty: Type,
239        expr: Box<Expression>,
240    },
241    /// Thread index access
242    ThreadIdx(Dimension),
243    /// Block index access
244    BlockIdx(Dimension),
245    /// Block dimension access
246    BlockDim(Dimension),
247    /// Grid dimension access
248    GridDim(Dimension),
249    /// Warp-level primitives
250    WarpPrimitive {
251        op: WarpOp,
252        args: Vec<Expression>,
253    },
254}
255
256/// Dimensions for thread/block indexing
257#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
258pub enum Dimension {
259    X,
260    Y,
261    Z,
262}
263
264/// Literal values
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub enum Literal {
267    Bool(bool),
268    Int(i64),
269    UInt(u64),
270    Float(f64),
271    String(String),
272}
273
274/// Binary operators
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub enum BinaryOp {
277    Add,
278    Sub,
279    Mul,
280    Div,
281    Mod,
282    And,
283    Or,
284    Xor,
285    Shl,
286    Shr,
287    Eq,
288    Ne,
289    Lt,
290    Le,
291    Gt,
292    Ge,
293    LogicalAnd,
294    LogicalOr,
295    Assign,
296}
297
298/// Unary operators
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub enum UnaryOp {
301    Not,
302    Neg,
303    BitNot,
304    PreInc,
305    PreDec,
306    PostInc,
307    PostDec,
308    Deref,
309    AddrOf,
310}
311
312/// Warp-level operations
313#[derive(Debug, Clone, Serialize, Deserialize)]
314pub enum WarpOp {
315    Shuffle,
316    ShuffleXor,
317    ShuffleUp,
318    ShuffleDown,
319    Vote,
320    Ballot,
321    ActiveMask,
322}
323
324/// Global variable definition
325#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct GlobalVar {
327    pub name: String,
328    pub ty: Type,
329    pub storage: StorageClass,
330    pub init: Option<Expression>,
331}
332
333/// Type definition
334#[derive(Debug, Clone, Serialize, Deserialize)]
335pub struct TypeDef {
336    pub name: String,
337    pub ty: Type,
338}