use std::collections::{HashMap, HashSet, VecDeque};
use anyhow::{Result, anyhow};
use ronn_core::{DataType, GraphEdge, GraphNode, NodeId, SubGraph, TensorLayout};
use tracing::{debug, info};
#[derive(Debug)]
pub struct KernelCompiler {
fusion_config: FusionConfig,
memory_config: MemoryConfig,
}
#[derive(Debug, Clone)]
pub struct FusionConfig {
pub enable_fusion: bool,
pub max_fusion_depth: usize,
pub enable_elementwise_fusion: bool,
pub enable_conv_fusion: bool,
pub enable_matmul_fusion: bool,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
enable_fusion: true,
max_fusion_depth: 4,
enable_elementwise_fusion: true,
enable_conv_fusion: true,
enable_matmul_fusion: true,
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryConfig {
pub enable_optimization: bool,
pub prefer_row_major: bool,
pub enable_tensor_reuse: bool,
pub max_memory_overhead: f32,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
enable_optimization: true,
prefer_row_major: true,
enable_tensor_reuse: true,
max_memory_overhead: 0.2, }
}
}
#[derive(Debug, Clone)]
pub struct FusedOperation {
pub id: usize,
pub nodes: Vec<GraphNode>,
pub fusion_type: FusionType,
pub inputs: Vec<usize>,
pub outputs: Vec<usize>,
pub cost: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum FusionType {
None,
ElementWise,
ConvBnRelu,
MatMulBias,
Custom(String),
}
#[derive(Debug, Clone)]
pub struct MemoryPlan {
pub tensor_count: usize,
pub tensor_info: Vec<TensorInfo>,
pub reuse_map: HashMap<usize, usize>,
pub total_memory: usize,
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub id: usize,
pub shape: Vec<usize>,
pub dtype: DataType,
pub layout: TensorLayout,
pub size_bytes: usize,
pub lifetime: (usize, usize),
}
#[derive(Debug, Clone)]
pub struct CompilationResult {
pub fused_ops: Vec<FusedOperation>,
pub memory_plan: MemoryPlan,
pub stats: CompilationStats,
}
#[derive(Debug, Clone)]
pub struct CompilationStats {
pub original_ops: usize,
pub fused_ops: usize,
pub memory_reduction: f32,
pub performance_improvement: f32,
}
impl KernelCompiler {
pub fn new() -> Self {
Self {
fusion_config: FusionConfig::default(),
memory_config: MemoryConfig::default(),
}
}
pub fn with_config(fusion_config: FusionConfig, memory_config: MemoryConfig) -> Self {
Self {
fusion_config,
memory_config,
}
}
pub fn compile(&self, subgraph: &SubGraph) -> Result<CompilationResult> {
debug!("Compiling subgraph with {} nodes", subgraph.nodes.len());
let topology = self.analyze_topology(subgraph)?;
let fusion_candidates = if self.fusion_config.enable_fusion {
self.detect_fusion_opportunities(subgraph, &topology)?
} else {
subgraph
.nodes
.iter()
.enumerate()
.map(|(i, node)| {
FusedOperation {
id: i,
nodes: vec![node.clone()],
fusion_type: FusionType::None,
inputs: vec![i], outputs: vec![i],
cost: 1.0,
}
})
.collect()
};
let fused_ops = self.apply_fusion(subgraph, fusion_candidates)?;
let memory_plan = if self.memory_config.enable_optimization {
self.optimize_memory_layout(subgraph, &fused_ops)?
} else {
self.create_basic_memory_plan(subgraph)?
};
let stats = CompilationStats {
original_ops: subgraph.nodes.len(),
fused_ops: fused_ops.len(),
memory_reduction: self.calculate_memory_reduction(subgraph, &memory_plan),
performance_improvement: self.estimate_performance_improvement(&fused_ops),
};
info!(
"Compilation complete: {} -> {} ops, {:.1}% memory reduction, {:.1}% performance improvement",
stats.original_ops,
stats.fused_ops,
stats.memory_reduction * 100.0,
stats.performance_improvement * 100.0
);
Ok(CompilationResult {
fused_ops,
memory_plan,
stats,
})
}
fn analyze_topology(&self, subgraph: &SubGraph) -> Result<TopologyInfo> {
let mut topology = TopologyInfo {
node_dependencies: HashMap::new(),
node_dependents: HashMap::new(),
execution_order: Vec::new(),
};
for edge in &subgraph.edges {
topology
.node_dependencies
.entry(edge.to_node)
.or_insert_with(Vec::new)
.push(edge.from_node);
topology
.node_dependents
.entry(edge.from_node)
.or_insert_with(Vec::new)
.push(edge.to_node);
}
topology.execution_order = self.topological_sort(subgraph)?;
Ok(topology)
}
fn topological_sort(&self, subgraph: &SubGraph) -> Result<Vec<NodeId>> {
let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
let mut adjacency: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
for node in &subgraph.nodes {
in_degree.insert(node.id, 0);
adjacency.insert(node.id, Vec::new());
}
for edge in &subgraph.edges {
adjacency
.get_mut(&edge.from_node)
.unwrap()
.push(edge.to_node);
*in_degree.get_mut(&edge.to_node).unwrap() += 1;
}
let mut queue: VecDeque<NodeId> = in_degree
.iter()
.filter(|(_, degree)| **degree == 0)
.map(|(node_id, _)| *node_id)
.collect();
let mut result = Vec::new();
while let Some(node_id) = queue.pop_front() {
result.push(node_id);
if let Some(neighbors) = adjacency.get(&node_id) {
for &neighbor in neighbors {
let degree = in_degree.get_mut(&neighbor).unwrap();
*degree -= 1;
if *degree == 0 {
queue.push_back(neighbor);
}
}
}
}
if result.len() != subgraph.nodes.len() {
return Err(anyhow!("Cycle detected in subgraph"));
}
Ok(result)
}
fn detect_fusion_opportunities(
&self,
subgraph: &SubGraph,
topology: &TopologyInfo,
) -> Result<Vec<FusedOperation>> {
let mut fusion_candidates = Vec::new();
let mut processed_nodes = HashSet::new();
for &node_id in &topology.execution_order {
if processed_nodes.contains(&node_id) {
continue;
}
let node = subgraph.nodes.iter().find(|n| n.id == node_id).unwrap();
let fusion_group =
self.try_create_fusion_group(node, subgraph, topology, &processed_nodes)?;
for fused_node in &fusion_group.nodes {
processed_nodes.insert(fused_node.id);
}
fusion_candidates.push(fusion_group);
}
Ok(fusion_candidates)
}
fn try_create_fusion_group(
&self,
start_node: &GraphNode,
subgraph: &SubGraph,
topology: &TopologyInfo,
processed_nodes: &HashSet<NodeId>,
) -> Result<FusedOperation> {
let mut fusion_group = vec![start_node.clone()];
let mut current_node_id = start_node.id;
for _ in 0..self.fusion_config.max_fusion_depth - 1 {
let current_node = subgraph
.nodes
.iter()
.find(|n| n.id == current_node_id)
.unwrap();
if let Some(next_node) =
self.find_fusable_successor(current_node, subgraph, topology, processed_nodes)?
{
current_node_id = next_node.id;
fusion_group.push(next_node);
} else {
break;
}
}
let fusion_type = self.classify_fusion_group(&fusion_group);
let cost = self.estimate_fusion_cost(&fusion_group, &fusion_type);
Ok(FusedOperation {
id: start_node.id, nodes: fusion_group,
fusion_type,
inputs: vec![0], outputs: vec![0], cost,
})
}
fn find_fusable_successor(
&self,
current_node: &GraphNode,
subgraph: &SubGraph,
topology: &TopologyInfo,
processed_nodes: &HashSet<NodeId>,
) -> Result<Option<GraphNode>> {
if let Some(dependents) = topology.node_dependents.get(¤t_node.id) {
for &dependent_id in dependents {
if processed_nodes.contains(&dependent_id) {
continue;
}
if let Some(dependent_node) = subgraph.nodes.iter().find(|n| n.id == dependent_id) {
if self.can_fuse(¤t_node.op_type, &dependent_node.op_type) {
return Ok(Some(dependent_node.clone()));
}
}
}
}
Ok(None)
}
fn can_fuse(&self, op1: &str, op2: &str) -> bool {
match (op1, op2) {
("Add", "ReLU") | ("Mul", "ReLU") | ("Sub", "ReLU") => {
self.fusion_config.enable_elementwise_fusion
}
("Add", "Add") | ("Mul", "Mul") => self.fusion_config.enable_elementwise_fusion,
("Conv", "BatchNormalization") | ("BatchNormalization", "ReLU") => {
self.fusion_config.enable_conv_fusion
}
("MatMul", "Add") | ("Gemm", "Add") => self.fusion_config.enable_matmul_fusion,
_ => false,
}
}
fn classify_fusion_group(&self, nodes: &[GraphNode]) -> FusionType {
if nodes.len() == 1 {
return FusionType::None;
}
let op_types: Vec<&str> = nodes.iter().map(|n| n.op_type.as_str()).collect();
match op_types.as_slice() {
["Conv", "BatchNormalization", "ReLU"] => FusionType::ConvBnRelu,
["MatMul", "Add"] | ["Gemm", "Add"] => FusionType::MatMulBias,
ops if ops
.iter()
.all(|op| matches!(*op, "Add" | "Mul" | "Sub" | "ReLU")) =>
{
FusionType::ElementWise
}
_ => FusionType::Custom(format!("{:?}", op_types)),
}
}
fn estimate_fusion_cost(&self, nodes: &[GraphNode], fusion_type: &FusionType) -> f64 {
let base_cost: f64 = nodes.len() as f64;
match fusion_type {
FusionType::None => base_cost,
FusionType::ElementWise => base_cost * 0.7, FusionType::ConvBnRelu => base_cost * 0.6, FusionType::MatMulBias => base_cost * 0.8, FusionType::Custom(_) => base_cost * 0.9, }
}
fn apply_fusion(
&self,
_subgraph: &SubGraph,
fusion_candidates: Vec<FusedOperation>,
) -> Result<Vec<FusedOperation>> {
Ok(fusion_candidates)
}
fn optimize_memory_layout(
&self,
subgraph: &SubGraph,
fused_ops: &[FusedOperation],
) -> Result<MemoryPlan> {
let mut tensor_info = Vec::new();
let mut reuse_map = HashMap::new();
let mut total_memory = 0;
let lifetimes = self.analyze_tensor_lifetimes(fused_ops)?;
for (i, (shape, dtype)) in self.estimate_tensor_shapes(subgraph)?.iter().enumerate() {
let size_bytes = self.calculate_tensor_size(shape, *dtype);
let layout = if self.memory_config.prefer_row_major {
TensorLayout::RowMajor
} else {
TensorLayout::ColumnMajor
};
let lifetime = lifetimes.get(&i).copied().unwrap_or((0, fused_ops.len()));
let info = TensorInfo {
id: i,
shape: shape.clone(),
dtype: *dtype,
layout,
size_bytes,
lifetime,
};
tensor_info.push(info);
total_memory += size_bytes;
}
if self.memory_config.enable_tensor_reuse {
reuse_map = self.detect_tensor_reuse(&tensor_info)?;
}
Ok(MemoryPlan {
tensor_count: tensor_info.len(),
tensor_info,
reuse_map,
total_memory,
})
}
fn create_basic_memory_plan(&self, subgraph: &SubGraph) -> Result<MemoryPlan> {
let tensor_shapes = self.estimate_tensor_shapes(subgraph)?;
let mut tensor_info = Vec::new();
let mut total_memory = 0;
for (i, (shape, dtype)) in tensor_shapes.iter().enumerate() {
let size_bytes = self.calculate_tensor_size(shape, *dtype);
let info = TensorInfo {
id: i,
shape: shape.clone(),
dtype: *dtype,
layout: TensorLayout::RowMajor,
size_bytes,
lifetime: (0, subgraph.nodes.len()),
};
tensor_info.push(info);
total_memory += size_bytes;
}
Ok(MemoryPlan {
tensor_count: tensor_info.len(),
tensor_info,
reuse_map: HashMap::new(),
total_memory,
})
}
fn estimate_tensor_shapes(&self, subgraph: &SubGraph) -> Result<Vec<(Vec<usize>, DataType)>> {
let mut shapes = Vec::new();
for _node in &subgraph.nodes {
shapes.push((vec![32, 32], DataType::F32));
}
Ok(shapes)
}
fn calculate_tensor_size(&self, shape: &[usize], dtype: DataType) -> usize {
let element_count: usize = shape.iter().product();
let element_size = match dtype {
DataType::F32 | DataType::I32 | DataType::U32 => 4,
DataType::F16 | DataType::BF16 => 2,
DataType::F64 | DataType::I64 => 8,
DataType::I8 | DataType::U8 | DataType::Bool => 1,
};
element_count * element_size
}
fn analyze_tensor_lifetimes(
&self,
fused_ops: &[FusedOperation],
) -> Result<HashMap<usize, (usize, usize)>> {
let mut lifetimes = HashMap::new();
for (op_idx, fused_op) in fused_ops.iter().enumerate() {
for &input_idx in &fused_op.inputs {
let entry = lifetimes.entry(input_idx).or_insert((op_idx, op_idx));
entry.0 = entry.0.min(op_idx);
entry.1 = entry.1.max(op_idx);
}
for &output_idx in &fused_op.outputs {
let entry = lifetimes.entry(output_idx).or_insert((op_idx, op_idx));
entry.0 = entry.0.min(op_idx);
entry.1 = entry.1.max(op_idx);
}
}
Ok(lifetimes)
}
fn detect_tensor_reuse(&self, tensor_info: &[TensorInfo]) -> Result<HashMap<usize, usize>> {
let mut reuse_map = HashMap::new();
for i in 0..tensor_info.len() {
for j in (i + 1)..tensor_info.len() {
let tensor1 = &tensor_info[i];
let tensor2 = &tensor_info[j];
if tensor1.lifetime.1 < tensor2.lifetime.0 {
if tensor1.size_bytes == tensor2.size_bytes && tensor1.dtype == tensor2.dtype {
reuse_map.insert(j, i);
break; }
}
}
}
Ok(reuse_map)
}
fn calculate_memory_reduction(&self, _subgraph: &SubGraph, memory_plan: &MemoryPlan) -> f32 {
let original_memory: usize = memory_plan.tensor_info.iter().map(|t| t.size_bytes).sum();
let reused_memory: usize = memory_plan
.reuse_map
.values()
.map(|&reused_from| memory_plan.tensor_info[reused_from].size_bytes)
.sum();
if original_memory > 0 {
reused_memory as f32 / original_memory as f32
} else {
0.0
}
}
fn estimate_performance_improvement(&self, fused_ops: &[FusedOperation]) -> f32 {
let total_savings: f64 = fused_ops
.iter()
.map(|op| {
let original_cost = op.nodes.len() as f64;
(original_cost - op.cost).max(0.0)
})
.sum();
let total_original_cost: f64 = fused_ops.iter().map(|op| op.nodes.len() as f64).sum();
if total_original_cost > 0.0 {
(total_savings / total_original_cost) as f32
} else {
0.0
}
}
}
#[derive(Debug)]
struct TopologyInfo {
node_dependencies: HashMap<NodeId, Vec<NodeId>>,
node_dependents: HashMap<NodeId, Vec<NodeId>>,
execution_order: Vec<NodeId>,
}
impl Default for KernelCompiler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_subgraph() -> SubGraph {
let nodes = vec![
GraphNode {
id: 0,
op_type: "Add".to_string(),
attributes: HashMap::new(),
inputs: vec!["input1".to_string(), "input2".to_string()],
outputs: vec!["temp1".to_string()],
name: Some("add1".to_string()),
},
GraphNode {
id: 1,
op_type: "ReLU".to_string(),
attributes: HashMap::new(),
inputs: vec!["temp1".to_string()],
outputs: vec!["output1".to_string()],
name: Some("relu1".to_string()),
},
];
let edges = vec![GraphEdge {
from_node: 0,
to_node: 1,
tensor_name: "temp1".to_string(),
tensor_shape: Some(vec![32, 32]),
tensor_dtype: DataType::F32,
}];
SubGraph {
nodes,
edges,
inputs: vec!["input1".to_string(), "input2".to_string()],
outputs: vec!["output1".to_string()],
}
}
#[test]
fn test_compiler_creation() {
let compiler = KernelCompiler::new();
assert!(compiler.fusion_config.enable_fusion);
assert!(compiler.memory_config.enable_optimization);
}
#[test]
fn test_compilation() -> Result<()> {
let compiler = KernelCompiler::new();
let subgraph = create_test_subgraph();
let result = compiler.compile(&subgraph)?;
assert!(result.fused_ops.len() <= subgraph.nodes.len());
assert!(result.memory_plan.tensor_count > 0);
assert_eq!(result.stats.original_ops, 2);
Ok(())
}
#[test]
fn test_fusion_detection() -> Result<()> {
let compiler = KernelCompiler::new();
assert!(compiler.can_fuse("Add", "ReLU"));
assert!(compiler.can_fuse("Conv", "BatchNormalization"));
assert!(compiler.can_fuse("MatMul", "Add"));
assert!(!compiler.can_fuse("Add", "Conv"));
assert!(!compiler.can_fuse("ReLU", "MatMul"));
Ok(())
}
#[test]
fn test_topological_sort() -> Result<()> {
let compiler = KernelCompiler::new();
let subgraph = create_test_subgraph();
let order = compiler.topological_sort(&subgraph)?;
assert_eq!(order.len(), 2);
assert_eq!(order[0], 0); assert_eq!(order[1], 1);
Ok(())
}
#[test]
fn test_memory_planning() -> Result<()> {
let compiler = KernelCompiler::new();
let subgraph = create_test_subgraph();
let memory_plan = compiler.create_basic_memory_plan(&subgraph)?;
assert_eq!(memory_plan.tensor_count, 2);
assert!(memory_plan.total_memory > 0);
for tensor in &memory_plan.tensor_info {
assert!(!tensor.shape.is_empty());
assert!(tensor.size_bytes > 0);
}
Ok(())
}
#[test]
fn test_fusion_classification() -> Result<()> {
let compiler = KernelCompiler::new();
let single_node = vec![GraphNode {
id: 0,
op_type: "Add".to_string(),
attributes: HashMap::new(),
inputs: vec![],
outputs: vec![],
name: None,
}];
assert_eq!(
compiler.classify_fusion_group(&single_node),
FusionType::None
);
let conv_bn_relu = vec![
GraphNode {
id: 0,
op_type: "Conv".to_string(),
attributes: HashMap::new(),
inputs: vec![],
outputs: vec![],
name: None,
},
GraphNode {
id: 1,
op_type: "BatchNormalization".to_string(),
attributes: HashMap::new(),
inputs: vec![],
outputs: vec![],
name: None,
},
GraphNode {
id: 2,
op_type: "ReLU".to_string(),
attributes: HashMap::new(),
inputs: vec![],
outputs: vec![],
name: None,
},
];
assert_eq!(
compiler.classify_fusion_group(&conv_bn_relu),
FusionType::ConvBnRelu
);
Ok(())
}
#[test]
fn test_custom_config() -> Result<()> {
let fusion_config = FusionConfig {
enable_fusion: false,
max_fusion_depth: 2,
..Default::default()
};
let memory_config = MemoryConfig {
enable_optimization: false,
..Default::default()
};
let compiler = KernelCompiler::with_config(fusion_config, memory_config);
let subgraph = create_test_subgraph();
let result = compiler.compile(&subgraph)?;
assert_eq!(result.fused_ops.len(), subgraph.nodes.len());
Ok(())
}
}