use std::collections::HashMap;
use tensorlogic_ir::EinsumGraph;
use crate::dependency_analyzer::DependencyAnalysis;
#[derive(Debug, Clone)]
pub struct NodeMemoryEstimate {
pub node_index: usize,
pub estimated_bytes: usize,
pub output_shape: Option<Vec<usize>>,
}
#[derive(Debug, Clone)]
pub struct EvaluationPlan {
pub levels: Vec<Vec<usize>>,
pub peak_memory_bytes: usize,
pub freeable_after: Vec<Vec<usize>>,
pub node_estimates: Vec<NodeMemoryEstimate>,
pub estimated_total_flops: u64,
}
impl EvaluationPlan {
pub fn build(graph: &EinsumGraph, shape_hints: Option<&HashMap<usize, Vec<usize>>>) -> Self {
if graph.nodes.is_empty() {
return Self {
levels: Vec::new(),
peak_memory_bytes: 0,
freeable_after: Vec::new(),
node_estimates: Vec::new(),
estimated_total_flops: 0,
};
}
let analysis = DependencyAnalysis::analyze(graph);
let node_estimates: Vec<NodeMemoryEstimate> = (0..graph.nodes.len())
.map(|node_idx| {
let output_shape = shape_hints.and_then(|h| h.get(&node_idx)).cloned();
let estimated_bytes = output_shape
.as_ref()
.map(|s| s.iter().product::<usize>() * 8)
.unwrap_or(0);
NodeMemoryEstimate {
node_index: node_idx,
estimated_bytes,
output_shape,
}
})
.collect();
let num_levels = analysis.num_levels;
let mut freeable_after: Vec<Vec<usize>> = vec![Vec::new(); num_levels.max(1)];
for op in &analysis.operations {
if op.dependents.is_empty() {
let level = op.execution_level.min(num_levels.saturating_sub(1));
freeable_after[level].push(op.node_index);
} else {
let last_level = op
.dependents
.iter()
.filter_map(|&dep_idx| {
analysis.operations.get(dep_idx).map(|d| d.execution_level)
})
.max()
.unwrap_or(op.execution_level);
let level = last_level.min(num_levels.saturating_sub(1));
freeable_after[level].push(op.node_index);
}
}
let mut live_bytes: usize = 0;
let mut peak_memory_bytes: usize = 0;
let mut live_set: HashMap<usize, usize> = HashMap::new();
for (level_idx, level_nodes) in analysis.execution_levels.iter().enumerate() {
for &node_idx in level_nodes {
let bytes = node_estimates
.get(node_idx)
.map(|e| e.estimated_bytes)
.unwrap_or(0);
live_set.insert(node_idx, bytes);
live_bytes = live_bytes.saturating_add(bytes);
}
if live_bytes > peak_memory_bytes {
peak_memory_bytes = live_bytes;
}
for &freed_node in &freeable_after[level_idx] {
if let Some(bytes) = live_set.remove(&freed_node) {
live_bytes = live_bytes.saturating_sub(bytes);
}
}
}
let total_elements: usize = node_estimates
.iter()
.map(|e| {
e.output_shape
.as_ref()
.map(|s| s.iter().product::<usize>())
.unwrap_or(1)
})
.sum();
let estimated_total_flops = (total_elements as u64).saturating_mul(num_levels as u64 + 1);
Self {
levels: analysis.execution_levels,
peak_memory_bytes,
freeable_after,
node_estimates,
estimated_total_flops,
}
}
pub fn can_fit_in_memory(&self, available_bytes: usize) -> bool {
self.peak_memory_bytes <= available_bytes
}
pub fn summary(&self) -> String {
format!(
"EvaluationPlan {{ levels: {}, nodes: {}, peak_memory: {} bytes, flops: {} }}",
self.levels_count(),
self.total_nodes(),
self.peak_memory_bytes,
self.estimated_total_flops,
)
}
pub fn levels_count(&self) -> usize {
self.levels.len()
}
pub fn total_nodes(&self) -> usize {
self.levels.iter().map(|l| l.len()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
fn single_node_graph() -> EinsumGraph {
let mut g = EinsumGraph::new();
let a = g.add_tensor("a");
let b = g.add_tensor("b");
g.add_input(a).unwrap();
let node = EinsumNode {
inputs: vec![a],
outputs: vec![b],
op: OpType::ElemUnary {
op: "relu".to_string(),
},
metadata: None,
};
g.add_node(node).unwrap();
g.add_output(b).unwrap();
g
}
fn three_node_chain_graph() -> EinsumGraph {
let mut g = EinsumGraph::new();
let t0 = g.add_tensor("t0");
let t1 = g.add_tensor("t1");
let t2 = g.add_tensor("t2");
let t3 = g.add_tensor("t3");
g.add_input(t0).unwrap();
g.add_node(EinsumNode {
inputs: vec![t0],
outputs: vec![t1],
op: OpType::ElemUnary {
op: "relu".to_string(),
},
metadata: None,
})
.unwrap();
g.add_node(EinsumNode {
inputs: vec![t1],
outputs: vec![t2],
op: OpType::ElemUnary {
op: "relu".to_string(),
},
metadata: None,
})
.unwrap();
g.add_node(EinsumNode {
inputs: vec![t2],
outputs: vec![t3],
op: OpType::ElemUnary {
op: "relu".to_string(),
},
metadata: None,
})
.unwrap();
g.add_output(t3).unwrap();
g
}
#[test]
fn test_evaluation_plan_empty_graph() {
let g = EinsumGraph::new();
let plan = EvaluationPlan::build(&g, None);
assert_eq!(plan.levels_count(), 0);
assert_eq!(plan.total_nodes(), 0);
assert_eq!(plan.peak_memory_bytes, 0);
}
#[test]
fn test_evaluation_plan_single_node() {
let g = single_node_graph();
let plan = EvaluationPlan::build(&g, None);
assert_eq!(plan.total_nodes(), 1);
assert!(plan.levels_count() >= 1);
}
#[test]
fn test_evaluation_plan_linear_chain() {
let g = three_node_chain_graph();
let plan = EvaluationPlan::build(&g, None);
assert_eq!(plan.levels_count(), 3);
assert_eq!(plan.total_nodes(), 3);
}
#[test]
fn test_evaluation_plan_can_fit_large_memory() {
let g = single_node_graph();
let plan = EvaluationPlan::build(&g, None);
assert!(plan.can_fit_in_memory(usize::MAX));
}
#[test]
fn test_evaluation_plan_cannot_fit_zero() {
let g = three_node_chain_graph();
let mut hints = HashMap::new();
hints.insert(0_usize, vec![1024_usize]);
let plan = EvaluationPlan::build(&g, Some(&hints));
assert!(!plan.can_fit_in_memory(0));
}
#[test]
fn test_evaluation_plan_summary_nonempty() {
let g = single_node_graph();
let plan = EvaluationPlan::build(&g, None);
let s = plan.summary();
assert!(!s.is_empty());
assert!(s.contains("EvaluationPlan"));
}
}