use super::super::instructions::{CmpOp, Operand, PtxInstruction, PtxOp};
use super::super::registers::VirtualReg;
use super::super::types::PtxType;
use super::core::KernelBuilderCore;
pub trait PtxComparison: KernelBuilderCore {
fn setp_ge_u32(&mut self, a: VirtualReg, b: VirtualReg) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::U32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::Reg(b));
instr.label = Some(CmpOp::Ge.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
fn setp_eq_u32(&mut self, a: VirtualReg, b: VirtualReg) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::U32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::Reg(b));
instr.label = Some(CmpOp::Eq.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
fn setp_ne_u32(&mut self, a: VirtualReg, b: VirtualReg) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::U32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::Reg(b));
instr.label = Some(CmpOp::Ne.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
fn setp_lt_u32(&mut self, a: VirtualReg, b: VirtualReg) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::U32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::Reg(b));
instr.label = Some(CmpOp::Lt.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
fn setp_gt_u32(&mut self, a: VirtualReg, b: VirtualReg) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::U32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::Reg(b));
instr.label = Some(CmpOp::Gt.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
fn setp_le_u32(&mut self, a: VirtualReg, b: VirtualReg) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::U32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::Reg(b));
instr.label = Some(CmpOp::Le.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
fn setp_lt_f32(&mut self, a: VirtualReg, b: VirtualReg) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::F32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::Reg(b));
instr.label = Some(CmpOp::Lt.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
fn setp_gt_f32(&mut self, a: VirtualReg, b: VirtualReg) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::F32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::Reg(b));
instr.label = Some(CmpOp::Gt.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
fn setp_ge_u32_imm(&mut self, a: VirtualReg, b: u32) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::U32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::ImmI64(b as i64));
instr.label = Some(CmpOp::Ge.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
fn setp_lt_u32_imm(&mut self, a: VirtualReg, b: u32) -> VirtualReg {
let pred = self.registers_mut().allocate_virtual(PtxType::Pred);
let mut instr = PtxInstruction::new(PtxOp::Setp, PtxType::U32)
.dst(Operand::Reg(pred))
.src(Operand::Reg(a))
.src(Operand::ImmI64(b as i64));
instr.label = Some(CmpOp::Lt.to_ptx_string().to_string());
self.instructions_mut().push(instr);
pred
}
}
impl<T: KernelBuilderCore> PtxComparison for T {}
#[cfg(test)]
mod tests {
use super::*;
use crate::ptx::registers::RegisterAllocator;
struct MockBuilder {
registers: RegisterAllocator,
instructions: Vec<PtxInstruction>,
labels: Vec<String>,
}
impl MockBuilder {
fn new() -> Self {
Self {
registers: RegisterAllocator::new(),
instructions: Vec::new(),
labels: Vec::new(),
}
}
}
impl KernelBuilderCore for MockBuilder {
fn registers_mut(&mut self) -> &mut RegisterAllocator {
&mut self.registers
}
fn instructions_mut(&mut self) -> &mut Vec<PtxInstruction> {
&mut self.instructions
}
fn labels_mut(&mut self) -> &mut Vec<String> {
&mut self.labels
}
}
#[test]
fn test_setp_ge_u32() {
let mut builder = MockBuilder::new();
let a = builder.registers.allocate_virtual(PtxType::U32);
let b = builder.registers.allocate_virtual(PtxType::U32);
let pred = builder.setp_ge_u32(a, b);
assert_eq!(builder.instructions.len(), 1);
assert_eq!(builder.instructions[0].op, PtxOp::Setp);
assert_eq!(builder.instructions[0].label.as_deref(), Some("ge"));
let _ = pred;
}
#[test]
fn test_all_comparisons() {
let mut builder = MockBuilder::new();
let a = builder.registers.allocate_virtual(PtxType::U32);
let b = builder.registers.allocate_virtual(PtxType::U32);
let _ge = builder.setp_ge_u32(a, b);
let _eq = builder.setp_eq_u32(a, b);
let _ne = builder.setp_ne_u32(a, b);
let _lt = builder.setp_lt_u32(a, b);
let _gt = builder.setp_gt_u32(a, b);
let _le = builder.setp_le_u32(a, b);
assert_eq!(builder.instructions.len(), 6);
let labels: Vec<_> = builder
.instructions
.iter()
.map(|i| i.label.as_deref().unwrap())
.collect();
assert_eq!(labels, vec!["ge", "eq", "ne", "lt", "gt", "le"]);
}
#[test]
fn test_f32_comparisons() {
let mut builder = MockBuilder::new();
let a = builder.registers.allocate_virtual(PtxType::F32);
let b = builder.registers.allocate_virtual(PtxType::F32);
let _lt = builder.setp_lt_f32(a, b);
let _gt = builder.setp_gt_f32(a, b);
assert_eq!(builder.instructions.len(), 2);
assert_eq!(builder.instructions[0].ty, PtxType::F32);
assert_eq!(builder.instructions[1].ty, PtxType::F32);
}
#[test]
fn test_setp_ge_u32_imm() {
let mut builder = MockBuilder::new();
let a = builder.registers.allocate_virtual(PtxType::U32);
let pred = builder.setp_ge_u32_imm(a, 42);
assert_eq!(builder.instructions.len(), 1);
assert_eq!(builder.instructions[0].op, PtxOp::Setp);
assert_eq!(builder.instructions[0].ty, PtxType::U32);
assert_eq!(builder.instructions[0].label.as_deref(), Some("ge"));
assert_eq!(builder.instructions[0].srcs.len(), 2);
match &builder.instructions[0].srcs[1] {
Operand::ImmI64(val) => assert_eq!(*val, 42),
_ => panic!("Expected ImmI64 operand"),
}
assert_eq!(pred.ty(), PtxType::Pred);
}
#[test]
fn test_setp_lt_u32_imm() {
let mut builder = MockBuilder::new();
let a = builder.registers.allocate_virtual(PtxType::U32);
let pred = builder.setp_lt_u32_imm(a, 100);
assert_eq!(builder.instructions.len(), 1);
assert_eq!(builder.instructions[0].op, PtxOp::Setp);
assert_eq!(builder.instructions[0].ty, PtxType::U32);
assert_eq!(builder.instructions[0].label.as_deref(), Some("lt"));
assert_eq!(builder.instructions[0].srcs.len(), 2);
match &builder.instructions[0].srcs[1] {
Operand::ImmI64(val) => assert_eq!(*val, 100),
_ => panic!("Expected ImmI64 operand"),
}
assert_eq!(pred.ty(), PtxType::Pred);
}
#[test]
fn test_imm_comparisons_boundary_values() {
let mut builder = MockBuilder::new();
let a = builder.registers.allocate_virtual(PtxType::U32);
let _pred_zero = builder.setp_ge_u32_imm(a, 0);
let _pred_max = builder.setp_lt_u32_imm(a, u32::MAX);
assert_eq!(builder.instructions.len(), 2);
match &builder.instructions[0].srcs[1] {
Operand::ImmI64(val) => assert_eq!(*val, 0),
_ => panic!("Expected ImmI64 operand"),
}
match &builder.instructions[1].srcs[1] {
Operand::ImmI64(val) => assert_eq!(*val, u32::MAX as i64),
_ => panic!("Expected ImmI64 operand"),
}
}
#[test]
fn test_all_comparisons_generate_unique_predicates() {
let mut builder = MockBuilder::new();
let a = builder.registers.allocate_virtual(PtxType::U32);
let b = builder.registers.allocate_virtual(PtxType::U32);
let f32_a = builder.registers.allocate_virtual(PtxType::F32);
let f32_b = builder.registers.allocate_virtual(PtxType::F32);
let pred1 = builder.setp_ge_u32(a, b);
let pred2 = builder.setp_eq_u32(a, b);
let pred3 = builder.setp_ne_u32(a, b);
let pred4 = builder.setp_lt_u32(a, b);
let pred5 = builder.setp_gt_u32(a, b);
let pred6 = builder.setp_le_u32(a, b);
let pred7 = builder.setp_lt_f32(f32_a, f32_b);
let pred8 = builder.setp_gt_f32(f32_a, f32_b);
let pred9 = builder.setp_ge_u32_imm(a, 10);
let pred10 = builder.setp_lt_u32_imm(a, 20);
let preds = vec![
pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10,
];
for i in 0..preds.len() {
for j in (i + 1)..preds.len() {
assert_ne!(
preds[i].id(),
preds[j].id(),
"Predicate {} and {} should have different IDs",
i,
j
);
}
}
for (i, pred) in preds.iter().enumerate() {
assert_eq!(
pred.ty(),
PtxType::Pred,
"Predicate {} should be of type Pred",
i
);
}
}
}