1use crate::{translation_error, Result};
7
8#[derive(Debug, Clone)]
10pub struct PtxModule {
11 pub version: String,
13 pub target: String,
15 pub address_size: u32,
17 pub directives: Vec<PtxDirective>,
19}
20
21#[derive(Debug, Clone)]
23pub enum PtxDirective {
24 Entry(PtxFunction),
26 Function(PtxFunction),
28 GlobalVar(PtxVariable),
30 ConstVar(PtxVariable),
32 SharedVar(PtxVariable),
34}
35
36#[derive(Debug, Clone)]
38pub struct PtxFunction {
39 pub name: String,
41 pub params: Vec<PtxVariable>,
43 pub registers: Vec<PtxRegDecl>,
45 pub locals: Vec<PtxVariable>,
47 pub body: Vec<PtxStatement>,
49 pub is_entry: bool,
51}
52
53#[derive(Debug, Clone)]
55pub struct PtxRegDecl {
56 pub reg_type: PtxType,
58 pub names: Vec<String>,
60 pub count: Option<u32>,
62}
63
64#[derive(Debug, Clone)]
66pub struct PtxVariable {
67 pub name: String,
69 pub var_type: PtxType,
71 pub space: PtxSpace,
73 pub array_size: Option<u32>,
75 pub alignment: Option<u32>,
77}
78
79#[derive(Debug, Clone, PartialEq)]
81pub enum PtxType {
82 Pred,
83 B8, B16, B32, B64,
84 S8, S16, S32, S64,
85 U8, U16, U32, U64,
86 F16, F32, F64,
87}
88
89#[derive(Debug, Clone, PartialEq)]
91pub enum PtxSpace {
92 Reg,
93 Param,
94 Local,
95 Shared,
96 Global,
97 Const,
98}
99
100#[derive(Debug, Clone)]
102pub enum PtxStatement {
103 Label(String),
105 Instruction(PtxInstruction),
107}
108
109#[derive(Debug, Clone)]
111pub struct PtxInstruction {
112 pub predicate: Option<PtxPredicate>,
114 pub opcode: String,
116 pub type_suffix: Option<PtxType>,
118 pub modifiers: Vec<String>,
120 pub operands: Vec<PtxOperand>,
122}
123
124#[derive(Debug, Clone)]
126pub enum PtxOperand {
127 Register(String),
129 SpecialReg(String),
131 ImmInt(i64),
133 ImmFloat(f64),
135 Label(String),
137 Address { base: String, offset: Option<i64> },
139 Vector(Vec<String>),
141}
142
143#[derive(Debug, Clone)]
145pub struct PtxPredicate {
146 pub register: String,
148 pub negated: bool,
150}
151
152pub fn parse_ptx(input: &str) -> Result<PtxModule> {
154 let mut module = PtxModule {
155 version: String::new(),
156 target: String::new(),
157 address_size: 64,
158 directives: Vec::new(),
159 };
160
161 let lines: Vec<&str> = input.lines().map(|l| l.trim()).collect();
162 let mut i = 0;
163
164 while i < lines.len() {
165 let line = lines[i];
166
167 if line.is_empty() || line.starts_with("//") {
169 i += 1;
170 continue;
171 }
172
173 if line.starts_with(".version") {
174 module.version = extract_value(line, ".version");
175 } else if line.starts_with(".target") {
176 module.target = extract_value(line, ".target");
177 } else if line.starts_with(".address_size") {
178 module.address_size = extract_value(line, ".address_size")
179 .parse()
180 .unwrap_or(64);
181 } else if line.contains(".entry") || line.contains(".func") {
182 let is_entry = line.contains(".entry");
183 let (func, end_idx) = parse_function(&lines, i, is_entry)?;
184 let directive = if is_entry {
185 PtxDirective::Entry(func)
186 } else {
187 PtxDirective::Function(func)
188 };
189 module.directives.push(directive);
190 i = end_idx;
191 } else if line.starts_with(".global") {
192 if let Some(var) = parse_variable(line, PtxSpace::Global) {
193 module.directives.push(PtxDirective::GlobalVar(var));
194 }
195 } else if line.starts_with(".const") {
196 if let Some(var) = parse_variable(line, PtxSpace::Const) {
197 module.directives.push(PtxDirective::ConstVar(var));
198 }
199 } else if line.starts_with(".shared") {
200 if let Some(var) = parse_variable(line, PtxSpace::Shared) {
201 module.directives.push(PtxDirective::SharedVar(var));
202 }
203 }
204
205 i += 1;
206 }
207
208 Ok(module)
209}
210
211pub fn ptx_to_ast(module: &PtxModule) -> Result<crate::parser::ast::Ast> {
213 use crate::parser::ast::*;
214
215 let mut items = Vec::new();
216
217 for directive in &module.directives {
218 match directive {
219 PtxDirective::Entry(func) => {
220 let params = func.params.iter().map(|p| Parameter {
221 name: clean_name(&p.name),
222 ty: ptx_type_to_ast(&p.var_type),
223 qualifiers: vec![],
224 }).collect();
225
226 let body = ptx_body_to_ast(&func.body)?;
227
228 items.push(Item::Kernel(KernelDef {
229 name: clean_name(&func.name),
230 params,
231 body,
232 attributes: vec![],
233 }));
234 }
235 PtxDirective::Function(func) => {
236 let params = func.params.iter().map(|p| Parameter {
237 name: clean_name(&p.name),
238 ty: ptx_type_to_ast(&p.var_type),
239 qualifiers: vec![],
240 }).collect();
241
242 let body = ptx_body_to_ast(&func.body)?;
243
244 items.push(Item::DeviceFunction(FunctionDef {
245 name: clean_name(&func.name),
246 return_type: Type::Void,
247 params,
248 body,
249 qualifiers: vec![FunctionQualifier::Device],
250 }));
251 }
252 PtxDirective::GlobalVar(var) => {
253 items.push(Item::GlobalVar(GlobalVar {
254 name: clean_name(&var.name),
255 ty: ptx_type_to_ast(&var.var_type),
256 storage: StorageClass::Global,
257 init: None,
258 }));
259 }
260 PtxDirective::ConstVar(var) => {
261 items.push(Item::GlobalVar(GlobalVar {
262 name: clean_name(&var.name),
263 ty: ptx_type_to_ast(&var.var_type),
264 storage: StorageClass::Constant,
265 init: None,
266 }));
267 }
268 PtxDirective::SharedVar(var) => {
269 items.push(Item::GlobalVar(GlobalVar {
270 name: clean_name(&var.name),
271 ty: ptx_type_to_ast(&var.var_type),
272 storage: StorageClass::Shared,
273 init: None,
274 }));
275 }
276 }
277 }
278
279 Ok(Ast { items })
280}
281
282fn extract_value(line: &str, prefix: &str) -> String {
285 line.trim_start_matches(prefix)
286 .trim()
287 .trim_end_matches(';')
288 .trim()
289 .to_string()
290}
291
292fn clean_name(name: &str) -> String {
293 name.trim_start_matches('%').trim_start_matches('_').to_string()
294}
295
296fn parse_type(s: &str) -> Option<PtxType> {
297 match s.trim_start_matches('.') {
298 "pred" => Some(PtxType::Pred),
299 "b8" => Some(PtxType::B8), "b16" => Some(PtxType::B16),
300 "b32" => Some(PtxType::B32), "b64" => Some(PtxType::B64),
301 "s8" => Some(PtxType::S8), "s16" => Some(PtxType::S16),
302 "s32" => Some(PtxType::S32), "s64" => Some(PtxType::S64),
303 "u8" => Some(PtxType::U8), "u16" => Some(PtxType::U16),
304 "u32" => Some(PtxType::U32), "u64" => Some(PtxType::U64),
305 "f16" => Some(PtxType::F16), "f32" => Some(PtxType::F32),
306 "f64" => Some(PtxType::F64),
307 _ => None,
308 }
309}
310
311fn parse_variable(line: &str, space: PtxSpace) -> Option<PtxVariable> {
312 let tokens: Vec<&str> = line.split_whitespace().collect();
313 if tokens.len() < 3 { return None; }
314
315 let var_type = parse_type(tokens[1]).unwrap_or(PtxType::B32);
316 let name = tokens.last()?.trim_end_matches(';').to_string();
317
318 Some(PtxVariable {
319 name,
320 var_type,
321 space,
322 array_size: None,
323 alignment: None,
324 })
325}
326
327fn parse_function(lines: &[&str], start: usize, is_entry: bool) -> Result<(PtxFunction, usize)> {
328 let header = lines[start];
329 let name = extract_func_name(header);
330
331 let mut func = PtxFunction {
332 name,
333 params: Vec::new(),
334 registers: Vec::new(),
335 locals: Vec::new(),
336 body: Vec::new(),
337 is_entry,
338 };
339
340 let mut i = start + 1;
341 let mut in_body = false;
342 let mut brace_depth = if header.contains('{') { 1 } else { 0 };
343
344 if brace_depth > 0 { in_body = true; }
345
346 while i < lines.len() {
347 let line = lines[i];
348
349 if !in_body {
350 if line.contains('{') {
351 in_body = true;
352 brace_depth += line.matches('{').count();
353 brace_depth -= line.matches('}').count();
354 if brace_depth == 0 { return Ok((func, i)); }
355 i += 1;
356 continue;
357 }
358 if line.contains(".param") {
359 if let Some(var) = parse_variable(line, PtxSpace::Param) {
360 func.params.push(var);
361 }
362 }
363 } else {
364 brace_depth += line.matches('{').count();
365 brace_depth -= line.matches('}').count();
366
367 if brace_depth == 0 {
368 return Ok((func, i));
369 }
370
371 if line.contains(".reg") {
372 let tokens: Vec<&str> = line.split_whitespace().collect();
373 if tokens.len() >= 3 {
374 let reg_type = parse_type(tokens[1]).unwrap_or(PtxType::B32);
375 let name_part = tokens[2].trim_end_matches(';');
376 let (names, count) = if name_part.contains('<') {
378 let parts: Vec<&str> = name_part.split('<').collect();
379 let base = parts[0].to_string();
380 let cnt: u32 = parts.get(1)
381 .and_then(|s| s.trim_end_matches('>').parse().ok())
382 .unwrap_or(1);
383 (vec![base], Some(cnt))
384 } else {
385 (vec![name_part.to_string()], None)
386 };
387 func.registers.push(PtxRegDecl { reg_type, names, count });
388 }
389 } else if line.contains(".local") || line.contains(".shared") {
390 let space = if line.contains(".shared") { PtxSpace::Shared } else { PtxSpace::Local };
391 if let Some(var) = parse_variable(line, space) {
392 func.locals.push(var);
393 }
394 } else if !line.is_empty() && !line.starts_with("//") {
395 if let Some(stmt) = parse_statement(line) {
396 func.body.push(stmt);
397 }
398 }
399 }
400
401 i += 1;
402 }
403
404 Ok((func, lines.len() - 1))
405}
406
407fn extract_func_name(line: &str) -> String {
408 let after_keyword = line
410 .replace(".visible", "")
411 .replace(".entry", "|")
412 .replace(".func", "|");
413 let parts: Vec<&str> = after_keyword.split('|').collect();
414 if parts.len() > 1 {
415 let name_part = parts[1].trim();
416 name_part
417 .split(|c: char| c.is_whitespace() || c == '(' || c == '{')
418 .next()
419 .unwrap_or("unknown")
420 .to_string()
421 } else {
422 "unknown".to_string()
423 }
424}
425
426fn parse_statement(line: &str) -> Option<PtxStatement> {
427 let trimmed = line.trim().trim_end_matches(';').trim();
428
429 if trimmed.ends_with(':') && !trimmed.starts_with('@') {
431 return Some(PtxStatement::Label(trimmed.trim_end_matches(':').to_string()));
432 }
433
434 let (predicate, rest) = if trimmed.starts_with('@') {
436 let parts: Vec<&str> = trimmed.splitn(2, char::is_whitespace).collect();
437 let pred_str = &parts[0][1..]; let negated = pred_str.starts_with('!');
439 let reg = if negated { &pred_str[1..] } else { pred_str }.to_string();
440 let rest = parts.get(1).unwrap_or(&"").trim();
441 (Some(PtxPredicate { register: reg, negated }), rest.to_string())
442 } else {
443 (None, trimmed.to_string())
444 };
445
446 let tokens: Vec<&str> = rest.split_whitespace().collect();
447 if tokens.is_empty() { return None; }
448
449 let opcode_full = tokens[0];
450 let opcode_parts: Vec<&str> = opcode_full.split('.').collect();
451 let opcode = opcode_parts[0].to_string();
452
453 let type_suffix = opcode_parts.iter().skip(1).find_map(|p| parse_type(p));
454 let modifiers: Vec<String> = opcode_parts.iter().skip(1)
455 .filter(|p| parse_type(p).is_none())
456 .map(|s| s.to_string())
457 .collect();
458
459 let operand_str = tokens[1..].join(" ");
460 let operands = parse_operands(&operand_str);
461
462 Some(PtxStatement::Instruction(PtxInstruction {
463 predicate,
464 opcode,
465 type_suffix,
466 modifiers,
467 operands,
468 }))
469}
470
471fn parse_operands(s: &str) -> Vec<PtxOperand> {
472 if s.is_empty() { return vec![]; }
473
474 s.split(',')
475 .map(|part| {
476 let t = part.trim();
477 if t.starts_with('%') {
478 let name = t.trim_start_matches('%');
479 if name.contains("tid.") || name.contains("ctaid.") || name.contains("ntid.")
480 || name.contains("nctaid.") || name == "laneid" || name == "warpid"
481 {
482 PtxOperand::SpecialReg(t.to_string())
483 } else {
484 PtxOperand::Register(t.to_string())
485 }
486 } else if t.starts_with('[') && t.ends_with(']') {
487 let inner = &t[1..t.len()-1];
488 if let Some(plus) = inner.find('+') {
489 let base = inner[..plus].trim().to_string();
490 let offset = inner[plus+1..].trim().parse().ok();
491 PtxOperand::Address { base, offset }
492 } else {
493 PtxOperand::Address { base: inner.trim().to_string(), offset: None }
494 }
495 } else if t.starts_with('{') {
496 let inner = t.trim_matches(|c| c == '{' || c == '}');
497 let regs: Vec<String> = inner.split(',').map(|r| r.trim().to_string()).collect();
498 PtxOperand::Vector(regs)
499 } else if let Ok(v) = t.parse::<i64>() {
500 PtxOperand::ImmInt(v)
501 } else if let Ok(v) = t.parse::<f64>() {
502 PtxOperand::ImmFloat(v)
503 } else {
504 PtxOperand::Label(t.to_string())
505 }
506 })
507 .collect()
508}
509
510fn ptx_type_to_ast(ty: &PtxType) -> crate::parser::ast::Type {
511 use crate::parser::ast::{Type, IntType, FloatType};
512 match ty {
513 PtxType::Pred => Type::Bool,
514 PtxType::B8 | PtxType::U8 => Type::Int(IntType::U8),
515 PtxType::B16 | PtxType::U16 => Type::Int(IntType::U16),
516 PtxType::B32 | PtxType::U32 => Type::Int(IntType::U32),
517 PtxType::B64 | PtxType::U64 => Type::Int(IntType::U64),
518 PtxType::S8 => Type::Int(IntType::I8),
519 PtxType::S16 => Type::Int(IntType::I16),
520 PtxType::S32 => Type::Int(IntType::I32),
521 PtxType::S64 => Type::Int(IntType::I64),
522 PtxType::F16 => Type::Float(FloatType::F16),
523 PtxType::F32 => Type::Float(FloatType::F32),
524 PtxType::F64 => Type::Float(FloatType::F64),
525 }
526}
527
528fn ptx_body_to_ast(stmts: &[PtxStatement]) -> Result<crate::parser::ast::Block> {
529 use crate::parser::ast::*;
530
531 let mut statements = Vec::new();
532
533 for stmt in stmts {
534 match stmt {
535 PtxStatement::Label(_) => {
536 }
538 PtxStatement::Instruction(inst) => {
539 let ast_stmt = ptx_instruction_to_ast(inst)?;
540 if let Some(s) = ast_stmt {
541 statements.push(s);
542 }
543 }
544 }
545 }
546
547 Ok(Block { statements })
548}
549
550fn ptx_instruction_to_ast(inst: &PtxInstruction) -> Result<Option<crate::parser::ast::Statement>> {
551 use crate::parser::ast::*;
552
553 match inst.opcode.as_str() {
554 "ret" => Ok(Some(Statement::Return(None))),
555 "bar" => Ok(Some(Statement::SyncThreads)),
556 "add" | "sub" | "mul" | "div" | "rem" | "and" | "or" | "xor" | "shl" | "shr" => {
557 if inst.operands.len() >= 3 {
558 let dst = operand_to_var(&inst.operands[0]);
559 let lhs = operand_to_expr(&inst.operands[1]);
560 let rhs = operand_to_expr(&inst.operands[2]);
561 let op = match inst.opcode.as_str() {
562 "add" => BinaryOp::Add, "sub" => BinaryOp::Sub,
563 "mul" => BinaryOp::Mul, "div" => BinaryOp::Div,
564 "rem" => BinaryOp::Mod, "and" => BinaryOp::And,
565 "or" => BinaryOp::Or, "xor" => BinaryOp::Xor,
566 "shl" => BinaryOp::Shl, "shr" => BinaryOp::Shr,
567 _ => BinaryOp::Add,
568 };
569 Ok(Some(Statement::Expr(Expression::Binary {
570 op: BinaryOp::Assign,
571 left: Box::new(Expression::Var(dst)),
572 right: Box::new(Expression::Binary {
573 op,
574 left: Box::new(lhs),
575 right: Box::new(rhs),
576 }),
577 })))
578 } else {
579 Ok(None)
580 }
581 }
582 "mov" => {
583 if inst.operands.len() >= 2 {
584 let dst = operand_to_var(&inst.operands[0]);
585 let src = operand_to_expr(&inst.operands[1]);
586 Ok(Some(Statement::Expr(Expression::Binary {
587 op: BinaryOp::Assign,
588 left: Box::new(Expression::Var(dst)),
589 right: Box::new(src),
590 })))
591 } else {
592 Ok(None)
593 }
594 }
595 "ld" => {
596 if inst.operands.len() >= 2 {
597 let dst = operand_to_var(&inst.operands[0]);
598 let src = operand_to_expr(&inst.operands[1]);
599 Ok(Some(Statement::Expr(Expression::Binary {
600 op: BinaryOp::Assign,
601 left: Box::new(Expression::Var(dst)),
602 right: Box::new(src),
603 })))
604 } else {
605 Ok(None)
606 }
607 }
608 "st" => {
609 if inst.operands.len() >= 2 {
610 let dst = operand_to_expr(&inst.operands[0]);
611 let src = operand_to_expr(&inst.operands[1]);
612 Ok(Some(Statement::Expr(Expression::Binary {
613 op: BinaryOp::Assign,
614 left: Box::new(dst),
615 right: Box::new(src),
616 })))
617 } else {
618 Ok(None)
619 }
620 }
621 _ => Ok(None), }
623}
624
625fn operand_to_var(op: &PtxOperand) -> String {
626 match op {
627 PtxOperand::Register(r) => clean_name(r),
628 PtxOperand::SpecialReg(r) => clean_name(r),
629 PtxOperand::Label(l) => l.clone(),
630 _ => "unknown".to_string(),
631 }
632}
633
634fn operand_to_expr(op: &PtxOperand) -> crate::parser::ast::Expression {
635 use crate::parser::ast::*;
636 match op {
637 PtxOperand::Register(r) => Expression::Var(clean_name(r)),
638 PtxOperand::SpecialReg(r) => {
639 let name = r.trim_start_matches('%');
640 match name {
641 "tid.x" => Expression::ThreadIdx(Dimension::X),
642 "tid.y" => Expression::ThreadIdx(Dimension::Y),
643 "tid.z" => Expression::ThreadIdx(Dimension::Z),
644 "ctaid.x" => Expression::BlockIdx(Dimension::X),
645 "ctaid.y" => Expression::BlockIdx(Dimension::Y),
646 "ctaid.z" => Expression::BlockIdx(Dimension::Z),
647 "ntid.x" => Expression::BlockDim(Dimension::X),
648 "ntid.y" => Expression::BlockDim(Dimension::Y),
649 "ntid.z" => Expression::BlockDim(Dimension::Z),
650 "nctaid.x" => Expression::GridDim(Dimension::X),
651 "nctaid.y" => Expression::GridDim(Dimension::Y),
652 "nctaid.z" => Expression::GridDim(Dimension::Z),
653 _ => Expression::Var(name.to_string()),
654 }
655 }
656 PtxOperand::ImmInt(v) => Expression::Literal(Literal::Int(*v)),
657 PtxOperand::ImmFloat(v) => Expression::Literal(Literal::Float(*v)),
658 PtxOperand::Address { base, offset } => {
659 let base_expr = Expression::Var(clean_name(base));
660 match offset {
661 Some(off) => Expression::Index {
662 array: Box::new(base_expr),
663 index: Box::new(Expression::Literal(Literal::Int(*off))),
664 },
665 None => base_expr,
666 }
667 }
668 PtxOperand::Label(l) => Expression::Var(l.clone()),
669 PtxOperand::Vector(regs) => {
670 Expression::Var(clean_name(regs.first().map(|s| s.as_str()).unwrap_or("v0")))
672 }
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679
680 #[test]
681 fn test_parse_version() {
682 let ptx = ".version 7.8\n.target sm_80\n.address_size 64\n";
683 let module = parse_ptx(ptx).unwrap();
684 assert_eq!(module.version, "7.8");
685 assert_eq!(module.target, "sm_80");
686 assert_eq!(module.address_size, 64);
687 }
688
689 #[test]
690 fn test_parse_entry_function() {
691 let ptx = r#"
692.version 7.8
693.target sm_80
694.address_size 64
695
696.visible .entry vectorAdd(
697 .param .u64 a,
698 .param .u64 b,
699 .param .u64 c
700)
701{
702 .reg .f32 %f<4>;
703 .reg .u32 %r<4>;
704 mov.u32 %r0, %tid.x;
705 add.f32 %f2, %f0, %f1;
706 ret;
707}
708"#;
709 let module = parse_ptx(ptx).unwrap();
710 assert_eq!(module.directives.len(), 1);
711 match &module.directives[0] {
712 PtxDirective::Entry(func) => {
713 assert_eq!(func.name, "vectorAdd");
714 assert_eq!(func.params.len(), 3);
715 assert!(func.is_entry);
716 assert!(!func.body.is_empty());
717 }
718 other => panic!("Expected entry directive, got {:?}", other),
719 }
720 }
721
722 #[test]
723 fn test_parse_type() {
724 assert_eq!(parse_type(".f32"), Some(PtxType::F32));
725 assert_eq!(parse_type(".s64"), Some(PtxType::S64));
726 assert_eq!(parse_type(".pred"), Some(PtxType::Pred));
727 assert_eq!(parse_type(".b16"), Some(PtxType::B16));
728 assert_eq!(parse_type(".invalid"), None);
729 }
730
731 #[test]
732 fn test_parse_instruction_basic() {
733 let stmt = parse_statement("add.f32 %f2, %f0, %f1;").unwrap();
734 match stmt {
735 PtxStatement::Instruction(inst) => {
736 assert_eq!(inst.opcode, "add");
737 assert_eq!(inst.type_suffix, Some(PtxType::F32));
738 assert_eq!(inst.operands.len(), 3);
739 }
740 other => panic!("Expected instruction, got {:?}", other),
741 }
742 }
743
744 #[test]
745 fn test_parse_predicated_instruction() {
746 let stmt = parse_statement("@p0 bra LOOP;").unwrap();
747 match stmt {
748 PtxStatement::Instruction(inst) => {
749 assert!(inst.predicate.is_some());
750 let pred = inst.predicate.unwrap();
751 assert_eq!(pred.register, "p0");
752 assert!(!pred.negated);
753 assert_eq!(inst.opcode, "bra");
754 }
755 other => panic!("Expected instruction, got {:?}", other),
756 }
757 }
758
759 #[test]
760 fn test_parse_negated_predicate() {
761 let stmt = parse_statement("@!p1 ret;").unwrap();
762 match stmt {
763 PtxStatement::Instruction(inst) => {
764 let pred = inst.predicate.unwrap();
765 assert_eq!(pred.register, "p1");
766 assert!(pred.negated);
767 }
768 other => panic!("Expected instruction, got {:?}", other),
769 }
770 }
771
772 #[test]
773 fn test_parse_label() {
774 let stmt = parse_statement("LOOP:").unwrap();
775 match stmt {
776 PtxStatement::Label(name) => assert_eq!(name, "LOOP"),
777 other => panic!("Expected label, got {:?}", other),
778 }
779 }
780
781 #[test]
782 fn test_parse_special_registers() {
783 let operands = parse_operands("%tid.x, %ctaid.y");
784 assert_eq!(operands.len(), 2);
785 match &operands[0] {
786 PtxOperand::SpecialReg(r) => assert_eq!(r, "%tid.x"),
787 other => panic!("Expected special register, got {:?}", other),
788 }
789 }
790
791 #[test]
792 fn test_parse_memory_address() {
793 let operands = parse_operands("[%r0+4]");
794 match &operands[0] {
795 PtxOperand::Address { base, offset } => {
796 assert_eq!(base, "%r0");
797 assert_eq!(*offset, Some(4));
798 }
799 other => panic!("Expected address, got {:?}", other),
800 }
801 }
802
803 #[test]
804 fn test_parse_immediate() {
805 let operands = parse_operands("42");
806 match &operands[0] {
807 PtxOperand::ImmInt(v) => assert_eq!(*v, 42),
808 other => panic!("Expected immediate int, got {:?}", other),
809 }
810 }
811
812 #[test]
813 fn test_parse_global_variable() {
814 let ptx = ".version 7.8\n.target sm_80\n.address_size 64\n.global .f32 result;\n";
815 let module = parse_ptx(ptx).unwrap();
816 assert_eq!(module.directives.len(), 1);
817 match &module.directives[0] {
818 PtxDirective::GlobalVar(var) => {
819 assert_eq!(var.var_type, PtxType::F32);
820 assert_eq!(var.space, PtxSpace::Global);
821 }
822 other => panic!("Expected global var, got {:?}", other),
823 }
824 }
825
826 #[test]
827 fn test_ptx_to_ast() {
828 let ptx = r#"
829.version 7.8
830.target sm_80
831.address_size 64
832
833.visible .entry simple(
834 .param .u64 data
835)
836{
837 .reg .u32 %r<2>;
838 mov.u32 %r0, %tid.x;
839 ret;
840}
841"#;
842 let module = parse_ptx(ptx).unwrap();
843 let ast = ptx_to_ast(&module).unwrap();
844 assert_eq!(ast.items.len(), 1);
845 match &ast.items[0] {
846 crate::parser::ast::Item::Kernel(k) => {
847 assert_eq!(k.name, "simple");
848 }
849 other => panic!("Expected kernel, got {:?}", other),
850 }
851 }
852
853 #[test]
854 fn test_ptx_type_conversion() {
855 use crate::parser::ast::{Type, IntType, FloatType};
856 assert!(matches!(ptx_type_to_ast(&PtxType::F32), Type::Float(FloatType::F32)));
857 assert!(matches!(ptx_type_to_ast(&PtxType::S32), Type::Int(IntType::I32)));
858 assert!(matches!(ptx_type_to_ast(&PtxType::Pred), Type::Bool));
859 }
860}