use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub enum FusionPattern {
ElementWise(Vec<String>),
Reduction(String),
MatrixOp(String),
Custom(String, Vec<String>),
}
#[derive(Debug, Clone)]
pub struct GraphNode {
pub id: usize,
pub op_type: String,
pub inputs: Vec<usize>,
pub outputs: Vec<usize>,
pub fusible: bool,
}
#[derive(Debug, Clone)]
pub struct ComputeGraph {
nodes: Vec<GraphNode>,
next_id: usize,
}
impl ComputeGraph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
next_id: 0,
}
}
pub fn add_node(&mut self, op_type: String, inputs: Vec<usize>, fusible: bool) -> usize {
let id = self.next_id;
self.next_id += 1;
for &input_id in &inputs {
if let Some(node) = self.nodes.iter_mut().find(|n| n.id == input_id) {
node.outputs.push(id);
}
}
self.nodes.push(GraphNode {
id,
op_type,
inputs,
outputs: Vec::new(),
fusible,
});
id
}
pub fn nodes(&self) -> &[GraphNode] {
&self.nodes
}
pub fn get_node(&self, id: usize) -> Option<&GraphNode> {
self.nodes.iter().find(|n| n.id == id)
}
}
impl Default for ComputeGraph {
fn default() -> Self {
Self::new()
}
}
pub struct FusionEngine {
patterns: Vec<FusionPattern>,
fused_ops: HashMap<String, Vec<String>>,
}
impl FusionEngine {
pub fn new() -> Self {
let mut engine = Self {
patterns: Vec::new(),
fused_ops: HashMap::new(),
};
engine.register_default_patterns();
engine
}
fn register_default_patterns(&mut self) {
self.add_pattern(FusionPattern::Custom(
"ConvBNReLU".to_string(),
vec!["Conv2d".to_string(), "BatchNorm".to_string(), "ReLU".to_string()],
));
self.add_pattern(FusionPattern::Custom(
"LinearReLU".to_string(),
vec!["Linear".to_string(), "ReLU".to_string()],
));
self.add_pattern(FusionPattern::Custom(
"GEMM".to_string(),
vec!["MatMul".to_string(), "Add".to_string()],
));
self.add_pattern(FusionPattern::Custom(
"AddReLU".to_string(),
vec!["Add".to_string(), "ReLU".to_string()],
));
self.add_pattern(FusionPattern::Custom(
"FMA".to_string(),
vec!["Mul".to_string(), "Add".to_string()],
));
self.add_pattern(FusionPattern::Custom(
"BNReLU".to_string(),
vec!["BatchNorm".to_string(), "ReLU".to_string()],
));
}
pub fn add_pattern(&mut self, pattern: FusionPattern) {
self.patterns.push(pattern);
}
pub fn analyze(&mut self, graph: &ComputeGraph) -> Vec<FusionOpportunity> {
let mut opportunities = Vec::new();
for pattern in &self.patterns {
if let FusionPattern::Custom(name, ops) = pattern {
opportunities.extend(self.find_pattern_matches(graph, name, ops));
}
}
opportunities
}
fn find_pattern_matches(
&self,
graph: &ComputeGraph,
pattern_name: &str,
ops: &[String],
) -> Vec<FusionOpportunity> {
let mut matches = Vec::new();
for i in 0..graph.nodes().len() {
if self.matches_pattern_at(graph, i, ops) {
let node_ids: Vec<usize> = (i..i + ops.len()).collect();
matches.push(FusionOpportunity {
pattern_name: pattern_name.to_string(),
nodes: node_ids,
estimated_speedup: self.estimate_speedup(ops),
});
}
}
matches
}
fn matches_pattern_at(&self, graph: &ComputeGraph, start: usize, ops: &[String]) -> bool {
if start + ops.len() > graph.nodes().len() {
return false;
}
for (i, op) in ops.iter().enumerate() {
if let Some(node) = graph.get_node(start + i) {
if &node.op_type != op || !node.fusible {
return false;
}
} else {
return false;
}
}
true
}
fn estimate_speedup(&self, ops: &[String]) -> f32 {
match ops.len() {
2 => 1.3, 3 => 1.5, 4 => 1.7, _ => 1.2, }
}
pub fn fuse(&mut self, graph: &mut ComputeGraph, opportunities: &[FusionOpportunity]) {
for opp in opportunities {
self.fused_ops.insert(
opp.pattern_name.clone(),
opp.nodes.iter().map(|&id| {
graph.get_node(id).map(|n| n.op_type.clone()).unwrap_or_default()
}).collect(),
);
}
}
pub fn get_fused_ops(&self) -> &HashMap<String, Vec<String>> {
&self.fused_ops
}
}
impl Default for FusionEngine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FusionOpportunity {
pub pattern_name: String,
pub nodes: Vec<usize>,
pub estimated_speedup: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_graph() {
let mut graph = ComputeGraph::new();
let n1 = graph.add_node("Input".to_string(), vec![], false);
let n2 = graph.add_node("Conv2d".to_string(), vec![n1], true);
let n3 = graph.add_node("ReLU".to_string(), vec![n2], true);
assert_eq!(graph.nodes().len(), 3);
assert_eq!(graph.get_node(n2).unwrap().op_type, "Conv2d");
}
#[test]
fn test_fusion_engine() {
let mut engine = FusionEngine::new();
let mut graph = ComputeGraph::new();
let n1 = graph.add_node("Input".to_string(), vec![], false);
let n2 = graph.add_node("Linear".to_string(), vec![n1], true);
let n3 = graph.add_node("ReLU".to_string(), vec![n2], true);
let opportunities = engine.analyze(&graph);
assert!(!opportunities.is_empty());
}
#[test]
fn test_fusion_patterns() {
let engine = FusionEngine::new();
assert!(!engine.patterns.is_empty());
}
}