use cudarc::cublaslt::sys::cublasLtEpilogue_t;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum Epilogue {
None,
Relu,
Bias,
ReluBias,
ReluAux,
ReluAuxBias,
Gelu,
GeluAux,
GeluBias,
GeluAuxBias,
DRelu,
DReluBgrad,
DGelu,
DGeluBgrad,
BgradA,
BgradB,
}
impl Epilogue {
pub fn to_cublas(self) -> cublasLtEpilogue_t {
match self {
Self::None => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
Self::Relu => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU,
Self::Bias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS,
Self::ReluBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS,
Self::ReluAux => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX,
Self::ReluAuxBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
Self::Gelu => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU,
Self::GeluAux => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX,
Self::GeluBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS,
Self::GeluAuxBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX_BIAS,
Self::DRelu | Self::DReluBgrad => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DRELU_BGRAD,
Self::DGelu | Self::DGeluBgrad => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU_BGRAD,
Self::BgradA => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADA,
Self::BgradB => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADB,
}
}
pub fn uses_bias(self) -> bool {
matches!(
self,
Self::Bias | Self::ReluBias | Self::ReluAuxBias | Self::GeluBias | Self::GeluAuxBias
)
}
pub fn uses_aux(self) -> bool {
matches!(
self,
Self::ReluAux
| Self::ReluAuxBias
| Self::GeluAux
| Self::GeluAuxBias
| Self::DRelu
| Self::DReluBgrad
| Self::DGelu
| Self::DGeluBgrad
)
}
pub fn produces_bias_grad(self) -> bool {
matches!(
self,
Self::BgradA | Self::BgradB | Self::DReluBgrad | Self::DGeluBgrad
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn epilogue_round_trip() {
let cases = [
(
Epilogue::None,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
),
(Epilogue::Relu, cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU),
(Epilogue::Bias, cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS),
(
Epilogue::ReluBias,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS,
),
(
Epilogue::ReluAux,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX,
),
(
Epilogue::ReluAuxBias,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
),
(Epilogue::Gelu, cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU),
(
Epilogue::GeluAux,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX,
),
(
Epilogue::GeluBias,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS,
),
(
Epilogue::GeluAuxBias,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX_BIAS,
),
(
Epilogue::DReluBgrad,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DRELU_BGRAD,
),
(
Epilogue::DGeluBgrad,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU_BGRAD,
),
(
Epilogue::BgradA,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADA,
),
(
Epilogue::BgradB,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADB,
),
];
for (lhs, rhs) in cases {
assert_eq!(lhs.to_cublas(), rhs, "{lhs:?}");
}
}
#[test]
fn epilogue_capability_flags() {
assert!(Epilogue::Bias.uses_bias());
assert!(!Epilogue::None.uses_bias());
assert!(Epilogue::ReluBias.uses_bias());
assert!(Epilogue::GeluAux.uses_aux());
assert!(Epilogue::DReluBgrad.produces_bias_grad());
assert!(Epilogue::BgradA.produces_bias_grad());
assert!(!Epilogue::Relu.produces_bias_grad());
}
#[test]
fn epilogue_default_is_none_variant() {
assert_eq!(
Epilogue::None.to_cublas() as u32,
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT as u32
);
}
}