use super::super::instructions::{Operand, Predicate, PtxInstruction, PtxOp};
use super::super::registers::VirtualReg;
use super::super::types::PtxType;
use super::core::KernelBuilderCore;
pub trait PtxControl: KernelBuilderCore {
fn label(&mut self, name: &str) {
self.labels_mut().push(name.to_string());
let mut instr = PtxInstruction::new(PtxOp::Mov, PtxType::B32);
instr.label = Some(format!("{}:", name));
self.instructions_mut().push(instr);
}
fn branch(&mut self, target: &str) {
self.instructions_mut().push(PtxInstruction::new(PtxOp::Bra, PtxType::B32).label(target));
}
fn branch_if(&mut self, pred: VirtualReg, target: &str) {
let predicate = Predicate { reg: pred, negated: false };
self.instructions_mut().push(
PtxInstruction::new(PtxOp::Bra, PtxType::B32).predicated(predicate).label(target),
);
}
fn branch_if_not(&mut self, pred: VirtualReg, target: &str) {
let predicate = Predicate { reg: pred, negated: true };
self.instructions_mut().push(
PtxInstruction::new(PtxOp::Bra, PtxType::B32).predicated(predicate).label(target),
);
}
fn ret(&mut self) {
self.instructions_mut().push(PtxInstruction::new(PtxOp::Ret, PtxType::Pred));
}
fn mov_u64_imm(&mut self, val: u64) -> VirtualReg {
let dst = self.registers_mut().allocate_virtual(PtxType::U64);
self.instructions_mut().push(
PtxInstruction::new(PtxOp::Mov, PtxType::U64)
.dst(Operand::Reg(dst))
.src(Operand::ImmU64(val)),
);
dst
}
fn mov_u32_imm(&mut self, val: u32) -> VirtualReg {
let dst = self.registers_mut().allocate_virtual(PtxType::U32);
self.instructions_mut().push(
PtxInstruction::new(PtxOp::Mov, PtxType::U32)
.dst(Operand::Reg(dst))
.src(Operand::ImmI64(val as i64)),
);
dst
}
fn mov_f32_imm(&mut self, val: f32) -> VirtualReg {
let dst = self.registers_mut().allocate_virtual(PtxType::F32);
self.instructions_mut().push(
PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(dst))
.src(Operand::ImmF32(val)),
);
dst
}
fn mov_reg(&mut self, src: VirtualReg, ty: PtxType) -> VirtualReg {
let dst = self.registers_mut().allocate_virtual(ty);
self.instructions_mut().push(
PtxInstruction::new(PtxOp::Mov, ty).dst(Operand::Reg(dst)).src(Operand::Reg(src)),
);
dst
}
}
impl<T: KernelBuilderCore> PtxControl for T {}
#[cfg(test)]
mod tests {
use super::*;
use crate::ptx::registers::RegisterAllocator;
struct MockBuilder {
registers: RegisterAllocator,
instructions: Vec<PtxInstruction>,
labels: Vec<String>,
}
impl MockBuilder {
fn new() -> Self {
Self {
registers: RegisterAllocator::new(),
instructions: Vec::new(),
labels: Vec::new(),
}
}
}
impl KernelBuilderCore for MockBuilder {
fn registers_mut(&mut self) -> &mut RegisterAllocator {
&mut self.registers
}
fn instructions_mut(&mut self) -> &mut Vec<PtxInstruction> {
&mut self.instructions
}
fn labels_mut(&mut self) -> &mut Vec<String> {
&mut self.labels
}
}
#[test]
fn test_label_and_branch() {
let mut builder = MockBuilder::new();
builder.label("loop_start");
builder.branch("loop_start");
assert_eq!(builder.labels.len(), 1);
assert_eq!(builder.labels[0], "loop_start");
assert_eq!(builder.instructions.len(), 2);
}
#[test]
fn test_conditional_branch() {
let mut builder = MockBuilder::new();
let pred = builder.registers.allocate_virtual(PtxType::Pred);
builder.branch_if(pred, "target");
builder.branch_if_not(pred, "other");
assert_eq!(builder.instructions.len(), 2);
assert!(builder.instructions[0].predicate.is_some());
assert!(builder.instructions[1].predicate.is_some());
assert!(!builder.instructions[0].predicate.as_ref().unwrap().negated);
assert!(builder.instructions[1].predicate.as_ref().unwrap().negated);
}
#[test]
fn test_mov_immediates() {
let mut builder = MockBuilder::new();
let _a = builder.mov_u32_imm(42);
let _b = builder.mov_u64_imm(12345);
let _c = builder.mov_f32_imm(3.14);
assert_eq!(builder.instructions.len(), 3);
for instr in &builder.instructions {
assert_eq!(instr.op, PtxOp::Mov);
}
}
#[test]
fn test_ret() {
let mut builder = MockBuilder::new();
builder.ret();
assert_eq!(builder.instructions.len(), 1);
assert_eq!(builder.instructions[0].op, PtxOp::Ret);
}
#[test]
fn test_mov_reg() {
let mut builder = MockBuilder::new();
let src = builder.mov_u32_imm(42);
let dst = builder.mov_reg(src, PtxType::U32);
assert_eq!(builder.instructions.len(), 2);
assert_eq!(builder.instructions[0].op, PtxOp::Mov);
assert_eq!(builder.instructions[1].op, PtxOp::Mov);
assert_ne!(src, dst);
assert!(!builder.instructions[1].srcs.is_empty());
match &builder.instructions[1].srcs[0] {
Operand::Reg(r) => assert_eq!(*r, src),
_ => panic!("Expected register source operand"),
}
match &builder.instructions[1].dst {
Some(Operand::Reg(r)) => assert_eq!(*r, dst),
_ => panic!("Expected register destination operand"),
}
}
#[test]
fn test_mov_reg_f32() {
let mut builder = MockBuilder::new();
let src = builder.mov_f32_imm(3.14);
let dst = builder.mov_reg(src, PtxType::F32);
assert_eq!(builder.instructions.len(), 2);
assert_eq!(builder.instructions[1].ty, PtxType::F32);
assert_ne!(src, dst);
}
#[test]
fn test_multiple_labels() {
let mut builder = MockBuilder::new();
builder.label("start");
builder.label("middle");
builder.label("end");
assert_eq!(builder.labels.len(), 3);
assert_eq!(builder.labels[0], "start");
assert_eq!(builder.labels[1], "middle");
assert_eq!(builder.labels[2], "end");
assert_eq!(builder.instructions.len(), 3);
for instr in &builder.instructions {
assert!(instr.label.is_some());
}
}
#[test]
fn test_label_format() {
let mut builder = MockBuilder::new();
builder.label("my_loop");
let instr = &builder.instructions[0];
assert_eq!(instr.label, Some("my_loop:".to_string()));
}
#[test]
fn test_branch_creates_label_target() {
let mut builder = MockBuilder::new();
builder.branch("exit");
let instr = &builder.instructions[0];
assert_eq!(instr.op, PtxOp::Bra);
assert_eq!(instr.label, Some("exit".to_string()));
assert!(instr.predicate.is_none());
}
#[test]
fn test_branch_if_predicate_structure() {
let mut builder = MockBuilder::new();
let pred = builder.registers.allocate_virtual(PtxType::Pred);
builder.branch_if(pred, "target_label");
let instr = &builder.instructions[0];
let predicate = instr.predicate.as_ref().unwrap();
assert_eq!(predicate.reg, pred);
assert!(!predicate.negated);
}
#[test]
fn test_branch_if_not_predicate_structure() {
let mut builder = MockBuilder::new();
let pred = builder.registers.allocate_virtual(PtxType::Pred);
builder.branch_if_not(pred, "other_label");
let instr = &builder.instructions[0];
let predicate = instr.predicate.as_ref().unwrap();
assert_eq!(predicate.reg, pred);
assert!(predicate.negated);
}
#[test]
fn test_mov_u64_imm_value() {
let mut builder = MockBuilder::new();
let reg = builder.mov_u64_imm(0xDEADBEEFCAFEBABE);
assert_eq!(builder.instructions.len(), 1);
let instr = &builder.instructions[0];
assert_eq!(instr.op, PtxOp::Mov);
assert_eq!(instr.ty, PtxType::U64);
match &instr.dst {
Some(Operand::Reg(r)) => assert_eq!(*r, reg),
_ => panic!("Expected register destination"),
}
assert!(!instr.srcs.is_empty());
match &instr.srcs[0] {
Operand::ImmU64(v) => assert_eq!(*v, 0xDEADBEEFCAFEBABE),
_ => panic!("Expected ImmU64 source"),
}
}
#[test]
fn test_mov_u32_imm_value() {
let mut builder = MockBuilder::new();
let reg = builder.mov_u32_imm(0xCAFEBABE);
let instr = &builder.instructions[0];
assert_eq!(instr.ty, PtxType::U32);
match &instr.dst {
Some(Operand::Reg(r)) => assert_eq!(*r, reg),
_ => panic!("Expected register destination"),
}
assert!(!instr.srcs.is_empty());
match &instr.srcs[0] {
Operand::ImmI64(v) => assert_eq!(*v, 0xCAFEBABE_i64),
_ => panic!("Expected ImmI64 source"),
}
}
#[test]
fn test_mov_f32_imm_value() {
let mut builder = MockBuilder::new();
let reg = builder.mov_f32_imm(std::f32::consts::E);
let instr = &builder.instructions[0];
assert_eq!(instr.ty, PtxType::F32);
match &instr.dst {
Some(Operand::Reg(r)) => assert_eq!(*r, reg),
_ => panic!("Expected register destination"),
}
assert!(!instr.srcs.is_empty());
match &instr.srcs[0] {
Operand::ImmF32(v) => assert!((v - std::f32::consts::E).abs() < 1e-5),
_ => panic!("Expected ImmF32 source"),
}
}
#[test]
fn test_ret_instruction_type() {
let mut builder = MockBuilder::new();
builder.ret();
let instr = &builder.instructions[0];
assert_eq!(instr.op, PtxOp::Ret);
assert_eq!(instr.ty, PtxType::Pred);
}
}