use std::fmt::Write as FmtWrite;
use oxicuda_ptx::ir::PtxType;
use crate::error::{BlasError, BlasResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EpilogueOp {
LinearCombination,
LinearCombinationRelu,
LinearCombinationGelu,
LinearCombinationBias,
LinearCombinationBiasRelu,
}
impl EpilogueOp {
pub fn as_str(self) -> &'static str {
match self {
Self::LinearCombination => "lincomb",
Self::LinearCombinationRelu => "lincomb_relu",
Self::LinearCombinationGelu => "lincomb_gelu",
Self::LinearCombinationBias => "lincomb_bias",
Self::LinearCombinationBiasRelu => "lincomb_bias_relu",
}
}
pub fn needs_bias(self) -> bool {
matches!(
self,
Self::LinearCombinationBias | Self::LinearCombinationBiasRelu
)
}
pub fn has_relu(self) -> bool {
matches!(
self,
Self::LinearCombinationRelu | Self::LinearCombinationBiasRelu
)
}
pub fn has_gelu(self) -> bool {
matches!(self, Self::LinearCombinationGelu)
}
pub fn to_ptx_kind(self) -> oxicuda_ptx::templates::gemm::EpilogueKind {
match self {
Self::LinearCombination => {
oxicuda_ptx::templates::gemm::EpilogueKind::LinearCombination
}
Self::LinearCombinationRelu => {
oxicuda_ptx::templates::gemm::EpilogueKind::LinearCombinationRelu
}
Self::LinearCombinationGelu => {
oxicuda_ptx::templates::gemm::EpilogueKind::LinearCombinationGelu
}
Self::LinearCombinationBias => {
oxicuda_ptx::templates::gemm::EpilogueKind::LinearCombinationBias
}
Self::LinearCombinationBiasRelu => {
oxicuda_ptx::templates::gemm::EpilogueKind::LinearCombinationBiasRelu
}
}
}
}
pub fn generate_epilogue_ptx(acc_type: PtxType, op: EpilogueOp) -> BlasResult<String> {
let ty = acc_type.as_ptx_str();
let mut ptx = String::with_capacity(512);
write_line(
&mut ptx,
&format!(" mul{ty} %f_result, %f_acc, %f_alpha;"),
)?;
write_line(
&mut ptx,
&format!(" fma.rn{ty} %f_result, %f_beta, %f_cold, %f_result;"),
)?;
if op.needs_bias() {
write_line(
&mut ptx,
&format!(" add{ty} %f_result, %f_result, %f_bias;"),
)?;
}
if op.has_relu() {
write_line(
&mut ptx,
&format!(" max{ty} %f_result, %f_result, 0f00000000;"),
)?;
} else if op.has_gelu() {
write_line(
&mut ptx,
&format!(" mul{ty} %f_gelu_s, %f_result, 0f3FDA6286; // 1.702"),
)?;
write_line(&mut ptx, &format!(" neg{ty} %f_gelu_s, %f_gelu_s;"))?;
write_line(
&mut ptx,
&format!(" mul{ty} %f_gelu_s, %f_gelu_s, 0f3FB8AA3B; // log2(e)"),
)?;
write_line(
&mut ptx,
&format!(" ex2.approx{ty} %f_gelu_s, %f_gelu_s;"),
)?;
write_line(
&mut ptx,
&format!(" add{ty} %f_gelu_s, %f_gelu_s, 0f3F800000; // +1.0"),
)?;
write_line(
&mut ptx,
&format!(" rcp.approx{ty} %f_gelu_s, %f_gelu_s;"),
)?;
write_line(
&mut ptx,
&format!(" mul{ty} %f_result, %f_result, %f_gelu_s;"),
)?;
}
Ok(ptx)
}
fn write_line(ptx: &mut String, line: &str) -> BlasResult<()> {
writeln!(ptx, "{line}").map_err(|e| BlasError::PtxGeneration(format!("fmt error: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn epilogue_op_labels() {
assert_eq!(EpilogueOp::LinearCombination.as_str(), "lincomb");
assert_eq!(EpilogueOp::LinearCombinationRelu.as_str(), "lincomb_relu");
assert_eq!(EpilogueOp::LinearCombinationGelu.as_str(), "lincomb_gelu");
assert_eq!(EpilogueOp::LinearCombinationBias.as_str(), "lincomb_bias");
assert_eq!(
EpilogueOp::LinearCombinationBiasRelu.as_str(),
"lincomb_bias_relu"
);
}
#[test]
fn epilogue_needs_bias() {
assert!(!EpilogueOp::LinearCombination.needs_bias());
assert!(!EpilogueOp::LinearCombinationRelu.needs_bias());
assert!(EpilogueOp::LinearCombinationBias.needs_bias());
assert!(EpilogueOp::LinearCombinationBiasRelu.needs_bias());
}
#[test]
fn epilogue_has_relu() {
assert!(!EpilogueOp::LinearCombination.has_relu());
assert!(EpilogueOp::LinearCombinationRelu.has_relu());
assert!(EpilogueOp::LinearCombinationBiasRelu.has_relu());
}
#[test]
fn generate_linear_combination() {
let ptx = generate_epilogue_ptx(PtxType::F32, EpilogueOp::LinearCombination)
.expect("epilogue generation failed");
assert!(ptx.contains("mul.f32"));
assert!(ptx.contains("fma.rn.f32"));
assert!(!ptx.contains("max.f32"));
}
#[test]
fn generate_relu_epilogue() {
let ptx = generate_epilogue_ptx(PtxType::F32, EpilogueOp::LinearCombinationRelu)
.expect("relu epilogue generation failed");
assert!(ptx.contains("max.f32"));
}
#[test]
fn generate_bias_epilogue() {
let ptx = generate_epilogue_ptx(PtxType::F32, EpilogueOp::LinearCombinationBias)
.expect("bias epilogue generation failed");
assert!(ptx.contains("add.f32 %f_result, %f_result, %f_bias"));
}
#[test]
fn generate_gelu_epilogue() {
let ptx = generate_epilogue_ptx(PtxType::F32, EpilogueOp::LinearCombinationGelu)
.expect("gelu epilogue generation failed");
assert!(ptx.contains("ex2.approx.f32"));
assert!(ptx.contains("rcp.approx.f32"));
}
#[test]
fn to_ptx_kind_roundtrip() {
let _ = EpilogueOp::LinearCombination.to_ptx_kind();
let _ = EpilogueOp::LinearCombinationRelu.to_ptx_kind();
let _ = EpilogueOp::LinearCombinationGelu.to_ptx_kind();
let _ = EpilogueOp::LinearCombinationBias.to_ptx_kind();
let _ = EpilogueOp::LinearCombinationBiasRelu.to_ptx_kind();
}
}