use std::collections::{HashMap, HashSet, VecDeque};
use super::{EinsumGraph, OpType};
#[derive(Clone, Debug, PartialEq)]
#[allow(dead_code)] pub struct MemoryFootprint {
pub tensor_count: usize,
pub intermediate_count: usize,
pub peak_live_tensors: usize,
pub op_counts: HashMap<String, usize>,
}
impl MemoryFootprint {
#[allow(dead_code)] pub fn estimate(graph: &EinsumGraph) -> Self {
let tensor_count = graph.tensors.len();
let intermediate_count = tensor_count - graph.inputs.len() - graph.outputs.len();
let liveness = compute_liveness(graph);
let peak_live_tensors = liveness
.values()
.map(|live_set| live_set.len())
.max()
.unwrap_or(0);
let mut op_counts = HashMap::new();
for node in &graph.nodes {
let op_name = match &node.op {
OpType::Einsum { .. } => "einsum",
OpType::ElemUnary { .. } => "elem_unary",
OpType::ElemBinary { .. } => "elem_binary",
OpType::Reduce { .. } => "reduce",
};
*op_counts.entry(op_name.to_string()).or_insert(0) += 1;
}
Self {
tensor_count,
intermediate_count,
peak_live_tensors,
op_counts,
}
}
}
#[derive(Clone, Debug, PartialEq)]
#[allow(dead_code)] pub struct CriticalPath {
pub path: Vec<usize>,
pub cost: usize,
pub max_parallelism: usize,
}
impl CriticalPath {
#[allow(dead_code)] pub fn analyze(graph: &EinsumGraph) -> Self {
let (distances, predecessors) = compute_distances(graph);
let (critical_output, max_distance) = graph
.outputs
.iter()
.filter_map(|&output_idx| distances.get(&output_idx).map(|&dist| (output_idx, dist)))
.max_by_key(|(_, dist)| *dist)
.unwrap_or((0, 0));
let mut path = Vec::new();
let mut current = critical_output;
while let Some(&pred) = predecessors.get(¤t) {
if let Some(node_idx) = find_producing_node(graph, current) {
path.push(node_idx);
}
current = pred;
}
path.reverse();
let max_parallelism = compute_max_parallelism(graph, &distances);
Self {
path,
cost: max_distance,
max_parallelism,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[allow(dead_code)] pub struct ExecutionSchedule {
pub levels: Vec<Vec<usize>>,
pub depth: usize,
}
impl ExecutionSchedule {
#[allow(dead_code)] pub fn compute(graph: &EinsumGraph) -> Self {
let mut levels = Vec::new();
let mut in_degree = vec![0; graph.nodes.len()];
let mut adj_list: HashMap<usize, Vec<usize>> = HashMap::new();
for (node_idx, node) in graph.nodes.iter().enumerate() {
for &input_tensor in &node.inputs {
if let Some(producer_idx) = find_producing_node(graph, input_tensor) {
adj_list.entry(producer_idx).or_default().push(node_idx);
in_degree[node_idx] += 1;
}
}
}
let mut queue: VecDeque<usize> = (0..graph.nodes.len())
.filter(|&i| in_degree[i] == 0)
.collect();
while !queue.is_empty() {
let level_size = queue.len();
let mut current_level = Vec::new();
for _ in 0..level_size {
if let Some(node_idx) = queue.pop_front() {
current_level.push(node_idx);
if let Some(successors) = adj_list.get(&node_idx) {
for &succ in successors {
in_degree[succ] -= 1;
if in_degree[succ] == 0 {
queue.push_back(succ);
}
}
}
}
}
if !current_level.is_empty() {
levels.push(current_level);
}
}
let depth = levels.len();
Self { levels, depth }
}
#[allow(dead_code)] pub fn max_parallelism(&self) -> usize {
self.levels
.iter()
.map(|level| level.len())
.max()
.unwrap_or(0)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[allow(dead_code)] pub struct DataFlowAnalysis {
pub reused_tensors: HashSet<usize>,
pub overwritten_tensors: HashSet<usize>,
pub fan_out: HashMap<usize, usize>,
}
impl DataFlowAnalysis {
#[allow(dead_code)] pub fn analyze(graph: &EinsumGraph) -> Self {
let mut read_count: HashMap<usize, usize> = HashMap::new();
let mut write_count: HashMap<usize, usize> = HashMap::new();
for node in &graph.nodes {
for &input in &node.inputs {
*read_count.entry(input).or_insert(0) += 1;
}
for &output in &node.outputs {
*write_count.entry(output).or_insert(0) += 1;
}
}
let reused_tensors = read_count
.iter()
.filter(|(_, &count)| count > 1)
.map(|(&tensor, _)| tensor)
.collect();
let overwritten_tensors = write_count
.iter()
.filter(|(_, &count)| count > 1)
.map(|(&tensor, _)| tensor)
.collect();
let fan_out = read_count;
Self {
reused_tensors,
overwritten_tensors,
fan_out,
}
}
}
#[allow(dead_code)] fn compute_liveness(graph: &EinsumGraph) -> HashMap<usize, HashSet<usize>> {
let mut liveness: HashMap<usize, HashSet<usize>> = HashMap::new();
for (node_idx, node) in graph.nodes.iter().enumerate() {
let mut live_set = HashSet::new();
for &input in &node.inputs {
live_set.insert(input);
}
for &output in &node.outputs {
live_set.insert(output);
}
liveness.insert(node_idx, live_set);
}
liveness
}
#[allow(dead_code)] fn compute_distances(graph: &EinsumGraph) -> (HashMap<usize, usize>, HashMap<usize, usize>) {
let mut distances: HashMap<usize, usize> = HashMap::new();
let mut predecessors: HashMap<usize, usize> = HashMap::new();
for &input_idx in &graph.inputs {
distances.insert(input_idx, 0);
}
let schedule = ExecutionSchedule::compute(graph);
for level in &schedule.levels {
for &node_idx in level {
let node = &graph.nodes[node_idx];
let max_input_distance = node
.inputs
.iter()
.filter_map(|&input| distances.get(&input))
.max()
.copied()
.unwrap_or(0);
for &output in &node.outputs {
let new_distance = max_input_distance + 1;
if new_distance > *distances.get(&output).unwrap_or(&0) {
distances.insert(output, new_distance);
if let Some(&input) = node.inputs.first() {
predecessors.insert(output, input);
}
}
}
}
}
(distances, predecessors)
}
#[allow(dead_code)] fn find_producing_node(graph: &EinsumGraph, tensor_idx: usize) -> Option<usize> {
graph
.nodes
.iter()
.position(|node| node.outputs.contains(&tensor_idx))
}
#[allow(dead_code)] fn compute_max_parallelism(graph: &EinsumGraph, distances: &HashMap<usize, usize>) -> usize {
let mut level_counts: HashMap<usize, usize> = HashMap::new();
for node in &graph.nodes {
let level = node
.inputs
.iter()
.filter_map(|&input| distances.get(&input))
.max()
.copied()
.unwrap_or(0);
*level_counts.entry(level).or_insert(0) += 1;
}
level_counts.values().copied().max().unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::EinsumNode;
fn create_simple_graph() -> EinsumGraph {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
graph.add_input(a).expect("unwrap");
graph.add_input(b).expect("unwrap");
graph
.add_node(EinsumNode::einsum("i,j->ij", vec![a, b], vec![c]))
.expect("unwrap");
graph.add_output(c).expect("unwrap");
graph
}
fn create_chain_graph() -> EinsumGraph {
let mut graph = EinsumGraph::new();
let t0 = graph.add_tensor("t0");
let t1 = graph.add_tensor("t1");
let t2 = graph.add_tensor("t2");
let t3 = graph.add_tensor("t3");
graph.add_input(t0).expect("unwrap");
graph
.add_node(EinsumNode::einsum("i->i", vec![t0], vec![t1]))
.expect("unwrap");
graph
.add_node(EinsumNode::einsum("i->i", vec![t1], vec![t2]))
.expect("unwrap");
graph
.add_node(EinsumNode::einsum("i->i", vec![t2], vec![t3]))
.expect("unwrap");
graph.add_output(t3).expect("unwrap");
graph
}
#[test]
fn test_memory_footprint_simple() {
let graph = create_simple_graph();
let footprint = MemoryFootprint::estimate(&graph);
assert_eq!(footprint.tensor_count, 3);
assert_eq!(footprint.intermediate_count, 0); assert!(footprint.op_counts.contains_key("einsum"));
}
#[test]
fn test_critical_path_simple() {
let graph = create_simple_graph();
let critical = CriticalPath::analyze(&graph);
assert_eq!(critical.path.len(), 1); assert!(critical.cost > 0);
}
#[test]
fn test_critical_path_chain() {
let graph = create_chain_graph();
let critical = CriticalPath::analyze(&graph);
assert_eq!(critical.path.len(), 3); assert_eq!(critical.cost, 3); }
#[test]
fn test_execution_schedule_simple() {
let graph = create_simple_graph();
let schedule = ExecutionSchedule::compute(&graph);
assert_eq!(schedule.depth, 1); assert_eq!(schedule.max_parallelism(), 1);
}
#[test]
fn test_execution_schedule_chain() {
let graph = create_chain_graph();
let schedule = ExecutionSchedule::compute(&graph);
assert_eq!(schedule.depth, 3); assert_eq!(schedule.max_parallelism(), 1); }
#[test]
fn test_data_flow_analysis() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
graph.add_input(a).expect("unwrap");
graph
.add_node(EinsumNode::einsum("i->i", vec![a], vec![b]))
.expect("unwrap");
graph
.add_node(EinsumNode::einsum("i->i", vec![a], vec![c]))
.expect("unwrap");
graph.add_output(b).expect("unwrap");
graph.add_output(c).expect("unwrap");
let analysis = DataFlowAnalysis::analyze(&graph);
assert!(analysis.reused_tensors.contains(&a)); assert_eq!(*analysis.fan_out.get(&a).expect("unwrap"), 2); }
#[test]
fn test_data_flow_no_reuse() {
let graph = create_simple_graph();
let analysis = DataFlowAnalysis::analyze(&graph);
assert!(analysis.reused_tensors.is_empty());
}
}