use std::collections::{HashMap, HashSet};
use super::{EinsumGraph, OpType};
use crate::error::IrError;
#[derive(Debug, Clone, PartialEq)]
pub struct TensorMemory {
pub tensor_idx: usize,
pub size_bytes: usize,
pub first_use: Option<usize>,
pub last_use: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct MemoryAnalysis {
pub tensors: Vec<TensorMemory>,
pub peak_memory_bytes: usize,
pub total_memory_bytes: usize,
pub avg_utilization: f64,
pub optimal_schedule: Vec<usize>,
}
impl MemoryAnalysis {
pub fn new() -> Self {
Self {
tensors: Vec::new(),
peak_memory_bytes: 0,
total_memory_bytes: 0,
avg_utilization: 0.0,
optimal_schedule: Vec::new(),
}
}
pub fn memory_waste_ratio(&self) -> f64 {
if self.peak_memory_bytes == 0 {
return 0.0;
}
let avg_memory = self.total_memory_bytes as f64 * self.avg_utilization;
(self.peak_memory_bytes as f64 - avg_memory) / self.peak_memory_bytes as f64
}
}
impl Default for MemoryAnalysis {
fn default() -> Self {
Self::new()
}
}
pub fn analyze_memory(
graph: &EinsumGraph,
element_size_bytes: usize,
) -> Result<MemoryAnalysis, IrError> {
if graph.nodes.is_empty() {
return Ok(MemoryAnalysis::new());
}
let tensor_lifetimes = analyze_tensor_lifetimes(graph);
let mut tensor_memories = Vec::new();
for (tensor_idx, (first_use, last_use)) in tensor_lifetimes.iter().enumerate() {
let size_bytes = estimate_tensor_size(graph, tensor_idx, element_size_bytes);
tensor_memories.push(TensorMemory {
tensor_idx,
size_bytes,
first_use: *first_use,
last_use: *last_use,
});
}
let peak_memory_bytes = compute_peak_memory(graph, &tensor_memories);
let total_memory_bytes = tensor_memories.iter().map(|t| t.size_bytes).sum();
let avg_utilization = if graph.nodes.is_empty() {
0.0
} else {
let total_live: usize = (0..graph.nodes.len())
.map(|step| count_live_tensors_at_step(step, &tensor_memories))
.sum();
let avg_live = total_live as f64 / graph.nodes.len() as f64;
let avg_memory = avg_live * (total_memory_bytes as f64 / tensor_memories.len() as f64);
if peak_memory_bytes > 0 {
avg_memory / peak_memory_bytes as f64
} else {
0.0
}
};
let optimal_schedule = generate_memory_optimal_schedule(graph, &tensor_memories)?;
Ok(MemoryAnalysis {
tensors: tensor_memories,
peak_memory_bytes,
total_memory_bytes,
avg_utilization,
optimal_schedule,
})
}
fn analyze_tensor_lifetimes(graph: &EinsumGraph) -> Vec<(Option<usize>, Option<usize>)> {
let mut lifetimes = vec![(None, None); graph.tensors.len()];
for (node_idx, node) in graph.nodes.iter().enumerate() {
for &input_idx in &node.inputs {
if input_idx < lifetimes.len() {
let (ref mut first, ref mut last) = lifetimes[input_idx];
*first = Some(first.map_or(node_idx, |f: usize| f.min(node_idx)));
*last = Some(last.map_or(node_idx, |l: usize| l.max(node_idx)));
}
}
for &output_idx in &node.outputs {
if output_idx < lifetimes.len() {
let (ref mut first, ref mut last) = lifetimes[output_idx];
*first = Some(first.map_or(node_idx, |f: usize| f.min(node_idx)));
*last = Some(last.map_or(node_idx, |l: usize| l.max(node_idx)));
}
}
}
lifetimes
}
fn estimate_tensor_size(
_graph: &EinsumGraph,
_tensor_idx: usize,
element_size_bytes: usize,
) -> usize {
1000 * element_size_bytes
}
fn compute_peak_memory(graph: &EinsumGraph, tensors: &[TensorMemory]) -> usize {
let mut peak = 0;
for step in 0..graph.nodes.len() {
let live_memory: usize = tensors
.iter()
.filter(|t| is_tensor_live_at_step(t, step))
.map(|t| t.size_bytes)
.sum();
peak = peak.max(live_memory);
}
peak
}
fn is_tensor_live_at_step(tensor: &TensorMemory, step: usize) -> bool {
match (tensor.first_use, tensor.last_use) {
(Some(first), Some(last)) => step >= first && step <= last,
_ => false,
}
}
fn count_live_tensors_at_step(step: usize, tensors: &[TensorMemory]) -> usize {
tensors
.iter()
.filter(|t| is_tensor_live_at_step(t, step))
.count()
}
fn generate_memory_optimal_schedule(
graph: &EinsumGraph,
_tensors: &[TensorMemory],
) -> Result<Vec<usize>, IrError> {
let dependencies = build_dependencies(graph);
let schedule = topological_sort_memory_aware(graph, &dependencies);
Ok(schedule)
}
fn build_dependencies(graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
for (node_idx, node) in graph.nodes.iter().enumerate() {
for &output_idx in &node.outputs {
tensor_producer.insert(output_idx, node_idx);
}
}
for (node_idx, node) in graph.nodes.iter().enumerate() {
let mut deps = Vec::new();
for &input_idx in &node.inputs {
if let Some(&producer) = tensor_producer.get(&input_idx) {
if producer != node_idx {
deps.push(producer);
}
}
}
dependencies.insert(node_idx, deps);
}
dependencies
}
fn topological_sort_memory_aware(
graph: &EinsumGraph,
dependencies: &HashMap<usize, Vec<usize>>,
) -> Vec<usize> {
let mut schedule = Vec::new();
let mut scheduled = HashSet::new();
let mut in_degree = vec![0; graph.nodes.len()];
for deps in dependencies.values() {
for &dep in deps {
if dep < in_degree.len() {
in_degree[dep] += 1;
}
}
}
while schedule.len() < graph.nodes.len() {
let ready: Vec<usize> = (0..graph.nodes.len())
.filter(|&i| !scheduled.contains(&i) && in_degree[i] == 0)
.collect();
if ready.is_empty() {
break; }
let next = select_next_node_memory_aware(graph, &ready);
schedule.push(next);
scheduled.insert(next);
if let Some(deps) = dependencies.get(&next) {
for &dep in deps {
if dep < in_degree.len() {
let current_degree: usize = in_degree[dep];
in_degree[dep] = current_degree.saturating_sub(1);
}
}
}
}
schedule
}
fn select_next_node_memory_aware(graph: &EinsumGraph, candidates: &[usize]) -> usize {
candidates
.iter()
.min_by_key(|&&idx| {
graph
.nodes
.get(idx)
.map(|n| n.outputs.len())
.unwrap_or(usize::MAX)
})
.copied()
.unwrap_or(0)
}
pub fn analyze_inplace_opportunities(graph: &EinsumGraph) -> Result<Vec<usize>, IrError> {
let mut inplace_candidates = Vec::new();
for (node_idx, node) in graph.nodes.iter().enumerate() {
if can_be_inplace(&node.op) && has_single_input_use(graph, node_idx) {
inplace_candidates.push(node_idx);
}
}
Ok(inplace_candidates)
}
fn can_be_inplace(op_type: &OpType) -> bool {
matches!(op_type, OpType::ElemUnary { .. })
}
fn has_single_input_use(graph: &EinsumGraph, node_idx: usize) -> bool {
let node = &graph.nodes[node_idx];
if node.inputs.is_empty() {
return false;
}
let input_tensor = node.inputs[0];
let use_count = graph
.nodes
.iter()
.filter(|n| n.inputs.contains(&input_tensor))
.count();
use_count == 1
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::EinsumNode;
#[test]
fn test_memory_analysis_default() {
let analysis = MemoryAnalysis::default();
assert_eq!(analysis.peak_memory_bytes, 0);
assert_eq!(analysis.total_memory_bytes, 0);
}
#[test]
fn test_analyze_empty_graph() {
let graph = EinsumGraph::new();
let analysis = analyze_memory(&graph, 8).expect("unwrap");
assert_eq!(analysis.peak_memory_bytes, 0);
assert_eq!(analysis.tensors.len(), 0);
}
#[test]
fn test_analyze_single_node() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
let analysis = analyze_memory(&graph, 8).expect("unwrap");
assert!(analysis.peak_memory_bytes > 0);
assert_eq!(analysis.tensors.len(), 2);
}
#[test]
fn test_tensor_lifetime_single_use() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
let lifetimes = analyze_tensor_lifetimes(&graph);
assert_eq!(lifetimes[a], (Some(0), Some(0)));
assert_eq!(lifetimes[b], (Some(0), Some(0)));
}
#[test]
fn test_tensor_lifetime_multiple_uses() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
graph
.add_node(EinsumNode::elem_unary("tanh", b, c))
.expect("unwrap");
let lifetimes = analyze_tensor_lifetimes(&graph);
assert_eq!(lifetimes[b], (Some(0), Some(1)));
}
#[test]
fn test_estimate_tensor_size() {
let graph = EinsumGraph::new();
let size = estimate_tensor_size(&graph, 0, 8);
assert_eq!(size, 8000); }
#[test]
fn test_is_tensor_live_at_step() {
let tensor = TensorMemory {
tensor_idx: 0,
size_bytes: 1000,
first_use: Some(2),
last_use: Some(5),
};
assert!(!is_tensor_live_at_step(&tensor, 0));
assert!(!is_tensor_live_at_step(&tensor, 1));
assert!(is_tensor_live_at_step(&tensor, 2));
assert!(is_tensor_live_at_step(&tensor, 3));
assert!(is_tensor_live_at_step(&tensor, 5));
assert!(!is_tensor_live_at_step(&tensor, 6));
}
#[test]
fn test_memory_waste_ratio_zero_peak() {
let analysis = MemoryAnalysis {
peak_memory_bytes: 0,
total_memory_bytes: 1000,
avg_utilization: 0.5,
..Default::default()
};
assert_eq!(analysis.memory_waste_ratio(), 0.0);
}
#[test]
fn test_can_be_inplace() {
assert!(can_be_inplace(&OpType::ElemUnary {
op: "relu".to_string()
}));
assert!(!can_be_inplace(&OpType::Einsum {
spec: "ij,jk->ik".to_string()
}));
}
#[test]
fn test_analyze_inplace_opportunities_empty() {
let graph = EinsumGraph::new();
let candidates = analyze_inplace_opportunities(&graph).expect("unwrap");
assert!(candidates.is_empty());
}
#[test]
fn test_analyze_inplace_single_use() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
let candidates = analyze_inplace_opportunities(&graph).expect("unwrap");
assert_eq!(candidates.len(), 1);
}
#[test]
fn test_build_dependencies() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
graph
.add_node(EinsumNode::elem_unary("tanh", b, c))
.expect("unwrap");
let deps = build_dependencies(&graph);
assert_eq!(deps.get(&0).expect("unwrap").len(), 0); assert_eq!(deps.get(&1).expect("unwrap"), &vec![0]); }
#[test]
fn test_topological_sort_simple() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
let deps = build_dependencies(&graph);
let schedule = topological_sort_memory_aware(&graph, &deps);
assert_eq!(schedule, vec![0]);
}
}