use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::{EinsumGraph, EinsumNode, OpType};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct OperationCost {
pub compute_flops: f64,
pub memory_bytes: f64,
pub communication_bytes: f64,
pub io_bytes: f64,
pub latency_ms: f64,
#[serde(default)]
pub custom: HashMap<String, f64>,
}
impl Default for OperationCost {
fn default() -> Self {
Self {
compute_flops: 0.0,
memory_bytes: 0.0,
communication_bytes: 0.0,
io_bytes: 0.0,
latency_ms: 0.0,
custom: HashMap::new(),
}
}
}
impl OperationCost {
pub fn new() -> Self {
Self::default()
}
pub fn compute_only(flops: f64) -> Self {
Self {
compute_flops: flops,
..Default::default()
}
}
pub fn compute_and_memory(flops: f64, memory_bytes: f64) -> Self {
Self {
compute_flops: flops,
memory_bytes,
..Default::default()
}
}
pub fn with_custom(mut self, key: impl Into<String>, value: f64) -> Self {
self.custom.insert(key.into(), value);
self
}
pub fn add(&self, other: &OperationCost) -> OperationCost {
OperationCost {
compute_flops: self.compute_flops + other.compute_flops,
memory_bytes: self.memory_bytes.max(other.memory_bytes), communication_bytes: self.communication_bytes + other.communication_bytes,
io_bytes: self.io_bytes + other.io_bytes,
latency_ms: self.latency_ms + other.latency_ms,
custom: {
let mut merged = self.custom.clone();
for (k, v) in &other.custom {
*merged.entry(k.clone()).or_insert(0.0) += v;
}
merged
},
}
}
pub fn max(&self, other: &OperationCost) -> OperationCost {
OperationCost {
compute_flops: self.compute_flops.max(other.compute_flops),
memory_bytes: self.memory_bytes + other.memory_bytes, communication_bytes: self.communication_bytes.max(other.communication_bytes),
io_bytes: self.io_bytes.max(other.io_bytes),
latency_ms: self.latency_ms.max(other.latency_ms),
custom: {
let mut merged = self.custom.clone();
for (k, v) in &other.custom {
let entry = merged.entry(k.clone()).or_insert(0.0);
*entry = entry.max(*v);
}
merged
},
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct GraphCostModel {
pub node_costs: HashMap<usize, OperationCost>,
pub total_cost: OperationCost,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
impl GraphCostModel {
pub fn new() -> Self {
Self {
node_costs: HashMap::new(),
total_cost: OperationCost::default(),
metadata: HashMap::new(),
}
}
pub fn set_node_cost(&mut self, node_idx: usize, cost: OperationCost) {
self.node_costs.insert(node_idx, cost);
}
pub fn get_node_cost(&self, node_idx: usize) -> Option<&OperationCost> {
self.node_costs.get(&node_idx)
}
pub fn compute_total_cost(&mut self, graph: &EinsumGraph) {
self.total_cost = estimate_graph_cost(graph, self);
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn summary(&self) -> CostSummary {
CostSummary {
total_flops: self.total_cost.compute_flops,
total_memory_bytes: self.total_cost.memory_bytes,
total_communication_bytes: self.total_cost.communication_bytes,
total_io_bytes: self.total_cost.io_bytes,
total_latency_ms: self.total_cost.latency_ms,
node_count: self.node_costs.len(),
}
}
}
impl Default for GraphCostModel {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct CostSummary {
pub total_flops: f64,
pub total_memory_bytes: f64,
pub total_communication_bytes: f64,
pub total_io_bytes: f64,
pub total_latency_ms: f64,
pub node_count: usize,
}
pub fn estimate_operation_cost(
node: &EinsumNode,
_tensor_sizes: &HashMap<usize, Vec<usize>>,
) -> OperationCost {
match &node.op {
OpType::Einsum { spec } => {
let inputs_len = node.inputs.len() as f64;
let outputs_len = node.outputs.len() as f64;
let estimated_flops = 1000.0 * inputs_len * outputs_len;
let estimated_memory = 100.0 * (inputs_len + outputs_len);
OperationCost::compute_and_memory(estimated_flops, estimated_memory)
.with_custom("spec_complexity", spec.len() as f64)
}
OpType::ElemUnary { .. } => {
OperationCost::compute_and_memory(100.0, 50.0)
}
OpType::ElemBinary { .. } => {
OperationCost::compute_and_memory(200.0, 100.0)
}
OpType::Reduce { .. } => {
OperationCost::compute_and_memory(500.0, 75.0)
}
}
}
pub fn estimate_graph_cost(graph: &EinsumGraph, cost_model: &GraphCostModel) -> OperationCost {
let mut total = OperationCost::default();
for (idx, _node) in graph.nodes.iter().enumerate() {
if let Some(node_cost) = cost_model.get_node_cost(idx) {
total = total.add(node_cost);
}
}
total
}
pub fn auto_annotate_costs(graph: &EinsumGraph) -> GraphCostModel {
let mut cost_model = GraphCostModel::new();
let tensor_sizes = HashMap::new();
for (idx, node) in graph.nodes.iter().enumerate() {
let cost = estimate_operation_cost(node, &tensor_sizes);
cost_model.set_node_cost(idx, cost);
}
cost_model.compute_total_cost(graph);
cost_model
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::EinsumNode;
#[test]
fn test_operation_cost_creation() {
let cost = OperationCost::compute_only(1000.0);
assert_eq!(cost.compute_flops, 1000.0);
assert_eq!(cost.memory_bytes, 0.0);
}
#[test]
fn test_operation_cost_add() {
let cost1 = OperationCost::compute_and_memory(1000.0, 500.0);
let cost2 = OperationCost::compute_and_memory(2000.0, 300.0);
let total = cost1.add(&cost2);
assert_eq!(total.compute_flops, 3000.0);
assert_eq!(total.memory_bytes, 500.0); }
#[test]
fn test_operation_cost_max() {
let cost1 = OperationCost::compute_and_memory(1000.0, 500.0);
let cost2 = OperationCost::compute_and_memory(2000.0, 300.0);
let max_cost = cost1.max(&cost2);
assert_eq!(max_cost.compute_flops, 2000.0);
assert_eq!(max_cost.memory_bytes, 800.0); }
#[test]
fn test_cost_model_creation() {
let mut model = GraphCostModel::new();
let cost = OperationCost::compute_only(1000.0);
model.set_node_cost(0, cost.clone());
assert_eq!(model.get_node_cost(0), Some(&cost));
}
#[test]
fn test_estimate_einsum_cost() {
let node = EinsumNode::einsum("ik,kj->ij", vec![0, 1], vec![2]);
let tensor_sizes = HashMap::new();
let cost = estimate_operation_cost(&node, &tensor_sizes);
assert!(cost.compute_flops > 0.0);
assert!(cost.memory_bytes > 0.0);
}
#[test]
fn test_auto_annotate_costs() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
graph.add_input(a).expect("unwrap");
graph.add_input(b).expect("unwrap");
graph
.add_node(EinsumNode::einsum("i,j->ij", vec![a, b], vec![c]))
.expect("unwrap");
graph.add_output(c).expect("unwrap");
let cost_model = auto_annotate_costs(&graph);
assert_eq!(cost_model.node_costs.len(), 1);
assert!(cost_model.total_cost.compute_flops > 0.0);
}
#[test]
fn test_cost_summary() {
let mut model = GraphCostModel::new();
model.set_node_cost(0, OperationCost::compute_and_memory(1000.0, 500.0));
model.set_node_cost(1, OperationCost::compute_and_memory(2000.0, 300.0));
let summary = model.summary();
assert_eq!(summary.node_count, 2);
}
#[test]
fn test_custom_cost_metrics() {
let cost = OperationCost::new()
.with_custom("custom_metric", 42.0)
.with_custom("another_metric", 100.0);
assert_eq!(cost.custom.get("custom_metric"), Some(&42.0));
assert_eq!(cost.custom.get("another_metric"), Some(&100.0));
}
#[test]
fn test_cost_model_metadata() {
let model = GraphCostModel::new()
.with_metadata("device", "GPU")
.with_metadata("precision", "fp32");
assert_eq!(model.metadata.get("device"), Some(&"GPU".to_string()));
assert_eq!(model.metadata.get("precision"), Some(&"fp32".to_string()));
}
}