1use codespan_reporting::{diagnostic::Diagnostic, files::SimpleFiles, term::termcolor::Buffer};
4use either::Either;
5use logos::Span;
6use std::path::Path;
7
8use crate::{lexer::Token, parser::Parser};
9
10#[derive(Debug)]
12pub struct Module {
13 pub version: Version,
14 pub target: Target,
15 pub address_size: AddressSize,
16 pub directives: Vec<Directive>,
17}
18
19#[derive(Debug)]
22pub enum Statement {
23 Instruction(Instruction),
24 Block(Vec<Statement>),
25 Directive(Directive),
26 Label(String),
27}
28
29#[derive(Debug)]
32pub struct Instruction {
33 pub predicate: Option<Predicate>,
34 pub opcode: Opcode,
35 pub operands: Vec<Operand>,
36 pub span: Span,
37}
38
39#[derive(Debug)]
40pub enum Directive {
41 Function(Function),
42 FunctionDecl(Function),
43 Variable(VariableDecl),
44 Loc,
45 Section,
46 Pragma(String),
47 File(usize, String),
48}
49
50#[derive(Debug)]
51pub struct Version {
52 pub major: u32,
53 pub minor: u32,
54}
55
56#[derive(Debug)]
57pub struct Target(pub String);
58
59#[derive(Debug)]
60pub enum AddressSize {
61 Bits32,
62 Bits64,
63}
64
65#[derive(Debug)]
66pub struct VariableDecl {
67 pub name: String,
68 pub state_space: Option<StateSpace>,
69 pub linkage: Option<LinkingDirective>,
70 pub ty: Type,
71 pub alignment: Option<u32>,
72 pub vector: Option<u32>,
73 pub array: Option<u32>,
74 pub init: Option<Operand>,
75 pub span: Span,
76}
77
78#[derive(Debug)]
79pub enum StateSpace {
80 Reg,
82 SReg,
84 Const,
86 Global,
88 Local,
90 Param,
93 Shared,
95 Tex,
97}
98
99#[derive(Debug)]
100pub enum LinkingDirective {
101 Extern,
103 Visible,
105 Weak,
107 Common,
109}
110
111#[derive(Debug)]
112pub enum Type {
113 B8,
114 B16,
115 B32,
116 B64,
117 S8,
118 S16,
119 S32,
120 S64,
121 U8,
122 U16,
123 U32,
124 U64,
125 F16,
126 F16x2,
127 F32,
128 F64,
129 Tf32,
130 Bf16,
131 Pred,
132 B128,
133}
134
135#[derive(Debug)]
136pub struct Function {
137 pub link_directive: Option<LinkingDirective>,
138 pub entry: bool,
139 pub name: String,
140 pub parameters: Vec<Parameter>,
141 pub return_params: Vec<Parameter>,
142 pub body: Vec<Statement>,
143}
144
145#[derive(Debug)]
146pub struct Parameter {
147 pub name: String,
148 pub ty: Type,
149 pub array: Option<u32>,
150 pub ptr: bool,
151 pub state_space: Option<StateSpace>,
152 pub alignment: Option<u32>,
153}
154
155#[derive(Debug)]
156pub enum Opcode {
157 Add {
158 ty: Type,
159 saturate: bool,
160 rnd: bool,
161 ftz: bool,
162 },
163 Sub {
164 ty: Type,
165 ftz: bool,
166 },
167 Clz(Type),
169 Not(Type),
170 Vote(Type),
172 PopC(Type),
174 Exit,
175 Call(CallInst),
176 CallPrototype {},
182 Abs {
183 ty: Type,
184 ftz: bool,
185 },
186 Ex2 {
187 ty: Type,
188 ftz: bool,
190 },
191 Rem(Type),
192 Fma {
193 rounding: Option<FloatRoundingMode>,
194 ty: Type,
195 ftz: bool,
196 },
197 Mad {
198 mode: Option<MulMode>,
199 rounding: Option<FloatRoundingMode>,
200 ty: Type,
201 ftz: bool,
202 sat: bool,
203 },
204 Rcp {
205 ty: Type,
206 ftz: bool,
207 approx: bool,
208 rounding: Option<FloatRoundingMode>,
209 },
210 Sqrt {
211 ty: Type,
212 ftz: bool,
213 approx: bool,
214 rounding: Option<FloatRoundingMode>,
215 },
216 RSqrt {
217 ty: Type,
218 ftz: bool,
219 approx: bool,
220 rounding: Option<FloatRoundingMode>,
221 },
222 Sin {
223 ty: Type,
224 approx: bool,
225 ftz: bool,
226 },
227 Cos {
228 ty: Type,
229 approx: bool,
230 ftz: bool,
231 },
232 Neg {
233 ty: Type,
234 ftz: bool,
235 },
236 Max(Type),
237 Min(Type),
238 Shfl(Type),
239 Bar {
240 thread: u32,
241 sync: bool,
242 },
243 Cvt {
244 from: Type,
245 to: Type,
246 rounding: Option<Either<FloatRoundingMode, IntegerRoundingMode>>,
247 saturate: bool,
248 },
249 Cvta {
250 to: bool,
251 state_space: StateSpace,
252 size: Type,
253 },
254 Cp,
255 Ret,
256 Mov(Type),
257 Lg2 {
258 ty: Type,
259 ftz: bool,
260 },
261 Mul {
262 ty: Type,
263 mode: Option<MulMode>,
264 rounding: Option<FloatRoundingMode>,
265 ftz: bool,
266 saturate: bool,
267 },
268 Mul24 {
269 ty: Type,
270 mode: Option<MulMode>,
271 },
272 Div {
273 approx: bool,
274 full: bool,
275 rounding: Option<FloatRoundingMode>,
276 ty: Type,
277 ftz: bool,
278 },
279 Shl(Type),
280 Shr(Type),
281 Shf {
282 direction: ShfDirection,
283 mode: ShfMode,
284 },
285 SetpLt,
286 SetpGt,
287 SetpEq,
288 Ld(Type),
289 Ldu(Type),
290 Tex(Type, Type),
291 Slct {
293 ftz: bool,
294 dtype: Type,
295 stype: Type,
296 },
297 Atom(Type),
299 LdMatrix {
300 shape: Shape2,
301 ty: Type,
302 xnum: u32,
303 shared: bool,
304 },
305 Mma {
306 shape: Shape3,
307 atype: Type,
308 btype: Type,
309 ctype: Type,
310 dtype: Type,
311 },
312 St(Type),
313 And(Type),
314 Or(Type),
315 XOr(Type),
316 Set {
317 cmp_op: PredicateOp,
318 ftz: bool,
319 dtype: Type,
320 stype: Type,
321 },
322 Setp(PredicateOp, Type),
323 Selp(Type),
324 Bfe(Type),
325 Bra,
326 Membar,
327 IsSpaceP(StateSpace),
329}
330
331#[derive(Debug)]
332pub struct CallInst {
333 pub is_uniform: bool,
334 pub return_operand: Option<Operand>,
335 pub function: String,
336 pub arguments: Vec<Operand>,
337 pub fproto: Option<String>,
338}
339
340#[derive(Debug)]
347pub enum Operand {
348 Register(Register),
349 RegisterOffset(Register, i64),
350 Constant(Constant),
351 Address(AddressOperand),
352 Vector(VectorOperand),
353 Label(String),
354 PlaceHolder,
355}
356
357#[derive(Debug)]
358pub enum Register {
359 Special(SpecialReg, Span),
360 Identifier(String, Span),
361}
362
363#[derive(Debug)]
364pub enum SpecialReg {
365 StackPtr,
366 Clock,
367
368 ThreadId,
369 ThreadIdX,
370 ThreadIdY,
371 ThreadIdZ,
372 BlockDim,
373 BlockDimX,
374 BlockDimY,
375 BlockDimZ,
376 BlockIdx,
377 BlockIdxX,
378 BlockIdxY,
379 BlockIdxZ,
380 GridDim,
381 GridDimX,
382 GridDimY,
383 GridDimZ,
384}
385
386#[derive(Debug)]
387pub enum Constant {
388 Integer(i64),
389 Float(f64),
390}
391
392#[derive(Debug)]
393pub enum AddressOperand {
394 Address(String),
396 AddressOffset(String, i64),
398 Immediate(u32),
400 ArrayIndex(String, usize),
402 List(String, Vec<Operand>),
404}
405
406#[derive(Debug)]
407pub struct VectorOperand {
408 pub elements: Vec<Operand>,
409}
410
411#[derive(Debug)]
412pub struct Predicate {
413 pub register: Register,
414 pub negated: bool,
415}
416
417#[derive(Debug)]
418pub enum FloatRoundingMode {
419 Rn,
421 Rna,
423 Rz,
425 Rm,
427 Rp,
429}
430
431#[derive(Debug)]
432pub enum IntegerRoundingMode {
433 Rni,
435 Rzi,
437 Rmi,
439 Rpi,
441}
442
443#[derive(Debug)]
444pub enum MulMode {
445 Hi,
446 Lo,
447 Wide,
448}
449
450#[derive(Debug)]
451pub enum ShfDirection {
452 Left,
453 Right,
454}
455
456#[derive(Debug)]
457pub enum ShfMode {
458 Wrap,
459 Clamp,
460}
461
462#[derive(Debug, Clone, Copy)]
463pub enum PredicateOp {
464 LessThan,
465 LessEqual,
466 GreaterThan,
467 GreaterEqual,
468 Equal,
469 NotEqual,
470 EqualUnsigned,
471 NotEqualUnsigned,
472 LessThanUnsigned,
473 LessEqualUnsigned,
474 GreaterThanUnsigned,
475 GreaterEqualUnsigned,
476}
477
478#[derive(Debug)]
479pub enum Shape3 {
480 M16N8K8,
481 M16N8K16,
482 M16N8K32,
483}
484
485#[derive(Debug)]
486pub enum Shape2 {
487 M8N8,
488 M8N16,
489 M16N8,
490 M16N16,
491}
492
493impl Module {
494 pub fn from_ptx(source: &str, file_id: usize) -> Result<Self, Diagnostic<usize>> {
495 let mut parser = Parser::new(file_id, source);
496 parser.parse()
497 }
498
499 pub fn from_ptx_path(path: &Path) -> Result<Self, String> {
500 let mut files = SimpleFiles::new();
501 let content = std::fs::read_to_string(path).expect("failed to read file");
502 let file_id = files.add(
503 path.to_str().expect("failed to convert path to str"),
504 &content,
505 );
506 Self::from_ptx(&content, file_id).map_err(|diagnostic| emit_string(diagnostic, files))
507 }
508}
509
510fn emit_string(diagnostic: Diagnostic<usize>, files: SimpleFiles<&str, &String>) -> String {
511 let mut buffer = Buffer::ansi();
512 let config = codespan_reporting::term::Config::default();
513 codespan_reporting::term::emit(&mut buffer, &config, &files, &diagnostic).unwrap();
514 String::from_utf8(buffer.into_inner()).unwrap()
515}
516
517impl From<Token> for MulMode {
518 fn from(token: Token) -> Self {
519 match token {
520 Token::Hi => MulMode::Hi,
521 Token::Lo => MulMode::Lo,
522 Token::Wide => MulMode::Wide,
523 _ => unreachable!(),
524 }
525 }
526}
527
528impl From<Token> for FloatRoundingMode {
529 fn from(token: Token) -> Self {
530 match token {
531 Token::Rn => FloatRoundingMode::Rn,
532 Token::Rna => FloatRoundingMode::Rna,
533 Token::Rz => FloatRoundingMode::Rz,
534 Token::Rm => FloatRoundingMode::Rm,
535 Token::Rp => FloatRoundingMode::Rp,
536 _ => unreachable!(),
537 }
538 }
539}
540
541impl From<Token> for IntegerRoundingMode {
542 fn from(token: Token) -> Self {
543 match token {
544 Token::Rni => IntegerRoundingMode::Rni,
545 Token::Rzi => IntegerRoundingMode::Rzi,
546 Token::Rmi => IntegerRoundingMode::Rmi,
547 Token::Rpi => IntegerRoundingMode::Rpi,
548 _ => unreachable!(),
549 }
550 }
551}