use super::registers::VirtualReg;
use super::types::{PtxStateSpace, PtxType};
#[derive(Debug, Clone, PartialEq)]
pub enum PtxOp {
Add,
Sub,
Mul,
Mad,
MadLo,
Div,
Rem,
Abs,
Neg,
Min,
Max,
Rcp,
Rsqrt,
Sqrt,
Sin,
Cos,
Ex2,
Lg2,
Fma,
Setp,
And,
Or,
Xor,
Not,
Shl,
Shr,
Mov,
Ld,
St,
LdParam,
Cvt,
Cvta,
Selp,
ShflDown,
ShflUp,
ShflBfly,
ShflIdx,
VoteAll,
VoteAny,
VoteBallot,
Bra,
Call,
Ret,
Exit,
Bar,
MemBar,
Tex,
Suld,
Sust,
AtomAdd,
AtomMin,
AtomMax,
AtomExch,
AtomCas,
WmmaLoadA,
WmmaLoadB,
WmmaLoadC,
WmmaMma,
WmmaStoreD,
Dp4a,
Dp4aUS,
Dp4aS32,
Popc,
Clz,
Bfind,
Bfe,
Bfi,
LdVolatile,
Vote,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CmpOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
Lo,
Ls,
Hi,
Hs,
}
impl CmpOp {
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::Eq => "eq",
Self::Ne => "ne",
Self::Lt => "lt",
Self::Le => "le",
Self::Gt => "gt",
Self::Ge => "ge",
Self::Lo => "lo",
Self::Ls => "ls",
Self::Hi => "hi",
Self::Hs => "hs",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WmmaLayout {
#[default]
RowMajor,
ColMajor,
}
impl WmmaLayout {
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::RowMajor => "row",
Self::ColMajor => "col",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct WmmaShape {
pub m: u32,
pub n: u32,
pub k: u32,
}
impl WmmaShape {
pub const M16N16K16: Self = Self {
m: 16,
n: 16,
k: 16,
};
pub const M8N32K16: Self = Self { m: 8, n: 32, k: 16 };
pub const M32N8K16: Self = Self { m: 32, n: 8, k: 16 };
#[must_use]
pub fn to_ptx_string(self) -> String {
format!("m{}n{}k{}", self.m, self.n, self.k)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RoundingMode {
#[default]
Rn,
Rz,
Rp,
Rm,
Rni,
Rzi,
Rpi,
Rmi,
}
impl RoundingMode {
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::Rn => ".rn",
Self::Rz => ".rz",
Self::Rp => ".rp",
Self::Rm => ".rm",
Self::Rni => ".rni",
Self::Rzi => ".rzi",
Self::Rpi => ".rpi",
Self::Rmi => ".rmi",
}
}
}
#[derive(Debug, Clone)]
pub struct PtxInstruction {
pub op: PtxOp,
pub ty: PtxType,
pub src_type: Option<PtxType>,
pub dst: Option<Operand>,
pub dsts: Vec<Operand>,
pub srcs: Vec<Operand>,
pub predicate: Option<Predicate>,
pub state_space: Option<PtxStateSpace>,
pub rounding: Option<RoundingMode>,
pub label: Option<String>,
}
#[derive(Debug, Clone)]
pub enum Operand {
Reg(VirtualReg),
SpecialReg(super::registers::PtxReg),
ImmI64(i64),
ImmU64(u64),
ImmF32(f32),
ImmF64(f64),
Param(String),
Addr {
base: VirtualReg,
offset: i32,
},
Label(String),
}
#[derive(Debug, Clone)]
pub struct Predicate {
pub reg: VirtualReg,
pub negated: bool,
}
impl PtxInstruction {
#[must_use]
pub fn new(op: PtxOp, ty: PtxType) -> Self {
Self {
op,
ty,
src_type: None,
dst: None,
dsts: Vec::new(),
srcs: Vec::new(),
predicate: None,
state_space: None,
rounding: None,
label: None,
}
}
#[must_use]
pub fn with_src_type(mut self, src_type: PtxType) -> Self {
self.src_type = Some(src_type);
self
}
#[must_use]
pub fn dst(mut self, dst: Operand) -> Self {
if matches!(self.ty, PtxType::V2F32 | PtxType::V4F32) {
self.dsts.push(dst);
} else {
self.dst = Some(dst);
}
self
}
#[must_use]
pub fn push_dst(mut self, dst: Operand) -> Self {
self.dsts.push(dst);
self
}
#[must_use]
pub fn src(mut self, src: Operand) -> Self {
self.srcs.push(src);
self
}
#[must_use]
pub fn predicated(mut self, pred: Predicate) -> Self {
self.predicate = Some(pred);
self
}
#[must_use]
pub fn space(mut self, space: PtxStateSpace) -> Self {
self.state_space = Some(space);
self
}
#[must_use]
pub fn rounding(mut self, mode: RoundingMode) -> Self {
self.rounding = Some(mode);
self
}
#[must_use]
pub fn label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cmp_op_strings() {
assert_eq!(CmpOp::Eq.to_ptx_string(), "eq");
assert_eq!(CmpOp::Lt.to_ptx_string(), "lt");
assert_eq!(CmpOp::Ge.to_ptx_string(), "ge");
}
#[test]
fn test_cmp_op_all_variants() {
assert_eq!(CmpOp::Ne.to_ptx_string(), "ne");
assert_eq!(CmpOp::Le.to_ptx_string(), "le");
assert_eq!(CmpOp::Gt.to_ptx_string(), "gt");
assert_eq!(CmpOp::Lo.to_ptx_string(), "lo");
assert_eq!(CmpOp::Ls.to_ptx_string(), "ls");
assert_eq!(CmpOp::Hi.to_ptx_string(), "hi");
assert_eq!(CmpOp::Hs.to_ptx_string(), "hs");
}
#[test]
fn test_rounding_mode_strings() {
assert_eq!(RoundingMode::Rn.to_ptx_string(), ".rn");
assert_eq!(RoundingMode::Rz.to_ptx_string(), ".rz");
}
#[test]
fn test_instruction_builder() {
let instr = PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::ImmF32(0.0))
.src(Operand::ImmF32(1.0))
.src(Operand::ImmF32(2.0));
assert_eq!(instr.op, PtxOp::Add);
assert_eq!(instr.ty, PtxType::F32);
assert!(instr.dst.is_some());
assert_eq!(instr.srcs.len(), 2);
}
#[test]
fn test_instruction_predicated() {
let pred_reg = VirtualReg::new(0, PtxType::Pred);
let pred = Predicate {
reg: pred_reg,
negated: false,
};
let instr = PtxInstruction::new(PtxOp::Bra, PtxType::B32)
.predicated(pred)
.label("exit");
assert!(instr.predicate.is_some());
assert!(instr.label.is_some());
}
#[test]
fn test_instruction_memory() {
let instr = PtxInstruction::new(PtxOp::Ld, PtxType::F32).space(PtxStateSpace::Global);
assert_eq!(instr.state_space, Some(PtxStateSpace::Global));
}
#[test]
fn test_wmma_layout_strings() {
assert_eq!(WmmaLayout::RowMajor.to_ptx_string(), "row");
assert_eq!(WmmaLayout::ColMajor.to_ptx_string(), "col");
}
#[test]
fn test_wmma_shape_strings() {
assert_eq!(WmmaShape::M16N16K16.to_ptx_string(), "m16n16k16");
assert_eq!(WmmaShape::M8N32K16.to_ptx_string(), "m8n32k16");
assert_eq!(WmmaShape::M32N8K16.to_ptx_string(), "m32n8k16");
}
#[test]
fn test_wmma_shape_values() {
let shape = WmmaShape::M16N16K16;
assert_eq!(shape.m, 16);
assert_eq!(shape.n, 16);
assert_eq!(shape.k, 16);
}
#[test]
fn test_wmma_ops_exist() {
let _load_a = PtxOp::WmmaLoadA;
let _load_b = PtxOp::WmmaLoadB;
let _load_c = PtxOp::WmmaLoadC;
let _mma = PtxOp::WmmaMma;
let _store = PtxOp::WmmaStoreD;
}
}