1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub enum PtxType {
6 S8,
8 S16,
10 S32,
12 S64,
14 U8,
16 U16,
18 U32,
20 U64,
22 F16,
24 F32,
26 F64,
28 B8,
30 B16,
32 B32,
34 B64,
36 Pred,
38}
39
40impl PtxType {
41 pub fn size_bytes(&self) -> usize {
43 match self {
44 PtxType::S8 | PtxType::U8 | PtxType::B8 => 1,
45 PtxType::S16 | PtxType::U16 | PtxType::B16 | PtxType::F16 => 2,
46 PtxType::S32 | PtxType::U32 | PtxType::B32 | PtxType::F32 => 4,
47 PtxType::S64 | PtxType::U64 | PtxType::B64 | PtxType::F64 => 8,
48 PtxType::Pred => 1,
49 }
50 }
51
52 pub fn is_signed(&self) -> bool {
54 matches!(
55 self,
56 PtxType::S8 | PtxType::S16 | PtxType::S32 | PtxType::S64
57 )
58 }
59
60 pub fn is_float(&self) -> bool {
62 matches!(self, PtxType::F16 | PtxType::F32 | PtxType::F64)
63 }
64
65 pub fn is_64bit(&self) -> bool {
67 matches!(
68 self,
69 PtxType::S64 | PtxType::U64 | PtxType::B64 | PtxType::F64
70 )
71 }
72}
73
74impl std::fmt::Display for PtxType {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 let s = match self {
77 PtxType::S8 => ".s8",
78 PtxType::S16 => ".s16",
79 PtxType::S32 => ".s32",
80 PtxType::S64 => ".s64",
81 PtxType::U8 => ".u8",
82 PtxType::U16 => ".u16",
83 PtxType::U32 => ".u32",
84 PtxType::U64 => ".u64",
85 PtxType::F16 => ".f16",
86 PtxType::F32 => ".f32",
87 PtxType::F64 => ".f64",
88 PtxType::B8 => ".b8",
89 PtxType::B16 => ".b16",
90 PtxType::B32 => ".b32",
91 PtxType::B64 => ".b64",
92 PtxType::Pred => ".pred",
93 };
94 write!(f, "{}", s)
95 }
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
100pub enum AddressSpace {
101 Generic,
103 Global,
105 Shared,
107 Local,
109 Const,
111 Param,
113 Texture,
115 Surface,
117}
118
119impl std::fmt::Display for AddressSpace {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 let s = match self {
122 AddressSpace::Generic => "",
123 AddressSpace::Global => ".global",
124 AddressSpace::Shared => ".shared",
125 AddressSpace::Local => ".local",
126 AddressSpace::Const => ".const",
127 AddressSpace::Param => ".param",
128 AddressSpace::Texture => ".tex",
129 AddressSpace::Surface => ".surf",
130 };
131 write!(f, "{}", s)
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
137pub enum SmTarget {
138 #[default]
140 Unknown,
141 Sm50,
143 Sm52,
145 Sm60,
147 Sm61,
149 Sm70,
151 Sm75,
153 Sm80,
155 Sm86,
157 Sm89,
159 Sm90,
161}
162
163impl SmTarget {
164 pub fn min_ptx_version(&self) -> (u8, u8) {
166 match self {
167 SmTarget::Unknown => (1, 0),
168 SmTarget::Sm50 | SmTarget::Sm52 => (4, 0),
169 SmTarget::Sm60 | SmTarget::Sm61 => (5, 0),
170 SmTarget::Sm70 => (6, 0),
171 SmTarget::Sm75 => (6, 3),
172 SmTarget::Sm80 | SmTarget::Sm86 => (7, 0),
173 SmTarget::Sm89 => (7, 8),
174 SmTarget::Sm90 => (8, 0),
175 }
176 }
177
178 pub fn has_tensor_cores(&self) -> bool {
180 matches!(
181 self,
182 SmTarget::Sm70
183 | SmTarget::Sm75
184 | SmTarget::Sm80
185 | SmTarget::Sm86
186 | SmTarget::Sm89
187 | SmTarget::Sm90
188 )
189 }
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
194pub enum Opcode {
195 Ld,
198 St,
200 Mov,
202 Cvta,
204 Cvt,
206
207 Add,
210 Sub,
212 Mul,
214 Div,
216 Rem,
218 Mad,
220 Fma,
222 Neg,
224 Abs,
226 Min,
228 Max,
230
231 And,
234 Or,
236 Xor,
238 Not,
240 Shl,
242 Shr,
244
245 Setp,
248 Selp,
250
251 Bra,
254 Call,
256 Ret,
258 Exit,
260
261 Bar,
264 MemBar,
266 Atom,
268 Red,
270
271 Tex,
274 Tld4,
276 Suld,
278 Sust,
280 Shfl,
282 Vote,
284 Mma,
286 Wmma,
288 LdMatrix,
290 Cp,
292 Prefetch,
294
295 Unknown,
297}
298
299impl Opcode {
300 pub fn is_load(&self) -> bool {
302 matches!(
303 self,
304 Opcode::Ld | Opcode::Tex | Opcode::Tld4 | Opcode::Suld | Opcode::LdMatrix
305 )
306 }
307
308 pub fn is_store(&self) -> bool {
310 matches!(self, Opcode::St | Opcode::Sust)
311 }
312
313 pub fn is_memory_op(&self) -> bool {
315 self.is_load() || self.is_store() || matches!(self, Opcode::Atom | Opcode::Red)
316 }
317
318 pub fn is_sync(&self) -> bool {
320 matches!(self, Opcode::Bar | Opcode::MemBar)
321 }
322
323 pub fn is_branch(&self) -> bool {
325 matches!(
326 self,
327 Opcode::Bra | Opcode::Call | Opcode::Ret | Opcode::Exit
328 )
329 }
330}
331
332#[derive(Debug, Clone, PartialEq, Eq, Hash)]
334pub enum Modifier {
335 Shared,
338 Global,
340 Local,
342 Const,
344 Param,
346
347 U32,
350 U64,
352 S32,
354 S64,
356 F32,
358 F64,
360 B32,
362 B64,
364
365 Sync,
368 Cta,
370 Gl,
372 Sys,
374
375 AtomicAdd,
378 AtomicCas,
380 AtomicExch,
382 AtomicMin,
384 AtomicMax,
386
387 Other(String),
390}
391
392impl Modifier {
393 pub fn as_address_space(&self) -> Option<AddressSpace> {
395 match self {
396 Modifier::Shared => Some(AddressSpace::Shared),
397 Modifier::Global => Some(AddressSpace::Global),
398 Modifier::Local => Some(AddressSpace::Local),
399 Modifier::Const => Some(AddressSpace::Const),
400 Modifier::Param => Some(AddressSpace::Param),
401 _ => None,
402 }
403 }
404
405 pub fn as_type(&self) -> Option<PtxType> {
407 match self {
408 Modifier::U32 => Some(PtxType::U32),
409 Modifier::U64 => Some(PtxType::U64),
410 Modifier::S32 => Some(PtxType::S32),
411 Modifier::S64 => Some(PtxType::S64),
412 Modifier::F32 => Some(PtxType::F32),
413 Modifier::F64 => Some(PtxType::F64),
414 Modifier::B32 => Some(PtxType::B32),
415 Modifier::B64 => Some(PtxType::B64),
416 _ => None,
417 }
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_ptx_type_size() {
427 assert_eq!(PtxType::U8.size_bytes(), 1);
428 assert_eq!(PtxType::U16.size_bytes(), 2);
429 assert_eq!(PtxType::U32.size_bytes(), 4);
430 assert_eq!(PtxType::U64.size_bytes(), 8);
431 assert_eq!(PtxType::F32.size_bytes(), 4);
432 assert_eq!(PtxType::F64.size_bytes(), 8);
433 }
434
435 #[test]
436 fn test_ptx_type_properties() {
437 assert!(PtxType::S32.is_signed());
438 assert!(!PtxType::U32.is_signed());
439 assert!(PtxType::F32.is_float());
440 assert!(!PtxType::U32.is_float());
441 assert!(PtxType::U64.is_64bit());
442 assert!(!PtxType::U32.is_64bit());
443 }
444
445 #[test]
446 fn test_sm_target_ptx_version() {
447 assert!(SmTarget::Sm90.min_ptx_version() >= (8, 0));
448 assert!(SmTarget::Sm70.min_ptx_version() >= (6, 0));
449 }
450
451 #[test]
452 fn test_opcode_categories() {
453 assert!(Opcode::Ld.is_load());
454 assert!(Opcode::St.is_store());
455 assert!(Opcode::Bar.is_sync());
456 assert!(Opcode::Bra.is_branch());
457 }
458
459 #[test]
460 fn test_modifier_conversion() {
461 assert_eq!(
462 Modifier::Shared.as_address_space(),
463 Some(AddressSpace::Shared)
464 );
465 assert_eq!(Modifier::U32.as_type(), Some(PtxType::U32));
466 }
467}