use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use crate::{EinsumGraph, EinsumNode, IrError, OpType};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ConstantInfo {
pub tensor_idx: usize,
pub is_compile_time_constant: bool,
pub is_identity: bool,
pub is_zero: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ConstantPropagationResult {
pub constant_tensors: HashSet<usize>,
pub constant_info: HashMap<usize, ConstantInfo>,
pub foldable_operations: usize,
pub estimated_speedup: f64,
}
impl ConstantPropagationResult {
pub fn none() -> Self {
Self {
constant_tensors: HashSet::new(),
constant_info: HashMap::new(),
foldable_operations: 0,
estimated_speedup: 1.0,
}
}
pub fn is_constant(&self, tensor_idx: usize) -> bool {
self.constant_tensors.contains(&tensor_idx)
}
pub fn get_info(&self, tensor_idx: usize) -> Option<&ConstantInfo> {
self.constant_info.get(&tensor_idx)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FoldingStats {
pub operations_folded: usize,
pub operations_simplified: usize,
pub operations_eliminated: usize,
pub estimated_speedup: f64,
}
impl FoldingStats {
pub fn none() -> Self {
Self {
operations_folded: 0,
operations_simplified: 0,
operations_eliminated: 0,
estimated_speedup: 1.0,
}
}
pub fn total_transformations(&self) -> usize {
self.operations_folded + self.operations_simplified + self.operations_eliminated
}
}
pub fn analyze_constants(graph: &EinsumGraph) -> Result<ConstantPropagationResult, IrError> {
let mut result = ConstantPropagationResult::none();
let _constant_candidates: HashSet<usize> = graph.inputs.iter().copied().collect();
for (tensor_idx, metadata) in &graph.tensor_metadata {
if is_compile_time_constant(metadata) {
result.constant_tensors.insert(*tensor_idx);
result.constant_info.insert(
*tensor_idx,
ConstantInfo {
tensor_idx: *tensor_idx,
is_compile_time_constant: true,
is_identity: is_identity_value(metadata),
is_zero: is_zero_value(metadata),
},
);
}
}
let mut changed = true;
while changed {
changed = false;
for node in graph.nodes.iter() {
let all_inputs_constant = node
.inputs
.iter()
.all(|&idx| result.constant_tensors.contains(&idx));
if all_inputs_constant && !node.inputs.is_empty() {
for &output_idx in &node.outputs {
if !result.constant_tensors.contains(&output_idx) {
result.constant_tensors.insert(output_idx);
result.constant_info.insert(
output_idx,
ConstantInfo {
tensor_idx: output_idx,
is_compile_time_constant: true,
is_identity: false,
is_zero: false,
},
);
result.foldable_operations += 1;
changed = true;
}
}
}
}
}
if result.foldable_operations > 0 {
let total_ops = graph.nodes.len();
let folding_ratio = result.foldable_operations as f64 / total_ops.max(1) as f64;
result.estimated_speedup = 1.0 + folding_ratio * 0.3; }
Ok(result)
}
pub fn apply_constant_folding(
graph: &mut EinsumGraph,
constants: &ConstantPropagationResult,
) -> Result<FoldingStats, IrError> {
let mut stats = FoldingStats::none();
let mut replacements: HashMap<usize, usize> = HashMap::new();
for node in graph.nodes.iter() {
if let Some(simplified_output) = try_simplify_operation(node, constants) {
if !node.outputs.is_empty() {
replacements.insert(node.outputs[0], simplified_output);
stats.operations_simplified += 1;
}
} else if try_eliminate_operation(node, constants) {
stats.operations_eliminated += 1;
} else if constants.is_constant(node.outputs.first().copied().unwrap_or(usize::MAX)) {
stats.operations_folded += 1;
}
}
for node in &mut graph.nodes {
for input_idx in &mut node.inputs {
if let Some(&replacement) = replacements.get(input_idx) {
*input_idx = replacement;
}
}
}
for output_idx in &mut graph.outputs {
if let Some(&replacement) = replacements.get(output_idx) {
*output_idx = replacement;
}
}
if stats.total_transformations() > 0 {
let total_ops = graph.nodes.len().max(1);
let optimization_ratio = stats.total_transformations() as f64 / total_ops as f64;
stats.estimated_speedup = 1.0 + optimization_ratio * 0.4;
}
Ok(stats)
}
pub fn fold_constants_aggressive(graph: &mut EinsumGraph) -> Result<FoldingStats, IrError> {
let mut total_stats = FoldingStats::none();
for _ in 0..3 {
let constants = analyze_constants(graph)?;
let stats = apply_constant_folding(graph, &constants)?;
total_stats.operations_folded += stats.operations_folded;
total_stats.operations_simplified += stats.operations_simplified;
total_stats.operations_eliminated += stats.operations_eliminated;
if stats.total_transformations() == 0 {
break;
}
}
if total_stats.total_transformations() > 0 {
let total_ops = graph.nodes.len().max(1);
let optimization_ratio = total_stats.total_transformations() as f64 / total_ops as f64;
total_stats.estimated_speedup = 1.0 + optimization_ratio * 0.5;
}
Ok(total_stats)
}
pub fn identify_constant_subgraphs(graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
let constants = analyze_constants(graph)?;
let mut subgraphs = Vec::new();
let mut visited = HashSet::new();
for (node_idx, node) in graph.nodes.iter().enumerate() {
if visited.contains(&node_idx) {
continue;
}
let all_constant = node.inputs.iter().all(|&idx| constants.is_constant(idx));
if all_constant && !node.inputs.is_empty() {
let mut subgraph = vec![node_idx];
visited.insert(node_idx);
let mut changed = true;
while changed {
changed = false;
for (idx, n) in graph.nodes.iter().enumerate() {
if visited.contains(&idx) {
continue;
}
let depends_on_subgraph = n.inputs.iter().any(|&input_idx| {
graph.nodes.iter().enumerate().any(|(sub_idx, sub_node)| {
subgraph.contains(&sub_idx) && sub_node.outputs.contains(&input_idx)
})
});
if depends_on_subgraph {
subgraph.push(idx);
visited.insert(idx);
changed = true;
}
}
}
if !subgraph.is_empty() {
subgraphs.push(subgraph);
}
}
}
Ok(subgraphs)
}
fn is_compile_time_constant(metadata: &crate::Metadata) -> bool {
metadata
.get_attribute("constant")
.map(|v| v == "true")
.unwrap_or(false)
}
fn is_identity_value(metadata: &crate::Metadata) -> bool {
metadata
.get_attribute("identity")
.map(|v| v == "true")
.unwrap_or(false)
}
fn is_zero_value(metadata: &crate::Metadata) -> bool {
metadata
.get_attribute("zero")
.map(|v| v == "true")
.unwrap_or(false)
}
fn try_simplify_operation(
node: &EinsumNode,
constants: &ConstantPropagationResult,
) -> Option<usize> {
if let OpType::ElemBinary { op } = &node.op {
if node.inputs.len() == 2 {
let left = node.inputs[0];
let right = node.inputs[1];
if op == "add" {
if constants.get_info(right).is_some_and(|info| info.is_zero) {
return Some(left);
}
if constants.get_info(left).is_some_and(|info| info.is_zero) {
return Some(right);
}
}
if op == "mul" {
if constants
.get_info(right)
.is_some_and(|info| info.is_identity)
{
return Some(left);
}
if constants
.get_info(left)
.is_some_and(|info| info.is_identity)
{
return Some(right);
}
}
}
}
None
}
fn try_eliminate_operation(node: &EinsumNode, constants: &ConstantPropagationResult) -> bool {
if let OpType::ElemBinary { op } = &node.op {
if node.inputs.len() == 2 {
let left = node.inputs[0];
let right = node.inputs[1];
if op == "mul" {
return constants.get_info(left).is_some_and(|info| info.is_zero)
|| constants.get_info(right).is_some_and(|info| info.is_zero);
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Metadata;
fn create_constant_metadata() -> Metadata {
Metadata::new().with_attribute("constant", "true")
}
fn create_zero_metadata() -> Metadata {
Metadata::new()
.with_attribute("constant", "true")
.with_attribute("zero", "true")
}
fn create_identity_metadata() -> Metadata {
Metadata::new()
.with_attribute("constant", "true")
.with_attribute("identity", "true")
}
#[test]
fn test_constant_info() {
let info = ConstantInfo {
tensor_idx: 0,
is_compile_time_constant: true,
is_identity: false,
is_zero: false,
};
assert_eq!(info.tensor_idx, 0);
assert!(info.is_compile_time_constant);
assert!(!info.is_identity);
assert!(!info.is_zero);
}
#[test]
fn test_constant_propagation_result_none() {
let result = ConstantPropagationResult::none();
assert!(result.constant_tensors.is_empty());
assert!(result.constant_info.is_empty());
assert_eq!(result.foldable_operations, 0);
assert_eq!(result.estimated_speedup, 1.0);
}
#[test]
fn test_folding_stats_none() {
let stats = FoldingStats::none();
assert_eq!(stats.operations_folded, 0);
assert_eq!(stats.operations_simplified, 0);
assert_eq!(stats.operations_eliminated, 0);
assert_eq!(stats.total_transformations(), 0);
}
#[test]
fn test_analyze_constants_empty_graph() {
let graph = EinsumGraph::new();
let result = analyze_constants(&graph).expect("unwrap");
assert!(result.constant_tensors.is_empty());
}
#[test]
fn test_analyze_constants_with_metadata() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
let b = graph.add_tensor("B");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
let result = analyze_constants(&graph).expect("unwrap");
assert!(result.is_constant(a));
assert!(result.is_constant(b)); assert_eq!(result.foldable_operations, 1);
}
#[test]
fn test_simplify_add_zero() {
let mut graph = EinsumGraph::new();
let x = graph.add_tensor("x");
let zero = graph.add_tensor_with_metadata("zero", create_zero_metadata());
let result = graph.add_tensor("result");
let node = EinsumNode::elem_binary("add", x, zero, result);
let mut const_result = ConstantPropagationResult::none();
const_result.constant_tensors.insert(zero);
const_result.constant_info.insert(
zero,
ConstantInfo {
tensor_idx: zero,
is_compile_time_constant: true,
is_identity: false,
is_zero: true,
},
);
let simplified = try_simplify_operation(&node, &const_result);
assert_eq!(simplified, Some(x));
}
#[test]
fn test_simplify_mul_one() {
let mut graph = EinsumGraph::new();
let x = graph.add_tensor("x");
let one = graph.add_tensor_with_metadata("one", create_identity_metadata());
let result = graph.add_tensor("result");
let node = EinsumNode::elem_binary("mul", x, one, result);
let mut const_result = ConstantPropagationResult::none();
const_result.constant_tensors.insert(one);
const_result.constant_info.insert(
one,
ConstantInfo {
tensor_idx: one,
is_compile_time_constant: true,
is_identity: true,
is_zero: false,
},
);
let simplified = try_simplify_operation(&node, &const_result);
assert_eq!(simplified, Some(x));
}
#[test]
fn test_eliminate_mul_zero() {
let mut graph = EinsumGraph::new();
let x = graph.add_tensor("x");
let zero = graph.add_tensor_with_metadata("zero", create_zero_metadata());
let result = graph.add_tensor("result");
let node = EinsumNode::elem_binary("mul", x, zero, result);
let mut const_result = ConstantPropagationResult::none();
const_result.constant_tensors.insert(zero);
const_result.constant_info.insert(
zero,
ConstantInfo {
tensor_idx: zero,
is_compile_time_constant: true,
is_identity: false,
is_zero: true,
},
);
let should_eliminate = try_eliminate_operation(&node, &const_result);
assert!(should_eliminate);
}
#[test]
fn test_apply_constant_folding() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
let b = graph.add_tensor_with_metadata("B", create_constant_metadata());
let c = graph.add_tensor("C");
graph
.add_node(EinsumNode::elem_binary("add", a, b, c))
.expect("unwrap");
let constants = analyze_constants(&graph).expect("unwrap");
let stats = apply_constant_folding(&mut graph, &constants).expect("unwrap");
assert!(stats.operations_folded > 0 || stats.total_transformations() > 0);
}
#[test]
fn test_fold_constants_aggressive() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
let b = graph.add_tensor_with_metadata("B", create_constant_metadata());
let c = graph.add_tensor("C");
let d = graph.add_tensor("D");
graph
.add_node(EinsumNode::elem_binary("add", a, b, c))
.expect("unwrap");
graph
.add_node(EinsumNode::elem_unary("relu", c, d))
.expect("unwrap");
let stats = fold_constants_aggressive(&mut graph).expect("unwrap");
assert!(stats.operations_folded >= 1);
}
#[test]
fn test_identify_constant_subgraphs() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
let b = graph.add_tensor_with_metadata("B", create_constant_metadata());
let c = graph.add_tensor("C");
graph
.add_node(EinsumNode::elem_binary("add", a, b, c))
.expect("unwrap");
let subgraphs = identify_constant_subgraphs(&graph).expect("unwrap");
assert!(!subgraphs.is_empty());
}
#[test]
fn test_is_constant_metadata_helpers() {
let const_metadata = create_constant_metadata();
assert!(is_compile_time_constant(&const_metadata));
let zero_metadata = create_zero_metadata();
assert!(is_compile_time_constant(&zero_metadata));
assert!(is_zero_value(&zero_metadata));
let identity_metadata = create_identity_metadata();
assert!(is_compile_time_constant(&identity_metadata));
assert!(is_identity_value(&identity_metadata));
}
#[test]
fn test_constant_propagation_through_chain() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
let d = graph.add_tensor("D");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
graph
.add_node(EinsumNode::elem_unary("relu", b, c))
.expect("unwrap");
graph
.add_node(EinsumNode::elem_unary("relu", c, d))
.expect("unwrap");
let result = analyze_constants(&graph).expect("unwrap");
assert!(result.is_constant(a));
assert!(result.is_constant(b));
assert!(result.is_constant(c));
assert!(result.is_constant(d));
assert_eq!(result.foldable_operations, 3);
}
#[test]
fn test_mixed_constant_and_variable_graph() {
let mut graph = EinsumGraph::new();
let const_a = graph.add_tensor_with_metadata("const_A", create_constant_metadata());
let var_x = graph.add_tensor("var_X");
let result = graph.add_tensor("result");
graph
.add_node(EinsumNode::elem_binary("add", const_a, var_x, result))
.expect("unwrap");
let analysis = analyze_constants(&graph).expect("unwrap");
assert!(analysis.is_constant(const_a));
assert!(!analysis.is_constant(var_x));
assert!(!analysis.is_constant(result)); }
#[test]
fn test_folding_stats_total_transformations() {
let stats = FoldingStats {
operations_folded: 2,
operations_simplified: 3,
operations_eliminated: 1,
estimated_speedup: 1.5,
};
assert_eq!(stats.total_transformations(), 6);
}
#[test]
fn test_speedup_estimation() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
let b = graph.add_tensor("B");
graph
.add_node(EinsumNode::elem_unary("relu", a, b))
.expect("unwrap");
let result = analyze_constants(&graph).expect("unwrap");
assert!(result.estimated_speedup > 1.0);
}
}