use crate::ir::{Operand, PtxInstruction, Register, RegisterAllocator, SpecialReg};
use crate::types::PtxType;
pub fn tid_x(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::TidX)
}
pub fn tid_y(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::TidY)
}
pub fn tid_z(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::TidZ)
}
pub fn ntid_x(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::NtidX)
}
pub fn ntid_y(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::NtidY)
}
pub fn ntid_z(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::NtidZ)
}
pub fn ctaid_x(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::CtaidX)
}
pub fn ctaid_y(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::CtaidY)
}
pub fn ctaid_z(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::CtaidZ)
}
pub fn nctaid_x(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::NctaidX)
}
pub fn nctaid_y(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::NctaidY)
}
pub fn nctaid_z(alloc: &mut RegisterAllocator) -> (Register, PtxInstruction) {
read_special(alloc, SpecialReg::NctaidZ)
}
fn read_special(alloc: &mut RegisterAllocator, sr: SpecialReg) -> (Register, PtxInstruction) {
let reg = alloc.alloc(PtxType::U32);
let instr = PtxInstruction::Mov {
dst: reg,
src: Operand::SpecialReg(sr),
ty: PtxType::U32,
};
(reg, instr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::RegKind;
#[test]
fn special_tid_x() {
let mut alloc = RegisterAllocator::new();
let (reg, instr) = tid_x(&mut alloc);
assert_eq!(reg.kind, RegKind::R);
assert_eq!(reg.index, 0);
assert_eq!(reg.ptx_type, PtxType::U32);
match &instr {
PtxInstruction::Mov { dst, src, ty } => {
assert_eq!(*dst, reg);
assert_eq!(*ty, PtxType::U32);
match src {
Operand::SpecialReg(sr) => assert_eq!(*sr, SpecialReg::TidX),
other => panic!("expected SpecialReg, got {other:?}"),
}
}
other => panic!("expected Mov, got {other:?}"),
}
}
#[test]
fn special_ctaid_x() {
let mut alloc = RegisterAllocator::new();
let (reg, instr) = ctaid_x(&mut alloc);
assert_eq!(reg.kind, RegKind::R);
assert_eq!(reg.index, 0);
match &instr {
PtxInstruction::Mov { src, .. } => match src {
Operand::SpecialReg(sr) => assert_eq!(*sr, SpecialReg::CtaidX),
other => panic!("expected SpecialReg, got {other:?}"),
},
other => panic!("expected Mov, got {other:?}"),
}
}
}