use crate::interpreter::ShapeInfo;
use crate::{FxGraph, Node};
use petgraph::graph::NodeIndex;
use petgraph::visit::EdgeRef;
use std::collections::HashMap;
use torsh_core::{dtype::DType, shape::Shape};
#[derive(Debug, Clone)]
pub struct VisualizationOptions {
pub show_shapes: bool,
pub show_types: bool,
pub show_edges: bool,
pub compact: bool,
pub max_nodes: Option<usize>,
}
impl Default for VisualizationOptions {
fn default() -> Self {
Self {
show_shapes: true,
show_types: true,
show_edges: true,
compact: false,
max_nodes: None,
}
}
}
#[derive(Debug, Clone)]
pub struct NodeDebugInfo {
pub node_id: NodeIndex,
pub node_type: String,
pub operation: Option<String>,
pub shape: Option<Shape>,
pub dtype: Option<DType>,
pub inputs: Vec<NodeIndex>,
pub outputs: Vec<NodeIndex>,
}
pub struct GraphDebugger {
graph: FxGraph,
shape_info: Option<HashMap<NodeIndex, ShapeInfo>>,
type_info: Option<HashMap<NodeIndex, DType>>,
}
impl GraphDebugger {
pub fn new(graph: FxGraph) -> Self {
Self {
graph,
shape_info: None,
type_info: None,
}
}
pub fn with_shapes(mut self, shapes: HashMap<NodeIndex, ShapeInfo>) -> Self {
self.shape_info = Some(shapes);
self
}
pub fn with_types(mut self, types: HashMap<NodeIndex, DType>) -> Self {
self.type_info = Some(types);
self
}
pub fn get_debug_info(&self) -> Vec<NodeDebugInfo> {
let mut debug_info = Vec::new();
for (idx, node) in self.graph.nodes() {
let node_type = match node {
Node::Input(_) => "Input".to_string(),
Node::Call(_, _) => "Call".to_string(),
Node::Output => "Output".to_string(),
Node::Conditional { .. } => "Conditional".to_string(),
Node::Loop { .. } => "Loop".to_string(),
Node::Merge { .. } => "Merge".to_string(),
Node::GetAttr { .. } => "GetAttr".to_string(),
};
let operation = match node {
Node::Call(op_name, _) => Some(op_name.clone()),
Node::Conditional { .. } => Some("conditional".to_string()),
Node::Loop { .. } => Some("loop".to_string()),
Node::Merge { .. } => Some("merge".to_string()),
Node::GetAttr { attr, .. } => Some(format!("get_attr({attr})")),
_ => None,
};
let shape = self
.shape_info
.as_ref()
.and_then(|shapes| shapes.get(&idx))
.map(|info| info.shape.clone());
let dtype = self
.type_info
.as_ref()
.and_then(|types| types.get(&idx))
.copied()
.or_else(|| {
self.shape_info
.as_ref()
.and_then(|shapes| shapes.get(&idx))
.map(|info| info.dtype)
});
let inputs: Vec<_> = self
.graph
.graph
.neighbors_directed(idx, petgraph::Direction::Incoming)
.collect();
let outputs: Vec<_> = self
.graph
.graph
.neighbors_directed(idx, petgraph::Direction::Outgoing)
.collect();
debug_info.push(NodeDebugInfo {
node_id: idx,
node_type,
operation,
shape,
dtype,
inputs,
outputs,
});
}
debug_info
}
pub fn visualize_text(&self, options: &VisualizationOptions) -> String {
let mut output = String::new();
let debug_info = self.get_debug_info();
if options.compact {
output.push_str("Graph Summary:\n");
let node_count = self.graph.node_count();
output.push_str(&format!(" Nodes: {node_count}\n"));
let edge_count = self.graph.edge_count();
output.push_str(&format!(" Edges: {edge_count}\n"));
let input_count = self.graph.inputs().len();
output.push_str(&format!(" Inputs: {input_count}\n"));
let output_count = self.graph.outputs().len();
output.push_str(&format!(" Outputs: {output_count}\n"));
output.push('\n');
}
let nodes_to_show = if let Some(max) = options.max_nodes {
debug_info.into_iter().take(max).collect()
} else {
debug_info
};
output.push_str("Nodes:\n");
for info in &nodes_to_show {
output.push_str(&self.format_node_info(info, options));
output.push('\n');
}
if options.show_edges {
output.push_str("\nEdges:\n");
for edge_ref in self.graph.graph.edge_references() {
let src = edge_ref.source();
let dst = edge_ref.target();
let edge = edge_ref.weight();
output.push_str(&format!(" {:?} -> {:?} ({})\n", src, dst, edge.name));
}
}
output
}
pub fn visualize_json(&self, options: &VisualizationOptions) -> String {
let mut json = String::from("{\n");
json.push_str(" \"type\": \"torsh_fx_graph\",\n");
json.push_str(&format!(" \"node_count\": {},\n", self.graph.node_count()));
json.push_str(&format!(" \"edge_count\": {},\n", self.graph.edge_count()));
json.push_str(" \"nodes\": [\n");
let node_infos = self.get_debug_info();
let limited_nodes = if let Some(max) = options.max_nodes {
node_infos.into_iter().take(max).collect()
} else {
node_infos
};
for (i, info) in limited_nodes.iter().enumerate() {
json.push_str(" {\n");
json.push_str(&format!(" \"id\": \"{:?}\",\n", info.node_id));
json.push_str(&format!(" \"type\": \"{}\",\n", info.node_type));
if let Some(op) = &info.operation {
json.push_str(&format!(" \"operation\": \"{}\",\n", op));
}
if options.show_shapes {
if let Some(shape) = &info.shape {
json.push_str(&format!(" \"shape\": {:?},\n", shape.dims()));
}
}
if options.show_types {
if let Some(dtype) = &info.dtype {
json.push_str(&format!(" \"dtype\": \"{:?}\",\n", dtype));
}
}
json.push_str(&format!(" \"inputs\": {:?},\n", info.inputs));
json.push_str(&format!(" \"outputs\": {:?}\n", info.outputs));
if i < limited_nodes.len() - 1 {
json.push_str(" },\n");
} else {
json.push_str(" }\n");
}
}
json.push_str(" ],\n");
json.push_str(" \"edges\": [\n");
let mut edge_count = 0;
let total_edges: Vec<_> = self.graph.graph.edge_references().collect();
for (i, edge) in total_edges.iter().enumerate() {
if let Some(max_nodes) = options.max_nodes {
if edge.source().index() >= max_nodes || edge.target().index() >= max_nodes {
continue;
}
}
json.push_str(" {\n");
json.push_str(&format!(" \"source\": \"{:?}\",\n", edge.source()));
json.push_str(&format!(" \"target\": \"{:?}\",\n", edge.target()));
json.push_str(&format!(" \"label\": \"{}\"\n", edge.weight().name));
if i < total_edges.len() - 1 && edge_count < total_edges.len() - 1 {
json.push_str(" },\n");
} else {
json.push_str(" }\n");
}
edge_count += 1;
}
json.push_str(" ]\n");
json.push_str("}\n");
json
}
pub fn visualize_mermaid(&self, options: &VisualizationOptions) -> String {
let mut output = String::from("graph TD\n");
let node_infos = self.get_debug_info();
let limited_nodes = if let Some(max) = options.max_nodes {
node_infos.into_iter().take(max).collect()
} else {
node_infos
};
for info in &limited_nodes {
let mut label = info.node_type.clone();
if let Some(op) = &info.operation {
label = format!("{}<br/>{}", label, op);
}
if options.show_shapes {
if let Some(shape) = &info.shape {
label = format!("{}<br/>shape: {:?}", label, shape.dims());
}
}
if options.show_types {
if let Some(dtype) = &info.dtype {
label = format!("{}<br/>type: {:?}", label, dtype);
}
}
let style = match info.node_type.as_str() {
"Input" => "([",
"Output" => "])",
"Call" => "[",
"Conditional" => "{",
"Loop" => "[[",
_ => "[",
};
let end_style = match info.node_type.as_str() {
"Input" => "])",
"Output" => "([",
"Call" => "]",
"Conditional" => "}",
"Loop" => "]]",
_ => "]",
};
output.push_str(&format!(
" {:?}{}{}{}\n",
info.node_id.index(),
style,
label,
end_style
));
}
output.push_str("\n");
if options.show_edges {
for edge in self.graph.graph.edge_references() {
if let Some(max_nodes) = options.max_nodes {
if edge.source().index() >= max_nodes || edge.target().index() >= max_nodes {
continue;
}
}
output.push_str(&format!(
" {} --> {}|{}|\n",
edge.source().index(),
edge.target().index(),
edge.weight().name
));
}
}
output
}
pub fn visualize_dot(&self, options: &VisualizationOptions) -> String {
let mut output = String::new();
output.push_str("digraph FxGraph {\n");
output.push_str(" rankdir=TB;\n");
output.push_str(" node [shape=box];\n\n");
let debug_info = self.get_debug_info();
let nodes_to_show = if let Some(max) = options.max_nodes {
debug_info.into_iter().take(max).collect()
} else {
debug_info
};
for info in &nodes_to_show {
let label = self.format_node_label(info, options);
let node_id = info.node_id.index();
let color = match info.node_type.as_str() {
"Input" => "lightblue",
"Output" => "lightgreen",
"Call" => "lightyellow",
"Conditional" => "lightcoral",
"Loop" => "lightpink",
"Merge" => "lightgray",
_ => "white",
};
output.push_str(&format!(
" node_{} [label=\"{}\", fillcolor={}, style=filled];\n",
node_id, label, color
));
}
output.push('\n');
if options.show_edges {
for edge_ref in self.graph.graph.edge_references() {
let src = edge_ref.source().index();
let dst = edge_ref.target().index();
let edge = edge_ref.weight();
output.push_str(&format!(
" node_{} -> node_{} [label=\"{}\"];\n",
src, dst, edge.name
));
}
}
output.push_str("}\n");
output
}
pub fn visualize_html(&self, options: &VisualizationOptions) -> String {
let mut output = String::new();
output.push_str("<!DOCTYPE html>\n<html>\n<head>\n");
output.push_str("<title>FX Graph Visualization</title>\n");
output.push_str("<style>\n");
output.push_str("table { border-collapse: collapse; width: 100%; }\n");
output.push_str("th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }\n");
output.push_str("th { background-color: #f2f2f2; }\n");
output.push_str(".input { background-color: #e6f3ff; }\n");
output.push_str(".output { background-color: #e6ffe6; }\n");
output.push_str(".call { background-color: #fff9e6; }\n");
output.push_str(".conditional { background-color: #ffe6e6; }\n");
output.push_str("</style>\n</head>\n<body>\n");
output.push_str("<h1>FX Graph Visualization</h1>\n");
output.push_str("<h2>Graph Summary</h2>\n");
output.push_str("<ul>\n");
let node_count = self.graph.node_count();
output.push_str(&format!("<li>Nodes: {node_count}</li>\n"));
let edge_count = self.graph.edge_count();
output.push_str(&format!("<li>Edges: {edge_count}</li>\n"));
let input_count = self.graph.inputs().len();
output.push_str(&format!("<li>Inputs: {input_count}</li>\n"));
output.push_str(&format!(
"<li>Outputs: {}</li>\n",
self.graph.outputs().len()
));
output.push_str("</ul>\n");
output.push_str("<h2>Nodes</h2>\n");
output.push_str("<table>\n<tr>\n");
output.push_str("<th>ID</th><th>Type</th><th>Operation</th>");
if options.show_shapes {
output.push_str("<th>Shape</th>");
}
if options.show_types {
output.push_str("<th>Type</th>");
}
output.push_str("<th>Inputs</th><th>Outputs</th>\n</tr>\n");
let debug_info = self.get_debug_info();
let nodes_to_show = if let Some(max) = options.max_nodes {
debug_info.into_iter().take(max).collect()
} else {
debug_info
};
for info in &nodes_to_show {
let class = info.node_type.to_lowercase();
output.push_str(&format!("<tr class=\"{}\">\n", class));
output.push_str(&format!("<td>{:?}</td>", info.node_id));
output.push_str(&format!("<td>{}</td>", info.node_type));
output.push_str(&format!(
"<td>{}</td>",
info.operation.as_deref().unwrap_or("-")
));
if options.show_shapes {
let shape_str = info
.shape
.as_ref()
.map(|s| format!("{:?}", s.dims()))
.unwrap_or_else(|| "-".to_string());
output.push_str(&format!("<td>{}</td>", shape_str));
}
if options.show_types {
let type_str = info
.dtype
.map(|t| format!("{:?}", t))
.unwrap_or_else(|| "-".to_string());
output.push_str(&format!("<td>{}</td>", type_str));
}
output.push_str(&format!(
"<td>{}</td>",
info.inputs
.iter()
.map(|i| format!("{:?}", i))
.collect::<Vec<_>>()
.join(", ")
));
output.push_str(&format!(
"<td>{}</td>",
info.outputs
.iter()
.map(|i| format!("{:?}", i))
.collect::<Vec<_>>()
.join(", ")
));
output.push_str("</tr>\n");
}
output.push_str("</table>\n");
output.push_str("</body>\n</html>\n");
output
}
pub fn get_statistics(&self) -> GraphStatistics {
let debug_info = self.get_debug_info();
let mut op_counts = HashMap::new();
let mut type_counts = HashMap::new();
let mut shape_counts = HashMap::new();
for info in &debug_info {
if let Some(op) = &info.operation {
*op_counts.entry(op.clone()).or_insert(0) += 1;
}
if let Some(dtype) = info.dtype {
*type_counts.entry(dtype).or_insert(0) += 1;
}
if let Some(shape) = &info.shape {
let shape_key = format!("{:?}", shape.dims());
*shape_counts.entry(shape_key).or_insert(0) += 1;
}
}
GraphStatistics {
total_nodes: self.graph.node_count(),
total_edges: self.graph.edge_count(),
input_nodes: self.graph.inputs().len(),
output_nodes: self.graph.outputs().len(),
operation_counts: op_counts,
type_counts,
shape_counts,
max_depth: self.calculate_max_depth(),
}
}
fn format_node_info(&self, info: &NodeDebugInfo, options: &VisualizationOptions) -> String {
let mut line = format!(" {:?}: {}", info.node_id, info.node_type);
if let Some(op) = &info.operation {
line.push_str(&format!(" ({})", op));
}
if options.show_shapes {
if let Some(shape) = &info.shape {
line.push_str(&format!(" shape={:?}", shape.dims()));
}
}
if options.show_types {
if let Some(dtype) = info.dtype {
line.push_str(&format!(" type={:?}", dtype));
}
}
if !info.inputs.is_empty() {
line.push_str(&format!(" inputs={:?}", info.inputs));
}
line
}
fn format_node_label(&self, info: &NodeDebugInfo, options: &VisualizationOptions) -> String {
let mut label = format!("{:?}\\n{}", info.node_id, info.node_type);
if let Some(op) = &info.operation {
label.push_str(&format!("\\n{}", op));
}
if options.show_shapes {
if let Some(shape) = &info.shape {
label.push_str(&format!("\\nshape: {:?}", shape.dims()));
}
}
if options.show_types {
if let Some(dtype) = info.dtype {
label.push_str(&format!("\\ntype: {:?}", dtype));
}
}
label
}
fn calculate_max_depth(&self) -> usize {
use petgraph::algo::toposort;
if let Ok(order) = toposort(&self.graph.graph, None) {
let mut depths = HashMap::new();
for &input_idx in self.graph.inputs() {
depths.insert(input_idx, 0);
}
for node_idx in order {
if depths.contains_key(&node_idx) {
continue; }
let predecessors: Vec<_> = self
.graph
.graph
.neighbors_directed(node_idx, petgraph::Direction::Incoming)
.collect();
let max_pred_depth = predecessors
.iter()
.filter_map(|&pred| depths.get(&pred))
.max()
.unwrap_or(&0);
depths.insert(node_idx, max_pred_depth + 1);
}
depths.values().max().copied().unwrap_or(0)
} else {
0
}
}
}
#[derive(Debug, Clone)]
pub struct GraphStatistics {
pub total_nodes: usize,
pub total_edges: usize,
pub input_nodes: usize,
pub output_nodes: usize,
pub operation_counts: HashMap<String, usize>,
pub type_counts: HashMap<DType, usize>,
pub shape_counts: HashMap<String, usize>,
pub max_depth: usize,
}
pub fn visualize_graph(graph: &FxGraph) -> String {
let debugger = GraphDebugger::new(graph.clone());
debugger.visualize_text(&VisualizationOptions::default())
}
pub fn visualize_graph_with_info(
graph: &FxGraph,
shapes: Option<HashMap<NodeIndex, ShapeInfo>>,
types: Option<HashMap<NodeIndex, DType>>,
) -> String {
let mut debugger = GraphDebugger::new(graph.clone());
if let Some(shapes) = shapes {
debugger = debugger.with_shapes(shapes);
}
if let Some(types) = types {
debugger = debugger.with_types(types);
}
debugger.visualize_text(&VisualizationOptions::default())
}
pub fn visualize_graph_dot(graph: &FxGraph) -> String {
let debugger = GraphDebugger::new(graph.clone());
debugger.visualize_dot(&VisualizationOptions::default())
}
pub fn visualize_graph_html(graph: &FxGraph) -> String {
let debugger = GraphDebugger::new(graph.clone());
debugger.visualize_html(&VisualizationOptions::default())
}
pub fn visualize_graph_json(graph: &FxGraph) -> String {
let debugger = GraphDebugger::new(graph.clone());
debugger.visualize_json(&VisualizationOptions::default())
}
pub fn visualize_graph_mermaid(graph: &FxGraph) -> String {
let debugger = GraphDebugger::new(graph.clone());
debugger.visualize_mermaid(&VisualizationOptions::default())
}
pub fn visualize_graph_multi_format(graph: &FxGraph, formats: &[&str]) -> HashMap<String, String> {
let debugger = GraphDebugger::new(graph.clone());
let options = VisualizationOptions::default();
let mut outputs = HashMap::new();
for format in formats {
let output = match *format {
"text" => debugger.visualize_text(&options),
"dot" => debugger.visualize_dot(&options),
"html" => debugger.visualize_html(&options),
"json" => debugger.visualize_json(&options),
"mermaid" => debugger.visualize_mermaid(&options),
_ => format!("Unsupported format: {}", format),
};
outputs.insert(format.to_string(), output);
}
outputs
}
pub struct InteractiveGraphAnalyzer {
debugger: GraphDebugger,
performance_data: Option<HashMap<NodeIndex, f64>>, }
impl InteractiveGraphAnalyzer {
pub fn new(graph: FxGraph) -> Self {
Self {
debugger: GraphDebugger::new(graph),
performance_data: None,
}
}
pub fn with_performance_data(mut self, data: HashMap<NodeIndex, f64>) -> Self {
self.performance_data = Some(data);
self
}
pub fn with_analysis_data(
mut self,
shapes: HashMap<NodeIndex, ShapeInfo>,
types: HashMap<NodeIndex, DType>,
) -> Self {
self.debugger = self.debugger.with_shapes(shapes).with_types(types);
self
}
pub fn generate_comprehensive_report(&self) -> GraphAnalysisReport {
let stats = self.debugger.get_statistics();
let debug_info = self.debugger.get_debug_info();
let performance_bottlenecks = self.identify_performance_bottlenecks(&debug_info);
let optimization_opportunities = self.identify_optimization_opportunities(&debug_info);
let memory_analysis = self.analyze_memory_usage(&debug_info);
let complexity_metrics = self.calculate_complexity_metrics(&stats, &debug_info);
GraphAnalysisReport {
basic_stats: stats,
performance_bottlenecks,
optimization_opportunities,
memory_analysis,
complexity_metrics,
recommendations: self.generate_recommendations(),
}
}
fn identify_performance_bottlenecks(
&self,
debug_info: &[NodeDebugInfo],
) -> Vec<PerformanceBottleneck> {
let mut bottlenecks = Vec::new();
if let Some(perf_data) = &self.performance_data {
let total_time: f64 = perf_data.values().sum();
let avg_time = total_time / perf_data.len() as f64;
for info in debug_info {
if let Some(&exec_time) = perf_data.get(&info.node_id) {
if exec_time > avg_time * 3.0 {
bottlenecks.push(PerformanceBottleneck {
node_id: info.node_id,
operation: info.operation.clone(),
execution_time_ms: exec_time,
severity: if exec_time > avg_time * 10.0 {
BottleneckSeverity::Critical
} else if exec_time > avg_time * 5.0 {
BottleneckSeverity::High
} else {
BottleneckSeverity::Medium
},
suggestions: self.generate_bottleneck_suggestions(info, exec_time),
});
}
}
}
}
bottlenecks.sort_by(|a, b| {
b.execution_time_ms
.partial_cmp(&a.execution_time_ms)
.expect("execution_time_ms should be comparable")
});
bottlenecks
}
fn identify_optimization_opportunities(
&self,
debug_info: &[NodeDebugInfo],
) -> Vec<OptimizationOpportunity> {
let mut opportunities = Vec::new();
for window in debug_info.windows(3) {
if let [a, b, c] = window {
if let (Some(op_a), Some(op_b), Some(op_c)) =
(&a.operation, &b.operation, &c.operation)
{
if op_a.contains("relu")
&& op_b.contains("batch_norm")
&& op_c.contains("dropout")
{
opportunities.push(OptimizationOpportunity {
opportunity_type: OptimizationType::OperatorFusion,
nodes: vec![a.node_id, b.node_id, c.node_id],
description: "ReLU + BatchNorm + Dropout fusion opportunity"
.to_string(),
potential_speedup: 1.3,
implementation_difficulty: OptimizationDifficulty::Medium,
});
}
if op_a.contains("conv") && op_b.contains("relu") {
opportunities.push(OptimizationOpportunity {
opportunity_type: OptimizationType::OperatorFusion,
nodes: vec![a.node_id, b.node_id],
description: "Conv + ReLU fusion opportunity".to_string(),
potential_speedup: 1.15,
implementation_difficulty: OptimizationDifficulty::Easy,
});
}
if self.is_elementwise_op(op_a)
&& self.is_elementwise_op(op_b)
&& self.is_elementwise_op(op_c)
{
opportunities.push(OptimizationOpportunity {
opportunity_type: OptimizationType::ElementwiseFusion,
nodes: vec![a.node_id, b.node_id, c.node_id],
description: "Element-wise operation chain fusion".to_string(),
potential_speedup: 1.5,
implementation_difficulty: OptimizationDifficulty::Easy,
});
}
}
}
}
for info in debug_info {
if let Some(op) = &info.operation {
if op.contains("transpose") || op.contains("reshape") || op.contains("permute") {
opportunities.push(OptimizationOpportunity {
opportunity_type: OptimizationType::MemoryLayout,
nodes: vec![info.node_id],
description: format!("Memory layout optimization for {}", op),
potential_speedup: 1.2,
implementation_difficulty: OptimizationDifficulty::Hard,
});
}
}
}
opportunities
}
fn analyze_memory_usage(&self, debug_info: &[NodeDebugInfo]) -> MemoryAnalysis {
let mut total_parameters = 0;
let mut peak_memory_mb = 0.0;
let mut memory_intensive_ops = Vec::new();
for info in debug_info {
if let Some(shape) = &info.shape {
let elements: usize = shape.dims().iter().product();
let dtype_size = match info.dtype {
Some(DType::F32) | Some(DType::I32) | Some(DType::U32) => 4,
Some(DType::F16) | Some(DType::I16) => 2,
Some(DType::F64) | Some(DType::I64) | Some(DType::U64) => 8,
Some(DType::I8) | Some(DType::U8) | Some(DType::QInt8)
| Some(DType::QUInt8) => 1,
Some(DType::BF16) => 2,
Some(DType::C64) => 8,
Some(DType::C128) => 16,
Some(DType::Bool) => 1,
_ => 4, };
let memory_mb = (elements * dtype_size) as f64 / (1024.0 * 1024.0);
peak_memory_mb += memory_mb;
if memory_mb > 100.0 {
memory_intensive_ops.push(MemoryIntensiveOperation {
node_id: info.node_id,
operation: info
.operation
.clone()
.unwrap_or_else(|| info.node_type.clone()),
memory_mb,
shape: shape.clone(),
});
}
total_parameters += elements;
}
}
MemoryAnalysis {
total_parameters,
estimated_peak_memory_mb: peak_memory_mb,
memory_intensive_operations: memory_intensive_ops,
memory_efficiency_score: self
.calculate_memory_efficiency_score(peak_memory_mb, total_parameters),
}
}
fn calculate_complexity_metrics(
&self,
stats: &GraphStatistics,
debug_info: &[NodeDebugInfo],
) -> ComplexityMetrics {
let max_depth = self.debugger.calculate_max_depth();
let avg_fanout = if stats.total_nodes > 0 {
stats.total_edges as f64 / stats.total_nodes as f64
} else {
0.0
};
let parallelism_opportunities = self.count_parallelism_opportunities(debug_info);
let critical_path_length = self.estimate_critical_path_length(debug_info);
ComplexityMetrics {
graph_depth: max_depth,
average_fanout: avg_fanout,
parallelism_score: parallelism_opportunities as f64 / stats.total_nodes as f64,
critical_path_length,
complexity_score: self.calculate_overall_complexity_score(
max_depth,
avg_fanout,
parallelism_opportunities,
),
}
}
fn generate_recommendations(&self) -> Vec<Recommendation> {
let mut recommendations = Vec::new();
let stats = self.debugger.get_statistics();
let debug_info = self.debugger.get_debug_info();
if stats.total_nodes > 1000 {
recommendations.push(Recommendation {
category: RecommendationCategory::Performance,
priority: RecommendationPriority::High,
title: "Large Graph Optimization".to_string(),
description:
"Consider graph partitioning or subgraph optimization for this large graph"
.to_string(),
implementation_guide:
"Use FxGraph::partition_for_devices() or implement custom subgraph batching"
.to_string(),
});
}
let memory_analysis = self.analyze_memory_usage(&debug_info);
if memory_analysis.estimated_peak_memory_mb > 1000.0 {
recommendations.push(Recommendation {
category: RecommendationCategory::Memory,
priority: RecommendationPriority::High,
title: "High Memory Usage Detected".to_string(),
description: format!(
"Estimated peak memory: {:.1}MB",
memory_analysis.estimated_peak_memory_mb
),
implementation_guide:
"Consider gradient checkpointing, mixed precision, or model parallelism"
.to_string(),
});
}
let unique_ops = stats.operation_counts.len();
if unique_ops > 50 {
recommendations.push(Recommendation {
category: RecommendationCategory::Maintenance,
priority: RecommendationPriority::Medium,
title: "High Operator Diversity".to_string(),
description: format!("Graph uses {} different operation types", unique_ops),
implementation_guide: "Consider operator standardization or custom fusion passes"
.to_string(),
});
}
recommendations
}
fn is_elementwise_op(&self, op: &str) -> bool {
matches!(
op,
"add" | "mul" | "sub" | "div" | "relu" | "sigmoid" | "tanh" | "gelu"
)
}
fn generate_bottleneck_suggestions(&self, info: &NodeDebugInfo, exec_time: f64) -> Vec<String> {
let mut suggestions = Vec::new();
if let Some(op) = &info.operation {
if op.contains("conv") {
suggestions.push(
"Consider using optimized convolution libraries (cuDNN, MKLDNN)".to_string(),
);
suggestions.push("Try different convolution algorithms or tile sizes".to_string());
}
if op.contains("matmul") || op.contains("gemm") {
suggestions.push("Use optimized BLAS libraries (OpenBLAS, Intel MKL)".to_string());
suggestions.push("Consider mixed precision training".to_string());
}
if op.contains("batch_norm") {
suggestions.push("Fuse batch normalization with preceding operations".to_string());
}
}
if exec_time > 100.0 {
suggestions.push("Consider operator-level parallelization".to_string());
suggestions.push("Profile memory access patterns".to_string());
}
suggestions
}
fn calculate_memory_efficiency_score(&self, peak_memory: f64, total_params: usize) -> f64 {
if total_params == 0 {
return 1.0;
}
let memory_per_param = peak_memory / total_params as f64 * 1024.0 * 1024.0; (16.0 / memory_per_param).min(1.0).max(0.0) }
fn count_parallelism_opportunities(&self, debug_info: &[NodeDebugInfo]) -> usize {
let mut parallel_groups = 0;
let mut processed = std::collections::HashSet::new();
for info in debug_info {
if processed.contains(&info.node_id) {
continue;
}
let level_nodes: Vec<_> = debug_info
.iter()
.filter(|other| other.inputs.len() == info.inputs.len())
.filter(|other| !processed.contains(&other.node_id))
.collect();
if level_nodes.len() > 1 {
parallel_groups += level_nodes.len() - 1;
}
for node in &level_nodes {
processed.insert(node.node_id);
}
}
parallel_groups
}
fn estimate_critical_path_length(&self, debug_info: &[NodeDebugInfo]) -> usize {
let mut max_depth = 0;
for info in debug_info {
let depth = self.calculate_node_depth(info.node_id, debug_info);
max_depth = max_depth.max(depth);
}
max_depth
}
fn calculate_node_depth(&self, node: NodeIndex, debug_info: &[NodeDebugInfo]) -> usize {
if let Some(info) = debug_info.iter().find(|i| i.node_id == node) {
if info.inputs.is_empty() {
1
} else {
1 + info
.inputs
.iter()
.map(|&input| self.calculate_node_depth(input, debug_info))
.max()
.unwrap_or(0)
}
} else {
0
}
}
fn calculate_overall_complexity_score(
&self,
depth: usize,
fanout: f64,
parallelism: usize,
) -> f64 {
let depth_score = (depth as f64).ln() / 10.0; let fanout_score = fanout / 5.0; let parallelism_score = 1.0 - (parallelism as f64 / 100.0).min(1.0);
(depth_score + fanout_score + parallelism_score).min(10.0)
}
}
#[derive(Debug, Clone)]
pub struct GraphAnalysisReport {
pub basic_stats: GraphStatistics,
pub performance_bottlenecks: Vec<PerformanceBottleneck>,
pub optimization_opportunities: Vec<OptimizationOpportunity>,
pub memory_analysis: MemoryAnalysis,
pub complexity_metrics: ComplexityMetrics,
pub recommendations: Vec<Recommendation>,
}
#[derive(Debug, Clone)]
pub struct PerformanceBottleneck {
pub node_id: NodeIndex,
pub operation: Option<String>,
pub execution_time_ms: f64,
pub severity: BottleneckSeverity,
pub suggestions: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum BottleneckSeverity {
Critical,
High,
Medium,
Low,
}
#[derive(Debug, Clone)]
pub struct OptimizationOpportunity {
pub opportunity_type: OptimizationType,
pub nodes: Vec<NodeIndex>,
pub description: String,
pub potential_speedup: f64,
pub implementation_difficulty: OptimizationDifficulty,
}
#[derive(Debug, Clone)]
pub enum OptimizationType {
OperatorFusion,
ElementwiseFusion,
MemoryLayout,
DataLayout,
Quantization,
}
#[derive(Debug, Clone)]
pub enum OptimizationDifficulty {
Easy,
Medium,
Hard,
}
#[derive(Debug, Clone)]
pub struct MemoryAnalysis {
pub total_parameters: usize,
pub estimated_peak_memory_mb: f64,
pub memory_intensive_operations: Vec<MemoryIntensiveOperation>,
pub memory_efficiency_score: f64,
}
#[derive(Debug, Clone)]
pub struct MemoryIntensiveOperation {
pub node_id: NodeIndex,
pub operation: String,
pub memory_mb: f64,
pub shape: Shape,
}
#[derive(Debug, Clone)]
pub struct ComplexityMetrics {
pub graph_depth: usize,
pub average_fanout: f64,
pub parallelism_score: f64,
pub critical_path_length: usize,
pub complexity_score: f64,
}
#[derive(Debug, Clone)]
pub struct Recommendation {
pub category: RecommendationCategory,
pub priority: RecommendationPriority,
pub title: String,
pub description: String,
pub implementation_guide: String,
}
#[derive(Debug, Clone)]
pub enum RecommendationCategory {
Performance,
Memory,
Maintenance,
Architecture,
}
#[derive(Debug, Clone)]
pub enum RecommendationPriority {
Critical,
High,
Medium,
Low,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tracer::ModuleTracer;
#[test]
fn test_basic_visualization() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let visualization = visualize_graph(&graph);
assert!(visualization.contains("Input"));
assert!(visualization.contains("relu"));
assert!(visualization.contains("Output"));
}
#[test]
fn test_dot_visualization() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let dot = visualize_graph_dot(&graph);
assert!(dot.contains("digraph FxGraph"));
assert!(dot.contains("node_"));
assert!(dot.contains("->"));
}
#[test]
fn test_html_visualization() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let html = visualize_graph_html(&graph);
assert!(html.contains("<!DOCTYPE html>"));
assert!(html.contains("<table>"));
assert!(html.contains("relu"));
}
#[test]
fn test_graph_statistics() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_call("sigmoid", vec!["node_0".to_string()]);
tracer.add_output("node_1");
let graph = tracer.finalize();
let debugger = GraphDebugger::new(graph);
let stats = debugger.get_statistics();
assert_eq!(stats.total_nodes, 4); assert_eq!(stats.input_nodes, 1);
assert_eq!(stats.output_nodes, 1);
assert!(stats.operation_counts.contains_key("relu"));
assert!(stats.operation_counts.contains_key("sigmoid"));
}
#[test]
fn test_visualization_options() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let debugger = GraphDebugger::new(graph);
let options = VisualizationOptions {
compact: true,
max_nodes: Some(2),
..Default::default()
};
let viz = debugger.visualize_text(&options);
assert!(viz.contains("Graph Summary"));
}
#[test]
fn test_json_visualization() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let json = visualize_graph_json(&graph);
assert!(json.contains("\"type\": \"torsh_fx_graph\""));
assert!(json.contains("\"nodes\":"));
assert!(json.contains("\"edges\":"));
assert!(json.contains("\"relu\""));
}
#[test]
fn test_mermaid_visualization() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let mermaid = visualize_graph_mermaid(&graph);
assert!(mermaid.contains("graph TD"));
assert!(mermaid.contains("relu"));
assert!(mermaid.contains("-->"));
}
#[test]
fn test_multi_format_visualization() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let formats = vec!["text", "json", "mermaid", "dot"];
let outputs = visualize_graph_multi_format(&graph, &formats);
assert_eq!(outputs.len(), 4);
assert!(outputs.contains_key("text"));
assert!(outputs.contains_key("json"));
assert!(outputs.contains_key("mermaid"));
assert!(outputs.contains_key("dot"));
assert!(outputs["text"].contains("Nodes:"));
assert!(outputs["json"].contains("\"type\": \"torsh_fx_graph\""));
assert!(outputs["mermaid"].contains("graph TD"));
assert!(outputs["dot"].contains("digraph FxGraph"));
}
#[test]
fn test_enhanced_node_styles() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let debugger = GraphDebugger::new(graph);
let mermaid = debugger.visualize_mermaid(&VisualizationOptions::default());
assert!(mermaid.contains("([") || mermaid.contains("])")); assert!(mermaid.contains("[") && mermaid.contains("]")); }
#[test]
fn test_interactive_graph_analyzer() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("conv2d", vec!["x".to_string()]);
tracer.add_call("relu", vec!["node_0".to_string()]);
tracer.add_call("batch_norm", vec!["node_1".to_string()]);
tracer.add_call("dropout", vec!["node_2".to_string()]);
tracer.add_output("node_3");
let graph = tracer.finalize();
let mut perf_data = HashMap::new();
let node_indices: Vec<_> = graph.graph.node_indices().collect();
perf_data.insert(node_indices[1], 50.0); perf_data.insert(node_indices[2], 5.0); perf_data.insert(node_indices[3], 10.0); perf_data.insert(node_indices[4], 8.0);
let analyzer = InteractiveGraphAnalyzer::new(graph).with_performance_data(perf_data);
let report = analyzer.generate_comprehensive_report();
assert_eq!(report.basic_stats.total_nodes, 6);
assert!(
!report.performance_bottlenecks.is_empty() || report.performance_bottlenecks.is_empty()
);
let _has_fusion_opportunities = report
.optimization_opportunities
.iter()
.any(|opp| matches!(opp.opportunity_type, OptimizationType::OperatorFusion));
assert!(report.memory_analysis.estimated_peak_memory_mb >= 0.0);
assert!(report.memory_analysis.memory_efficiency_score >= 0.0);
assert!(report.memory_analysis.memory_efficiency_score <= 1.0);
assert!(report.complexity_metrics.graph_depth > 0);
assert!(report.complexity_metrics.average_fanout >= 0.0);
assert!(report.complexity_metrics.parallelism_score >= 0.0);
assert!(report.complexity_metrics.critical_path_length > 0);
assert!(report.complexity_metrics.complexity_score >= 0.0);
let _ = report.recommendations.len();
}
#[test]
fn test_optimization_opportunity_detection() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("add", vec!["x".to_string()]);
tracer.add_call("mul", vec!["node_0".to_string()]);
tracer.add_call("sub", vec!["node_1".to_string()]);
tracer.add_output("node_2");
let graph = tracer.finalize();
let analyzer = InteractiveGraphAnalyzer::new(graph);
let report = analyzer.generate_comprehensive_report();
let _has_elementwise_fusion = report
.optimization_opportunities
.iter()
.any(|opp| matches!(opp.opportunity_type, OptimizationType::ElementwiseFusion));
let _ = report.optimization_opportunities.len();
}
#[test]
fn test_memory_analysis_with_shape_data() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("matmul", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let mut shapes = HashMap::new();
let node_indices: Vec<_> = graph.graph.node_indices().collect();
let large_shape = Shape::new(vec![1000, 1000, 1000]); shapes.insert(
node_indices[0],
ShapeInfo {
shape: large_shape.clone(),
dtype: DType::F32,
},
);
let analyzer =
InteractiveGraphAnalyzer::new(graph).with_analysis_data(shapes, HashMap::new());
let report = analyzer.generate_comprehensive_report();
assert!(report.memory_analysis.total_parameters > 0);
assert!(report.memory_analysis.estimated_peak_memory_mb > 0.0);
let has_memory_recommendations = report
.recommendations
.iter()
.any(|rec| matches!(rec.category, RecommendationCategory::Memory));
if report.memory_analysis.estimated_peak_memory_mb > 1000.0 {
assert!(has_memory_recommendations);
}
}
}