use super::registers::VirtualReg;
use super::types::{PtxStateSpace, PtxType};
#[derive(Debug, Clone, PartialEq)]
pub enum PtxOp {
Add,
Sub,
Mul,
Mad,
MadLo,
Div,
Rem,
Abs,
Neg,
Min,
Max,
Rcp,
Rsqrt,
Sqrt,
Sin,
Cos,
Ex2,
Lg2,
Fma,
Setp,
And,
Or,
Xor,
Not,
Shl,
Shr,
Mov,
Ld,
St,
LdParam,
Cvt,
Cvta,
Selp,
ShflDown,
ShflUp,
ShflBfly,
ShflIdx,
VoteAll,
VoteAny,
VoteBallot,
Bra,
Call,
Ret,
Exit,
Bar,
MemBar,
Tex,
Suld,
Sust,
AtomAdd,
AtomMin,
AtomMax,
AtomExch,
AtomCas,
WmmaLoadA,
WmmaLoadB,
WmmaLoadC,
WmmaMma,
WmmaStoreD,
Dp4a,
Dp4aUS,
Dp4aS32,
Popc,
Clz,
Bfind,
Bfe,
Bfi,
LdVolatile,
Prefetch,
Vote,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CmpOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
Lo,
Ls,
Hi,
Hs,
}
impl CmpOp {
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::Eq => "eq",
Self::Ne => "ne",
Self::Lt => "lt",
Self::Le => "le",
Self::Gt => "gt",
Self::Ge => "ge",
Self::Lo => "lo",
Self::Ls => "ls",
Self::Hi => "hi",
Self::Hs => "hs",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WmmaLayout {
#[default]
RowMajor,
ColMajor,
}
impl WmmaLayout {
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::RowMajor => "row",
Self::ColMajor => "col",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct WmmaShape {
pub m: u32,
pub n: u32,
pub k: u32,
}
impl WmmaShape {
pub const M16N16K16: Self = Self { m: 16, n: 16, k: 16 };
pub const M8N32K16: Self = Self { m: 8, n: 32, k: 16 };
pub const M32N8K16: Self = Self { m: 32, n: 8, k: 16 };
#[must_use]
pub fn to_ptx_string(self) -> String {
format!("m{}n{}k{}", self.m, self.n, self.k)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RoundingMode {
#[default]
Rn,
Rz,
Rp,
Rm,
Rni,
Rzi,
Rpi,
Rmi,
}
impl RoundingMode {
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::Rn => ".rn",
Self::Rz => ".rz",
Self::Rp => ".rp",
Self::Rm => ".rm",
Self::Rni => ".rni",
Self::Rzi => ".rzi",
Self::Rpi => ".rpi",
Self::Rmi => ".rmi",
}
}
}
#[derive(Debug, Clone)]
pub struct PtxInstruction {
pub op: PtxOp,
pub ty: PtxType,
pub src_type: Option<PtxType>,
pub dst: Option<Operand>,
pub dsts: Vec<Operand>,
pub srcs: Vec<Operand>,
pub predicate: Option<Predicate>,
pub state_space: Option<PtxStateSpace>,
pub rounding: Option<RoundingMode>,
pub label: Option<String>,
}
#[derive(Debug, Clone)]
pub enum Operand {
Reg(VirtualReg),
SpecialReg(super::registers::PtxReg),
ImmI64(i64),
ImmU64(u64),
ImmF32(f32),
ImmF64(f64),
Param(String),
Addr {
base: VirtualReg,
offset: i32,
},
Label(String),
}
#[derive(Debug, Clone)]
pub struct Predicate {
pub reg: VirtualReg,
pub negated: bool,
}
impl PtxInstruction {
#[must_use]
pub fn new(op: PtxOp, ty: PtxType) -> Self {
Self {
op,
ty,
src_type: None,
dst: None,
dsts: Vec::new(),
srcs: Vec::new(),
predicate: None,
state_space: None,
rounding: None,
label: None,
}
}
#[must_use]
pub fn with_src_type(mut self, src_type: PtxType) -> Self {
self.src_type = Some(src_type);
self
}
#[must_use]
pub fn dst(mut self, dst: Operand) -> Self {
if matches!(self.ty, PtxType::V2F32 | PtxType::V4F32) {
self.dsts.push(dst);
} else {
self.dst = Some(dst);
}
self
}
#[must_use]
pub fn push_dst(mut self, dst: Operand) -> Self {
self.dsts.push(dst);
self
}
#[must_use]
pub fn src(mut self, src: Operand) -> Self {
self.srcs.push(src);
self
}
#[must_use]
pub fn predicated(mut self, pred: Predicate) -> Self {
self.predicate = Some(pred);
self
}
#[must_use]
pub fn space(mut self, space: PtxStateSpace) -> Self {
self.state_space = Some(space);
self
}
#[must_use]
pub fn rounding(mut self, mode: RoundingMode) -> Self {
self.rounding = Some(mode);
self
}
#[must_use]
pub fn label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
}
#[cfg(test)]
mod tests;