use std::fmt;
use crate::emit::{Emit, PtxWriter};
use crate::ir::{Operand, Register};
use crate::types::PtxType;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CmpOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
}
impl CmpOp {
pub fn ptx_str(&self) -> &'static str {
match self {
Self::Eq => "eq",
Self::Ne => "ne",
Self::Lt => "lt",
Self::Le => "le",
Self::Gt => "gt",
Self::Ge => "ge",
}
}
}
#[derive(Debug, Clone)]
pub enum ControlOp {
SetP {
dst: Register,
cmp_op: CmpOp,
lhs: Operand,
rhs: Operand,
ty: PtxType,
},
SetPAnd {
dst: Register,
cmp_op: CmpOp,
lhs: Operand,
rhs: Operand,
ty: PtxType,
src_pred: Register,
},
BraPred {
pred: Register,
target: String,
negate: bool,
},
Bra {
target: String,
},
Ret,
BarSync {
barrier_id: u32,
},
ShflSyncDown {
dst: Register,
src: Register,
delta: Operand,
c: u32,
mask: u32,
},
ShflSyncUp {
dst: Register,
src: Register,
delta: Operand,
c: u32,
mask: u32,
},
ShflSyncBfly {
dst: Register,
src: Register,
lane_mask: Operand,
c: u32,
mask: u32,
},
}
impl Emit for ControlOp {
fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
match self {
ControlOp::SetP {
dst,
cmp_op,
lhs,
rhs,
ty,
} => {
let mnemonic = format!("setp.{}{}", cmp_op.ptx_str(), ty.ptx_suffix());
w.instruction(&mnemonic, &[dst as &dyn fmt::Display, lhs, rhs])
}
ControlOp::SetPAnd {
dst,
cmp_op,
lhs,
rhs,
ty,
src_pred,
} => {
let mnemonic = format!("setp.{}.and{}", cmp_op.ptx_str(), ty.ptx_suffix());
w.instruction(&mnemonic, &[dst as &dyn fmt::Display, lhs, rhs, src_pred])
}
ControlOp::BraPred {
pred,
target,
negate,
} => {
let neg = if *negate { "!" } else { "" };
w.line(&format!("@{neg}{pred} bra {target};"))
}
ControlOp::Bra { target } => w.instruction("bra", &[&target as &dyn fmt::Display]),
ControlOp::Ret => w.instruction("ret", &[]),
ControlOp::BarSync { barrier_id } => w.line(&format!("bar.sync {barrier_id};")),
ControlOp::ShflSyncDown {
dst,
src,
delta,
c,
mask,
} => w.line(&format!(
"shfl.sync.down.b32 {dst}, {src}, {delta}, {c}, 0x{mask:08X};"
)),
ControlOp::ShflSyncUp {
dst,
src,
delta,
c,
mask,
} => w.line(&format!(
"shfl.sync.up.b32 {dst}, {src}, {delta}, {c}, 0x{mask:08X};"
)),
ControlOp::ShflSyncBfly {
dst,
src,
lane_mask,
c,
mask,
} => w.line(&format!(
"shfl.sync.bfly.b32 {dst}, {src}, {lane_mask}, {c}, 0x{mask:08X};"
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::RegKind;
fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
Register {
kind,
index,
ptx_type,
}
}
#[test]
fn emit_setp_and_lt_u32() {
let mut w = PtxWriter::new();
w.indent();
let op = ControlOp::SetPAnd {
dst: reg(RegKind::P, 3, PtxType::Pred),
cmp_op: CmpOp::Lt,
lhs: Operand::Reg(reg(RegKind::R, 5, PtxType::U32)),
rhs: Operand::Reg(reg(RegKind::R, 10, PtxType::U32)),
ty: PtxType::U32,
src_pred: reg(RegKind::P, 2, PtxType::Pred),
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " setp.lt.and.u32 %p3, %r5, %r10, %p2;\n");
}
#[test]
fn emit_setp_ge_u32() {
let mut w = PtxWriter::new();
w.indent();
let op = ControlOp::SetP {
dst: reg(RegKind::P, 1, PtxType::Pred),
cmp_op: CmpOp::Ge,
lhs: Operand::Reg(reg(RegKind::R, 1, PtxType::U32)),
rhs: Operand::Reg(reg(RegKind::R, 2, PtxType::U32)),
ty: PtxType::U32,
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " setp.ge.u32 %p1, %r1, %r2;\n");
}
#[test]
fn emit_bra_pred() {
let mut w = PtxWriter::new();
w.indent();
let op = ControlOp::BraPred {
pred: reg(RegKind::P, 1, PtxType::Pred),
target: "$L__BB0_2".to_string(),
negate: false,
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " @%p1 bra $L__BB0_2;\n");
}
#[test]
fn emit_bra_pred_negated() {
let mut w = PtxWriter::new();
w.indent();
let op = ControlOp::BraPred {
pred: reg(RegKind::P, 1, PtxType::Pred),
target: "IF_END_0".to_string(),
negate: true,
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " @!%p1 bra IF_END_0;\n");
}
#[test]
fn emit_ret() {
let mut w = PtxWriter::new();
w.indent();
ControlOp::Ret.emit(&mut w).unwrap();
assert_eq!(w.finish(), " ret;\n");
}
#[test]
fn emit_bra_unconditional() {
let mut w = PtxWriter::new();
w.indent();
let op = ControlOp::Bra {
target: "LOOP".to_string(),
};
op.emit(&mut w).unwrap();
assert_eq!(w.finish(), " bra LOOP;\n");
}
#[test]
fn control_via_ptx_instruction() {
use crate::ir::PtxInstruction;
let mut w = PtxWriter::new();
w.indent();
let instr = PtxInstruction::Control(ControlOp::Ret);
instr.emit(&mut w).unwrap();
assert_eq!(w.finish(), " ret;\n");
}
#[test]
fn emit_bar_sync() {
let mut w = PtxWriter::new();
w.indent();
ControlOp::BarSync { barrier_id: 0 }.emit(&mut w).unwrap();
assert_eq!(w.finish(), " bar.sync 0;\n");
}
#[test]
fn emit_shfl_sync_down() {
let mut w = PtxWriter::new();
w.indent();
let op = ControlOp::ShflSyncDown {
dst: reg(RegKind::R, 2, PtxType::U32),
src: reg(RegKind::R, 1, PtxType::U32),
delta: Operand::ImmU32(1),
c: 31,
mask: 0xFFFFFFFF,
};
op.emit(&mut w).unwrap();
assert_eq!(
w.finish(),
" shfl.sync.down.b32 %r2, %r1, 1, 31, 0xFFFFFFFF;\n"
);
}
#[test]
fn emit_shfl_sync_up() {
let mut w = PtxWriter::new();
w.indent();
let op = ControlOp::ShflSyncUp {
dst: reg(RegKind::R, 2, PtxType::U32),
src: reg(RegKind::R, 1, PtxType::U32),
delta: Operand::ImmU32(1),
c: 0,
mask: 0xFFFFFFFF,
};
op.emit(&mut w).unwrap();
assert_eq!(
w.finish(),
" shfl.sync.up.b32 %r2, %r1, 1, 0, 0xFFFFFFFF;\n"
);
}
#[test]
fn emit_shfl_sync_bfly() {
let mut w = PtxWriter::new();
w.indent();
let op = ControlOp::ShflSyncBfly {
dst: reg(RegKind::R, 2, PtxType::U32),
src: reg(RegKind::R, 1, PtxType::U32),
lane_mask: Operand::ImmU32(1),
c: 31,
mask: 0xFFFFFFFF,
};
op.emit(&mut w).unwrap();
assert_eq!(
w.finish(),
" shfl.sync.bfly.b32 %r2, %r1, 1, 31, 0xFFFFFFFF;\n"
);
}
#[test]
fn shfl_sync_down_with_register_delta() {
let mut w = PtxWriter::new();
w.indent();
let op = ControlOp::ShflSyncDown {
dst: reg(RegKind::R, 3, PtxType::U32),
src: reg(RegKind::R, 0, PtxType::U32),
delta: Operand::Reg(reg(RegKind::R, 4, PtxType::U32)),
c: 31,
mask: 0xFFFFFFFF,
};
op.emit(&mut w).unwrap();
assert_eq!(
w.finish(),
" shfl.sync.down.b32 %r3, %r0, %r4, 31, 0xFFFFFFFF;\n"
);
}
#[test]
fn cmp_op_all_variants() {
assert_eq!(CmpOp::Eq.ptx_str(), "eq");
assert_eq!(CmpOp::Ne.ptx_str(), "ne");
assert_eq!(CmpOp::Lt.ptx_str(), "lt");
assert_eq!(CmpOp::Le.ptx_str(), "le");
assert_eq!(CmpOp::Gt.ptx_str(), "gt");
assert_eq!(CmpOp::Ge.ptx_str(), "ge");
}
}