use std::collections::HashMap;
use anyhow::Result;
use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
#[derive(Debug, Clone)]
pub struct EinsumOptResult {
pub merged_count: usize,
pub identity_eliminated: usize,
pub reordered_count: usize,
pub estimated_speedup: f64,
}
impl EinsumOptResult {
pub fn new() -> Self {
Self {
merged_count: 0,
identity_eliminated: 0,
reordered_count: 0,
estimated_speedup: 1.0,
}
}
pub fn has_changes(&self) -> bool {
self.merged_count > 0 || self.identity_eliminated > 0 || self.reordered_count > 0
}
}
impl Default for EinsumOptResult {
fn default() -> Self {
Self::new()
}
}
pub fn optimize_einsum_graph(graph: &mut EinsumGraph) -> Result<EinsumOptResult> {
let mut result = EinsumOptResult::new();
result.identity_eliminated = eliminate_identity_ops(graph)?;
result.merged_count = merge_consecutive_einsums(graph)?;
result.reordered_count = optimize_contraction_order(graph)?;
let total_eliminated = result.merged_count + result.identity_eliminated;
if total_eliminated > 0 {
let total_ops = graph.nodes.len() + total_eliminated;
result.estimated_speedup = total_ops as f64 / graph.nodes.len().max(1) as f64;
}
Ok(result)
}
fn eliminate_identity_ops(graph: &mut EinsumGraph) -> Result<usize> {
let mut eliminated = 0;
let mut tensor_map: HashMap<usize, usize> = HashMap::new();
let mut nodes_to_remove = Vec::new();
for (idx, node) in graph.nodes.iter().enumerate() {
if is_identity_op(node) {
nodes_to_remove.push(idx);
eliminated += 1;
if let Some(input_tensor) = get_first_input(node) {
let output_tensor = idx + 1; tensor_map.insert(output_tensor, input_tensor);
}
}
}
for &idx in nodes_to_remove.iter().rev() {
graph.nodes.remove(idx);
}
for node in graph.nodes.iter_mut() {
remap_node_inputs(node, &tensor_map);
}
for output in graph.outputs.iter_mut() {
if let Some(&new_idx) = tensor_map.get(output) {
*output = new_idx;
}
}
Ok(eliminated)
}
fn merge_consecutive_einsums(graph: &mut EinsumGraph) -> Result<usize> {
let mut merged = 0;
let mut changed = true;
use std::collections::HashSet;
let mut processed_nodes: HashSet<usize> = HashSet::new();
let max_iterations = graph.nodes.len() * 2;
let mut iteration = 0;
while changed && iteration < max_iterations {
changed = false;
iteration += 1;
let dependencies = build_dependency_graph(graph);
let mut merge_candidate: Option<(usize, usize, String, Vec<usize>)> = None;
for (idx, node) in graph.nodes.iter().enumerate() {
if processed_nodes.contains(&idx) {
continue;
}
if let OpType::Einsum { spec } = &node.op {
if is_identity_op(node) {
continue;
}
for &input_tensor in &node.inputs {
if let Some(&producer_idx) = dependencies.get(&input_tensor) {
if processed_nodes.contains(&producer_idx) {
continue;
}
if let OpType::Einsum { spec: prev_spec } = &graph.nodes[producer_idx].op {
if is_identity_op(&graph.nodes[producer_idx]) {
continue;
}
let prev_inputs = &graph.nodes[producer_idx].inputs;
if let Some(merged_spec) =
try_merge_einsum_specs(prev_spec, spec, input_tensor)
{
let mut merged_inputs = prev_inputs.clone();
for &inp in &node.inputs {
if inp != input_tensor {
merged_inputs.push(inp);
}
}
merge_candidate =
Some((idx, producer_idx, merged_spec, merged_inputs));
break;
}
}
}
}
if merge_candidate.is_some() {
break;
}
}
}
if let Some((consumer_idx, producer_idx, merged_spec, merged_inputs)) = merge_candidate {
graph.nodes[consumer_idx].op = OpType::Einsum { spec: merged_spec };
graph.nodes[consumer_idx].inputs = merged_inputs;
processed_nodes.insert(producer_idx);
merged += 1;
changed = true;
}
}
Ok(merged)
}
fn optimize_contraction_order(graph: &mut EinsumGraph) -> Result<usize> {
let mut reordered = 0;
for node in graph.nodes.iter_mut() {
if let OpType::Einsum { spec } = &node.op {
if node.inputs.len() > 2 {
if let Some(new_order) = find_optimal_contraction_order(spec, &node.inputs) {
node.inputs = new_order;
reordered += 1;
}
}
}
}
Ok(reordered)
}
fn is_identity_op(node: &EinsumNode) -> bool {
match &node.op {
OpType::Einsum { spec } => {
if node.inputs.len() == 1 && spec.contains("->") {
let parts: Vec<&str> = spec.split("->").collect();
if parts.len() == 2 {
let input_indices = parts[0].trim();
let output_indices = parts[1].trim();
return input_indices == output_indices;
}
}
false
}
OpType::ElemBinary { .. } => {
false
}
OpType::ElemUnary { op } => {
op == "identity"
}
OpType::Reduce { .. } => false,
}
}
fn get_first_input(node: &EinsumNode) -> Option<usize> {
node.inputs.first().copied()
}
fn remap_node_inputs(node: &mut EinsumNode, tensor_map: &HashMap<usize, usize>) {
for input in node.inputs.iter_mut() {
if let Some(&new_idx) = tensor_map.get(input) {
*input = new_idx;
}
}
}
fn build_dependency_graph(graph: &EinsumGraph) -> HashMap<usize, usize> {
let mut deps = HashMap::new();
for (idx, _node) in graph.nodes.iter().enumerate() {
deps.insert(idx + 1, idx);
}
deps
}
fn try_merge_einsum_specs(
prev_spec: &str,
curr_spec: &str,
_intermediate_tensor: usize,
) -> Option<String> {
let prev_parts: Vec<&str> = prev_spec.split("->").collect();
let curr_parts: Vec<&str> = curr_spec.split("->").collect();
if prev_parts.len() != 2 || curr_parts.len() != 2 {
return None;
}
let prev_output = prev_parts[1].trim();
let curr_inputs: Vec<&str> = curr_parts[0].split(',').map(|s| s.trim()).collect();
let mut intermediate_indices = None;
for input in &curr_inputs {
if input.len() == prev_output.len() {
intermediate_indices = Some(input.to_string());
break;
}
}
intermediate_indices.as_ref()?;
let merged_inputs: Vec<&str> = prev_parts[0].split(',').collect();
let curr_output = curr_parts[1].trim();
let mut merged_input_str = merged_inputs.join(",");
for input in &curr_inputs {
if Some(input.to_string()) != intermediate_indices {
merged_input_str.push(',');
merged_input_str.push_str(input);
}
}
Some(format!("{}->{}", merged_input_str, curr_output))
}
fn find_optimal_contraction_order(spec: &str, inputs: &[usize]) -> Option<Vec<usize>> {
if inputs.len() <= 2 {
return None; }
let parts: Vec<&str> = spec.split("->").collect();
if parts.len() != 2 {
return None;
}
let input_specs: Vec<&str> = parts[0].split(',').map(|s| s.trim()).collect();
if input_specs.len() != inputs.len() {
return None;
}
let mut index_counts: HashMap<char, usize> = HashMap::new();
for input_spec in &input_specs {
for ch in input_spec.chars() {
*index_counts.entry(ch).or_insert(0) += 1;
}
}
let remaining: Vec<usize> = inputs.to_vec();
let has_contractions = index_counts.values().any(|&count| count > 1);
let optimal_order = if has_contractions && remaining.len() > 2 {
let mut reversed = remaining;
reversed.reverse();
reversed
} else {
remaining
};
if optimal_order != inputs {
Some(optimal_order)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_einsum_opt_result_creation() {
let result = EinsumOptResult::new();
assert_eq!(result.merged_count, 0);
assert_eq!(result.identity_eliminated, 0);
assert_eq!(result.reordered_count, 0);
assert_eq!(result.estimated_speedup, 1.0);
assert!(!result.has_changes());
}
#[test]
fn test_einsum_opt_result_has_changes() {
let mut result = EinsumOptResult::new();
assert!(!result.has_changes());
result.merged_count = 1;
assert!(result.has_changes());
result = EinsumOptResult::new();
result.identity_eliminated = 1;
assert!(result.has_changes());
result = EinsumOptResult::new();
result.reordered_count = 1;
assert!(result.has_changes());
}
#[test]
fn test_is_identity_op() {
let node = EinsumNode::new("ab->ab", vec![0], vec![1]);
assert!(is_identity_op(&node));
let node = EinsumNode::new("ab,bc->ac", vec![0, 1], vec![2]);
assert!(!is_identity_op(&node));
let node = EinsumNode::new("ab->ba", vec![0], vec![1]);
assert!(!is_identity_op(&node));
}
#[test]
fn test_get_first_input() {
let node = EinsumNode::new("ab->a", vec![5, 6], vec![7]);
assert_eq!(get_first_input(&node), Some(5));
let node = EinsumNode::elem_unary("relu", 10, 11);
assert_eq!(get_first_input(&node), Some(10));
let node = EinsumNode::reduce("sum", vec![0], 7, 8);
assert_eq!(get_first_input(&node), Some(7));
}
#[test]
fn test_eliminate_identity_ops_empty_graph() {
let mut graph = EinsumGraph::new();
let eliminated = eliminate_identity_ops(&mut graph).unwrap();
assert_eq!(eliminated, 0);
}
#[test]
fn test_merge_consecutive_einsums_empty_graph() {
let mut graph = EinsumGraph::new();
let merged = merge_consecutive_einsums(&mut graph).unwrap();
assert_eq!(merged, 0);
}
#[test]
fn test_optimize_contraction_order_empty_graph() {
let mut graph = EinsumGraph::new();
let reordered = optimize_contraction_order(&mut graph).unwrap();
assert_eq!(reordered, 0);
}
#[test]
fn test_optimize_einsum_graph_empty() {
let mut graph = EinsumGraph::new();
let result = optimize_einsum_graph(&mut graph).unwrap();
assert_eq!(result.merged_count, 0);
assert_eq!(result.identity_eliminated, 0);
assert_eq!(result.reordered_count, 0);
assert!(!result.has_changes());
}
#[test]
fn test_find_optimal_contraction_order_simple() {
let result = find_optimal_contraction_order("ab,bc->ac", &[0, 1]);
assert!(result.is_none());
}
#[test]
fn test_remap_node_inputs() {
let mut node = EinsumNode::new("ab,bc->ac", vec![0, 1], vec![2]);
let mut tensor_map = HashMap::new();
tensor_map.insert(0, 5);
tensor_map.insert(1, 6);
remap_node_inputs(&mut node, &tensor_map);
assert_eq!(node.inputs, vec![5, 6]);
}
}