use super::operand::Operand;
use super::register::Register;
use super::types::{
AtomOp, CacheQualifier, CmpOp, FenceScope, MemorySpace, MulMode, PtxType, RoundingMode,
SpecialReg, VectorWidth,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WmmaOp {
LoadA,
LoadB,
StoreD,
Mma,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WmmaShape {
M16N16K16,
M8N32K16,
M32N8K16,
}
impl WmmaShape {
#[must_use]
pub(crate) const fn as_ptx_str(self) -> &'static str {
match self {
Self::M16N16K16 => ".m16n16k16",
Self::M8N32K16 => ".m8n32k16",
Self::M32N8K16 => ".m32n8k16",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WmmaLayout {
RowMajor,
ColMajor,
}
impl WmmaLayout {
#[must_use]
pub(crate) const fn as_ptx_str(self) -> &'static str {
match self {
Self::RowMajor => ".row",
Self::ColMajor => ".col",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MmaShape {
M16N8K8,
M16N8K16,
M16N8K32,
M8N8K16,
M8N8K32,
}
impl MmaShape {
#[must_use]
pub const fn as_ptx_str(self) -> &'static str {
match self {
Self::M16N8K8 => ".m16n8k8",
Self::M16N8K16 => ".m16n8k16",
Self::M16N8K32 => ".m16n8k32",
Self::M8N8K16 => ".m8n8k16",
Self::M8N8K32 => ".m8n8k32",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WgmmaShape {
M64N8K16,
M64N16K16,
M64N32K16,
M64N64K16,
M64N128K16,
M64N256K16,
}
impl WgmmaShape {
#[must_use]
pub(crate) const fn as_ptx_str(self) -> &'static str {
match self {
Self::M64N8K16 => ".m64n8k16",
Self::M64N16K16 => ".m64n16k16",
Self::M64N32K16 => ".m64n32k16",
Self::M64N64K16 => ".m64n64k16",
Self::M64N128K16 => ".m64n128k16",
Self::M64N256K16 => ".m64n256k16",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReduxOp {
Add,
Min,
Max,
And,
Or,
Xor,
}
impl ReduxOp {
#[must_use]
pub(crate) const fn as_ptx_str(self) -> &'static str {
match self {
Self::Add => ".add",
Self::Min => ".min",
Self::Max => ".max",
Self::And => ".and",
Self::Or => ".or",
Self::Xor => ".xor",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StmatrixShape {
M8n8x1,
M8n8x2,
M8n8x4,
}
impl StmatrixShape {
#[must_use]
pub(crate) const fn as_ptx_str(self) -> &'static str {
match self {
Self::M8n8x1 => ".m8n8.x1",
Self::M8n8x2 => ".m8n8.x2",
Self::M8n8x4 => ".m8n8.x4",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SetmaxnregAction {
Inc,
Dec,
}
impl SetmaxnregAction {
#[must_use]
pub(crate) const fn as_ptx_str(self) -> &'static str {
match self {
Self::Inc => ".inc",
Self::Dec => ".dec",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GridDepAction {
LaunchDependents,
Wait,
}
impl GridDepAction {
#[must_use]
pub(crate) const fn as_ptx_str(self) -> &'static str {
match self {
Self::LaunchDependents => ".launch_dependents",
Self::Wait => ".wait",
}
}
}
#[derive(Debug, Clone)]
pub enum Instruction {
Add {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Sub {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Mul {
ty: PtxType,
mode: MulMode,
dst: Register,
a: Operand,
b: Operand,
},
Mad {
ty: PtxType,
mode: MulMode,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
MadLo {
typ: PtxType,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
MadHi {
typ: PtxType,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
MadWide {
src_typ: PtxType,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
Fma {
rnd: RoundingMode,
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
Neg {
ty: PtxType,
dst: Register,
src: Operand,
},
Abs {
ty: PtxType,
dst: Register,
src: Operand,
},
Min {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Max {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Addc {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
carry_out: bool,
},
Selp {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
pred: Register,
},
Brev {
ty: PtxType,
dst: Register,
src: Operand,
},
Clz {
ty: PtxType,
dst: Register,
src: Operand,
},
Popc {
ty: PtxType,
dst: Register,
src: Operand,
},
Bfind {
ty: PtxType,
dst: Register,
src: Operand,
},
Bfe {
ty: PtxType,
dst: Register,
src: Operand,
start: Operand,
len: Operand,
},
Bfi {
ty: PtxType,
dst: Register,
insert: Operand,
base: Operand,
start: Operand,
len: Operand,
},
Rcp {
rnd: Option<RoundingMode>,
ty: PtxType,
dst: Register,
src: Operand,
},
Rsqrt {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Sqrt {
rnd: Option<RoundingMode>,
ty: PtxType,
dst: Register,
src: Operand,
},
Ex2 {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Lg2 {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Sin {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Cos {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Shl {
ty: PtxType,
dst: Register,
src: Operand,
amount: Operand,
},
Shr {
ty: PtxType,
dst: Register,
src: Operand,
amount: Operand,
},
Div {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Rem {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
And {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Or {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Xor {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
SetP {
cmp: CmpOp,
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Load {
space: MemorySpace,
qualifier: CacheQualifier,
vec: VectorWidth,
ty: PtxType,
dst: Register,
addr: Operand,
},
Store {
space: MemorySpace,
qualifier: CacheQualifier,
vec: VectorWidth,
ty: PtxType,
addr: Operand,
src: Register,
},
CpAsync {
bytes: u32,
dst_shared: Operand,
src_global: Operand,
},
CpAsyncCommit,
CpAsyncWait {
n: u32,
},
Cvt {
rnd: Option<RoundingMode>,
dst_ty: PtxType,
src_ty: PtxType,
dst: Register,
src: Operand,
},
Branch {
target: String,
predicate: Option<(Register, bool)>,
},
Label(String),
Return,
BarSync {
id: u32,
},
BarArrive {
id: u32,
count: u32,
},
FenceAcqRel {
scope: FenceScope,
},
Wmma {
op: WmmaOp,
shape: WmmaShape,
layout: WmmaLayout,
ty: PtxType,
fragments: Vec<Register>,
addr: Option<Operand>,
stride: Option<Operand>,
},
Mma {
shape: MmaShape,
a_ty: PtxType,
b_ty: PtxType,
c_ty: PtxType,
d_ty: PtxType,
d_regs: Vec<Register>,
a_regs: Vec<Register>,
b_regs: Vec<Register>,
c_regs: Vec<Register>,
},
Wgmma {
shape: WgmmaShape,
d_ty: PtxType,
a_ty: PtxType,
b_ty: PtxType,
desc_a: Register,
desc_b: Register,
d_regs: Vec<Register>,
scale_d: i32,
imm_scale_a: i32,
imm_scale_b: i32,
trans_a: i32,
trans_b: i32,
},
TmaLoad {
dst_shared: Operand,
desc: Register,
coords: Vec<Register>,
barrier: Register,
},
Atom {
space: MemorySpace,
op: AtomOp,
ty: PtxType,
dst: Register,
addr: Operand,
src: Operand,
},
AtomCas {
space: MemorySpace,
ty: PtxType,
dst: Register,
addr: Operand,
compare: Operand,
value: Operand,
},
Red {
space: MemorySpace,
op: AtomOp,
ty: PtxType,
addr: Operand,
src: Operand,
},
AtomGlobalAddFloat {
ty: PtxType,
dst: Register,
addr: Operand,
src: Operand,
},
MovSpecial {
dst: Register,
special: SpecialReg,
},
LoadParam {
ty: PtxType,
dst: Register,
param_name: String,
},
Comment(String),
Raw(String),
Pragma(String),
Dp4a {
dst: Register,
a: Operand,
b: Operand,
c: Operand,
signed_a: bool,
signed_b: bool,
},
Dp2a {
dst: Register,
a: Operand,
b: Operand,
c: Operand,
signed_a: bool,
signed_b: bool,
lo: bool,
},
Tex1d {
ty: PtxType,
dst: Register,
tex_ref: String,
coord: Operand,
},
Tex2d {
ty: PtxType,
dst: Register,
tex_ref: String,
coord_x: Operand,
coord_y: Operand,
},
Tex3d {
ty: PtxType,
dst: Register,
tex_ref: String,
coord_x: Operand,
coord_y: Operand,
coord_z: Operand,
},
SurfLoad {
ty: PtxType,
dst: Register,
surf_ref: String,
coord: Operand,
},
SurfStore {
ty: PtxType,
surf_ref: String,
coord: Operand,
src: Register,
},
Redux {
op: ReduxOp,
dst: Register,
src: Operand,
membership_mask: u32,
},
Stmatrix {
dst_addr: Operand,
src: Register,
shape: StmatrixShape,
trans: bool,
},
ElectSync {
dst: Register,
membership_mask: u32,
},
Setmaxnreg {
reg_count: u32,
action: SetmaxnregAction,
},
Griddepcontrol {
action: GridDepAction,
},
FenceProxy {
scope: FenceScope,
space: MemorySpace,
},
MbarrierInit {
addr: Operand,
count: Operand,
},
MbarrierArrive {
addr: Operand,
},
MbarrierWait {
addr: Operand,
phase: Operand,
},
Tcgen05Mma {
a_desc: Register,
b_desc: Register,
},
BarrierCluster,
FenceCluster,
CpAsyncBulk {
dst_smem: Register,
src_gmem: Register,
desc: Register,
},
Ldmatrix {
num_fragments: u32,
trans: bool,
dst_regs: Vec<Register>,
src_addr: Operand,
},
}
#[path = "instruction_emit.rs"]
mod emit_impl;
#[cfg(test)]
#[path = "instruction_tests.rs"]
mod tests;