use tensorlogic_ir::{EinsumGraph, OpType};
use crate::capabilities::DType;
#[derive(Debug, Clone)]
pub struct TensorMemory {
pub tensor_idx: usize,
pub shape: Vec<usize>,
pub element_count: usize,
pub bytes: usize,
}
impl TensorMemory {
pub fn new(tensor_idx: usize, shape: Vec<usize>, dtype: DType) -> Self {
let element_count: usize = shape.iter().product();
let bytes = element_count * dtype.byte_size();
TensorMemory {
tensor_idx,
shape,
element_count,
bytes,
}
}
pub fn megabytes(&self) -> f64 {
self.bytes as f64 / (1024.0 * 1024.0)
}
}
#[derive(Debug, Clone)]
pub struct MemoryEstimate {
pub input_memory: Vec<TensorMemory>,
pub intermediate_memory: Vec<TensorMemory>,
pub output_memory: Vec<TensorMemory>,
pub total_bytes: usize,
pub peak_bytes: usize,
}
impl MemoryEstimate {
pub fn new() -> Self {
MemoryEstimate {
input_memory: Vec::new(),
intermediate_memory: Vec::new(),
output_memory: Vec::new(),
total_bytes: 0,
peak_bytes: 0,
}
}
pub fn total_megabytes(&self) -> f64 {
self.total_bytes as f64 / (1024.0 * 1024.0)
}
pub fn peak_megabytes(&self) -> f64 {
self.peak_bytes as f64 / (1024.0 * 1024.0)
}
pub fn summary(&self) -> String {
format!(
"Memory Estimate:\n\
- Inputs: {} tensors, {:.2} MB\n\
- Intermediates: {} tensors, {:.2} MB\n\
- Outputs: {} tensors, {:.2} MB\n\
- Total: {:.2} MB\n\
- Peak: {:.2} MB",
self.input_memory.len(),
self.input_memory.iter().map(|t| t.megabytes()).sum::<f64>(),
self.intermediate_memory.len(),
self.intermediate_memory
.iter()
.map(|t| t.megabytes())
.sum::<f64>(),
self.output_memory.len(),
self.output_memory
.iter()
.map(|t| t.megabytes())
.sum::<f64>(),
self.total_megabytes(),
self.peak_megabytes()
)
}
}
impl Default for MemoryEstimate {
fn default() -> Self {
Self::new()
}
}
pub struct MemoryEstimator {
dtype: DType,
}
impl MemoryEstimator {
pub fn new(dtype: DType) -> Self {
MemoryEstimator { dtype }
}
pub fn estimate(&self, graph: &EinsumGraph) -> MemoryEstimate {
let mut estimate = MemoryEstimate::new();
let default_shape = vec![10];
for idx in 0..graph.tensors.len() {
let mem = TensorMemory::new(idx, default_shape.clone(), self.dtype);
estimate.total_bytes += mem.bytes;
estimate.input_memory.push(mem);
}
let num_inputs = graph.tensors.len();
for (node_idx, node) in graph.nodes.iter().enumerate() {
let tensor_idx = num_inputs + node_idx;
let shape = self.estimate_output_shape(node, graph);
let mem = TensorMemory::new(tensor_idx, shape, self.dtype);
estimate.total_bytes += mem.bytes;
if node_idx == graph.nodes.len() - 1 {
estimate.output_memory.push(mem);
} else {
estimate.intermediate_memory.push(mem);
}
}
estimate.peak_bytes = estimate.total_bytes;
estimate
}
pub fn estimate_with_lifetime(&self, graph: &EinsumGraph) -> MemoryEstimate {
let mut estimate = self.estimate(graph);
let num_tensors = graph.tensors.len() + graph.nodes.len();
let mut alive = vec![false; num_tensors];
for item in alive.iter_mut().take(graph.tensors.len()) {
*item = true;
}
let mut peak_bytes = 0;
for (node_idx, node) in graph.nodes.iter().enumerate() {
let output_idx = graph.tensors.len() + node_idx;
alive[output_idx] = true;
let current_bytes: usize = alive
.iter()
.enumerate()
.filter(|(_, &is_alive)| is_alive)
.map(|(idx, _)| {
if idx < graph.tensors.len() {
&estimate.input_memory[idx]
} else {
let node_offset = idx - graph.tensors.len();
if node_offset < estimate.intermediate_memory.len() {
&estimate.intermediate_memory[node_offset]
} else {
&estimate.output_memory[0]
}
}
})
.map(|mem| mem.bytes)
.sum();
peak_bytes = peak_bytes.max(current_bytes);
for &input_idx in &node.inputs {
if self.is_last_use(input_idx, node_idx, graph) {
alive[input_idx] = false;
}
}
}
estimate.peak_bytes = peak_bytes;
estimate
}
fn estimate_output_shape(
&self,
node: &tensorlogic_ir::EinsumNode,
_graph: &EinsumGraph,
) -> Vec<usize> {
match &node.op {
OpType::Einsum { spec } => {
if let Some(arrow_pos) = spec.find("->") {
let output_axes = &spec[arrow_pos + 2..];
vec![10; output_axes.len()]
} else {
vec![10]
}
}
OpType::ElemUnary { op: _ } | OpType::ElemBinary { op: _ } => {
vec![10]
}
OpType::Reduce { op: _, axes } => {
let default_shape = vec![10, 10]; let mut shape = default_shape.clone();
for &axis in axes.iter().rev() {
if axis < shape.len() {
shape.remove(axis);
}
}
if shape.is_empty() {
vec![1]
} else {
shape
}
}
}
}
fn is_last_use(&self, tensor_idx: usize, current_node: usize, graph: &EinsumGraph) -> bool {
for (node_idx, node) in graph.nodes.iter().enumerate() {
if node_idx > current_node && node.inputs.contains(&tensor_idx) {
return false;
}
}
true
}
pub fn estimate_batch(&self, graph: &EinsumGraph, batch_size: usize) -> MemoryEstimate {
let single_estimate = self.estimate(graph);
let mut batch_estimate = MemoryEstimate::new();
batch_estimate.total_bytes = single_estimate.total_bytes * batch_size;
batch_estimate.peak_bytes = single_estimate.peak_bytes * batch_size;
for input in &single_estimate.input_memory {
let mut batched = input.clone();
batched.bytes *= batch_size;
batch_estimate.input_memory.push(batched);
}
for intermediate in &single_estimate.intermediate_memory {
let mut batched = intermediate.clone();
batched.bytes *= batch_size;
batch_estimate.intermediate_memory.push(batched);
}
for output in &single_estimate.output_memory {
let mut batched = output.clone();
batched.bytes *= batch_size;
batch_estimate.output_memory.push(batched);
}
batch_estimate
}
}
impl Default for MemoryEstimator {
fn default() -> Self {
Self::new(DType::F64)
}
}