Skip to main content

cuda_rust_wasm/parser/
ptx_parser.rs

1//! PTX (Parallel Thread Execution) parser
2//!
3//! Parses NVIDIA PTX assembly into a structured representation and converts
4//! it to the common CUDA AST for downstream transpilation to WGSL/Rust.
5
6use crate::{translation_error, Result};
7
8/// PTX module — top-level compilation unit
9#[derive(Debug, Clone)]
10pub struct PtxModule {
11    /// PTX ISA version (e.g., "7.8")
12    pub version: String,
13    /// Target architecture (e.g., "sm_80")
14    pub target: String,
15    /// Address size in bits (32 or 64)
16    pub address_size: u32,
17    /// Top-level directives (functions, variables, etc.)
18    pub directives: Vec<PtxDirective>,
19}
20
21/// PTX top-level directive
22#[derive(Debug, Clone)]
23pub enum PtxDirective {
24    /// Kernel entry point (.entry)
25    Entry(PtxFunction),
26    /// Device function (.func)
27    Function(PtxFunction),
28    /// Global variable (.global)
29    GlobalVar(PtxVariable),
30    /// Constant variable (.const)
31    ConstVar(PtxVariable),
32    /// Shared variable (.shared)
33    SharedVar(PtxVariable),
34}
35
36/// PTX function (entry or helper)
37#[derive(Debug, Clone)]
38pub struct PtxFunction {
39    /// Function name
40    pub name: String,
41    /// Parameters (.param declarations)
42    pub params: Vec<PtxVariable>,
43    /// Register declarations (.reg)
44    pub registers: Vec<PtxRegDecl>,
45    /// Local variable declarations
46    pub locals: Vec<PtxVariable>,
47    /// Instruction body
48    pub body: Vec<PtxStatement>,
49    /// Whether this is an entry point
50    pub is_entry: bool,
51}
52
53/// PTX register declaration
54#[derive(Debug, Clone)]
55pub struct PtxRegDecl {
56    /// Register type
57    pub reg_type: PtxType,
58    /// Register names
59    pub names: Vec<String>,
60    /// Number of registers (for array decls like .reg .f32 %f<32>)
61    pub count: Option<u32>,
62}
63
64/// PTX variable
65#[derive(Debug, Clone)]
66pub struct PtxVariable {
67    /// Variable name
68    pub name: String,
69    /// Type
70    pub var_type: PtxType,
71    /// Storage space
72    pub space: PtxSpace,
73    /// Array size (number of elements, if any)
74    pub array_size: Option<u32>,
75    /// Alignment
76    pub alignment: Option<u32>,
77}
78
79/// PTX data types
80#[derive(Debug, Clone, PartialEq)]
81pub enum PtxType {
82    Pred,
83    B8, B16, B32, B64,
84    S8, S16, S32, S64,
85    U8, U16, U32, U64,
86    F16, F32, F64,
87}
88
89/// PTX address spaces
90#[derive(Debug, Clone, PartialEq)]
91pub enum PtxSpace {
92    Reg,
93    Param,
94    Local,
95    Shared,
96    Global,
97    Const,
98}
99
100/// PTX statement (instruction or label)
101#[derive(Debug, Clone)]
102pub enum PtxStatement {
103    /// Label target
104    Label(String),
105    /// Instruction (possibly predicated)
106    Instruction(PtxInstruction),
107}
108
109/// PTX instruction
110#[derive(Debug, Clone)]
111pub struct PtxInstruction {
112    /// Optional predicate guard (@p or @!p)
113    pub predicate: Option<PtxPredicate>,
114    /// Opcode (e.g., "add", "ld", "st", "setp", "bra")
115    pub opcode: String,
116    /// Type suffix (e.g., ".f32", ".s32")
117    pub type_suffix: Option<PtxType>,
118    /// Modifier suffixes (e.g., ".rn", ".uni", ".wide", ".lu")
119    pub modifiers: Vec<String>,
120    /// Operands
121    pub operands: Vec<PtxOperand>,
122}
123
124/// PTX operand
125#[derive(Debug, Clone)]
126pub enum PtxOperand {
127    /// Register (%r0, %f1, %p0)
128    Register(String),
129    /// Special register (%tid.x, %ctaid.y, %ntid.z, %laneid, %warpid)
130    SpecialReg(String),
131    /// Immediate integer
132    ImmInt(i64),
133    /// Immediate float
134    ImmFloat(f64),
135    /// Label reference
136    Label(String),
137    /// Memory address [%r0], [%r0+4], [name]
138    Address { base: String, offset: Option<i64> },
139    /// Vector operand {%r0, %r1, %r2, %r3}
140    Vector(Vec<String>),
141}
142
143/// PTX predicate guard
144#[derive(Debug, Clone)]
145pub struct PtxPredicate {
146    /// Register name (e.g., "p0")
147    pub register: String,
148    /// Whether negated (@!p)
149    pub negated: bool,
150}
151
152/// Parse a PTX source string into a PtxModule
153pub fn parse_ptx(input: &str) -> Result<PtxModule> {
154    let mut module = PtxModule {
155        version: String::new(),
156        target: String::new(),
157        address_size: 64,
158        directives: Vec::new(),
159    };
160
161    let lines: Vec<&str> = input.lines().map(|l| l.trim()).collect();
162    let mut i = 0;
163
164    while i < lines.len() {
165        let line = lines[i];
166
167        // Skip empty lines and comments
168        if line.is_empty() || line.starts_with("//") {
169            i += 1;
170            continue;
171        }
172
173        if line.starts_with(".version") {
174            module.version = extract_value(line, ".version");
175        } else if line.starts_with(".target") {
176            module.target = extract_value(line, ".target");
177        } else if line.starts_with(".address_size") {
178            module.address_size = extract_value(line, ".address_size")
179                .parse()
180                .unwrap_or(64);
181        } else if line.contains(".entry") || line.contains(".func") {
182            let is_entry = line.contains(".entry");
183            let (func, end_idx) = parse_function(&lines, i, is_entry)?;
184            let directive = if is_entry {
185                PtxDirective::Entry(func)
186            } else {
187                PtxDirective::Function(func)
188            };
189            module.directives.push(directive);
190            i = end_idx;
191        } else if line.starts_with(".global") {
192            if let Some(var) = parse_variable(line, PtxSpace::Global) {
193                module.directives.push(PtxDirective::GlobalVar(var));
194            }
195        } else if line.starts_with(".const") {
196            if let Some(var) = parse_variable(line, PtxSpace::Const) {
197                module.directives.push(PtxDirective::ConstVar(var));
198            }
199        } else if line.starts_with(".shared") {
200            if let Some(var) = parse_variable(line, PtxSpace::Shared) {
201                module.directives.push(PtxDirective::SharedVar(var));
202            }
203        }
204
205        i += 1;
206    }
207
208    Ok(module)
209}
210
211/// Convert a PtxModule to the common CUDA AST for downstream transpilation
212pub fn ptx_to_ast(module: &PtxModule) -> Result<crate::parser::ast::Ast> {
213    use crate::parser::ast::*;
214
215    let mut items = Vec::new();
216
217    for directive in &module.directives {
218        match directive {
219            PtxDirective::Entry(func) => {
220                let params = func.params.iter().map(|p| Parameter {
221                    name: clean_name(&p.name),
222                    ty: ptx_type_to_ast(&p.var_type),
223                    qualifiers: vec![],
224                }).collect();
225
226                let body = ptx_body_to_ast(&func.body)?;
227
228                items.push(Item::Kernel(KernelDef {
229                    name: clean_name(&func.name),
230                    params,
231                    body,
232                    attributes: vec![],
233                }));
234            }
235            PtxDirective::Function(func) => {
236                let params = func.params.iter().map(|p| Parameter {
237                    name: clean_name(&p.name),
238                    ty: ptx_type_to_ast(&p.var_type),
239                    qualifiers: vec![],
240                }).collect();
241
242                let body = ptx_body_to_ast(&func.body)?;
243
244                items.push(Item::DeviceFunction(FunctionDef {
245                    name: clean_name(&func.name),
246                    return_type: Type::Void,
247                    params,
248                    body,
249                    qualifiers: vec![FunctionQualifier::Device],
250                }));
251            }
252            PtxDirective::GlobalVar(var) => {
253                items.push(Item::GlobalVar(GlobalVar {
254                    name: clean_name(&var.name),
255                    ty: ptx_type_to_ast(&var.var_type),
256                    storage: StorageClass::Global,
257                    init: None,
258                }));
259            }
260            PtxDirective::ConstVar(var) => {
261                items.push(Item::GlobalVar(GlobalVar {
262                    name: clean_name(&var.name),
263                    ty: ptx_type_to_ast(&var.var_type),
264                    storage: StorageClass::Constant,
265                    init: None,
266                }));
267            }
268            PtxDirective::SharedVar(var) => {
269                items.push(Item::GlobalVar(GlobalVar {
270                    name: clean_name(&var.name),
271                    ty: ptx_type_to_ast(&var.var_type),
272                    storage: StorageClass::Shared,
273                    init: None,
274                }));
275            }
276        }
277    }
278
279    Ok(Ast { items })
280}
281
282// --- Internal helpers ---
283
284fn extract_value(line: &str, prefix: &str) -> String {
285    line.trim_start_matches(prefix)
286        .trim()
287        .trim_end_matches(';')
288        .trim()
289        .to_string()
290}
291
292fn clean_name(name: &str) -> String {
293    name.trim_start_matches('%').trim_start_matches('_').to_string()
294}
295
296fn parse_type(s: &str) -> Option<PtxType> {
297    match s.trim_start_matches('.') {
298        "pred" => Some(PtxType::Pred),
299        "b8" => Some(PtxType::B8), "b16" => Some(PtxType::B16),
300        "b32" => Some(PtxType::B32), "b64" => Some(PtxType::B64),
301        "s8" => Some(PtxType::S8), "s16" => Some(PtxType::S16),
302        "s32" => Some(PtxType::S32), "s64" => Some(PtxType::S64),
303        "u8" => Some(PtxType::U8), "u16" => Some(PtxType::U16),
304        "u32" => Some(PtxType::U32), "u64" => Some(PtxType::U64),
305        "f16" => Some(PtxType::F16), "f32" => Some(PtxType::F32),
306        "f64" => Some(PtxType::F64),
307        _ => None,
308    }
309}
310
311fn parse_variable(line: &str, space: PtxSpace) -> Option<PtxVariable> {
312    let tokens: Vec<&str> = line.split_whitespace().collect();
313    if tokens.len() < 3 { return None; }
314
315    let var_type = parse_type(tokens[1]).unwrap_or(PtxType::B32);
316    let name = tokens.last()?.trim_end_matches(';').to_string();
317
318    Some(PtxVariable {
319        name,
320        var_type,
321        space,
322        array_size: None,
323        alignment: None,
324    })
325}
326
327fn parse_function(lines: &[&str], start: usize, is_entry: bool) -> Result<(PtxFunction, usize)> {
328    let header = lines[start];
329    let name = extract_func_name(header);
330
331    let mut func = PtxFunction {
332        name,
333        params: Vec::new(),
334        registers: Vec::new(),
335        locals: Vec::new(),
336        body: Vec::new(),
337        is_entry,
338    };
339
340    let mut i = start + 1;
341    let mut in_body = false;
342    let mut brace_depth = if header.contains('{') { 1 } else { 0 };
343
344    if brace_depth > 0 { in_body = true; }
345
346    while i < lines.len() {
347        let line = lines[i];
348
349        if !in_body {
350            if line.contains('{') {
351                in_body = true;
352                brace_depth += line.matches('{').count();
353                brace_depth -= line.matches('}').count();
354                if brace_depth == 0 { return Ok((func, i)); }
355                i += 1;
356                continue;
357            }
358            if line.contains(".param") {
359                if let Some(var) = parse_variable(line, PtxSpace::Param) {
360                    func.params.push(var);
361                }
362            }
363        } else {
364            brace_depth += line.matches('{').count();
365            brace_depth -= line.matches('}').count();
366
367            if brace_depth == 0 {
368                return Ok((func, i));
369            }
370
371            if line.contains(".reg") {
372                let tokens: Vec<&str> = line.split_whitespace().collect();
373                if tokens.len() >= 3 {
374                    let reg_type = parse_type(tokens[1]).unwrap_or(PtxType::B32);
375                    let name_part = tokens[2].trim_end_matches(';');
376                    // Handle array syntax %f<32>
377                    let (names, count) = if name_part.contains('<') {
378                        let parts: Vec<&str> = name_part.split('<').collect();
379                        let base = parts[0].to_string();
380                        let cnt: u32 = parts.get(1)
381                            .and_then(|s| s.trim_end_matches('>').parse().ok())
382                            .unwrap_or(1);
383                        (vec![base], Some(cnt))
384                    } else {
385                        (vec![name_part.to_string()], None)
386                    };
387                    func.registers.push(PtxRegDecl { reg_type, names, count });
388                }
389            } else if line.contains(".local") || line.contains(".shared") {
390                let space = if line.contains(".shared") { PtxSpace::Shared } else { PtxSpace::Local };
391                if let Some(var) = parse_variable(line, space) {
392                    func.locals.push(var);
393                }
394            } else if !line.is_empty() && !line.starts_with("//") {
395                if let Some(stmt) = parse_statement(line) {
396                    func.body.push(stmt);
397                }
398            }
399        }
400
401        i += 1;
402    }
403
404    Ok((func, lines.len() - 1))
405}
406
407fn extract_func_name(line: &str) -> String {
408    // Look for name after .entry or .func
409    let after_keyword = line
410        .replace(".visible", "")
411        .replace(".entry", "|")
412        .replace(".func", "|");
413    let parts: Vec<&str> = after_keyword.split('|').collect();
414    if parts.len() > 1 {
415        let name_part = parts[1].trim();
416        name_part
417            .split(|c: char| c.is_whitespace() || c == '(' || c == '{')
418            .next()
419            .unwrap_or("unknown")
420            .to_string()
421    } else {
422        "unknown".to_string()
423    }
424}
425
426fn parse_statement(line: &str) -> Option<PtxStatement> {
427    let trimmed = line.trim().trim_end_matches(';').trim();
428
429    // Label
430    if trimmed.ends_with(':') && !trimmed.starts_with('@') {
431        return Some(PtxStatement::Label(trimmed.trim_end_matches(':').to_string()));
432    }
433
434    // Instruction (possibly predicated)
435    let (predicate, rest) = if trimmed.starts_with('@') {
436        let parts: Vec<&str> = trimmed.splitn(2, char::is_whitespace).collect();
437        let pred_str = &parts[0][1..]; // skip @
438        let negated = pred_str.starts_with('!');
439        let reg = if negated { &pred_str[1..] } else { pred_str }.to_string();
440        let rest = parts.get(1).unwrap_or(&"").trim();
441        (Some(PtxPredicate { register: reg, negated }), rest.to_string())
442    } else {
443        (None, trimmed.to_string())
444    };
445
446    let tokens: Vec<&str> = rest.split_whitespace().collect();
447    if tokens.is_empty() { return None; }
448
449    let opcode_full = tokens[0];
450    let opcode_parts: Vec<&str> = opcode_full.split('.').collect();
451    let opcode = opcode_parts[0].to_string();
452
453    let type_suffix = opcode_parts.iter().skip(1).find_map(|p| parse_type(p));
454    let modifiers: Vec<String> = opcode_parts.iter().skip(1)
455        .filter(|p| parse_type(p).is_none())
456        .map(|s| s.to_string())
457        .collect();
458
459    let operand_str = tokens[1..].join(" ");
460    let operands = parse_operands(&operand_str);
461
462    Some(PtxStatement::Instruction(PtxInstruction {
463        predicate,
464        opcode,
465        type_suffix,
466        modifiers,
467        operands,
468    }))
469}
470
471fn parse_operands(s: &str) -> Vec<PtxOperand> {
472    if s.is_empty() { return vec![]; }
473
474    s.split(',')
475        .map(|part| {
476            let t = part.trim();
477            if t.starts_with('%') {
478                let name = t.trim_start_matches('%');
479                if name.contains("tid.") || name.contains("ctaid.") || name.contains("ntid.")
480                    || name.contains("nctaid.") || name == "laneid" || name == "warpid"
481                {
482                    PtxOperand::SpecialReg(t.to_string())
483                } else {
484                    PtxOperand::Register(t.to_string())
485                }
486            } else if t.starts_with('[') && t.ends_with(']') {
487                let inner = &t[1..t.len()-1];
488                if let Some(plus) = inner.find('+') {
489                    let base = inner[..plus].trim().to_string();
490                    let offset = inner[plus+1..].trim().parse().ok();
491                    PtxOperand::Address { base, offset }
492                } else {
493                    PtxOperand::Address { base: inner.trim().to_string(), offset: None }
494                }
495            } else if t.starts_with('{') {
496                let inner = t.trim_matches(|c| c == '{' || c == '}');
497                let regs: Vec<String> = inner.split(',').map(|r| r.trim().to_string()).collect();
498                PtxOperand::Vector(regs)
499            } else if let Ok(v) = t.parse::<i64>() {
500                PtxOperand::ImmInt(v)
501            } else if let Ok(v) = t.parse::<f64>() {
502                PtxOperand::ImmFloat(v)
503            } else {
504                PtxOperand::Label(t.to_string())
505            }
506        })
507        .collect()
508}
509
510fn ptx_type_to_ast(ty: &PtxType) -> crate::parser::ast::Type {
511    use crate::parser::ast::{Type, IntType, FloatType};
512    match ty {
513        PtxType::Pred => Type::Bool,
514        PtxType::B8 | PtxType::U8 => Type::Int(IntType::U8),
515        PtxType::B16 | PtxType::U16 => Type::Int(IntType::U16),
516        PtxType::B32 | PtxType::U32 => Type::Int(IntType::U32),
517        PtxType::B64 | PtxType::U64 => Type::Int(IntType::U64),
518        PtxType::S8 => Type::Int(IntType::I8),
519        PtxType::S16 => Type::Int(IntType::I16),
520        PtxType::S32 => Type::Int(IntType::I32),
521        PtxType::S64 => Type::Int(IntType::I64),
522        PtxType::F16 => Type::Float(FloatType::F16),
523        PtxType::F32 => Type::Float(FloatType::F32),
524        PtxType::F64 => Type::Float(FloatType::F64),
525    }
526}
527
528fn ptx_body_to_ast(stmts: &[PtxStatement]) -> Result<crate::parser::ast::Block> {
529    use crate::parser::ast::*;
530
531    let mut statements = Vec::new();
532
533    for stmt in stmts {
534        match stmt {
535            PtxStatement::Label(_) => {
536                // Labels are used for control flow; skip in high-level AST
537            }
538            PtxStatement::Instruction(inst) => {
539                let ast_stmt = ptx_instruction_to_ast(inst)?;
540                if let Some(s) = ast_stmt {
541                    statements.push(s);
542                }
543            }
544        }
545    }
546
547    Ok(Block { statements })
548}
549
550fn ptx_instruction_to_ast(inst: &PtxInstruction) -> Result<Option<crate::parser::ast::Statement>> {
551    use crate::parser::ast::*;
552
553    match inst.opcode.as_str() {
554        "ret" => Ok(Some(Statement::Return(None))),
555        "bar" => Ok(Some(Statement::SyncThreads)),
556        "add" | "sub" | "mul" | "div" | "rem" | "and" | "or" | "xor" | "shl" | "shr" => {
557            if inst.operands.len() >= 3 {
558                let dst = operand_to_var(&inst.operands[0]);
559                let lhs = operand_to_expr(&inst.operands[1]);
560                let rhs = operand_to_expr(&inst.operands[2]);
561                let op = match inst.opcode.as_str() {
562                    "add" => BinaryOp::Add, "sub" => BinaryOp::Sub,
563                    "mul" => BinaryOp::Mul, "div" => BinaryOp::Div,
564                    "rem" => BinaryOp::Mod, "and" => BinaryOp::And,
565                    "or" => BinaryOp::Or, "xor" => BinaryOp::Xor,
566                    "shl" => BinaryOp::Shl, "shr" => BinaryOp::Shr,
567                    _ => BinaryOp::Add,
568                };
569                Ok(Some(Statement::Expr(Expression::Binary {
570                    op: BinaryOp::Assign,
571                    left: Box::new(Expression::Var(dst)),
572                    right: Box::new(Expression::Binary {
573                        op,
574                        left: Box::new(lhs),
575                        right: Box::new(rhs),
576                    }),
577                })))
578            } else {
579                Ok(None)
580            }
581        }
582        "mov" => {
583            if inst.operands.len() >= 2 {
584                let dst = operand_to_var(&inst.operands[0]);
585                let src = operand_to_expr(&inst.operands[1]);
586                Ok(Some(Statement::Expr(Expression::Binary {
587                    op: BinaryOp::Assign,
588                    left: Box::new(Expression::Var(dst)),
589                    right: Box::new(src),
590                })))
591            } else {
592                Ok(None)
593            }
594        }
595        "ld" => {
596            if inst.operands.len() >= 2 {
597                let dst = operand_to_var(&inst.operands[0]);
598                let src = operand_to_expr(&inst.operands[1]);
599                Ok(Some(Statement::Expr(Expression::Binary {
600                    op: BinaryOp::Assign,
601                    left: Box::new(Expression::Var(dst)),
602                    right: Box::new(src),
603                })))
604            } else {
605                Ok(None)
606            }
607        }
608        "st" => {
609            if inst.operands.len() >= 2 {
610                let dst = operand_to_expr(&inst.operands[0]);
611                let src = operand_to_expr(&inst.operands[1]);
612                Ok(Some(Statement::Expr(Expression::Binary {
613                    op: BinaryOp::Assign,
614                    left: Box::new(dst),
615                    right: Box::new(src),
616                })))
617            } else {
618                Ok(None)
619            }
620        }
621        _ => Ok(None), // Skip unhandled opcodes
622    }
623}
624
625fn operand_to_var(op: &PtxOperand) -> String {
626    match op {
627        PtxOperand::Register(r) => clean_name(r),
628        PtxOperand::SpecialReg(r) => clean_name(r),
629        PtxOperand::Label(l) => l.clone(),
630        _ => "unknown".to_string(),
631    }
632}
633
634fn operand_to_expr(op: &PtxOperand) -> crate::parser::ast::Expression {
635    use crate::parser::ast::*;
636    match op {
637        PtxOperand::Register(r) => Expression::Var(clean_name(r)),
638        PtxOperand::SpecialReg(r) => {
639            let name = r.trim_start_matches('%');
640            match name {
641                "tid.x" => Expression::ThreadIdx(Dimension::X),
642                "tid.y" => Expression::ThreadIdx(Dimension::Y),
643                "tid.z" => Expression::ThreadIdx(Dimension::Z),
644                "ctaid.x" => Expression::BlockIdx(Dimension::X),
645                "ctaid.y" => Expression::BlockIdx(Dimension::Y),
646                "ctaid.z" => Expression::BlockIdx(Dimension::Z),
647                "ntid.x" => Expression::BlockDim(Dimension::X),
648                "ntid.y" => Expression::BlockDim(Dimension::Y),
649                "ntid.z" => Expression::BlockDim(Dimension::Z),
650                "nctaid.x" => Expression::GridDim(Dimension::X),
651                "nctaid.y" => Expression::GridDim(Dimension::Y),
652                "nctaid.z" => Expression::GridDim(Dimension::Z),
653                _ => Expression::Var(name.to_string()),
654            }
655        }
656        PtxOperand::ImmInt(v) => Expression::Literal(Literal::Int(*v)),
657        PtxOperand::ImmFloat(v) => Expression::Literal(Literal::Float(*v)),
658        PtxOperand::Address { base, offset } => {
659            let base_expr = Expression::Var(clean_name(base));
660            match offset {
661                Some(off) => Expression::Index {
662                    array: Box::new(base_expr),
663                    index: Box::new(Expression::Literal(Literal::Int(*off))),
664                },
665                None => base_expr,
666            }
667        }
668        PtxOperand::Label(l) => Expression::Var(l.clone()),
669        PtxOperand::Vector(regs) => {
670            // Return first register as a simple expression
671            Expression::Var(clean_name(regs.first().map(|s| s.as_str()).unwrap_or("v0")))
672        }
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679
680    #[test]
681    fn test_parse_version() {
682        let ptx = ".version 7.8\n.target sm_80\n.address_size 64\n";
683        let module = parse_ptx(ptx).unwrap();
684        assert_eq!(module.version, "7.8");
685        assert_eq!(module.target, "sm_80");
686        assert_eq!(module.address_size, 64);
687    }
688
689    #[test]
690    fn test_parse_entry_function() {
691        let ptx = r#"
692.version 7.8
693.target sm_80
694.address_size 64
695
696.visible .entry vectorAdd(
697    .param .u64 a,
698    .param .u64 b,
699    .param .u64 c
700)
701{
702    .reg .f32 %f<4>;
703    .reg .u32 %r<4>;
704    mov.u32 %r0, %tid.x;
705    add.f32 %f2, %f0, %f1;
706    ret;
707}
708"#;
709        let module = parse_ptx(ptx).unwrap();
710        assert_eq!(module.directives.len(), 1);
711        match &module.directives[0] {
712            PtxDirective::Entry(func) => {
713                assert_eq!(func.name, "vectorAdd");
714                assert_eq!(func.params.len(), 3);
715                assert!(func.is_entry);
716                assert!(!func.body.is_empty());
717            }
718            other => panic!("Expected entry directive, got {:?}", other),
719        }
720    }
721
722    #[test]
723    fn test_parse_type() {
724        assert_eq!(parse_type(".f32"), Some(PtxType::F32));
725        assert_eq!(parse_type(".s64"), Some(PtxType::S64));
726        assert_eq!(parse_type(".pred"), Some(PtxType::Pred));
727        assert_eq!(parse_type(".b16"), Some(PtxType::B16));
728        assert_eq!(parse_type(".invalid"), None);
729    }
730
731    #[test]
732    fn test_parse_instruction_basic() {
733        let stmt = parse_statement("add.f32 %f2, %f0, %f1;").unwrap();
734        match stmt {
735            PtxStatement::Instruction(inst) => {
736                assert_eq!(inst.opcode, "add");
737                assert_eq!(inst.type_suffix, Some(PtxType::F32));
738                assert_eq!(inst.operands.len(), 3);
739            }
740            other => panic!("Expected instruction, got {:?}", other),
741        }
742    }
743
744    #[test]
745    fn test_parse_predicated_instruction() {
746        let stmt = parse_statement("@p0 bra LOOP;").unwrap();
747        match stmt {
748            PtxStatement::Instruction(inst) => {
749                assert!(inst.predicate.is_some());
750                let pred = inst.predicate.unwrap();
751                assert_eq!(pred.register, "p0");
752                assert!(!pred.negated);
753                assert_eq!(inst.opcode, "bra");
754            }
755            other => panic!("Expected instruction, got {:?}", other),
756        }
757    }
758
759    #[test]
760    fn test_parse_negated_predicate() {
761        let stmt = parse_statement("@!p1 ret;").unwrap();
762        match stmt {
763            PtxStatement::Instruction(inst) => {
764                let pred = inst.predicate.unwrap();
765                assert_eq!(pred.register, "p1");
766                assert!(pred.negated);
767            }
768            other => panic!("Expected instruction, got {:?}", other),
769        }
770    }
771
772    #[test]
773    fn test_parse_label() {
774        let stmt = parse_statement("LOOP:").unwrap();
775        match stmt {
776            PtxStatement::Label(name) => assert_eq!(name, "LOOP"),
777            other => panic!("Expected label, got {:?}", other),
778        }
779    }
780
781    #[test]
782    fn test_parse_special_registers() {
783        let operands = parse_operands("%tid.x, %ctaid.y");
784        assert_eq!(operands.len(), 2);
785        match &operands[0] {
786            PtxOperand::SpecialReg(r) => assert_eq!(r, "%tid.x"),
787            other => panic!("Expected special register, got {:?}", other),
788        }
789    }
790
791    #[test]
792    fn test_parse_memory_address() {
793        let operands = parse_operands("[%r0+4]");
794        match &operands[0] {
795            PtxOperand::Address { base, offset } => {
796                assert_eq!(base, "%r0");
797                assert_eq!(*offset, Some(4));
798            }
799            other => panic!("Expected address, got {:?}", other),
800        }
801    }
802
803    #[test]
804    fn test_parse_immediate() {
805        let operands = parse_operands("42");
806        match &operands[0] {
807            PtxOperand::ImmInt(v) => assert_eq!(*v, 42),
808            other => panic!("Expected immediate int, got {:?}", other),
809        }
810    }
811
812    #[test]
813    fn test_parse_global_variable() {
814        let ptx = ".version 7.8\n.target sm_80\n.address_size 64\n.global .f32 result;\n";
815        let module = parse_ptx(ptx).unwrap();
816        assert_eq!(module.directives.len(), 1);
817        match &module.directives[0] {
818            PtxDirective::GlobalVar(var) => {
819                assert_eq!(var.var_type, PtxType::F32);
820                assert_eq!(var.space, PtxSpace::Global);
821            }
822            other => panic!("Expected global var, got {:?}", other),
823        }
824    }
825
826    #[test]
827    fn test_ptx_to_ast() {
828        let ptx = r#"
829.version 7.8
830.target sm_80
831.address_size 64
832
833.visible .entry simple(
834    .param .u64 data
835)
836{
837    .reg .u32 %r<2>;
838    mov.u32 %r0, %tid.x;
839    ret;
840}
841"#;
842        let module = parse_ptx(ptx).unwrap();
843        let ast = ptx_to_ast(&module).unwrap();
844        assert_eq!(ast.items.len(), 1);
845        match &ast.items[0] {
846            crate::parser::ast::Item::Kernel(k) => {
847                assert_eq!(k.name, "simple");
848            }
849            other => panic!("Expected kernel, got {:?}", other),
850        }
851    }
852
853    #[test]
854    fn test_ptx_type_conversion() {
855        use crate::parser::ast::{Type, IntType, FloatType};
856        assert!(matches!(ptx_type_to_ast(&PtxType::F32), Type::Float(FloatType::F32)));
857        assert!(matches!(ptx_type_to_ast(&PtxType::S32), Type::Int(IntType::I32)));
858        assert!(matches!(ptx_type_to_ast(&PtxType::Pred), Type::Bool));
859    }
860}