use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComputationGraph {
pub id: Uuid,
pub nodes: HashMap<String, GraphNode>,
pub edges: HashMap<String, Vec<String>>,
pub root_nodes: HashSet<String>,
pub leaf_nodes: HashSet<String>,
pub metadata: GraphMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphMetadata {
pub name: String,
pub node_count: usize,
pub edge_count: usize,
pub max_depth: usize,
pub estimated_memory_usage: u64,
pub estimated_flops: u64,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphNode {
pub id: String,
pub name: String,
pub operation_type: OperationType,
pub input_shapes: Vec<Vec<usize>>,
pub output_shapes: Vec<Vec<usize>>,
pub flop_count: u64,
pub memory_usage: u64,
pub execution_time_us: Option<u64>,
pub parameter_count: Option<u64>,
pub topo_order: Option<usize>,
pub depth: usize,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum OperationType {
Add,
Subtract,
Multiply,
Divide,
MatMul,
Dot,
ReLU,
Sigmoid,
Tanh,
GELU,
Softmax,
LayerNorm,
BatchNorm,
RMSNorm,
Conv1D,
Conv2D,
Conv3D,
ConvTranspose,
MaxPool,
AvgPool,
AdaptivePool,
Reshape,
Transpose,
Concat,
Split,
Slice,
Gather,
Scatter,
Sum,
Mean,
Max,
Min,
Attention,
MultiHeadAttention,
SelfAttention,
CrossAttention,
Embedding,
PositionalEmbedding,
CrossEntropyLoss,
MSELoss,
L1Loss,
If,
While,
Loop,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphAnalysisConfig {
pub enable_memory_analysis: bool,
pub enable_flop_analysis: bool,
pub enable_optimization_analysis: bool,
pub enable_bottleneck_detection: bool,
pub enable_dataflow_analysis: bool,
pub bottleneck_threshold_us: u64,
pub large_memory_threshold: u64,
}
impl Default for GraphAnalysisConfig {
fn default() -> Self {
Self {
enable_memory_analysis: true,
enable_flop_analysis: true,
enable_optimization_analysis: true,
enable_bottleneck_detection: true,
enable_dataflow_analysis: true,
bottleneck_threshold_us: 1000, large_memory_threshold: 1024 * 1024 * 100, }
}
}
#[derive(Debug)]
pub struct ComputationGraphAnalyzer {
config: GraphAnalysisConfig,
graphs: HashMap<Uuid, ComputationGraph>,
analysis_results: HashMap<Uuid, GraphAnalysisResult>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphAnalysisResult {
pub graph_id: Uuid,
pub memory_analysis: Option<MemoryAnalysis>,
pub flop_analysis: Option<FlopAnalysis>,
pub optimization_opportunities: Vec<OptimizationOpportunity>,
pub bottleneck_analysis: Option<BottleneckAnalysis>,
pub dataflow_analysis: Option<DataFlowAnalysis>,
pub critical_path: Vec<String>,
pub statistics: GraphStatistics,
pub recommendations: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryAnalysis {
pub total_memory_usage: u64,
pub peak_memory_usage: u64,
pub memory_by_operation: HashMap<OperationType, u64>,
pub memory_hotspots: Vec<(String, u64)>,
pub fragmentation_ratio: f64,
pub optimization_suggestions: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlopAnalysis {
pub total_flops: u64,
pub flops_by_operation: HashMap<OperationType, u64>,
pub compute_hotspots: Vec<(String, u64)>,
pub arithmetic_intensity: f64,
pub complexity_analysis: ComplexityAnalysis,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplexityAnalysis {
pub time_complexity: String,
pub space_complexity: String,
pub parallelization_potential: f64,
pub sequential_dependencies: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationOpportunity {
pub optimization_type: OptimizationType,
pub description: String,
pub affected_nodes: Vec<String>,
pub estimated_improvement: EstimatedImprovement,
pub implementation_difficulty: u8,
pub priority: OptimizationPriority,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum OptimizationType {
OperationFusion,
RedundancyElimination,
MemoryLayoutOptimization,
AlgorithmicOptimization,
Parallelization,
DataAccessOptimization,
PrecisionOptimization,
Memoization,
ControlFlowOptimization,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum OptimizationPriority {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EstimatedImprovement {
pub speedup_factor: f64,
pub memory_reduction: u64,
pub energy_savings: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BottleneckAnalysis {
pub bottleneck_nodes: Vec<String>,
pub critical_path_nodes: Vec<String>,
pub critical_path_time_us: u64,
pub parallelizable_nodes: Vec<String>,
pub scheduling_suggestions: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataFlowAnalysis {
pub data_dependencies: HashMap<String, Vec<String>>,
pub live_variables: HashMap<String, HashSet<String>>,
pub variable_lifetimes: HashMap<String, VariableLifetime>,
pub memory_reuse_opportunities: Vec<MemoryReuseOpportunity>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VariableLifetime {
pub birth_node: String,
pub death_node: String,
pub usage_nodes: Vec<String>,
pub memory_footprint: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryReuseOpportunity {
pub reusable_variables: Vec<String>,
pub memory_savings: u64,
pub complexity: u8,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphStatistics {
pub nodes_by_type: HashMap<OperationType, usize>,
pub average_fan_in: f64,
pub average_fan_out: f64,
pub diameter: usize,
pub clustering_coefficient: f64,
pub strongly_connected_components: usize,
}
impl ComputationGraphAnalyzer {
pub fn new(config: GraphAnalysisConfig) -> Self {
Self {
config,
graphs: HashMap::new(),
analysis_results: HashMap::new(),
}
}
pub fn add_graph(&mut self, graph: ComputationGraph) -> Result<()> {
let graph_id = graph.id;
self.graphs.insert(graph_id, graph);
Ok(())
}
pub fn create_graph(
&mut self,
name: String,
operations: Vec<(String, OperationType, Vec<String>)>, ) -> Result<Uuid> {
let graph_id = Uuid::new_v4();
let mut nodes = HashMap::new();
let mut edges = HashMap::new();
let mut root_nodes = HashSet::new();
let mut leaf_nodes = HashSet::new();
for (node_id, op_type, dependencies) in &operations {
let node = GraphNode {
id: node_id.clone(),
name: node_id.clone(),
operation_type: op_type.clone(),
input_shapes: vec![],
output_shapes: vec![],
flop_count: self.estimate_flops(op_type, &[]),
memory_usage: self.estimate_memory(op_type, &[]),
execution_time_us: None,
parameter_count: self.estimate_parameters(op_type),
topo_order: None,
depth: 0,
metadata: HashMap::new(),
};
nodes.insert(node_id.clone(), node);
if dependencies.is_empty() {
root_nodes.insert(node_id.clone());
}
edges.insert(node_id.clone(), dependencies.clone());
}
let all_dependencies: HashSet<String> = edges.values().flatten().cloned().collect();
for node_id in nodes.keys() {
if !all_dependencies.contains(node_id) {
leaf_nodes.insert(node_id.clone());
}
}
self.calculate_depth_and_topo_order(&mut nodes, &edges)?;
let metadata = GraphMetadata {
name,
node_count: nodes.len(),
edge_count: edges.values().map(|deps| deps.len()).sum(),
max_depth: nodes.values().map(|n| n.depth).max().unwrap_or(0),
estimated_memory_usage: nodes.values().map(|n| n.memory_usage).sum(),
estimated_flops: nodes.values().map(|n| n.flop_count).sum(),
created_at: chrono::Utc::now(),
};
let graph = ComputationGraph {
id: graph_id,
nodes,
edges,
root_nodes,
leaf_nodes,
metadata,
};
self.graphs.insert(graph_id, graph);
Ok(graph_id)
}
pub fn analyze_graph(&mut self, graph_id: Uuid) -> Result<GraphAnalysisResult> {
let graph = self
.graphs
.get(&graph_id)
.ok_or_else(|| anyhow::anyhow!("Graph not found: {}", graph_id))?;
let mut result = GraphAnalysisResult {
graph_id,
memory_analysis: None,
flop_analysis: None,
optimization_opportunities: Vec::new(),
bottleneck_analysis: None,
dataflow_analysis: None,
critical_path: Vec::new(),
statistics: self.calculate_statistics(graph)?,
recommendations: Vec::new(),
};
if self.config.enable_memory_analysis {
result.memory_analysis = Some(self.analyze_memory_usage(graph)?);
}
if self.config.enable_flop_analysis {
result.flop_analysis = Some(self.analyze_flop_usage(graph)?);
}
if self.config.enable_optimization_analysis {
result.optimization_opportunities = self.detect_optimization_opportunities(graph)?;
}
if self.config.enable_bottleneck_detection {
result.bottleneck_analysis = Some(self.analyze_bottlenecks(graph)?);
}
if self.config.enable_dataflow_analysis {
result.dataflow_analysis = Some(self.analyze_dataflow(graph)?);
}
result.critical_path = self.find_critical_path(graph)?;
result.recommendations = self.generate_recommendations(&result)?;
self.analysis_results.insert(graph_id, result.clone());
Ok(result)
}
pub fn get_analysis_result(&self, graph_id: Uuid) -> Option<&GraphAnalysisResult> {
self.analysis_results.get(&graph_id)
}
pub fn export_to_dot(&self, graph_id: Uuid) -> Result<String> {
let graph = self
.graphs
.get(&graph_id)
.ok_or_else(|| anyhow::anyhow!("Graph not found: {}", graph_id))?;
let mut dot = String::new();
dot.push_str(&format!("digraph \"{}\" {{\n", graph.metadata.name));
dot.push_str(" rankdir=TB;\n");
dot.push_str(" node [shape=box, style=filled];\n\n");
for node in graph.nodes.values() {
let color = self.get_node_color(&node.operation_type);
let label = format!(
"{}\\n{}\\n{:.1} GFLOP\\n{:.1} MB",
node.name,
format!("{:?}", node.operation_type),
node.flop_count as f64 / 1e9,
node.memory_usage as f64 / (1024.0 * 1024.0)
);
dot.push_str(&format!(
" \"{}\" [label=\"{}\", fillcolor=\"{}\"];\n",
node.id, label, color
));
}
dot.push('\n');
for (node_id, dependencies) in &graph.edges {
for dep in dependencies {
dot.push_str(&format!(" \"{}\" -> \"{}\";\n", dep, node_id));
}
}
dot.push_str("}\n");
Ok(dot)
}
fn calculate_depth_and_topo_order(
&self,
nodes: &mut HashMap<String, GraphNode>,
edges: &HashMap<String, Vec<String>>,
) -> Result<()> {
let mut in_degree: HashMap<String, usize> = HashMap::new();
let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
for node_id in nodes.keys() {
in_degree.insert(node_id.clone(), 0);
adj_list.insert(node_id.clone(), Vec::new());
}
for (node_id, dependencies) in edges {
in_degree.insert(node_id.clone(), dependencies.len());
for dep in dependencies {
if let Some(adj) = adj_list.get_mut(dep) {
adj.push(node_id.clone());
}
}
}
let mut queue = VecDeque::new();
let mut topo_order = 0;
for (node_id, °ree) in &in_degree {
if degree == 0 {
queue.push_back((node_id.clone(), 0)); }
}
while let Some((node_id, depth)) = queue.pop_front() {
if let Some(node) = nodes.get_mut(&node_id) {
node.depth = depth;
node.topo_order = Some(topo_order);
topo_order += 1;
}
if let Some(neighbors) = adj_list.get(&node_id) {
for neighbor in neighbors {
if let Some(degree) = in_degree.get_mut(neighbor) {
*degree -= 1;
if *degree == 0 {
queue.push_back((neighbor.clone(), depth + 1));
}
}
}
}
}
Ok(())
}
fn estimate_flops(&self, op_type: &OperationType, shapes: &[Vec<usize>]) -> u64 {
match op_type {
OperationType::MatMul => {
if shapes.len() >= 2 {
let a_shape = &shapes[0];
let b_shape = &shapes[1];
if a_shape.len() >= 2 && b_shape.len() >= 2 {
let m = a_shape[a_shape.len() - 2];
let k = a_shape[a_shape.len() - 1];
let n = b_shape[b_shape.len() - 1];
return (2 * m * k * n) as u64;
}
}
1000000 },
OperationType::Add | OperationType::Subtract | OperationType::Multiply => {
shapes.first().map(|s| s.iter().product::<usize>() as u64).unwrap_or(1000)
},
OperationType::ReLU | OperationType::Sigmoid | OperationType::Tanh => {
shapes.first().map(|s| s.iter().product::<usize>() as u64).unwrap_or(1000)
},
OperationType::LayerNorm | OperationType::BatchNorm => {
shapes.first().map(|s| (s.iter().product::<usize>() * 5) as u64).unwrap_or(5000)
},
_ => 1000, }
}
fn estimate_memory(&self, op_type: &OperationType, shapes: &[Vec<usize>]) -> u64 {
let element_size = 4u64;
match op_type {
OperationType::MatMul => {
shapes
.iter()
.map(|s| s.iter().product::<usize>() as u64 * element_size)
.sum::<u64>()
.max(1024) },
_ => shapes
.first()
.map(|s| s.iter().product::<usize>() as u64 * element_size)
.unwrap_or(1024),
}
}
fn estimate_parameters(&self, op_type: &OperationType) -> Option<u64> {
match op_type {
OperationType::MatMul => Some(1000000), OperationType::Conv2D => Some(500000),
OperationType::Embedding => Some(2000000),
OperationType::LayerNorm => Some(1000),
_ => None,
}
}
fn analyze_memory_usage(&self, graph: &ComputationGraph) -> Result<MemoryAnalysis> {
let total_memory_usage = graph.nodes.values().map(|n| n.memory_usage).sum();
let mut memory_by_operation: HashMap<OperationType, u64> = HashMap::new();
for node in graph.nodes.values() {
*memory_by_operation.entry(node.operation_type.clone()).or_insert(0) +=
node.memory_usage;
}
let mut memory_hotspots: Vec<(String, u64)> =
graph.nodes.values().map(|n| (n.id.clone(), n.memory_usage)).collect();
memory_hotspots.sort_by_key(|item| std::cmp::Reverse(item.1));
memory_hotspots.truncate(10);
let peak_memory_usage = total_memory_usage; let fragmentation_ratio = 0.1;
let optimization_suggestions = vec![
"Consider memory pooling for frequently allocated tensors".to_string(),
"Implement in-place operations where possible".to_string(),
"Use gradient checkpointing for memory-intensive layers".to_string(),
];
Ok(MemoryAnalysis {
total_memory_usage,
peak_memory_usage,
memory_by_operation,
memory_hotspots,
fragmentation_ratio,
optimization_suggestions,
})
}
fn analyze_flop_usage(&self, graph: &ComputationGraph) -> Result<FlopAnalysis> {
let total_flops = graph.nodes.values().map(|n| n.flop_count).sum();
let mut flops_by_operation: HashMap<OperationType, u64> = HashMap::new();
for node in graph.nodes.values() {
*flops_by_operation.entry(node.operation_type.clone()).or_insert(0) += node.flop_count;
}
let mut compute_hotspots: Vec<(String, u64)> =
graph.nodes.values().map(|n| (n.id.clone(), n.flop_count)).collect();
compute_hotspots.sort_by_key(|item| std::cmp::Reverse(item.1));
compute_hotspots.truncate(10);
let total_memory = graph.nodes.values().map(|n| n.memory_usage).sum::<u64>();
let arithmetic_intensity =
if total_memory > 0 { total_flops as f64 / total_memory as f64 } else { 0.0 };
let complexity_analysis = ComplexityAnalysis {
time_complexity: "O(n)".to_string(), space_complexity: "O(n)".to_string(), parallelization_potential: 0.7, sequential_dependencies: graph.metadata.max_depth,
};
Ok(FlopAnalysis {
total_flops,
flops_by_operation,
compute_hotspots,
arithmetic_intensity,
complexity_analysis,
})
}
fn detect_optimization_opportunities(
&self,
graph: &ComputationGraph,
) -> Result<Vec<OptimizationOpportunity>> {
let mut opportunities = Vec::new();
opportunities.extend(self.detect_fusion_opportunities(graph)?);
opportunities.extend(self.detect_redundancy_opportunities(graph)?);
opportunities.extend(self.detect_memory_optimizations(graph)?);
Ok(opportunities)
}
fn detect_fusion_opportunities(
&self,
graph: &ComputationGraph,
) -> Result<Vec<OptimizationOpportunity>> {
let mut opportunities = Vec::new();
for node in graph.nodes.values() {
if let OperationType::Add = node.operation_type {
let empty_deps = vec![];
let dependencies = graph.edges.get(&node.id).unwrap_or(&empty_deps);
for dep in dependencies {
if let Some(dep_node) = graph.nodes.get(dep) {
if let OperationType::MatMul = dep_node.operation_type {
opportunities.push(OptimizationOpportunity {
optimization_type: OptimizationType::OperationFusion,
description:
"Fuse MatMul and Add operations into a single GEMM operation"
.to_string(),
affected_nodes: vec![dep.clone(), node.id.clone()],
estimated_improvement: EstimatedImprovement {
speedup_factor: 1.2,
memory_reduction: 1024 * 1024, energy_savings: 0.1,
},
implementation_difficulty: 2,
priority: OptimizationPriority::Medium,
});
}
}
}
}
}
Ok(opportunities)
}
fn detect_redundancy_opportunities(
&self,
_graph: &ComputationGraph,
) -> Result<Vec<OptimizationOpportunity>> {
Ok(vec![])
}
fn detect_memory_optimizations(
&self,
graph: &ComputationGraph,
) -> Result<Vec<OptimizationOpportunity>> {
let mut opportunities = Vec::new();
for node in graph.nodes.values() {
if node.memory_usage > self.config.large_memory_threshold {
opportunities.push(OptimizationOpportunity {
optimization_type: OptimizationType::MemoryLayoutOptimization,
description: format!(
"Optimize memory layout for large operation: {}",
node.name
),
affected_nodes: vec![node.id.clone()],
estimated_improvement: EstimatedImprovement {
speedup_factor: 1.1,
memory_reduction: node.memory_usage / 4, energy_savings: 0.05,
},
implementation_difficulty: 3,
priority: OptimizationPriority::Medium,
});
}
}
Ok(opportunities)
}
fn analyze_bottlenecks(&self, graph: &ComputationGraph) -> Result<BottleneckAnalysis> {
let mut bottleneck_nodes = Vec::new();
let mut parallelizable_nodes = Vec::new();
for node in graph.nodes.values() {
if let Some(exec_time) = node.execution_time_us {
if exec_time > self.config.bottleneck_threshold_us {
bottleneck_nodes.push(node.id.clone());
}
}
match node.operation_type {
OperationType::MatMul | OperationType::Conv2D | OperationType::Add => {
parallelizable_nodes.push(node.id.clone());
},
_ => {},
}
}
let critical_path_nodes = self.find_critical_path(graph)?;
let critical_path_time_us = critical_path_nodes
.iter()
.filter_map(|id| graph.nodes.get(id))
.filter_map(|node| node.execution_time_us)
.sum();
let scheduling_suggestions = vec![
"Consider parallel execution of independent operations".to_string(),
"Use asynchronous execution for I/O operations".to_string(),
"Implement pipeline parallelism for sequential operations".to_string(),
];
Ok(BottleneckAnalysis {
bottleneck_nodes,
critical_path_nodes,
critical_path_time_us,
parallelizable_nodes,
scheduling_suggestions,
})
}
fn analyze_dataflow(&self, graph: &ComputationGraph) -> Result<DataFlowAnalysis> {
let mut data_dependencies = HashMap::new();
let mut live_variables = HashMap::new();
let mut variable_lifetimes = HashMap::new();
for (node_id, dependencies) in &graph.edges {
data_dependencies.insert(node_id.clone(), dependencies.clone());
live_variables.insert(node_id.clone(), dependencies.iter().cloned().collect());
for dep in dependencies {
if !variable_lifetimes.contains_key(dep) {
variable_lifetimes.insert(
dep.clone(),
VariableLifetime {
birth_node: dep.clone(),
death_node: node_id.clone(),
usage_nodes: vec![node_id.clone()],
memory_footprint: graph
.nodes
.get(dep)
.map(|n| n.memory_usage)
.unwrap_or(0),
},
);
} else {
let lifetime = variable_lifetimes
.get_mut(dep)
.expect("variable lifetime should exist for previously seen dependency");
lifetime.death_node = node_id.clone();
lifetime.usage_nodes.push(node_id.clone());
}
}
}
let memory_reuse_opportunities = vec![MemoryReuseOpportunity {
reusable_variables: vec!["var1".to_string(), "var2".to_string()],
memory_savings: 1024 * 1024, complexity: 2,
}];
Ok(DataFlowAnalysis {
data_dependencies,
live_variables,
variable_lifetimes,
memory_reuse_opportunities,
})
}
fn find_critical_path(&self, graph: &ComputationGraph) -> Result<Vec<String>> {
let mut path = Vec::new();
let mut current_depth = graph.metadata.max_depth;
while current_depth > 0 {
for node in graph.nodes.values() {
if node.depth == current_depth {
path.push(node.id.clone());
current_depth -= 1;
break;
}
}
current_depth = current_depth.saturating_sub(1);
}
path.reverse();
Ok(path)
}
fn calculate_statistics(&self, graph: &ComputationGraph) -> Result<GraphStatistics> {
let mut nodes_by_type: HashMap<OperationType, usize> = HashMap::new();
for node in graph.nodes.values() {
*nodes_by_type.entry(node.operation_type.clone()).or_insert(0) += 1;
}
let total_fan_in: usize = graph.edges.values().map(|deps| deps.len()).sum();
let total_fan_out = total_fan_in; let average_fan_in = total_fan_in as f64 / graph.nodes.len() as f64;
let average_fan_out = total_fan_out as f64 / graph.nodes.len() as f64;
Ok(GraphStatistics {
nodes_by_type,
average_fan_in,
average_fan_out,
diameter: graph.metadata.max_depth,
clustering_coefficient: 0.0, strongly_connected_components: graph.nodes.len(), })
}
fn generate_recommendations(&self, analysis: &GraphAnalysisResult) -> Result<Vec<String>> {
let mut recommendations = Vec::new();
if let Some(ref memory_analysis) = analysis.memory_analysis {
if memory_analysis.total_memory_usage > 1024 * 1024 * 1024 {
recommendations.push(
"Consider using gradient checkpointing to reduce memory usage".to_string(),
);
}
if memory_analysis.fragmentation_ratio > 0.2 {
recommendations
.push("Implement memory pooling to reduce fragmentation".to_string());
}
}
if let Some(ref flop_analysis) = analysis.flop_analysis {
if flop_analysis.arithmetic_intensity < 1.0 {
recommendations
.push("Consider kernel fusion to improve arithmetic intensity".to_string());
}
if flop_analysis.complexity_analysis.parallelization_potential > 0.5 {
recommendations.push(
"Explore parallelization opportunities for compute-intensive operations"
.to_string(),
);
}
}
if analysis.optimization_opportunities.len() > 3 {
recommendations.push(
"Multiple optimization opportunities detected - prioritize by estimated impact"
.to_string(),
);
}
if let Some(ref bottleneck_analysis) = analysis.bottleneck_analysis {
if !bottleneck_analysis.bottleneck_nodes.is_empty() {
recommendations.push(
"Address bottleneck operations through optimization or parallelization"
.to_string(),
);
}
}
Ok(recommendations)
}
fn get_node_color(&self, op_type: &OperationType) -> &'static str {
match op_type {
OperationType::MatMul | OperationType::Dot => "lightblue",
OperationType::Add
| OperationType::Subtract
| OperationType::Multiply
| OperationType::Divide => "lightgreen",
OperationType::ReLU
| OperationType::Sigmoid
| OperationType::Tanh
| OperationType::GELU => "orange",
OperationType::LayerNorm | OperationType::BatchNorm | OperationType::RMSNorm => {
"yellow"
},
OperationType::Conv1D | OperationType::Conv2D | OperationType::Conv3D => "lightcoral",
OperationType::Attention | OperationType::MultiHeadAttention => "purple",
OperationType::Embedding | OperationType::PositionalEmbedding => "pink",
_ => "lightgray",
}
}
}
impl fmt::Display for OperationType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OperationType::Custom(name) => write!(f, "Custom({})", name),
_ => write!(f, "{:?}", self),
}
}
}
impl Default for ComputationGraphAnalyzer {
fn default() -> Self {
Self::new(GraphAnalysisConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_computation_graph_creation() {
let mut analyzer = ComputationGraphAnalyzer::default();
let operations = vec![
(
"input".to_string(),
OperationType::Custom("Input".to_string()),
vec![],
),
(
"linear1".to_string(),
OperationType::MatMul,
vec!["input".to_string()],
),
(
"relu1".to_string(),
OperationType::ReLU,
vec!["linear1".to_string()],
),
(
"linear2".to_string(),
OperationType::MatMul,
vec!["relu1".to_string()],
),
(
"output".to_string(),
OperationType::Custom("Output".to_string()),
vec!["linear2".to_string()],
),
];
let graph_id = analyzer
.create_graph("test_model".to_string(), operations)
.expect("operation failed in test");
let analysis = analyzer.analyze_graph(graph_id).expect("operation failed in test");
assert_eq!(analysis.statistics.nodes_by_type.len(), 4); assert!(!analysis.critical_path.is_empty());
}
#[test]
fn test_optimization_detection() {
let mut analyzer = ComputationGraphAnalyzer::default();
let operations = vec![
(
"input".to_string(),
OperationType::Custom("Input".to_string()),
vec![],
),
(
"matmul".to_string(),
OperationType::MatMul,
vec!["input".to_string()],
),
(
"add".to_string(),
OperationType::Add,
vec!["matmul".to_string()],
),
];
let graph_id = analyzer
.create_graph("fusion_test".to_string(), operations)
.expect("operation failed in test");
let analysis = analyzer.analyze_graph(graph_id).expect("operation failed in test");
assert!(analysis
.optimization_opportunities
.iter()
.any(|op| op.optimization_type == OptimizationType::OperationFusion));
}
#[test]
fn test_dot_export() {
let mut analyzer = ComputationGraphAnalyzer::default();
let operations = vec![
("a".to_string(), OperationType::MatMul, vec![]),
("b".to_string(), OperationType::ReLU, vec!["a".to_string()]),
];
let graph_id = analyzer
.create_graph("simple".to_string(), operations)
.expect("operation failed in test");
let dot = analyzer.export_to_dot(graph_id).expect("operation failed in test");
assert!(dot.contains("digraph"));
assert!(dot.contains("MatMul"));
assert!(dot.contains("ReLU"));
}
}