use super::operand::Operand;
use super::register::Register;
use super::types::{
AtomOp, CacheQualifier, CmpOp, FenceScope, MemorySpace, MulMode, PtxType, RoundingMode,
SpecialReg, VectorWidth,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WmmaOp {
LoadA,
LoadB,
StoreD,
Mma,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WmmaShape {
M16N16K16,
M8N32K16,
M32N8K16,
}
impl WmmaShape {
#[must_use]
const fn as_ptx_str(self) -> &'static str {
match self {
Self::M16N16K16 => ".m16n16k16",
Self::M8N32K16 => ".m8n32k16",
Self::M32N8K16 => ".m32n8k16",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WmmaLayout {
RowMajor,
ColMajor,
}
impl WmmaLayout {
#[must_use]
const fn as_ptx_str(self) -> &'static str {
match self {
Self::RowMajor => ".row",
Self::ColMajor => ".col",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MmaShape {
M16N8K8,
M16N8K16,
M16N8K32,
M8N8K16,
M8N8K32,
}
impl MmaShape {
#[must_use]
pub const fn as_ptx_str(self) -> &'static str {
match self {
Self::M16N8K8 => ".m16n8k8",
Self::M16N8K16 => ".m16n8k16",
Self::M16N8K32 => ".m16n8k32",
Self::M8N8K16 => ".m8n8k16",
Self::M8N8K32 => ".m8n8k32",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WgmmaShape {
M64N8K16,
M64N16K16,
M64N32K16,
M64N64K16,
M64N128K16,
M64N256K16,
}
impl WgmmaShape {
#[must_use]
const fn as_ptx_str(self) -> &'static str {
match self {
Self::M64N8K16 => ".m64n8k16",
Self::M64N16K16 => ".m64n16k16",
Self::M64N32K16 => ".m64n32k16",
Self::M64N64K16 => ".m64n64k16",
Self::M64N128K16 => ".m64n128k16",
Self::M64N256K16 => ".m64n256k16",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReduxOp {
Add,
Min,
Max,
And,
Or,
Xor,
}
impl ReduxOp {
#[must_use]
const fn as_ptx_str(self) -> &'static str {
match self {
Self::Add => ".add",
Self::Min => ".min",
Self::Max => ".max",
Self::And => ".and",
Self::Or => ".or",
Self::Xor => ".xor",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StmatrixShape {
M8n8x1,
M8n8x2,
M8n8x4,
}
impl StmatrixShape {
#[must_use]
const fn as_ptx_str(self) -> &'static str {
match self {
Self::M8n8x1 => ".m8n8.x1",
Self::M8n8x2 => ".m8n8.x2",
Self::M8n8x4 => ".m8n8.x4",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SetmaxnregAction {
Inc,
Dec,
}
impl SetmaxnregAction {
#[must_use]
const fn as_ptx_str(self) -> &'static str {
match self {
Self::Inc => ".inc",
Self::Dec => ".dec",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GridDepAction {
LaunchDependents,
Wait,
}
impl GridDepAction {
#[must_use]
const fn as_ptx_str(self) -> &'static str {
match self {
Self::LaunchDependents => ".launch_dependents",
Self::Wait => ".wait",
}
}
}
#[derive(Debug, Clone)]
pub enum Instruction {
Add {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Sub {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Mul {
ty: PtxType,
mode: MulMode,
dst: Register,
a: Operand,
b: Operand,
},
Mad {
ty: PtxType,
mode: MulMode,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
MadLo {
typ: PtxType,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
MadHi {
typ: PtxType,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
MadWide {
src_typ: PtxType,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
Fma {
rnd: RoundingMode,
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
c: Operand,
},
Neg {
ty: PtxType,
dst: Register,
src: Operand,
},
Abs {
ty: PtxType,
dst: Register,
src: Operand,
},
Min {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Max {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Brev {
ty: PtxType,
dst: Register,
src: Operand,
},
Clz {
ty: PtxType,
dst: Register,
src: Operand,
},
Popc {
ty: PtxType,
dst: Register,
src: Operand,
},
Bfind {
ty: PtxType,
dst: Register,
src: Operand,
},
Bfe {
ty: PtxType,
dst: Register,
src: Operand,
start: Operand,
len: Operand,
},
Bfi {
ty: PtxType,
dst: Register,
insert: Operand,
base: Operand,
start: Operand,
len: Operand,
},
Rcp {
rnd: Option<RoundingMode>,
ty: PtxType,
dst: Register,
src: Operand,
},
Rsqrt {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Sqrt {
rnd: Option<RoundingMode>,
ty: PtxType,
dst: Register,
src: Operand,
},
Ex2 {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Lg2 {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Sin {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Cos {
approx: bool,
ty: PtxType,
dst: Register,
src: Operand,
},
Shl {
ty: PtxType,
dst: Register,
src: Operand,
amount: Operand,
},
Shr {
ty: PtxType,
dst: Register,
src: Operand,
amount: Operand,
},
Div {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Rem {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
And {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Or {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Xor {
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
SetP {
cmp: CmpOp,
ty: PtxType,
dst: Register,
a: Operand,
b: Operand,
},
Load {
space: MemorySpace,
qualifier: CacheQualifier,
vec: VectorWidth,
ty: PtxType,
dst: Register,
addr: Operand,
},
Store {
space: MemorySpace,
qualifier: CacheQualifier,
vec: VectorWidth,
ty: PtxType,
addr: Operand,
src: Register,
},
CpAsync {
bytes: u32,
dst_shared: Operand,
src_global: Operand,
},
CpAsyncCommit,
CpAsyncWait {
n: u32,
},
Cvt {
rnd: Option<RoundingMode>,
dst_ty: PtxType,
src_ty: PtxType,
dst: Register,
src: Operand,
},
Branch {
target: String,
predicate: Option<(Register, bool)>,
},
Label(String),
Return,
BarSync {
id: u32,
},
BarArrive {
id: u32,
count: u32,
},
FenceAcqRel {
scope: FenceScope,
},
Wmma {
op: WmmaOp,
shape: WmmaShape,
layout: WmmaLayout,
ty: PtxType,
fragments: Vec<Register>,
addr: Option<Operand>,
stride: Option<Operand>,
},
Mma {
shape: MmaShape,
a_ty: PtxType,
b_ty: PtxType,
c_ty: PtxType,
d_ty: PtxType,
d_regs: Vec<Register>,
a_regs: Vec<Register>,
b_regs: Vec<Register>,
c_regs: Vec<Register>,
},
Wgmma {
shape: WgmmaShape,
d_ty: PtxType,
a_ty: PtxType,
b_ty: PtxType,
desc_a: Register,
desc_b: Register,
d_regs: Vec<Register>,
scale_d: i32,
imm_scale_a: i32,
imm_scale_b: i32,
trans_a: i32,
trans_b: i32,
},
TmaLoad {
dst_shared: Operand,
desc: Register,
coords: Vec<Register>,
barrier: Register,
},
Atom {
space: MemorySpace,
op: AtomOp,
ty: PtxType,
dst: Register,
addr: Operand,
src: Operand,
},
AtomCas {
space: MemorySpace,
ty: PtxType,
dst: Register,
addr: Operand,
compare: Operand,
value: Operand,
},
Red {
space: MemorySpace,
op: AtomOp,
ty: PtxType,
addr: Operand,
src: Operand,
},
MovSpecial {
dst: Register,
special: SpecialReg,
},
LoadParam {
ty: PtxType,
dst: Register,
param_name: String,
},
Comment(String),
Raw(String),
Pragma(String),
Dp4a {
dst: Register,
a: Operand,
b: Operand,
c: Operand,
signed_a: bool,
signed_b: bool,
},
Dp2a {
dst: Register,
a: Operand,
b: Operand,
c: Operand,
signed_a: bool,
signed_b: bool,
lo: bool,
},
Tex1d {
ty: PtxType,
dst: Register,
tex_ref: String,
coord: Operand,
},
Tex2d {
ty: PtxType,
dst: Register,
tex_ref: String,
coord_x: Operand,
coord_y: Operand,
},
Tex3d {
ty: PtxType,
dst: Register,
tex_ref: String,
coord_x: Operand,
coord_y: Operand,
coord_z: Operand,
},
SurfLoad {
ty: PtxType,
dst: Register,
surf_ref: String,
coord: Operand,
},
SurfStore {
ty: PtxType,
surf_ref: String,
coord: Operand,
src: Register,
},
Redux {
op: ReduxOp,
dst: Register,
src: Operand,
membership_mask: u32,
},
Stmatrix {
dst_addr: Operand,
src: Register,
shape: StmatrixShape,
trans: bool,
},
ElectSync {
dst: Register,
membership_mask: u32,
},
Setmaxnreg {
reg_count: u32,
action: SetmaxnregAction,
},
Griddepcontrol {
action: GridDepAction,
},
FenceProxy {
scope: FenceScope,
space: MemorySpace,
},
MbarrierInit {
addr: Operand,
count: Operand,
},
MbarrierArrive {
addr: Operand,
},
MbarrierWait {
addr: Operand,
phase: Operand,
},
Tcgen05Mma {
a_desc: Register,
b_desc: Register,
},
BarrierCluster,
FenceCluster,
CpAsyncBulk {
dst_smem: Register,
src_gmem: Register,
desc: Register,
},
Ldmatrix {
num_fragments: u32,
trans: bool,
dst_regs: Vec<Register>,
src_addr: Operand,
},
}
impl Instruction {
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn emit(&self) -> String {
match self {
Self::Add { ty, dst, a, b } => {
format!("add{} {dst}, {a}, {b};", ty.as_ptx_str())
}
Self::Sub { ty, dst, a, b } => {
format!("sub{} {dst}, {a}, {b};", ty.as_ptx_str())
}
Self::Mul {
ty,
mode,
dst,
a,
b,
} => {
format!(
"mul{}{} {dst}, {a}, {b};",
mode.as_ptx_str(),
ty.as_ptx_str()
)
}
Self::Mad {
ty,
mode,
dst,
a,
b,
c,
} => {
format!(
"mad{}{} {dst}, {a}, {b}, {c};",
mode.as_ptx_str(),
ty.as_ptx_str()
)
}
Self::MadLo { typ, dst, a, b, c } => {
format!("mad.lo{} {dst}, {a}, {b}, {c};", typ.as_ptx_str())
}
Self::MadHi { typ, dst, a, b, c } => {
format!("mad.hi{} {dst}, {a}, {b}, {c};", typ.as_ptx_str())
}
Self::MadWide {
src_typ,
dst,
a,
b,
c,
} => {
format!("mad.wide{} {dst}, {a}, {b}, {c};", src_typ.as_ptx_str())
}
Self::Fma {
rnd,
ty,
dst,
a,
b,
c,
} => {
format!(
"fma{}{} {dst}, {a}, {b}, {c};",
rnd.as_ptx_str(),
ty.as_ptx_str()
)
}
Self::Neg { ty, dst, src } => {
format!("neg{} {dst}, {src};", ty.as_ptx_str())
}
Self::Abs { ty, dst, src } => {
format!("abs{} {dst}, {src};", ty.as_ptx_str())
}
Self::Min { ty, dst, a, b } => {
format!("min{} {dst}, {a}, {b};", ty.as_ptx_str())
}
Self::Max { ty, dst, a, b } => {
format!("max{} {dst}, {a}, {b};", ty.as_ptx_str())
}
Self::Brev { ty, dst, src } => {
format!("brev{} {dst}, {src};", ty.as_ptx_str())
}
Self::Clz { ty, dst, src } => {
format!("clz{} {dst}, {src};", ty.as_ptx_str())
}
Self::Popc { ty, dst, src } => {
format!("popc{} {dst}, {src};", ty.as_ptx_str())
}
Self::Bfind { ty, dst, src } => {
format!("bfind{} {dst}, {src};", ty.as_ptx_str())
}
Self::Bfe {
ty,
dst,
src,
start,
len,
} => {
format!("bfe{} {dst}, {src}, {start}, {len};", ty.as_ptx_str())
}
Self::Bfi {
ty,
dst,
insert,
base,
start,
len,
} => {
format!(
"bfi{} {dst}, {insert}, {base}, {start}, {len};",
ty.as_ptx_str()
)
}
Self::Rcp { rnd, ty, dst, src } => {
let rnd_str = rnd.map_or(String::new(), |r| r.as_ptx_str().to_string());
format!("rcp{rnd_str}{} {dst}, {src};", ty.as_ptx_str())
}
Self::Rsqrt {
approx,
ty,
dst,
src,
} => {
let approx_str = if *approx { ".approx" } else { "" };
format!("rsqrt{approx_str}{} {dst}, {src};", ty.as_ptx_str())
}
Self::Sqrt { rnd, ty, dst, src } => {
let rnd_str = rnd.map_or(String::new(), |r| r.as_ptx_str().to_string());
format!("sqrt{rnd_str}{} {dst}, {src};", ty.as_ptx_str())
}
Self::Ex2 {
approx,
ty,
dst,
src,
} => {
let approx_str = if *approx { ".approx" } else { "" };
format!("ex2{approx_str}{} {dst}, {src};", ty.as_ptx_str())
}
Self::Lg2 {
approx,
ty,
dst,
src,
} => {
let approx_str = if *approx { ".approx" } else { "" };
format!("lg2{approx_str}{} {dst}, {src};", ty.as_ptx_str())
}
Self::Sin {
approx,
ty,
dst,
src,
} => {
let approx_str = if *approx { ".approx" } else { "" };
format!("sin{approx_str}{} {dst}, {src};", ty.as_ptx_str())
}
Self::Cos {
approx,
ty,
dst,
src,
} => {
let approx_str = if *approx { ".approx" } else { "" };
format!("cos{approx_str}{} {dst}, {src};", ty.as_ptx_str())
}
Self::Shl {
ty,
dst,
src,
amount,
} => {
format!("shl{} {dst}, {src}, {amount};", ty.as_ptx_str())
}
Self::Shr {
ty,
dst,
src,
amount,
} => {
format!("shr{} {dst}, {src}, {amount};", ty.as_ptx_str())
}
Self::Div { ty, dst, a, b } => {
format!("div{} {dst}, {a}, {b};", ty.as_ptx_str())
}
Self::Rem { ty, dst, a, b } => {
format!("rem{} {dst}, {a}, {b};", ty.as_ptx_str())
}
Self::And { ty, dst, a, b } => {
format!("and{} {dst}, {a}, {b};", ty.as_ptx_str())
}
Self::Or { ty, dst, a, b } => {
format!("or{} {dst}, {a}, {b};", ty.as_ptx_str())
}
Self::Xor { ty, dst, a, b } => {
format!("xor{} {dst}, {a}, {b};", ty.as_ptx_str())
}
Self::SetP { cmp, ty, dst, a, b } => {
format!(
"setp{}{} {dst}, {a}, {b};",
cmp.as_ptx_str(),
ty.as_ptx_str()
)
}
Self::Load {
space,
qualifier,
vec,
ty,
dst,
addr,
} => {
format!(
"ld{}{}{}{} {dst}, {addr};",
space.as_ptx_str(),
qualifier.as_ptx_str(),
vec.as_ptx_str(),
ty.as_ptx_str()
)
}
Self::Store {
space,
qualifier,
vec,
ty,
addr,
src,
} => {
format!(
"st{}{}{}{} {addr}, {src};",
space.as_ptx_str(),
qualifier.as_ptx_str(),
vec.as_ptx_str(),
ty.as_ptx_str()
)
}
Self::CpAsync {
bytes,
dst_shared,
src_global,
} => {
format!("cp.async.ca.shared.global [{dst_shared}], [{src_global}], {bytes};")
}
Self::CpAsyncCommit => "cp.async.commit_group;".to_string(),
Self::CpAsyncWait { n } => {
format!("cp.async.wait_group {n};")
}
Self::Cvt {
rnd,
dst_ty,
src_ty,
dst,
src,
} => {
let rnd_str = rnd.map_or(String::new(), |r| r.as_ptx_str().to_string());
format!(
"cvt{rnd_str}{}{} {dst}, {src};",
dst_ty.as_ptx_str(),
src_ty.as_ptx_str()
)
}
Self::Branch { target, predicate } => match predicate {
Some((pred, negated)) => {
let neg = if *negated { "!" } else { "" };
format!("@{neg}{pred} bra ${target};")
}
None => format!("bra ${target};"),
},
Self::Label(name) => format!("${name}:"),
Self::Return => "ret;".to_string(),
Self::BarSync { id } => format!("bar.sync {id};"),
Self::BarArrive { id, count } => {
format!("bar.arrive {id}, {count};")
}
Self::FenceAcqRel { scope } => {
format!("fence.acq_rel{};", scope.as_ptx_str())
}
Self::Wmma {
op,
shape,
layout,
ty,
fragments,
addr,
stride,
} => emit_wmma(
*op,
*shape,
*layout,
*ty,
fragments,
addr.as_ref(),
stride.as_ref(),
),
Self::Mma {
shape,
a_ty,
b_ty,
c_ty,
d_ty,
d_regs,
a_regs,
b_regs,
c_regs,
} => emit_mma(
*shape, *a_ty, *b_ty, *c_ty, *d_ty, d_regs, a_regs, b_regs, c_regs,
),
Self::Wgmma {
shape,
d_ty,
a_ty,
b_ty,
desc_a,
desc_b,
d_regs,
scale_d,
imm_scale_a,
imm_scale_b,
trans_a,
trans_b,
} => {
let d_list = reg_list(d_regs);
format!(
"wgmma.mma_async.sync.aligned{}{}{}{} {{{d_list}}}, {desc_a}, {desc_b}, {scale_d}, {imm_scale_a}, {imm_scale_b}, {trans_a}, {trans_b};",
shape.as_ptx_str(),
d_ty.as_ptx_str(),
a_ty.as_ptx_str(),
b_ty.as_ptx_str(),
)
}
Self::TmaLoad {
dst_shared,
desc,
coords,
barrier,
} => {
let coord_list = coords
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ");
format!(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [{dst_shared}], [{desc}, {{{coord_list}}}], [{barrier}];",
)
}
Self::Atom {
space,
op,
ty,
dst,
addr,
src,
} => {
format!(
"atom{}{}{} {dst}, [{addr}], {src};",
space.as_ptx_str(),
op.as_ptx_str(),
ty.as_ptx_str()
)
}
Self::AtomCas {
space,
ty,
dst,
addr,
compare,
value,
} => {
format!(
"atom{}.cas{} {dst}, [{addr}], {compare}, {value};",
space.as_ptx_str(),
ty.as_ptx_str()
)
}
Self::Red {
space,
op,
ty,
addr,
src,
} => {
format!(
"red{}{}{} [{addr}], {src};",
space.as_ptx_str(),
op.as_ptx_str(),
ty.as_ptx_str()
)
}
Self::MovSpecial { dst, special } => {
format!("mov.u32 {dst}, {};", special.as_ptx_str())
}
Self::LoadParam {
ty,
dst,
param_name,
} => {
format!("ld.param{} {dst}, [{param_name}];", ty.as_ptx_str())
}
Self::Comment(text) => format!("// {text}"),
Self::Raw(text) => text.clone(),
Self::Pragma(text) => format!(".pragma \"{text}\";"),
Self::Dp4a {
dst,
a,
b,
c,
signed_a,
signed_b,
} => {
let a_ty = if *signed_a { ".s32" } else { ".u32" };
let b_ty = if *signed_b { ".s32" } else { ".u32" };
format!("dp4a{a_ty}{b_ty} {dst}, {a}, {b}, {c};")
}
Self::Dp2a {
dst,
a,
b,
c,
signed_a,
signed_b,
lo,
} => {
let a_ty = if *signed_a { ".s32" } else { ".u32" };
let b_ty = if *signed_b { ".s32" } else { ".u32" };
let half = if *lo { ".lo" } else { ".hi" };
format!("dp2a{half}{a_ty}{b_ty} {dst}, {a}, {b}, {c};")
}
Self::Redux {
op,
dst,
src,
membership_mask,
} => {
format!(
"redux.sync{}.u32 {dst}, {src}, 0x{membership_mask:08x};",
op.as_ptx_str()
)
}
Self::Stmatrix {
dst_addr,
src,
shape,
trans,
} => {
let trans_str = if *trans { ".trans" } else { "" };
format!(
"stmatrix.sync.aligned{}{trans_str}.shared.b16 [{dst_addr}], {{{src}}};",
shape.as_ptx_str()
)
}
Self::ElectSync {
dst,
membership_mask,
} => {
format!("elect.sync {dst}, 0x{membership_mask:08x};")
}
Self::Setmaxnreg { reg_count, action } => {
format!("setmaxnreg{} {reg_count};", action.as_ptx_str())
}
Self::Griddepcontrol { action } => {
format!("griddepcontrol{};", action.as_ptx_str())
}
Self::FenceProxy { scope, space } => {
format!(
"fence.proxy.async{}{};",
scope.as_ptx_str(),
space.as_ptx_str()
)
}
Self::MbarrierInit { addr, count } => {
format!("mbarrier.init.shared.b64 [{addr}], {count};")
}
Self::MbarrierArrive { addr } => {
format!("mbarrier.arrive.shared.b64 [{addr}];")
}
Self::MbarrierWait { addr, phase } => {
format!("mbarrier.try_wait.parity.shared.b64 [{addr}], {phase};")
}
Self::Tcgen05Mma { a_desc, b_desc } => {
format!("tcgen05.mma.cta_group::1.kind::f32 [{a_desc}], [{b_desc}];")
}
Self::BarrierCluster => "barrier.cluster.arrive;".to_string(),
Self::FenceCluster => "fence.mbarrier_init.release.cluster;".to_string(),
Self::CpAsyncBulk {
dst_smem,
src_gmem,
desc,
} => {
format!(
"cp.async.bulk.tensor.1d.shared::cluster.global.tile.bulk_group [{dst_smem}], [{src_gmem}, {{{desc}}}];"
)
}
Self::Tex1d {
ty,
dst,
tex_ref,
coord,
} => {
format!(
"tex.1d.v4{}.s32 {dst}, [{tex_ref}, {{{coord}}}];",
ty.as_ptx_str()
)
}
Self::Tex2d {
ty,
dst,
tex_ref,
coord_x,
coord_y,
} => {
format!(
"tex.2d.v4{}.s32 {dst}, [{tex_ref}, {{{coord_x}, {coord_y}}}];",
ty.as_ptx_str()
)
}
Self::Tex3d {
ty,
dst,
tex_ref,
coord_x,
coord_y,
coord_z,
} => {
format!(
"tex.3d.v4{}.s32 {dst}, [{tex_ref}, {{{coord_x}, {coord_y}, {coord_z}}}];",
ty.as_ptx_str()
)
}
Self::SurfLoad {
ty,
dst,
surf_ref,
coord,
} => {
format!(
"suld.b.1d{} {dst}, [{surf_ref}, {{{coord}}}];",
ty.as_ptx_str()
)
}
Self::SurfStore {
ty,
surf_ref,
coord,
src,
} => {
format!(
"sust.b.1d{} [{surf_ref}, {{{coord}}}], {src};",
ty.as_ptx_str()
)
}
Self::Ldmatrix {
num_fragments,
trans,
dst_regs,
src_addr,
} => {
let trans_str = if *trans { ".trans" } else { "" };
let x_str = match num_fragments {
2 => ".x2",
4 => ".x4",
_ => ".x1",
};
let dst_list = dst_regs
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ");
format!(
"ldmatrix.sync.aligned.m8n8{x_str}{trans_str}.shared.b16 {{{dst_list}}}, [{src_addr}];"
)
}
}
}
}
#[allow(clippy::too_many_lines)]
fn emit_wmma(
op: WmmaOp,
shape: WmmaShape,
layout: WmmaLayout,
ty: PtxType,
fragments: &[Register],
addr: Option<&Operand>,
stride: Option<&Operand>,
) -> String {
let frag_list = reg_list(fragments);
match op {
WmmaOp::LoadA => {
let addr_str = addr.map_or(String::new(), |a| format!("{a}"));
let stride_str = stride.map_or(String::new(), |s| format!(", {s}"));
format!(
"wmma.load.a.sync.aligned{}{}{} {{{frag_list}}}, [{addr_str}]{stride_str};",
shape.as_ptx_str(),
layout.as_ptx_str(),
ty.as_ptx_str()
)
}
WmmaOp::LoadB => {
let addr_str = addr.map_or(String::new(), |a| format!("{a}"));
let stride_str = stride.map_or(String::new(), |s| format!(", {s}"));
format!(
"wmma.load.b.sync.aligned{}{}{} {{{frag_list}}}, [{addr_str}]{stride_str};",
shape.as_ptx_str(),
layout.as_ptx_str(),
ty.as_ptx_str()
)
}
WmmaOp::StoreD => {
let addr_str = addr.map_or(String::new(), |a| format!("{a}"));
let stride_str = stride.map_or(String::new(), |s| format!(", {s}"));
format!(
"wmma.store.d.sync.aligned{}{}{} [{addr_str}], {{{frag_list}}}{stride_str};",
shape.as_ptx_str(),
layout.as_ptx_str(),
ty.as_ptx_str()
)
}
WmmaOp::Mma => {
format!(
"wmma.mma.sync.aligned{}{}{} {{{frag_list}}};",
shape.as_ptx_str(),
layout.as_ptx_str(),
ty.as_ptx_str()
)
}
}
}
#[allow(clippy::too_many_arguments)]
fn emit_mma(
shape: MmaShape,
a_ty: PtxType,
b_ty: PtxType,
c_ty: PtxType,
d_ty: PtxType,
d_regs: &[Register],
a_regs: &[Register],
b_regs: &[Register],
c_regs: &[Register],
) -> String {
let d_list = reg_list(d_regs);
let a_list = reg_list(a_regs);
let b_list = reg_list(b_regs);
let c_list = reg_list(c_regs);
format!(
"mma.sync.aligned{}.row.col{}{}{}{} {{{d_list}}}, {{{a_list}}}, {{{b_list}}}, {{{c_list}}};",
shape.as_ptx_str(),
d_ty.as_ptx_str(),
a_ty.as_ptx_str(),
b_ty.as_ptx_str(),
c_ty.as_ptx_str()
)
}
fn reg_list(regs: &[Register]) -> String {
regs.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ")
}
#[cfg(test)]
#[path = "instruction_tests.rs"]
mod tests;