Skip to main content

trueno_ptx_debug/parser/
types.rs

1//! PTX Type System definitions
2
3/// PTX data types
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub enum PtxType {
6    /// 8-bit signed integer
7    S8,
8    /// 16-bit signed integer
9    S16,
10    /// 32-bit signed integer
11    S32,
12    /// 64-bit signed integer
13    S64,
14    /// 8-bit unsigned integer
15    U8,
16    /// 16-bit unsigned integer
17    U16,
18    /// 32-bit unsigned integer
19    U32,
20    /// 64-bit unsigned integer
21    U64,
22    /// 16-bit floating point
23    F16,
24    /// 32-bit floating point
25    F32,
26    /// 64-bit floating point
27    F64,
28    /// 8-bit untyped
29    B8,
30    /// 16-bit untyped
31    B16,
32    /// 32-bit untyped
33    B32,
34    /// 64-bit untyped
35    B64,
36    /// Predicate (boolean)
37    Pred,
38}
39
40impl PtxType {
41    /// Size in bytes
42    pub fn size_bytes(&self) -> usize {
43        match self {
44            PtxType::S8 | PtxType::U8 | PtxType::B8 => 1,
45            PtxType::S16 | PtxType::U16 | PtxType::B16 | PtxType::F16 => 2,
46            PtxType::S32 | PtxType::U32 | PtxType::B32 | PtxType::F32 => 4,
47            PtxType::S64 | PtxType::U64 | PtxType::B64 | PtxType::F64 => 8,
48            PtxType::Pred => 1,
49        }
50    }
51
52    /// Is this a signed type
53    pub fn is_signed(&self) -> bool {
54        matches!(
55            self,
56            PtxType::S8 | PtxType::S16 | PtxType::S32 | PtxType::S64
57        )
58    }
59
60    /// Is this a floating point type
61    pub fn is_float(&self) -> bool {
62        matches!(self, PtxType::F16 | PtxType::F32 | PtxType::F64)
63    }
64
65    /// Is this a 64-bit type
66    pub fn is_64bit(&self) -> bool {
67        matches!(
68            self,
69            PtxType::S64 | PtxType::U64 | PtxType::B64 | PtxType::F64
70        )
71    }
72}
73
74impl std::fmt::Display for PtxType {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        let s = match self {
77            PtxType::S8 => ".s8",
78            PtxType::S16 => ".s16",
79            PtxType::S32 => ".s32",
80            PtxType::S64 => ".s64",
81            PtxType::U8 => ".u8",
82            PtxType::U16 => ".u16",
83            PtxType::U32 => ".u32",
84            PtxType::U64 => ".u64",
85            PtxType::F16 => ".f16",
86            PtxType::F32 => ".f32",
87            PtxType::F64 => ".f64",
88            PtxType::B8 => ".b8",
89            PtxType::B16 => ".b16",
90            PtxType::B32 => ".b32",
91            PtxType::B64 => ".b64",
92            PtxType::Pred => ".pred",
93        };
94        write!(f, "{}", s)
95    }
96}
97
98/// Address space qualifiers
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
100pub enum AddressSpace {
101    /// Generic (unqualified) address
102    Generic,
103    /// Global memory
104    Global,
105    /// Shared memory (per-block)
106    Shared,
107    /// Local memory (per-thread)
108    Local,
109    /// Constant memory
110    Const,
111    /// Parameter space
112    Param,
113    /// Texture memory
114    Texture,
115    /// Surface memory
116    Surface,
117}
118
119impl std::fmt::Display for AddressSpace {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        let s = match self {
122            AddressSpace::Generic => "",
123            AddressSpace::Global => ".global",
124            AddressSpace::Shared => ".shared",
125            AddressSpace::Local => ".local",
126            AddressSpace::Const => ".const",
127            AddressSpace::Param => ".param",
128            AddressSpace::Texture => ".tex",
129            AddressSpace::Surface => ".surf",
130        };
131        write!(f, "{}", s)
132    }
133}
134
135/// SM (Streaming Multiprocessor) target architecture
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
137pub enum SmTarget {
138    /// Unknown/unspecified
139    #[default]
140    Unknown,
141    /// SM 5.0 (Maxwell)
142    Sm50,
143    /// SM 5.2 (Maxwell)
144    Sm52,
145    /// SM 6.0 (Pascal)
146    Sm60,
147    /// SM 6.1 (Pascal)
148    Sm61,
149    /// SM 7.0 (Volta)
150    Sm70,
151    /// SM 7.5 (Turing)
152    Sm75,
153    /// SM 8.0 (Ampere)
154    Sm80,
155    /// SM 8.6 (Ampere)
156    Sm86,
157    /// SM 8.9 (Ada Lovelace)
158    Sm89,
159    /// SM 9.0 (Hopper)
160    Sm90,
161}
162
163impl SmTarget {
164    /// Minimum PTX version for this target
165    pub fn min_ptx_version(&self) -> (u8, u8) {
166        match self {
167            SmTarget::Unknown => (1, 0),
168            SmTarget::Sm50 | SmTarget::Sm52 => (4, 0),
169            SmTarget::Sm60 | SmTarget::Sm61 => (5, 0),
170            SmTarget::Sm70 => (6, 0),
171            SmTarget::Sm75 => (6, 3),
172            SmTarget::Sm80 | SmTarget::Sm86 => (7, 0),
173            SmTarget::Sm89 => (7, 8),
174            SmTarget::Sm90 => (8, 0),
175        }
176    }
177
178    /// Does this target support Tensor Cores
179    pub fn has_tensor_cores(&self) -> bool {
180        matches!(
181            self,
182            SmTarget::Sm70
183                | SmTarget::Sm75
184                | SmTarget::Sm80
185                | SmTarget::Sm86
186                | SmTarget::Sm89
187                | SmTarget::Sm90
188        )
189    }
190}
191
192/// PTX Opcodes
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
194pub enum Opcode {
195    // Data Movement
196    /// Load
197    Ld,
198    /// Store
199    St,
200    /// Move
201    Mov,
202    /// Convert address space
203    Cvta,
204    /// Convert type
205    Cvt,
206
207    // Arithmetic
208    /// Add
209    Add,
210    /// Subtract
211    Sub,
212    /// Multiply
213    Mul,
214    /// Divide
215    Div,
216    /// Remainder
217    Rem,
218    /// Multiply-add
219    Mad,
220    /// Fused multiply-add
221    Fma,
222    /// Negate
223    Neg,
224    /// Absolute value
225    Abs,
226    /// Minimum
227    Min,
228    /// Maximum
229    Max,
230
231    // Logic
232    /// Bitwise AND
233    And,
234    /// Bitwise OR
235    Or,
236    /// Bitwise XOR
237    Xor,
238    /// Bitwise NOT
239    Not,
240    /// Shift left
241    Shl,
242    /// Shift right
243    Shr,
244
245    // Comparison
246    /// Set predicate
247    Setp,
248    /// Select
249    Selp,
250
251    // Control Flow
252    /// Branch
253    Bra,
254    /// Call function
255    Call,
256    /// Return
257    Ret,
258    /// Exit kernel
259    Exit,
260
261    // Synchronization
262    /// Barrier
263    Bar,
264    /// Memory barrier
265    MemBar,
266    /// Atomic operation
267    Atom,
268    /// Reduction operation
269    Red,
270
271    // Special
272    /// Texture load
273    Tex,
274    /// Texture load 4
275    Tld4,
276    /// Surface load
277    Suld,
278    /// Surface store
279    Sust,
280    /// Warp shuffle
281    Shfl,
282    /// Warp vote
283    Vote,
284    /// Matrix multiply-accumulate
285    Mma,
286    /// Warp MMA
287    Wmma,
288    /// Load matrix
289    LdMatrix,
290    /// Copy (async)
291    Cp,
292    /// Prefetch
293    Prefetch,
294
295    /// Unknown opcode
296    Unknown,
297}
298
299impl Opcode {
300    /// Is this a load instruction
301    pub fn is_load(&self) -> bool {
302        matches!(
303            self,
304            Opcode::Ld | Opcode::Tex | Opcode::Tld4 | Opcode::Suld | Opcode::LdMatrix
305        )
306    }
307
308    /// Is this a store instruction
309    pub fn is_store(&self) -> bool {
310        matches!(self, Opcode::St | Opcode::Sust)
311    }
312
313    /// Is this a memory operation
314    pub fn is_memory_op(&self) -> bool {
315        self.is_load() || self.is_store() || matches!(self, Opcode::Atom | Opcode::Red)
316    }
317
318    /// Is this a synchronization instruction
319    pub fn is_sync(&self) -> bool {
320        matches!(self, Opcode::Bar | Opcode::MemBar)
321    }
322
323    /// Is this a branch instruction
324    pub fn is_branch(&self) -> bool {
325        matches!(
326            self,
327            Opcode::Bra | Opcode::Call | Opcode::Ret | Opcode::Exit
328        )
329    }
330}
331
332/// Instruction modifiers
333#[derive(Debug, Clone, PartialEq, Eq, Hash)]
334pub enum Modifier {
335    // Address space
336    /// .shared
337    Shared,
338    /// .global
339    Global,
340    /// .local
341    Local,
342    /// .const
343    Const,
344    /// .param
345    Param,
346
347    // Types
348    /// .u32
349    U32,
350    /// .u64
351    U64,
352    /// .s32
353    S32,
354    /// .s64
355    S64,
356    /// .f32
357    F32,
358    /// .f64
359    F64,
360    /// .b32
361    B32,
362    /// .b64
363    B64,
364
365    // Synchronization
366    /// .sync
367    Sync,
368    /// .cta
369    Cta,
370    /// .gl
371    Gl,
372    /// .sys
373    Sys,
374
375    // Atomic
376    /// .add (atomic add)
377    AtomicAdd,
378    /// .cas (compare and swap)
379    AtomicCas,
380    /// .exch (exchange)
381    AtomicExch,
382    /// .min
383    AtomicMin,
384    /// .max
385    AtomicMax,
386
387    // Other
388    /// Other modifier
389    Other(String),
390}
391
392impl Modifier {
393    /// Get the address space if this is an address space modifier
394    pub fn as_address_space(&self) -> Option<AddressSpace> {
395        match self {
396            Modifier::Shared => Some(AddressSpace::Shared),
397            Modifier::Global => Some(AddressSpace::Global),
398            Modifier::Local => Some(AddressSpace::Local),
399            Modifier::Const => Some(AddressSpace::Const),
400            Modifier::Param => Some(AddressSpace::Param),
401            _ => None,
402        }
403    }
404
405    /// Get the type if this is a type modifier
406    pub fn as_type(&self) -> Option<PtxType> {
407        match self {
408            Modifier::U32 => Some(PtxType::U32),
409            Modifier::U64 => Some(PtxType::U64),
410            Modifier::S32 => Some(PtxType::S32),
411            Modifier::S64 => Some(PtxType::S64),
412            Modifier::F32 => Some(PtxType::F32),
413            Modifier::F64 => Some(PtxType::F64),
414            Modifier::B32 => Some(PtxType::B32),
415            Modifier::B64 => Some(PtxType::B64),
416            _ => None,
417        }
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_ptx_type_size() {
427        assert_eq!(PtxType::U8.size_bytes(), 1);
428        assert_eq!(PtxType::U16.size_bytes(), 2);
429        assert_eq!(PtxType::U32.size_bytes(), 4);
430        assert_eq!(PtxType::U64.size_bytes(), 8);
431        assert_eq!(PtxType::F32.size_bytes(), 4);
432        assert_eq!(PtxType::F64.size_bytes(), 8);
433    }
434
435    #[test]
436    fn test_ptx_type_properties() {
437        assert!(PtxType::S32.is_signed());
438        assert!(!PtxType::U32.is_signed());
439        assert!(PtxType::F32.is_float());
440        assert!(!PtxType::U32.is_float());
441        assert!(PtxType::U64.is_64bit());
442        assert!(!PtxType::U32.is_64bit());
443    }
444
445    #[test]
446    fn test_sm_target_ptx_version() {
447        assert!(SmTarget::Sm90.min_ptx_version() >= (8, 0));
448        assert!(SmTarget::Sm70.min_ptx_version() >= (6, 0));
449    }
450
451    #[test]
452    fn test_opcode_categories() {
453        assert!(Opcode::Ld.is_load());
454        assert!(Opcode::St.is_store());
455        assert!(Opcode::Bar.is_sync());
456        assert!(Opcode::Bra.is_branch());
457    }
458
459    #[test]
460    fn test_modifier_conversion() {
461        assert_eq!(
462            Modifier::Shared.as_address_space(),
463            Some(AddressSpace::Shared)
464        );
465        assert_eq!(Modifier::U32.as_type(), Some(PtxType::U32));
466    }
467}