use crate::arch::SmVersion;
use crate::builder::KernelBuilder;
use crate::error::PtxGenError;
use crate::ir::PtxType;
use super::elementwisetemplate_type::ElementwiseTemplate;
use super::functions::{float_one_literal, float_zero_literal, scalar_param_type};
use super::types::ElementwiseOp;
impl ElementwiseTemplate {
#[must_use]
pub const fn new(op: ElementwiseOp, precision: PtxType, target: SmVersion) -> Self {
Self {
op,
precision,
target,
}
}
#[must_use]
pub fn kernel_name(&self) -> String {
let type_str = self.precision.as_ptx_str().trim_start_matches('.');
format!("elementwise_{}_{}", self.op.as_str(), type_str)
}
pub fn generate(&self) -> Result<String, PtxGenError> {
self.validate_precision()?;
match self.op {
ElementwiseOp::Add => self.generate_binary_arith("add"),
ElementwiseOp::Sub => self.generate_binary_arith("sub"),
ElementwiseOp::Mul => self.generate_binary_arith("mul"),
ElementwiseOp::Div => self.generate_div(),
ElementwiseOp::Relu => self.generate_relu(),
ElementwiseOp::Gelu => self.generate_gelu(),
ElementwiseOp::Sigmoid => self.generate_sigmoid(),
ElementwiseOp::Silu => self.generate_silu(),
ElementwiseOp::Tanh => self.generate_tanh(),
ElementwiseOp::Neg => self.generate_unary("neg"),
ElementwiseOp::Abs => self.generate_unary("abs"),
ElementwiseOp::Sqrt => self.generate_sqrt(),
ElementwiseOp::Rsqrt => self.generate_rsqrt(),
ElementwiseOp::Exp => self.generate_exp(),
ElementwiseOp::Log => self.generate_log(),
ElementwiseOp::Ceil => self.generate_ceil(),
ElementwiseOp::Floor => self.generate_floor(),
ElementwiseOp::HardSigmoid => self.generate_hard_sigmoid(),
ElementwiseOp::HardSwish => self.generate_hard_swish(),
ElementwiseOp::Softplus => self.generate_softplus(),
ElementwiseOp::LeakyRelu => self.generate_leaky_relu(),
ElementwiseOp::OneMinus => self.generate_one_minus(),
ElementwiseOp::Scale => self.generate_scale(),
ElementwiseOp::AddScalar => self.generate_add_scalar(),
ElementwiseOp::FusedAddRelu => self.generate_fused_add_relu(),
ElementwiseOp::FusedScaleAdd => self.generate_fused_scale_add(),
ElementwiseOp::Pow => self.generate_pow(),
ElementwiseOp::Min => self.generate_binary_minmax("min"),
ElementwiseOp::Max | ElementwiseOp::OrMax => self.generate_binary_minmax("max"),
ElementwiseOp::CmpEq => self.generate_binary_cmp("eq"),
ElementwiseOp::CmpNe => self.generate_binary_cmp("ne"),
ElementwiseOp::CmpLt => self.generate_binary_cmp("lt"),
ElementwiseOp::CmpGt => self.generate_binary_cmp("gt"),
ElementwiseOp::CmpLe => self.generate_binary_cmp("le"),
ElementwiseOp::CmpGe => self.generate_binary_cmp("ge"),
ElementwiseOp::OrProbSum => self.generate_or_prob_sum(),
ElementwiseOp::Nand => self.generate_nand(),
ElementwiseOp::Nor => self.generate_nor(),
ElementwiseOp::Xor => self.generate_xor(),
ElementwiseOp::Fill => self.generate_fill(),
}
}
fn validate_precision(&self) -> Result<(), PtxGenError> {
if !matches!(
self.precision,
PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64
) {
return Err(PtxGenError::InvalidType(format!(
"elementwise operations require F16, BF16, F32, or F64, got {}",
self.precision.as_ptx_str()
)));
}
Ok(())
}
pub(super) const fn ty_str(&self) -> &'static str {
self.precision.as_ptx_str()
}
fn generate_binary_arith(&self, op_name: &str) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let op_name = op_name.to_string();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_a, [%rd_a];\n \
ld.global{ty} %f_b, [%rd_b];\n \
{op_name}{ty} %f_c, %f_a, %f_b;\n \
st.global{ty} [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_div(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_a, [%rd_a];\n \
ld.global{ty} %f_b, [%rd_b];\n \
div.rn{ty} %f_c, %f_a, %f_b;\n \
st.global{ty} [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_relu(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let zero_lit = float_zero_literal(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
max{ty} %f_y, %f_x, {zero_lit};\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_sigmoid(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
neg{ty} %f_neg, %f_x;\n \
mul{ty} %f_neg, %f_neg, 0f3FB8AA3B;\n \
ex2.approx{ty} %f_exp, %f_neg;\n \
add{ty} %f_denom, %f_exp, 0f3F800000;\n \
rcp.approx{ty} %f_y, %f_denom;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_gelu(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
mul{ty} %f_x3, %f_x, %f_x;\n \
mul{ty} %f_x3, %f_x3, %f_x;\n \
mul{ty} %f_x3, %f_x3, 0f3D372713;\n \
add{ty} %f_inner, %f_x, %f_x3;\n \
mul{ty} %f_inner, %f_inner, 0f3F4C422A;\n \
mul{ty} %f_2a, %f_inner, 0f40000000;\n \
neg{ty} %f_neg2a, %f_2a;\n \
mul{ty} %f_neg2a, %f_neg2a, 0f3FB8AA3B;\n \
ex2.approx{ty} %f_exp, %f_neg2a;\n \
add{ty} %f_denom, %f_exp, 0f3F800000;\n \
rcp.approx{ty} %f_sig, %f_denom;\n \
mul{ty} %f_sig, %f_sig, 0f40000000;\n \
sub{ty} %f_tanh, %f_sig, 0f3F800000;\n \
add{ty} %f_tanh, %f_tanh, 0f3F800000;\n \
mul{ty} %f_y, 0f3F000000, %f_x;\n \
mul{ty} %f_y, %f_y, %f_tanh;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_silu(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
neg{ty} %f_neg, %f_x;\n \
mul{ty} %f_neg, %f_neg, 0f3FB8AA3B;\n \
ex2.approx{ty} %f_exp, %f_neg;\n \
add{ty} %f_denom, %f_exp, 0f3F800000;\n \
rcp.approx{ty} %f_sig, %f_denom;\n \
mul{ty} %f_y, %f_x, %f_sig;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_tanh(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
mul{ty} %f_2x, %f_x, 0f40000000;\n \
neg{ty} %f_neg, %f_2x;\n \
mul{ty} %f_neg, %f_neg, 0f3FB8AA3B;\n \
ex2.approx{ty} %f_exp, %f_neg;\n \
add{ty} %f_denom, %f_exp, 0f3F800000;\n \
rcp.approx{ty} %f_sig, %f_denom;\n \
mul{ty} %f_y, %f_sig, 0f40000000;\n \
sub{ty} %f_y, %f_y, 0f3F800000;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_unary(&self, op_name: &str) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let op_name = op_name.to_string();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
{op_name}{ty} %f_y, %f_x;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_sqrt(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
sqrt.rn{ty} %f_y, %f_x;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_rsqrt(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
rsqrt.approx{ty} %f_y, %f_x;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_exp(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
mul{ty} %f_x2, %f_x, 0f3FB8AA3B;\n \
ex2.approx{ty} %f_y, %f_x2;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_log(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
lg2.approx{ty} %f_lg, %f_x;\n \
mul{ty} %f_y, %f_lg, 0f3F317218;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_ceil(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
cvt.rpi{ty}{ty} %f_y, %f_x;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_floor(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
cvt.rmi{ty}{ty} %f_y, %f_x;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_hard_sigmoid(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let zero_lit = float_zero_literal(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
mul{ty} %f_ax, %f_x, 0f3E4CCCCD;\n \
add{ty} %f_lin, %f_ax, 0f3F000000;\n \
min{ty} %f_clip, %f_lin, 0f3F800000;\n \
max{ty} %f_y, %f_clip, {zero_lit};\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_hard_swish(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let zero_lit = float_zero_literal(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
add{ty} %f_xp3, %f_x, 0f40400000;\n \
min{ty} %f_clip, %f_xp3, 0f40C00000;\n \
max{ty} %f_clip, %f_clip, {zero_lit};\n \
mul{ty} %f_div, %f_clip, 0f3E2AAAAB;\n \
mul{ty} %f_y, %f_x, %f_div;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_softplus(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
mul{ty} %f_xe, %f_x, 0f3FB8AA3B;\n \
ex2.approx{ty} %f_exp, %f_xe;\n \
add{ty} %f_sum, %f_exp, 0f3F800000;\n \
lg2.approx{ty} %f_lg, %f_sum;\n \
mul{ty} %f_y, %f_lg, 0f3F317218;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_leaky_relu(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let zero_lit = float_zero_literal(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
mul{ty} %f_leak, %f_x, 0f3C23D70A;\n \
setp.ge{ty} %p_ge, %f_x, {zero_lit};\n \
selp{ty} %f_y, %f_x, %f_leak, %p_ge;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_scale(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let scalar_ty = scalar_param_type(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("alpha", scalar_ty)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.param{ty} %f_alpha, [%param_alpha];\n \
ld.global{ty} %f_x, [%rd_a];\n \
mul{ty} %f_y, %f_alpha, %f_x;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_add_scalar(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let scalar_ty = scalar_param_type(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("scalar", scalar_ty)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.param{ty} %f_s, [%param_scalar];\n \
ld.global{ty} %f_x, [%rd_a];\n \
add{ty} %f_y, %f_x, %f_s;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_fused_add_relu(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let zero_lit = float_zero_literal(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_a, [%rd_a];\n \
ld.global{ty} %f_b, [%rd_b];\n \
add{ty} %f_sum, %f_a, %f_b;\n \
max{ty} %f_y, %f_sum, {zero_lit};\n \
st.global{ty} [%rd_c], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_one_minus(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let one_lit = float_one_literal(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_x, [%rd_a];\n \
sub{ty} %f_y, {one_lit}, %f_x;\n \
st.global{ty} [%rd_b], %f_y;"
));
});
b.ret();
})
.build()
}
fn generate_pow(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_a, [%rd_a];\n \
ld.global{ty} %f_b, [%rd_b];\n \
lg2.approx{ty} %f_t1, %f_a;\n \
mul{ty} %f_t2, %f_t1, %f_b;\n \
ex2.approx{ty} %f_c, %f_t2;\n \
st.global{ty} [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_binary_minmax(&self, min_or_max: &str) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let min_or_max = min_or_max.to_string();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_a, [%rd_a];\n \
ld.global{ty} %f_b, [%rd_b];\n \
{min_or_max}{ty} %f_c, %f_a, %f_b;\n \
st.global{ty} [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_binary_cmp(&self, cond: &str) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let one_lit = float_one_literal(self.precision);
let zero_lit = float_zero_literal(self.precision);
let cond = cond.to_string();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_a, [%rd_a];\n \
ld.global{ty} %f_b, [%rd_b];\n \
setp.{cond}{ty} %p_cmp, %f_a, %f_b;\n \
selp{ty} %f_c, {one_lit}, {zero_lit}, %p_cmp;\n \
st.global{ty} [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_or_prob_sum(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_a, [%rd_a];\n \
ld.global{ty} %f_b, [%rd_b];\n \
mul{ty} %f_t, %f_a, %f_b;\n \
sub{ty} %f_s, %f_a, %f_t;\n \
add{ty} %f_c, %f_s, %f_b;\n \
st.global{ty} [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_nand(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let one_lit = float_one_literal(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_a, [%rd_a];\n \
ld.global{ty} %f_b, [%rd_b];\n \
mul{ty} %f_t, %f_a, %f_b;\n \
sub{ty} %f_c, {one_lit}, %f_t;\n \
st.global{ty} [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
fn generate_nor(&self) -> Result<String, PtxGenError> {
let kernel_name = self.kernel_name();
let ty = self.ty_str();
let byte_size = self.precision.size_bytes();
let one_lit = float_one_literal(self.precision);
KernelBuilder::new(&kernel_name)
.target(self.target)
.param("a_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("c_ptr", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(tid, n_reg, move |b| {
let a_ptr = b.load_param_u64("a_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let c_ptr = b.load_param_u64("c_ptr");
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_off, {tid_name};\n \
mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
add.u64 %rd_a, {a_ptr}, %rd_off;\n \
add.u64 %rd_b, {b_ptr}, %rd_off;\n \
add.u64 %rd_c, {c_ptr}, %rd_off;"
));
b.raw_ptx(&format!(
"ld.global{ty} %f_a, [%rd_a];\n \
ld.global{ty} %f_b, [%rd_b];\n \
mul{ty} %f_t, %f_a, %f_b;\n \
sub{ty} %f_s, %f_a, %f_t;\n \
add{ty} %f_u, %f_s, %f_b;\n \
sub{ty} %f_c, {one_lit}, %f_u;\n \
st.global{ty} [%rd_c], %f_c;"
));
});
b.ret();
})
.build()
}
}