mod arithmetic;
mod atomic;
mod comparison;
mod control;
mod core;
mod emit;
mod memory;
mod sync;
mod atomics_debug;
mod bitwise_ops;
mod conversions;
mod generic_mem;
mod global_mem;
mod inplace_ops;
mod misc_ops;
mod precise_math;
mod tensor_core;
mod warp_vote;
mod ptx_module;
pub use arithmetic::PtxArithmetic;
pub use atomic::PtxAtomic;
pub use comparison::PtxComparison;
pub use control::PtxControl;
pub use core::KernelBuilderCore;
pub use memory::PtxMemory;
pub use sync::PtxSync;
pub use ptx_module::{KernelParam, PtxKernel, PtxModule};
use super::instructions::{Operand, Predicate, PtxInstruction, PtxOp};
use super::registers::{PtxReg, RegisterAllocator, VirtualReg};
use super::types::{PtxStateSpace, PtxType};
macro_rules! impl_dp4a_inplace {
($fn_name:ident, $op:ident, $ty:ident, $doc:expr) => {
#[doc = $doc]
pub fn $fn_name(&mut self, acc: VirtualReg, a: VirtualReg, b: VirtualReg) {
self.instructions.push(
PtxInstruction::new(PtxOp::$op, PtxType::$ty)
.dst(Operand::Reg(acc))
.src(Operand::Reg(a))
.src(Operand::Reg(b))
.src(Operand::Reg(acc)),
);
}
};
}
pub struct KernelBuilder<'a> {
pub(crate) registers: &'a mut RegisterAllocator,
pub(crate) instructions: Vec<PtxInstruction>,
pub(crate) labels: Vec<String>,
}
impl<'a> core::KernelBuilderCore for KernelBuilder<'a> {
fn registers_mut(&mut self) -> &mut RegisterAllocator {
self.registers
}
fn instructions_mut(&mut self) -> &mut Vec<PtxInstruction> {
&mut self.instructions
}
fn labels_mut(&mut self) -> &mut Vec<String> {
&mut self.labels
}
}
impl<'a> KernelBuilder<'a> {
pub(crate) fn new(registers: &'a mut RegisterAllocator) -> Self {
Self { registers, instructions: Vec::new(), labels: Vec::new() }
}
pub fn special_reg(&mut self, reg: PtxReg) -> VirtualReg {
let vreg = self.registers.allocate_virtual(reg.data_type());
self.instructions.push(
PtxInstruction::new(PtxOp::Mov, reg.data_type())
.dst(Operand::Reg(vreg))
.src(Operand::SpecialReg(reg)),
);
vreg
}
pub fn load_param_u32(&mut self, name: &str) -> VirtualReg {
let vreg = self.registers.allocate_virtual(PtxType::U32);
self.instructions.push(
PtxInstruction::new(PtxOp::LdParam, PtxType::U32)
.dst(Operand::Reg(vreg))
.src(Operand::Param(name.to_string())),
);
vreg
}
pub fn load_param_u64(&mut self, name: &str) -> VirtualReg {
let vreg = self.registers.allocate_virtual(PtxType::U64);
self.instructions.push(
PtxInstruction::new(PtxOp::LdParam, PtxType::U64)
.dst(Operand::Reg(vreg))
.src(Operand::Param(name.to_string())),
);
vreg
}
pub fn load_param_f32(&mut self, name: &str) -> VirtualReg {
let vreg = self.registers.allocate_virtual(PtxType::F32);
self.instructions.push(
PtxInstruction::new(PtxOp::LdParam, PtxType::F32)
.dst(Operand::Reg(vreg))
.src(Operand::Param(name.to_string())),
);
vreg
}
pub fn mov_u64_into(&mut self, dst: VirtualReg, val: u64) {
self.instructions.push(
PtxInstruction::new(PtxOp::Mov, PtxType::U64)
.dst(Operand::Reg(dst))
.src(Operand::ImmU64(val)),
);
}
pub fn mov_u32_into(&mut self, dst: VirtualReg, val: u32) {
self.instructions.push(
PtxInstruction::new(PtxOp::Mov, PtxType::U32)
.dst(Operand::Reg(dst))
.src(Operand::ImmI64(val as i64)),
);
}
pub fn add_u32(&mut self, a: VirtualReg, b: u32) -> VirtualReg {
let dst = self.registers.allocate_virtual(PtxType::U32);
self.instructions.push(
PtxInstruction::new(PtxOp::Add, PtxType::U32)
.dst(Operand::Reg(dst))
.src(Operand::Reg(a))
.src(Operand::ImmU64(b as u64)),
);
dst
}
pub fn dp4a_u32(&mut self, a: VirtualReg, b: VirtualReg, c: VirtualReg) -> VirtualReg {
let dst = self.registers.allocate_virtual(PtxType::U32);
self.instructions.push(
PtxInstruction::new(PtxOp::Dp4a, PtxType::U32)
.dst(Operand::Reg(dst))
.src(Operand::Reg(a))
.src(Operand::Reg(b))
.src(Operand::Reg(c)),
);
dst
}
impl_dp4a_inplace!(
dp4a_u32_inplace,
Dp4a,
U32,
"DP4A u32 in-place: acc += dot4(a, b) where a,b are packed u8x4"
);
impl_dp4a_inplace!(
dp4a_u32_s32_inplace,
Dp4aUS,
S32,
"DP4A u32\u{00d7}s32 in-place: acc += dot4(u8x4, s8x4)"
);
impl_dp4a_inplace!(
dp4a_s32_inplace,
Dp4aS32,
S32,
"DP4A s32 in-place: acc += dot4(s8x4, s8x4)"
);
pub fn bar_sync(&mut self, barrier_id: u32) {
self.instructions.push(
PtxInstruction::new(PtxOp::Bar, PtxType::B32).label(format!("sync {}", barrier_id)),
);
}
pub fn membar_cta(&mut self) {
self.instructions
.push(PtxInstruction::new(PtxOp::MemBar, PtxType::B32).label("cta".to_string()));
}
pub fn membar_gl(&mut self) {
self.instructions
.push(PtxInstruction::new(PtxOp::MemBar, PtxType::B32).label("gl".to_string()));
}
pub fn st_shared_u16(&mut self, addr: VirtualReg, val: VirtualReg) {
self.instructions.push(
PtxInstruction::new(PtxOp::St, PtxType::U16)
.src(Operand::Reg(addr))
.src(Operand::Reg(val))
.space(PtxStateSpace::Shared),
);
}
}
#[cfg(test)]
mod tests;