1mod ast;
6mod error;
7mod lexer;
8pub mod types;
9
10pub use ast::{
11 Directive, FunctionDef, GlobalDecl, Instruction, KernelDef, Operand, Param, Predicate,
12 PtxModule, RegisterDecl, SharedMemDecl, SourceLocation, Statement,
13};
14pub use error::ParseError;
15pub use lexer::{Lexer, Token, TokenKind};
16pub use types::{AddressSpace, Modifier, Opcode, PtxType, SmTarget};
17
18macro_rules! match_contains {
28 ($text:expr, $default:expr, $( [ $( $pattern:expr ),+ ] => $value:expr ),+ $(,)?) => {{
29 let __text = $text;
30 $(
31 if $( __text.contains($pattern) )||+ {
32 $value
33 } else
34 )+
35 { $default }
36 }};
37}
38
39macro_rules! match_str_lookup {
44 ($text:expr, $default:expr, $( $lit:expr => $value:expr ),+ $(,)?) => {
45 match $text {
46 $( $lit => $value, )+
47 _ => $default,
48 }
49 };
50}
51
52pub struct Parser<'a> {
54 lexer: Lexer<'a>,
55 current: Token,
56 peek: Token,
57}
58
59impl<'a> Parser<'a> {
60 pub fn new(source: &'a str) -> Result<Self, ParseError> {
62 let mut lexer = Lexer::new(source);
63 let current = lexer.next_token()?;
64 let peek = lexer.next_token()?;
65 Ok(Self {
66 lexer,
67 current,
68 peek,
69 })
70 }
71
72 pub fn parse(&mut self) -> Result<PtxModule, ParseError> {
74 let mut module = PtxModule::default();
75
76 while self.current.kind != TokenKind::Eof {
77 match self.current.kind {
78 TokenKind::Directive => {
79 self.parse_directive(&mut module)?;
80 }
81 TokenKind::Entry | TokenKind::Func => {
82 let kernel = self.parse_kernel()?;
83 module.kernels.push(kernel);
84 }
85 _ => {
86 self.advance()?;
87 }
88 }
89 }
90
91 Ok(module)
92 }
93
94 fn advance(&mut self) -> Result<(), ParseError> {
95 self.current = std::mem::replace(&mut self.peek, self.lexer.next_token()?);
96 Ok(())
97 }
98
99 fn parse_directive(&mut self, module: &mut PtxModule) -> Result<(), ParseError> {
100 let directive_text = self.current.text.clone();
101
102 if directive_text.starts_with(".version") {
103 module.version = self.parse_version_directive()?;
104 } else if directive_text.starts_with(".target") {
105 module.target = self.parse_target_directive()?;
106 } else if directive_text.starts_with(".address_size") {
107 module.address_size = self.parse_address_size_directive()?;
108 }
109
110 self.advance()?;
111 Ok(())
112 }
113
114 fn parse_version_directive(&self) -> Result<(u8, u8), ParseError> {
115 let text = &self.current.text;
117 if let Some(rest) = text.strip_prefix(".version") {
118 let version_str = rest.trim();
119 let parts: Vec<&str> = version_str.split('.').collect();
120 if parts.len() >= 2 {
121 let major = parts[0].parse().unwrap_or(0);
122 let minor = parts[1].parse().unwrap_or(0);
123 return Ok((major, minor));
124 }
125 }
126 Ok((0, 0))
127 }
128
129 fn parse_target_directive(&self) -> Result<SmTarget, ParseError> {
130 let text = &self.current.text;
131 Ok(match_contains!(text, SmTarget::Unknown,
132 ["sm_70"] => SmTarget::Sm70,
133 ["sm_75"] => SmTarget::Sm75,
134 ["sm_80"] => SmTarget::Sm80,
135 ["sm_86"] => SmTarget::Sm86,
136 ["sm_89"] => SmTarget::Sm89,
137 ["sm_90"] => SmTarget::Sm90,
138 ))
139 }
140
141 fn parse_address_size_directive(&self) -> Result<u8, ParseError> {
142 let text = &self.current.text;
143 Ok(match_contains!(text, 0,
144 ["64"] => 64,
145 ["32"] => 32,
146 ))
147 }
148
149 fn parse_kernel(&mut self) -> Result<KernelDef, ParseError> {
150 let is_entry = self.current.kind == TokenKind::Entry;
151 self.advance()?;
152
153 let name = if self.current.kind == TokenKind::Identifier {
155 let n = self.current.text.clone();
156 self.advance()?;
157 n
158 } else {
159 return Err(ParseError::UnexpectedToken {
160 expected: "kernel name".into(),
161 found: format!("{:?}", self.current.kind),
162 location: self.current.location.clone(),
163 });
164 };
165
166 let mut kernel = KernelDef {
167 name,
168 is_entry,
169 params: Vec::new(),
170 registers: Vec::new(),
171 shared_mem: Vec::new(),
172 body: Vec::new(),
173 };
174
175 while self.current.kind != TokenKind::Eof {
177 if self.current.kind == TokenKind::Entry || self.current.kind == TokenKind::Func {
178 break;
179 }
180
181 match self.current.kind {
182 TokenKind::Reg => {
183 let reg = self.parse_register_decl()?;
184 kernel.registers.push(reg);
185 }
186 TokenKind::Shared => {
187 let shared = self.parse_shared_mem_decl()?;
188 kernel.shared_mem.push(shared);
189 }
190 TokenKind::Instruction => {
191 let instr = self.parse_instruction()?;
192 kernel.body.push(Statement::Instruction(instr));
193 }
194 TokenKind::Label => {
195 let label = self.current.text.trim_end_matches(':').to_string();
196 kernel.body.push(Statement::Label(label));
197 self.advance()?;
198 }
199 _ => {
200 self.advance()?;
201 }
202 }
203 }
204
205 Ok(kernel)
206 }
207
208 fn parse_register_decl(&mut self) -> Result<RegisterDecl, ParseError> {
209 let text = self.current.text.clone();
211 self.advance()?;
212
213 let (ty, name) = self.extract_reg_type_and_name(&text);
215
216 Ok(RegisterDecl { name, ty })
217 }
218
219 fn extract_reg_type_and_name(&self, text: &str) -> (PtxType, String) {
220 let ty = match_contains!(text, PtxType::B32,
221 [".b64", ".u64"] => PtxType::B64,
222 [".b32", ".u32"] => PtxType::U32,
223 [".f32"] => PtxType::F32,
224 [".f64"] => PtxType::F64,
225 [".pred"] => PtxType::Pred,
226 );
227
228 let name = text
230 .split_whitespace()
231 .find(|s| s.starts_with('%'))
232 .map(|s| s.trim_end_matches(|c: char| !c.is_alphanumeric() && c != '%' && c != '_'))
233 .unwrap_or("%unknown")
234 .to_string();
235
236 (ty, name)
237 }
238
239 fn parse_shared_mem_decl(&mut self) -> Result<SharedMemDecl, ParseError> {
240 let text = self.current.text.clone();
241 self.advance()?;
242
243 let name = text
245 .split_whitespace()
246 .find(|s| !s.starts_with('.'))
247 .unwrap_or("unknown")
248 .to_string();
249
250 let size = text
251 .split('[')
252 .nth(1)
253 .and_then(|s| s.split(']').next())
254 .and_then(|s| s.parse().ok())
255 .unwrap_or(0);
256
257 Ok(SharedMemDecl {
258 name,
259 size,
260 ty: PtxType::B8,
261 })
262 }
263
264 fn parse_instruction(&mut self) -> Result<Instruction, ParseError> {
265 let text = self.current.text.clone();
266 let location = self.current.location.clone();
267 self.advance()?;
268
269 let (opcode, modifiers) = self.parse_opcode(&text);
270 let operands = self.parse_operands(&text);
271
272 Ok(Instruction {
273 opcode,
274 modifiers,
275 operands,
276 predicate: None,
277 location,
278 })
279 }
280
281 fn parse_opcode(&self, text: &str) -> (Opcode, Vec<Modifier>) {
282 let parts: Vec<&str> = text
283 .split_whitespace()
284 .next()
285 .unwrap_or("")
286 .split('.')
287 .collect();
288
289 let opcode = match parts.first().copied() {
290 Some(s) => match_str_lookup!(s, Opcode::Unknown,
291 "ld" => Opcode::Ld,
292 "st" => Opcode::St,
293 "mov" => Opcode::Mov,
294 "add" => Opcode::Add,
295 "sub" => Opcode::Sub,
296 "mul" => Opcode::Mul,
297 "mad" => Opcode::Mad,
298 "fma" => Opcode::Fma,
299 "cvta" => Opcode::Cvta,
300 "cvt" => Opcode::Cvt,
301 "setp" => Opcode::Setp,
302 "bra" => Opcode::Bra,
303 "bar" => Opcode::Bar,
304 "atom" => Opcode::Atom,
305 "ret" => Opcode::Ret,
306 "exit" => Opcode::Exit,
307 "and" => Opcode::And,
308 "or" => Opcode::Or,
309 "xor" => Opcode::Xor,
310 "shl" => Opcode::Shl,
311 "shr" => Opcode::Shr,
312 "membar" => Opcode::MemBar,
313 ),
314 None => Opcode::Unknown,
315 };
316
317 let modifiers = parts
318 .iter()
319 .skip(1)
320 .map(|&m| self.parse_modifier(m))
321 .collect();
322
323 (opcode, modifiers)
324 }
325
326 fn parse_modifier(&self, s: &str) -> Modifier {
327 match_str_lookup!(s, Modifier::Other(s.to_string()),
328 "shared" => Modifier::Shared,
329 "global" => Modifier::Global,
330 "local" => Modifier::Local,
331 "const" => Modifier::Const,
332 "param" => Modifier::Param,
333 "u32" => Modifier::U32,
334 "u64" => Modifier::U64,
335 "s32" => Modifier::S32,
336 "s64" => Modifier::S64,
337 "f32" => Modifier::F32,
338 "f64" => Modifier::F64,
339 "b32" => Modifier::B32,
340 "b64" => Modifier::B64,
341 "sync" => Modifier::Sync,
342 "cta" => Modifier::Cta,
343 "gl" => Modifier::Gl,
344 "add" => Modifier::AtomicAdd,
345 "cas" => Modifier::AtomicCas,
346 )
347 }
348
349 fn parse_operands(&self, text: &str) -> Vec<Operand> {
350 let mut operands = Vec::new();
352
353 let operand_part = text
355 .split_whitespace()
356 .skip(1)
357 .collect::<Vec<_>>()
358 .join(" ");
359
360 for part in operand_part.split(',') {
361 let trimmed = part.trim();
362 if trimmed.is_empty() {
363 continue;
364 }
365
366 let operand = if trimmed.starts_with('%') {
367 Operand::Register(trimmed.to_string())
368 } else if trimmed.starts_with('[') {
369 Operand::Memory(trimmed.to_string())
370 } else if trimmed.parse::<i64>().is_ok() {
371 Operand::Immediate(trimmed.parse().unwrap_or(0))
372 } else {
373 Operand::Label(trimmed.to_string())
374 };
375 operands.push(operand);
376 }
377
378 operands
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
388 fn f001_version_directive_present() {
389 let ptx = r#"
390 .version 8.0
391 .target sm_70
392 .address_size 64
393 "#;
394 let mut parser = Parser::new(ptx).expect("parser creation should succeed");
395 let module = parser.parse().expect("parsing should succeed");
396 assert!(module.version.0 > 0, "F001: Missing .version directive");
397 }
398
399 #[test]
401 fn f001_version_directive_missing() {
402 let ptx = r#"
403 .target sm_70
404 .address_size 64
405 "#;
406 let mut parser = Parser::new(ptx).expect("parser creation should succeed");
407 let module = parser.parse().expect("parsing should succeed");
408 assert_eq!(module.version, (0, 0), "Should detect missing version");
409 }
410
411 #[test]
413 fn f002_target_directive_present() {
414 let ptx = r#"
415 .version 8.0
416 .target sm_70
417 .address_size 64
418 "#;
419 let mut parser = Parser::new(ptx).expect("parser creation should succeed");
420 let module = parser.parse().expect("parsing should succeed");
421 assert_ne!(
422 module.target,
423 SmTarget::Unknown,
424 "F002: Missing .target directive"
425 );
426 }
427
428 #[test]
430 fn f003_address_size_valid() {
431 let ptx = r#"
432 .version 8.0
433 .target sm_70
434 .address_size 64
435 "#;
436 let mut parser = Parser::new(ptx).expect("parser creation should succeed");
437 let module = parser.parse().expect("parsing should succeed");
438 assert!(
439 module.address_size == 32 || module.address_size == 64,
440 "F003: address_size must be 32 or 64"
441 );
442 }
443
444 #[test]
446 fn parse_simple_kernel() {
447 let ptx = r#"
448 .version 8.0
449 .target sm_70
450 .address_size 64
451
452 .entry test_kernel(
453 .param .u64 param0
454 )
455 {
456 .reg .u32 %r<10>;
457 .reg .u64 %rd<10>;
458
459 mov.u32 %r0, 0;
460 ret;
461 }
462 "#;
463 let mut parser = Parser::new(ptx).expect("parser creation should succeed");
464 let module = parser.parse().expect("parsing should succeed");
465
466 assert_eq!(module.version, (8, 0));
467 assert_eq!(module.target, SmTarget::Sm70);
468 assert_eq!(module.address_size, 64);
469 assert!(
470 !module.kernels.is_empty(),
471 "Should have at least one kernel"
472 );
473 }
474
475 #[test]
477 fn parse_ld_st_instructions() {
478 let ptx = r#"
479 .version 8.0
480 .target sm_70
481 .address_size 64
482
483 .entry test(
484 )
485 {
486 .reg .u32 %r<10>;
487
488 ld.shared.u32 %r0, [%r1];
489 st.global.u32 [%r2], %r0;
490 ret;
491 }
492 "#;
493 let mut parser = Parser::new(ptx).expect("parser creation should succeed");
494 let module = parser.parse().expect("parsing should succeed");
495
496 let kernel = &module.kernels[0];
497 let instructions: Vec<_> = kernel
498 .body
499 .iter()
500 .filter_map(|s| match s {
501 Statement::Instruction(i) => Some(i),
502 _ => None,
503 })
504 .collect();
505
506 assert!(instructions.iter().any(|i| matches!(i.opcode, Opcode::Ld)));
507 assert!(instructions.iter().any(|i| matches!(i.opcode, Opcode::St)));
508 }
509}