use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum OpKind {
MatMul,
BiasAdd,
Relu,
Gelu,
Sigmoid,
Tanh,
Swish,
Conv2d,
BatchNorm,
Add,
Sub,
Mul,
Div,
Neg,
Square,
Exp,
Log,
Sqrt,
Sum,
Mean,
Max,
Min,
Input,
Constant,
Custom(String),
}
impl fmt::Display for OpKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OpKind::MatMul => write!(f, "matmul"),
OpKind::BiasAdd => write!(f, "bias_add"),
OpKind::Relu => write!(f, "relu"),
OpKind::Gelu => write!(f, "gelu"),
OpKind::Sigmoid => write!(f, "sigmoid"),
OpKind::Tanh => write!(f, "tanh"),
OpKind::Swish => write!(f, "swish"),
OpKind::Conv2d => write!(f, "conv2d"),
OpKind::BatchNorm => write!(f, "batch_norm"),
OpKind::Add => write!(f, "add"),
OpKind::Sub => write!(f, "sub"),
OpKind::Mul => write!(f, "mul"),
OpKind::Div => write!(f, "div"),
OpKind::Neg => write!(f, "neg"),
OpKind::Square => write!(f, "square"),
OpKind::Exp => write!(f, "exp"),
OpKind::Log => write!(f, "log"),
OpKind::Sqrt => write!(f, "sqrt"),
OpKind::Sum => write!(f, "sum"),
OpKind::Mean => write!(f, "mean"),
OpKind::Max => write!(f, "max"),
OpKind::Min => write!(f, "min"),
OpKind::Input => write!(f, "input"),
OpKind::Constant => write!(f, "constant"),
OpKind::Custom(name) => write!(f, "custom({})", name),
}
}
}
impl OpKind {
pub fn is_elementwise(&self) -> bool {
matches!(
self,
OpKind::Add
| OpKind::Sub
| OpKind::Mul
| OpKind::Div
| OpKind::Neg
| OpKind::Square
| OpKind::Exp
| OpKind::Log
| OpKind::Sqrt
| OpKind::Relu
| OpKind::Gelu
| OpKind::Sigmoid
| OpKind::Tanh
| OpKind::Swish
)
}
pub fn is_activation(&self) -> bool {
matches!(
self,
OpKind::Relu | OpKind::Gelu | OpKind::Sigmoid | OpKind::Tanh | OpKind::Swish
)
}
pub fn is_reduction(&self) -> bool {
matches!(self, OpKind::Sum | OpKind::Mean | OpKind::Max | OpKind::Min)
}
}
#[derive(Debug, Clone)]
pub struct GraphNode {
pub id: usize,
pub op: OpKind,
pub inputs: Vec<usize>,
pub output_shape: Vec<usize>,
pub num_consumers: usize,
}
impl GraphNode {
pub fn new(id: usize, op: OpKind, inputs: Vec<usize>, output_shape: Vec<usize>) -> Self {
Self {
id,
op,
inputs,
output_shape,
num_consumers: 0,
}
}
pub fn output_numel(&self) -> Option<usize> {
if self.output_shape.is_empty() {
None
} else {
Some(self.output_shape.iter().product())
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum FusionPattern {
ElementWise,
MatMulBias,
MatMulActivation,
MatMulBiasActivation,
ConvBN,
ConvBNActivation,
ReductionElementWise,
Affine,
SumDivToMean,
SquareMeanToVariance,
Softmax,
Custom(String),
}
impl fmt::Display for FusionPattern {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FusionPattern::ElementWise => write!(f, "elementwise_chain"),
FusionPattern::MatMulBias => write!(f, "matmul_bias"),
FusionPattern::MatMulActivation => write!(f, "matmul_activation"),
FusionPattern::MatMulBiasActivation => write!(f, "matmul_bias_activation"),
FusionPattern::ConvBN => write!(f, "conv_bn"),
FusionPattern::ConvBNActivation => write!(f, "conv_bn_activation"),
FusionPattern::ReductionElementWise => write!(f, "reduction_elementwise"),
FusionPattern::Affine => write!(f, "affine"),
FusionPattern::SumDivToMean => write!(f, "sum_div_to_mean"),
FusionPattern::SquareMeanToVariance => write!(f, "square_mean_to_variance"),
FusionPattern::Softmax => write!(f, "softmax"),
FusionPattern::Custom(name) => write!(f, "custom({})", name),
}
}
}
pub fn can_fuse_matmul_bias(node_a: &GraphNode, node_b: &GraphNode) -> bool {
if node_a.op != OpKind::MatMul {
return false;
}
if !matches!(node_b.op, OpKind::BiasAdd | OpKind::Add) {
return false;
}
if !node_b.inputs.contains(&node_a.id) {
return false;
}
node_a.num_consumers == 1
}
pub fn can_fuse_matmul_activation(node_a: &GraphNode, node_b: &GraphNode) -> bool {
if node_a.op != OpKind::MatMul {
return false;
}
if !node_b.op.is_activation() {
return false;
}
if !node_b.inputs.contains(&node_a.id) {
return false;
}
node_a.num_consumers == 1
}
pub fn can_fuse_matmul_bias_activation(
node_a: &GraphNode,
node_b: &GraphNode,
node_c: &GraphNode,
) -> bool {
can_fuse_matmul_bias(node_a, node_b)
&& node_b.op.is_elementwise()
&& node_c.op.is_activation()
&& node_c.inputs.contains(&node_b.id)
&& node_b.num_consumers == 1
}
pub fn can_fuse_conv_bn(node_a: &GraphNode, node_b: &GraphNode) -> bool {
if node_a.op != OpKind::Conv2d {
return false;
}
if node_b.op != OpKind::BatchNorm {
return false;
}
if !node_b.inputs.contains(&node_a.id) {
return false;
}
node_a.num_consumers == 1
}
pub fn can_fuse_conv_bn_activation(
node_a: &GraphNode,
node_b: &GraphNode,
node_c: &GraphNode,
) -> bool {
can_fuse_conv_bn(node_a, node_b)
&& node_c.op.is_activation()
&& node_c.inputs.contains(&node_b.id)
&& node_b.num_consumers == 1
}
pub fn can_fuse_elementwise(node_a: &GraphNode, node_b: &GraphNode) -> bool {
if !node_a.op.is_elementwise() || !node_b.op.is_elementwise() {
return false;
}
if !node_b.inputs.contains(&node_a.id) {
return false;
}
node_a.num_consumers == 1
}
pub fn can_fuse_affine(node_a: &GraphNode, node_b: &GraphNode) -> bool {
if node_a.op != OpKind::Mul {
return false;
}
if node_b.op != OpKind::Add {
return false;
}
if !node_b.inputs.contains(&node_a.id) {
return false;
}
node_a.num_consumers == 1
}
pub fn can_fuse_sum_div_to_mean(node_a: &GraphNode, node_b: &GraphNode) -> bool {
if node_a.op != OpKind::Sum {
return false;
}
if node_b.op != OpKind::Div {
return false;
}
if !node_b.inputs.contains(&node_a.id) {
return false;
}
node_a.num_consumers == 1
}
pub fn can_fuse_square_mean_to_variance(node_a: &GraphNode, node_b: &GraphNode) -> bool {
if node_a.op != OpKind::Square {
return false;
}
if node_b.op != OpKind::Mean {
return false;
}
if !node_b.inputs.contains(&node_a.id) {
return false;
}
node_a.num_consumers == 1
}
pub fn can_fuse_softmax(node_a: &GraphNode, node_b: &GraphNode, node_c: &GraphNode) -> bool {
if node_a.op != OpKind::Exp {
return false;
}
if node_b.op != OpKind::Sum {
return false;
}
if node_c.op != OpKind::Div {
return false;
}
if !node_b.inputs.contains(&node_a.id) {
return false;
}
if !node_c.inputs.contains(&node_a.id) || !node_c.inputs.contains(&node_b.id) {
return false;
}
node_a.num_consumers == 2 && node_b.num_consumers == 1
}
pub fn detect_two_node_pattern(node_a: &GraphNode, node_b: &GraphNode) -> Option<FusionPattern> {
if can_fuse_matmul_bias(node_a, node_b) {
return Some(FusionPattern::MatMulBias);
}
if can_fuse_matmul_activation(node_a, node_b) {
return Some(FusionPattern::MatMulActivation);
}
if can_fuse_conv_bn(node_a, node_b) {
return Some(FusionPattern::ConvBN);
}
if can_fuse_sum_div_to_mean(node_a, node_b) {
return Some(FusionPattern::SumDivToMean);
}
if can_fuse_square_mean_to_variance(node_a, node_b) {
return Some(FusionPattern::SquareMeanToVariance);
}
if can_fuse_affine(node_a, node_b) {
return Some(FusionPattern::Affine);
}
if can_fuse_elementwise(node_a, node_b) {
return Some(FusionPattern::ElementWise);
}
None
}
pub fn detect_three_node_pattern(
node_a: &GraphNode,
node_b: &GraphNode,
node_c: &GraphNode,
) -> Option<FusionPattern> {
if can_fuse_matmul_bias_activation(node_a, node_b, node_c) {
return Some(FusionPattern::MatMulBiasActivation);
}
if can_fuse_conv_bn_activation(node_a, node_b, node_c) {
return Some(FusionPattern::ConvBNActivation);
}
if can_fuse_softmax(node_a, node_b, node_c) {
return Some(FusionPattern::Softmax);
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn make_node(id: usize, op: OpKind, inputs: Vec<usize>, consumers: usize) -> GraphNode {
let mut node = GraphNode::new(id, op, inputs, vec![2, 3]);
node.num_consumers = consumers;
node
}
#[test]
fn test_op_kind_is_elementwise() {
assert!(OpKind::Add.is_elementwise());
assert!(OpKind::Relu.is_elementwise());
assert!(!OpKind::MatMul.is_elementwise());
assert!(!OpKind::Conv2d.is_elementwise());
assert!(!OpKind::Sum.is_elementwise());
}
#[test]
fn test_op_kind_is_activation() {
assert!(OpKind::Relu.is_activation());
assert!(OpKind::Gelu.is_activation());
assert!(OpKind::Sigmoid.is_activation());
assert!(!OpKind::Add.is_activation());
assert!(!OpKind::MatMul.is_activation());
}
#[test]
fn test_op_kind_is_reduction() {
assert!(OpKind::Sum.is_reduction());
assert!(OpKind::Mean.is_reduction());
assert!(!OpKind::Add.is_reduction());
}
#[test]
fn test_op_kind_display() {
assert_eq!(format!("{}", OpKind::MatMul), "matmul");
assert_eq!(format!("{}", OpKind::Relu), "relu");
assert_eq!(
format!("{}", OpKind::Custom("my_op".to_string())),
"custom(my_op)"
);
}
#[test]
fn test_can_fuse_matmul_bias() {
let matmul = make_node(0, OpKind::MatMul, vec![], 1);
let bias_add = make_node(1, OpKind::BiasAdd, vec![0], 1);
assert!(can_fuse_matmul_bias(&matmul, &bias_add));
}
#[test]
fn test_cannot_fuse_matmul_bias_multiple_consumers() {
let matmul = make_node(0, OpKind::MatMul, vec![], 2);
let bias_add = make_node(1, OpKind::BiasAdd, vec![0], 1);
assert!(!can_fuse_matmul_bias(&matmul, &bias_add));
}
#[test]
fn test_cannot_fuse_matmul_bias_no_edge() {
let matmul = make_node(0, OpKind::MatMul, vec![], 1);
let bias_add = make_node(1, OpKind::BiasAdd, vec![5], 1);
assert!(!can_fuse_matmul_bias(&matmul, &bias_add));
}
#[test]
fn test_can_fuse_matmul_activation() {
let matmul = make_node(0, OpKind::MatMul, vec![], 1);
let relu = make_node(1, OpKind::Relu, vec![0], 1);
assert!(can_fuse_matmul_activation(&matmul, &relu));
}
#[test]
fn test_cannot_fuse_matmul_nonactivation() {
let matmul = make_node(0, OpKind::MatMul, vec![], 1);
let add = make_node(1, OpKind::Add, vec![0], 1);
assert!(!can_fuse_matmul_activation(&matmul, &add));
}
#[test]
fn test_can_fuse_matmul_bias_activation() {
let matmul = make_node(0, OpKind::MatMul, vec![], 1);
let bias_add = make_node(1, OpKind::Add, vec![0], 1);
let relu = make_node(2, OpKind::Relu, vec![1], 1);
assert!(can_fuse_matmul_bias_activation(&matmul, &bias_add, &relu));
}
#[test]
fn test_can_fuse_matmul_bias_gelu() {
let matmul = make_node(0, OpKind::MatMul, vec![], 1);
let bias_add = make_node(1, OpKind::Add, vec![0], 1);
let gelu = make_node(2, OpKind::Gelu, vec![1], 1);
assert!(can_fuse_matmul_bias_activation(&matmul, &bias_add, &gelu));
}
#[test]
fn test_can_fuse_conv_bn() {
let conv = make_node(0, OpKind::Conv2d, vec![], 1);
let bn = make_node(1, OpKind::BatchNorm, vec![0], 1);
assert!(can_fuse_conv_bn(&conv, &bn));
}
#[test]
fn test_cannot_fuse_conv_bn_wrong_order() {
let bn = make_node(0, OpKind::BatchNorm, vec![], 1);
let conv = make_node(1, OpKind::Conv2d, vec![0], 1);
assert!(!can_fuse_conv_bn(&bn, &conv));
}
#[test]
fn test_can_fuse_conv_bn_relu() {
let conv = make_node(0, OpKind::Conv2d, vec![], 1);
let bn = make_node(1, OpKind::BatchNorm, vec![0], 1);
let relu = make_node(2, OpKind::Relu, vec![1], 1);
assert!(can_fuse_conv_bn_activation(&conv, &bn, &relu));
}
#[test]
fn test_can_fuse_elementwise() {
let add = make_node(0, OpKind::Add, vec![], 1);
let mul = make_node(1, OpKind::Mul, vec![0], 1);
assert!(can_fuse_elementwise(&add, &mul));
}
#[test]
fn test_cannot_fuse_elementwise_non_elementwise() {
let add = make_node(0, OpKind::Add, vec![], 1);
let matmul = make_node(1, OpKind::MatMul, vec![0], 1);
assert!(!can_fuse_elementwise(&add, &matmul));
}
#[test]
fn test_can_fuse_affine() {
let mul = make_node(0, OpKind::Mul, vec![], 1);
let add = make_node(1, OpKind::Add, vec![0], 1);
assert!(can_fuse_affine(&mul, &add));
}
#[test]
fn test_can_fuse_sum_div_to_mean() {
let sum = make_node(0, OpKind::Sum, vec![], 1);
let div = make_node(1, OpKind::Div, vec![0], 1);
assert!(can_fuse_sum_div_to_mean(&sum, &div));
}
#[test]
fn test_can_fuse_square_mean_to_variance() {
let sq = make_node(0, OpKind::Square, vec![], 1);
let mean = make_node(1, OpKind::Mean, vec![0], 1);
assert!(can_fuse_square_mean_to_variance(&sq, &mean));
}
#[test]
fn test_can_fuse_softmax() {
let exp = make_node(0, OpKind::Exp, vec![], 2);
let sum = make_node(1, OpKind::Sum, vec![0], 1);
let div = make_node(2, OpKind::Div, vec![0, 1], 1);
assert!(can_fuse_softmax(&exp, &sum, &div));
}
#[test]
fn test_cannot_fuse_softmax_wrong_consumers() {
let exp = make_node(0, OpKind::Exp, vec![], 1); let sum = make_node(1, OpKind::Sum, vec![0], 1);
let div = make_node(2, OpKind::Div, vec![0, 1], 1);
assert!(!can_fuse_softmax(&exp, &sum, &div));
}
#[test]
fn test_detect_two_node_matmul_bias() {
let matmul = make_node(0, OpKind::MatMul, vec![], 1);
let bias_add = make_node(1, OpKind::BiasAdd, vec![0], 1);
assert_eq!(
detect_two_node_pattern(&matmul, &bias_add),
Some(FusionPattern::MatMulBias)
);
}
#[test]
fn test_detect_two_node_none() {
let matmul = make_node(0, OpKind::MatMul, vec![], 1);
let conv = make_node(1, OpKind::Conv2d, vec![0], 1);
assert_eq!(detect_two_node_pattern(&matmul, &conv), None);
}
#[test]
fn test_detect_three_node_matmul_bias_activation() {
let matmul = make_node(0, OpKind::MatMul, vec![], 1);
let bias_add = make_node(1, OpKind::Add, vec![0], 1);
let relu = make_node(2, OpKind::Relu, vec![1], 1);
assert_eq!(
detect_three_node_pattern(&matmul, &bias_add, &relu),
Some(FusionPattern::MatMulBiasActivation)
);
}
#[test]
fn test_graph_node_output_numel() {
let node = GraphNode::new(0, OpKind::Add, vec![], vec![2, 3, 4]);
assert_eq!(node.output_numel(), Some(24));
let empty = GraphNode::new(1, OpKind::Add, vec![], vec![]);
assert_eq!(empty.output_numel(), None);
}
#[test]
fn test_fusion_pattern_display() {
assert_eq!(format!("{}", FusionPattern::MatMulBias), "matmul_bias");
assert_eq!(
format!("{}", FusionPattern::MatMulBiasActivation),
"matmul_bias_activation"
);
assert_eq!(
format!("{}", FusionPattern::ConvBNActivation),
"conv_bn_activation"
);
}
}