use std::time::Duration;
use tensorlogic_ir::{
EinsumGraph, EinsumNode, ExecutionProfile, NodeStats, OptimizationHint, ProfileGuidedOptimizer,
};
fn main() {
println!("=== Profile-Guided Optimization in TensorLogic ===\n");
example_basic_profiling();
example_node_statistics();
example_hot_nodes();
example_memory_intensive();
example_profile_merging();
example_optimization_hints();
example_profile_serialization();
example_complete_workflow();
}
fn example_basic_profiling() {
println!("--- Example 1: Basic Profiling ---");
let mut profile = ExecutionProfile::new();
profile.record_node(0, Duration::from_millis(10), 1024);
profile.record_node(0, Duration::from_millis(12), 1024);
profile.record_node(0, Duration::from_millis(11), 1024);
println!("Recorded 3 executions of node 0");
if let Some(stats) = profile.node_stats.get(&0) {
println!(" Execution count: {}", stats.execution_count);
println!(" Total time: {:?}", stats.total_time);
println!(" Average time: {:?}", stats.avg_time());
println!(" Min time: {:?}", stats.min_time);
println!(" Max time: {:?}", stats.max_time);
println!(" Peak memory: {} bytes", stats.peak_memory);
}
println!();
}
fn example_node_statistics() {
println!("--- Example 2: Node Statistics ---");
let mut stats = NodeStats::new();
stats.record_execution(Duration::from_millis(50), 2048);
stats.record_execution(Duration::from_millis(75), 3072);
stats.record_execution(Duration::from_millis(60), 2560);
println!("Node statistics:");
println!(" Executions: {}", stats.execution_count);
println!(" Average time: {:?}", stats.avg_time());
println!(" Time variance (max-min): {:?}", stats.time_variance());
println!(" Peak memory: {} bytes", stats.peak_memory);
println!(" Is hot (threshold=10): {}", stats.is_hot(10));
println!(" Is hot (threshold=2): {}", stats.is_hot(2));
println!(" Performance score: {:.2}\n", stats.performance_score());
}
fn example_hot_nodes() {
println!("--- Example 3: Hot Node Identification ---");
let mut profile = ExecutionProfile::new();
for _ in 0..100 {
profile.record_node(0, Duration::from_millis(5), 512);
}
for _ in 0..3 {
profile.record_node(1, Duration::from_millis(500), 10240);
}
for _ in 0..20 {
profile.record_node(2, Duration::from_millis(50), 2048);
}
println!("Execution summary:");
println!(" Node 0: 100 executions @ 5ms");
println!(" Node 1: 3 executions @ 500ms");
println!(" Node 2: 20 executions @ 50ms");
let hot_nodes = profile.get_hot_nodes(3);
println!("\nTop 3 hot nodes (by performance score):");
for (i, (node_id, score)) in hot_nodes.iter().enumerate() {
println!(" {}. Node {} (score: {:.2})", i + 1, node_id, score);
}
println!();
}
fn example_memory_intensive() {
println!("--- Example 4: Memory-Intensive Operations ---");
let mut profile = ExecutionProfile::new();
profile.record_node(0, Duration::from_millis(10), 1024);
profile.record_node(1, Duration::from_millis(20), 50 * 1024 * 1024);
profile.record_node(2, Duration::from_millis(30), 200 * 1024 * 1024);
let threshold = 100 * 1024 * 1024; let memory_nodes = profile.get_memory_intensive_nodes(threshold);
println!("Memory-intensive nodes (>= 100 MB):");
for node_id in &memory_nodes {
if let Some(stats) = profile.node_stats.get(node_id) {
let mb = stats.peak_memory as f64 / (1024.0 * 1024.0);
println!(" Node {}: {:.2} MB", node_id, mb);
}
}
println!();
}
fn example_profile_merging() {
println!("--- Example 5: Profile Merging ---");
let mut profile1 = ExecutionProfile::new();
profile1.record_node(0, Duration::from_millis(100), 1024);
profile1.record_node(1, Duration::from_millis(200), 2048);
profile1.total_executions = 1;
let mut profile2 = ExecutionProfile::new();
profile2.record_node(0, Duration::from_millis(110), 1024);
profile2.record_node(1, Duration::from_millis(210), 2048);
profile2.record_node(2, Duration::from_millis(50), 512); profile2.total_executions = 1;
println!(
"Profile 1: {} nodes, {} total executions",
profile1.node_stats.len(),
profile1.total_executions
);
println!(
"Profile 2: {} nodes, {} total executions",
profile2.node_stats.len(),
profile2.total_executions
);
profile1.merge(&profile2);
println!("\nMerged profile:");
println!(" Unique nodes: {}", profile1.node_stats.len());
println!(" Total executions: {}", profile1.total_executions);
for (node_id, stats) in &profile1.node_stats {
println!(
" Node {}: {} executions, avg {:?}",
node_id,
stats.execution_count,
stats.avg_time()
);
}
println!();
}
fn example_optimization_hints() {
println!("--- Example 6: Optimization Hints Generation ---");
let mut profile = ExecutionProfile::new();
for _ in 0..50 {
profile.record_node(0, Duration::from_millis(10), 1024);
profile.record_node(1, Duration::from_millis(10), 1024);
}
profile.record_node(2, Duration::from_millis(100), 200 * 1024 * 1024);
for _ in 0..100 {
profile.record_tensor_access(0, 4096);
}
let optimizer = ProfileGuidedOptimizer::new(profile)
.with_hot_threshold(10)
.with_memory_threshold(100 * 1024 * 1024);
let graph = EinsumGraph::new(); let hints = optimizer.analyze(&graph);
println!("Generated {} optimization hints:", hints.len());
for (i, hint) in hints.iter().enumerate() {
println!(" {}. {:?}", i + 1, hint);
}
println!();
}
fn example_profile_serialization() {
println!("--- Example 7: Profile Serialization ---");
let mut profile = ExecutionProfile::new();
profile.record_node(0, Duration::from_millis(50), 1024);
profile.record_node(1, Duration::from_millis(75), 2048);
profile.record_tensor_access(0, 4096);
match profile.to_json() {
Ok(json) => {
println!("Exported profile to JSON:");
println!("{}", &json[..json.len().min(300)]); if json.len() > 300 {
println!("... (truncated)");
}
match ExecutionProfile::from_json(&json) {
Ok(restored) => {
println!("\n✓ Successfully restored profile from JSON");
println!(" Node stats: {}", restored.node_stats.len());
println!(" Tensor stats: {}", restored.tensor_stats.len());
}
Err(e) => println!("\n✗ Failed to restore: {}", e),
}
}
Err(e) => println!("✗ Failed to export: {}", e),
}
println!();
}
fn example_complete_workflow() {
println!("--- Example 8: Complete PGO Workflow ---");
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
graph
.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
.expect("unwrap");
graph.add_output(c).expect("unwrap");
println!("Step 1: Created computation graph");
println!(" Tensors: {}", graph.tensor_count());
println!(" Nodes: {}", graph.node_count());
let mut profile = ExecutionProfile::new();
println!("\nStep 2: Simulating execution...");
for run in 0..10 {
let time = Duration::from_millis(50 + (run % 3) * 10);
let memory = 1024 * (1 + run % 2);
profile.record_node(0, time, memory);
profile.record_tensor_access(a, 1024);
profile.record_tensor_access(b, 2048);
profile.record_tensor_access(c, 1024);
profile.total_executions += 1;
}
println!(" Completed {} executions", profile.total_executions);
println!("\nStep 3: Analyzing profile...");
let hot_nodes = profile.get_hot_nodes(5);
for (node_id, score) in &hot_nodes {
let stats = &profile.node_stats[node_id];
println!(
" Node {}: {} execs, avg {:?}, score: {:.2}",
node_id,
stats.execution_count,
stats.avg_time(),
score
);
}
println!("\nStep 4: Generating optimization hints...");
let optimizer = ProfileGuidedOptimizer::new(profile);
let hints = optimizer.analyze(&graph);
println!(" Generated {} hints:", hints.len());
for hint in &hints {
match hint {
OptimizationHint::FuseNodes(nodes) => {
println!(" - Fuse nodes: {:?}", nodes);
}
OptimizationHint::CacheTensor(tid) => {
println!(" - Cache tensor: {}", tid);
}
OptimizationHint::PreAllocate { tensor, size } => {
println!(" - Pre-allocate tensor {} ({} bytes)", tensor, size);
}
_ => println!(" - {:?}", hint),
}
}
println!("\nStep 5: Applying optimizations...");
let mut optimized_graph = graph.clone();
match optimizer.apply_hints(&mut optimized_graph, &hints) {
Ok(applied) => println!(" Applied {} optimization(s)", applied),
Err(e) => println!(" Error applying hints: {}", e),
}
println!("\n✓ Complete PGO workflow finished");
}