ptx_ir/
ir.rs

1//! Intermediate representation of PTX code.
2
3use codespan_reporting::{diagnostic::Diagnostic, files::SimpleFiles, term::termcolor::Buffer};
4use either::Either;
5use logos::Span;
6use std::path::Path;
7
8use crate::{lexer::Token, parser::Parser};
9
10/// PTX module, the top-level structure in a PTX file.
11#[derive(Debug)]
12pub struct Module {
13    pub version: Version,
14    pub target: Target,
15    pub address_size: AddressSize,
16    pub directives: Vec<Directive>,
17}
18
19/// A PTX statement is either a directive or an instruction.
20/// Statements begin with an optional label and end with a semicolon.
21#[derive(Debug)]
22pub enum Statement {
23    Instruction(Instruction),
24    Block(Vec<Statement>),
25    Directive(Directive),
26    Label(String),
27}
28
29/// PTX instructions generally have from zero to four operands,
30/// plus an optional guard predicate appearing after an @ symbol to the left of the opcode:
31#[derive(Debug)]
32pub struct Instruction {
33    pub predicate: Option<Predicate>,
34    pub opcode: Opcode,
35    pub operands: Vec<Operand>,
36    pub span: Span,
37}
38
39#[derive(Debug)]
40pub enum Directive {
41    Function(Function),
42    FunctionDecl(Function),
43    Variable(VariableDecl),
44    Loc,
45    Section,
46    Pragma(String),
47    File(usize, String),
48}
49
50#[derive(Debug)]
51pub struct Version {
52    pub major: u32,
53    pub minor: u32,
54}
55
56#[derive(Debug)]
57pub struct Target(pub String);
58
59#[derive(Debug)]
60pub enum AddressSize {
61    Bits32,
62    Bits64,
63}
64
65#[derive(Debug)]
66pub struct VariableDecl {
67    pub name: String,
68    pub state_space: Option<StateSpace>,
69    pub linkage: Option<LinkingDirective>,
70    pub ty: Type,
71    pub alignment: Option<u32>,
72    pub vector: Option<u32>,
73    pub array: Option<u32>,
74    pub init: Option<Operand>,
75    pub span: Span,
76}
77
78#[derive(Debug)]
79pub enum StateSpace {
80    /// Registers, fast
81    Reg,
82    /// Special registers. Read-only; pre-defined; platform-specific.
83    SReg,
84    /// Shared, read-only memory.
85    Const,
86    /// Global memory, shared by all threads.
87    Global,
88    /// Local memory, private to each thread.
89    Local,
90    /// Kernel parameters, defined per-grid; or
91    /// Function or local parameters, defined per-thread.
92    Param,
93    /// Addressable memory, defined per CTA, accessible to all threads in the cluster throughout the lifetime of the CTA that defines it.
94    Shared,
95    /// Global texture memory (deprecated).
96    Tex,
97}
98
99#[derive(Debug)]
100pub enum LinkingDirective {
101    /// External symbol declaration.
102    Extern,
103    /// Visible (externally) symbol declaration.
104    Visible,
105    /// Visible (externally) symbol declaration.
106    Weak,
107    /// Visible (externally) symbol declaration.
108    Common,
109}
110
111#[derive(Debug)]
112pub enum Type {
113    B8,
114    B16,
115    B32,
116    B64,
117    S8,
118    S16,
119    S32,
120    S64,
121    U8,
122    U16,
123    U32,
124    U64,
125    F16,
126    F16x2,
127    F32,
128    F64,
129    Tf32,
130    Bf16,
131    Pred,
132    B128,
133}
134
135#[derive(Debug)]
136pub struct Function {
137    pub link_directive: Option<LinkingDirective>,
138    pub entry: bool,
139    pub name: String,
140    pub parameters: Vec<Parameter>,
141    pub return_params: Vec<Parameter>,
142    pub body: Vec<Statement>,
143}
144
145#[derive(Debug)]
146pub struct Parameter {
147    pub name: String,
148    pub ty: Type,
149    pub array: Option<u32>,
150    pub ptr: bool,
151    pub state_space: Option<StateSpace>,
152    pub alignment: Option<u32>,
153}
154
155#[derive(Debug)]
156pub enum Opcode {
157    Add {
158        ty: Type,
159        saturate: bool,
160        rnd: bool,
161        ftz: bool,
162    },
163    Sub {
164        ty: Type,
165        ftz: bool,
166    },
167    /// Count leading zeros
168    Clz(Type),
169    Not(Type),
170    /// Vote across thread group, Deprecated
171    Vote(Type),
172    /// Population count.
173    PopC(Type),
174    Exit,
175    Call(CallInst),
176    // CallPrototype {
177    //     parameters: Vec<Parameter>,
178    //     return_params: Vec<Parameter>,
179    //     no_return: bool,
180    // },
181    CallPrototype {},
182    Abs {
183        ty: Type,
184        ftz: bool,
185    },
186    Ex2 {
187        ty: Type,
188        /// flush-to-zero
189        ftz: bool,
190    },
191    Rem(Type),
192    Fma {
193        rounding: Option<FloatRoundingMode>,
194        ty: Type,
195        ftz: bool,
196    },
197    Mad {
198        mode: Option<MulMode>,
199        rounding: Option<FloatRoundingMode>,
200        ty: Type,
201        ftz: bool,
202        sat: bool,
203    },
204    Rcp {
205        ty: Type,
206        ftz: bool,
207        approx: bool,
208        rounding: Option<FloatRoundingMode>,
209    },
210    Sqrt {
211        ty: Type,
212        ftz: bool,
213        approx: bool,
214        rounding: Option<FloatRoundingMode>,
215    },
216    RSqrt {
217        ty: Type,
218        ftz: bool,
219        approx: bool,
220        rounding: Option<FloatRoundingMode>,
221    },
222    Sin {
223        ty: Type,
224        approx: bool,
225        ftz: bool,
226    },
227    Cos {
228        ty: Type,
229        approx: bool,
230        ftz: bool,
231    },
232    Neg {
233        ty: Type,
234        ftz: bool,
235    },
236    Max(Type),
237    Min(Type),
238    Shfl(Type),
239    Bar {
240        thread: u32,
241        sync: bool,
242    },
243    Cvt {
244        from: Type,
245        to: Type,
246        rounding: Option<Either<FloatRoundingMode, IntegerRoundingMode>>,
247        saturate: bool,
248    },
249    Cvta {
250        to: bool,
251        state_space: StateSpace,
252        size: Type,
253    },
254    Cp,
255    Ret,
256    Mov(Type),
257    Lg2 {
258        ty: Type,
259        ftz: bool,
260    },
261    Mul {
262        ty: Type,
263        mode: Option<MulMode>,
264        rounding: Option<FloatRoundingMode>,
265        ftz: bool,
266        saturate: bool,
267    },
268    Mul24 {
269        ty: Type,
270        mode: Option<MulMode>,
271    },
272    Div {
273        approx: bool,
274        full: bool,
275        rounding: Option<FloatRoundingMode>,
276        ty: Type,
277        ftz: bool,
278    },
279    Shl(Type),
280    Shr(Type),
281    Shf {
282        direction: ShfDirection,
283        mode: ShfMode,
284    },
285    SetpLt,
286    SetpGt,
287    SetpEq,
288    Ld(Type),
289    Ldu(Type),
290    Tex(Type, Type),
291    /// Select one source operand, based on the sign of the third operand.
292    Slct {
293        ftz: bool,
294        dtype: Type,
295        stype: Type,
296    },
297    /// Atomic reduction operations for thread-to-thread communication.
298    Atom(Type),
299    LdMatrix {
300        shape: Shape2,
301        ty: Type,
302        xnum: u32,
303        shared: bool,
304    },
305    Mma {
306        shape: Shape3,
307        atype: Type,
308        btype: Type,
309        ctype: Type,
310        dtype: Type,
311    },
312    St(Type),
313    And(Type),
314    Or(Type),
315    XOr(Type),
316    Set {
317        cmp_op: PredicateOp,
318        ftz: bool,
319        dtype: Type,
320        stype: Type,
321    },
322    Setp(PredicateOp, Type),
323    Selp(Type),
324    Bfe(Type),
325    Bra,
326    Membar,
327    /// Query whether a generic address falls within a specified state space window
328    IsSpaceP(StateSpace),
329}
330
331#[derive(Debug)]
332pub struct CallInst {
333    pub is_uniform: bool,
334    pub return_operand: Option<Operand>,
335    pub function: String,
336    pub arguments: Vec<Operand>,
337    pub fproto: Option<String>,
338}
339
340/// Operands may be
341/// - register variables,
342/// - constant expressions,
343/// - address expressions,
344/// - label names.
345/// - place holders
346#[derive(Debug)]
347pub enum Operand {
348    Register(Register),
349    RegisterOffset(Register, i64),
350    Constant(Constant),
351    Address(AddressOperand),
352    Vector(VectorOperand),
353    Label(String),
354    PlaceHolder,
355}
356
357#[derive(Debug)]
358pub enum Register {
359    Special(SpecialReg, Span),
360    Identifier(String, Span),
361}
362
363#[derive(Debug)]
364pub enum SpecialReg {
365    StackPtr,
366    Clock,
367
368    ThreadId,
369    ThreadIdX,
370    ThreadIdY,
371    ThreadIdZ,
372    BlockDim,
373    BlockDimX,
374    BlockDimY,
375    BlockDimZ,
376    BlockIdx,
377    BlockIdxX,
378    BlockIdxY,
379    BlockIdxZ,
380    GridDim,
381    GridDimX,
382    GridDimY,
383    GridDimZ,
384}
385
386#[derive(Debug)]
387pub enum Constant {
388    Integer(i64),
389    Float(f64),
390}
391
392#[derive(Debug)]
393pub enum AddressOperand {
394    /// the name of an addressable variable var.
395    Address(String),
396    /// a sum of register reg containing a byte address plus a constant integer byte offset (signed, 32-bit).
397    AddressOffset(String, i64),
398    /// an immediate absolute byte address (unsigned, 32-bit).
399    Immediate(u32),
400    /// an array element
401    ArrayIndex(String, usize),
402    /// List
403    List(String, Vec<Operand>),
404}
405
406#[derive(Debug)]
407pub struct VectorOperand {
408    pub elements: Vec<Operand>,
409}
410
411#[derive(Debug)]
412pub struct Predicate {
413    pub register: Register,
414    pub negated: bool,
415}
416
417#[derive(Debug)]
418pub enum FloatRoundingMode {
419    /// Round to nearest even
420    Rn,
421    /// Round to nearest, ties away from zero
422    Rna,
423    /// Round towards zero
424    Rz,
425    /// Round towards -∞
426    Rm,
427    /// Round towards +∞
428    Rp,
429}
430
431#[derive(Debug)]
432pub enum IntegerRoundingMode {
433    /// Round to nearest integer, choosing even integer if source is equidistant between two integers.
434    Rni,
435    /// Round to nearest integer in the direction of zero
436    Rzi,
437    /// Round to nearest integer in direction of negative infinity
438    Rmi,
439    /// Round to nearest integer in direction of positive infinity
440    Rpi,
441}
442
443#[derive(Debug)]
444pub enum MulMode {
445    Hi,
446    Lo,
447    Wide,
448}
449
450#[derive(Debug)]
451pub enum ShfDirection {
452    Left,
453    Right,
454}
455
456#[derive(Debug)]
457pub enum ShfMode {
458    Wrap,
459    Clamp,
460}
461
462#[derive(Debug, Clone, Copy)]
463pub enum PredicateOp {
464    LessThan,
465    LessEqual,
466    GreaterThan,
467    GreaterEqual,
468    Equal,
469    NotEqual,
470    EqualUnsigned,
471    NotEqualUnsigned,
472    LessThanUnsigned,
473    LessEqualUnsigned,
474    GreaterThanUnsigned,
475    GreaterEqualUnsigned,
476}
477
478#[derive(Debug)]
479pub enum Shape3 {
480    M16N8K8,
481    M16N8K16,
482    M16N8K32,
483}
484
485#[derive(Debug)]
486pub enum Shape2 {
487    M8N8,
488    M8N16,
489    M16N8,
490    M16N16,
491}
492
493impl Module {
494    pub fn from_ptx(source: &str, file_id: usize) -> Result<Self, Diagnostic<usize>> {
495        let mut parser = Parser::new(file_id, source);
496        parser.parse()
497    }
498
499    pub fn from_ptx_path(path: &Path) -> Result<Self, String> {
500        let mut files = SimpleFiles::new();
501        let content = std::fs::read_to_string(path).expect("failed to read file");
502        let file_id = files.add(
503            path.to_str().expect("failed to convert path to str"),
504            &content,
505        );
506        Self::from_ptx(&content, file_id).map_err(|diagnostic| emit_string(diagnostic, files))
507    }
508}
509
510fn emit_string(diagnostic: Diagnostic<usize>, files: SimpleFiles<&str, &String>) -> String {
511    let mut buffer = Buffer::ansi();
512    let config = codespan_reporting::term::Config::default();
513    codespan_reporting::term::emit(&mut buffer, &config, &files, &diagnostic).unwrap();
514    String::from_utf8(buffer.into_inner()).unwrap()
515}
516
517impl From<Token> for MulMode {
518    fn from(token: Token) -> Self {
519        match token {
520            Token::Hi => MulMode::Hi,
521            Token::Lo => MulMode::Lo,
522            Token::Wide => MulMode::Wide,
523            _ => unreachable!(),
524        }
525    }
526}
527
528impl From<Token> for FloatRoundingMode {
529    fn from(token: Token) -> Self {
530        match token {
531            Token::Rn => FloatRoundingMode::Rn,
532            Token::Rna => FloatRoundingMode::Rna,
533            Token::Rz => FloatRoundingMode::Rz,
534            Token::Rm => FloatRoundingMode::Rm,
535            Token::Rp => FloatRoundingMode::Rp,
536            _ => unreachable!(),
537        }
538    }
539}
540
541impl From<Token> for IntegerRoundingMode {
542    fn from(token: Token) -> Self {
543        match token {
544            Token::Rni => IntegerRoundingMode::Rni,
545            Token::Rzi => IntegerRoundingMode::Rzi,
546            Token::Rmi => IntegerRoundingMode::Rmi,
547            Token::Rpi => IntegerRoundingMode::Rpi,
548            _ => unreachable!(),
549        }
550    }
551}