use crate::graph::Node;
use std::collections::HashMap;
fn node_fingerprint(node: &Node) -> String {
let op_str = node.op.as_str();
let mut sorted_inputs = node.inputs.clone();
sorted_inputs.sort();
let inputs_str = sorted_inputs.join(",");
let mut attr_parts: Vec<String> = Vec::new();
let mut float_keys: Vec<&String> = node.attrs.floats.keys().collect();
float_keys.sort();
for k in float_keys {
if let Some(v) = node.attrs.floats.get(k) {
attr_parts.push(format!("f:{k}={v}"));
}
}
let mut int_keys: Vec<&String> = node.attrs.ints.keys().collect();
int_keys.sort();
for k in int_keys {
if let Some(v) = node.attrs.ints.get(k) {
attr_parts.push(format!("i:{k}={v}"));
}
}
let mut str_keys: Vec<&String> = node.attrs.strings.keys().collect();
str_keys.sort();
for k in str_keys {
if let Some(v) = node.attrs.strings.get(k) {
attr_parts.push(format!("s:{k}={v}"));
}
}
let mut il_keys: Vec<&String> = node.attrs.int_lists.keys().collect();
il_keys.sort();
for k in il_keys {
if let Some(v) = node.attrs.int_lists.get(k) {
attr_parts.push(format!("il:{k}={v:?}"));
}
}
let mut fl_keys: Vec<&String> = node.attrs.float_lists.keys().collect();
fl_keys.sort();
for k in fl_keys {
if let Some(v) = node.attrs.float_lists.get(k) {
attr_parts.push(format!("fl:{k}={v:?}"));
}
}
let mut t_keys: Vec<&String> = node.attrs.tensors.keys().collect();
t_keys.sort();
for k in t_keys {
if let Some(t) = node.attrs.tensors.get(k) {
let data_hash: u64 = t
.data
.iter()
.fold(0u64, |acc, &val| acc.wrapping_add(val.to_bits() as u64));
attr_parts.push(format!("t:{k}=shape{:?}hash{data_hash}", t.shape));
}
}
let attrs_str = attr_parts.join(";");
format!("{op_str}|{inputs_str}|{attrs_str}")
}
pub fn eliminate_common_subexpressions(nodes: Vec<Node>) -> Vec<Node> {
let mut seen: HashMap<String, Vec<String>> = HashMap::new();
let mut redirects: HashMap<String, String> = HashMap::new();
let mut duplicate_indices: Vec<bool> = vec![false; nodes.len()];
for (idx, node) in nodes.iter().enumerate() {
let fp = node_fingerprint(node);
if let Some(original_outputs) = seen.get(&fp) {
for (dup_out, orig_out) in node.outputs.iter().zip(original_outputs.iter()) {
if !dup_out.is_empty() && !orig_out.is_empty() {
redirects.insert(dup_out.clone(), orig_out.clone());
}
}
duplicate_indices[idx] = true;
} else {
seen.insert(fp, node.outputs.clone());
}
}
nodes
.into_iter()
.zip(duplicate_indices)
.filter(|(_, is_dup)| !is_dup)
.map(|(mut node, _)| {
for inp in &mut node.inputs {
if let Some(redirect) = redirects.get(inp) {
*inp = redirect.clone();
}
}
node
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::OpKind;
use crate::optimizer::test_utils::make_node;
#[test]
fn test_cse_removes_duplicate_nodes() {
let nodes = vec![
make_node(OpKind::Add, "add1", vec!["x", "y"], vec!["out1"]),
make_node(OpKind::Add, "add2", vec!["x", "y"], vec!["out2"]),
];
let result = eliminate_common_subexpressions(nodes);
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "add1");
}
#[test]
fn test_cse_preserves_different_attrs() {
let mut node1 = make_node(OpKind::Conv, "conv1", vec!["x", "w"], vec!["out1"]);
node1.attrs.ints.insert("group".to_string(), 1);
let mut node2 = make_node(OpKind::Conv, "conv2", vec!["x", "w"], vec!["out2"]);
node2.attrs.ints.insert("group".to_string(), 4);
let nodes = vec![node1, node2];
let result = eliminate_common_subexpressions(nodes);
assert_eq!(result.len(), 2);
}
#[test]
fn test_cse_redirects_consumers() {
let nodes = vec![
make_node(OpKind::Add, "add1", vec!["x", "y"], vec!["out1"]),
make_node(OpKind::Add, "add2", vec!["x", "y"], vec!["out2"]),
make_node(OpKind::Relu, "relu", vec!["out2"], vec!["relu_out"]),
];
let result = eliminate_common_subexpressions(nodes);
assert_eq!(result.len(), 2); assert_eq!(result[0].name, "add1");
assert_eq!(result[1].name, "relu");
assert_eq!(result[1].inputs[0], "out1"); }
#[test]
fn test_cse_different_inputs_not_removed() {
let nodes = vec![
make_node(OpKind::Add, "add1", vec!["x", "y"], vec!["out1"]),
make_node(OpKind::Add, "add2", vec!["x", "z"], vec!["out2"]),
];
let result = eliminate_common_subexpressions(nodes);
assert_eq!(result.len(), 2);
}
#[test]
fn test_cse_different_ops_not_removed() {
let nodes = vec![
make_node(OpKind::Add, "add1", vec!["x", "y"], vec!["out1"]),
make_node(OpKind::Mul, "mul1", vec!["x", "y"], vec!["out2"]),
];
let result = eliminate_common_subexpressions(nodes);
assert_eq!(result.len(), 2);
}
#[test]
fn test_cse_preserves_different_float_attrs() {
let mut node1 = make_node(OpKind::Add, "add1", vec!["x", "y"], vec!["out1"]);
node1.attrs.floats.insert("alpha".to_string(), 1.0);
let mut node2 = make_node(OpKind::Add, "add2", vec!["x", "y"], vec!["out2"]);
node2.attrs.floats.insert("alpha".to_string(), 2.0);
let nodes = vec![node1, node2];
let result = eliminate_common_subexpressions(nodes);
assert_eq!(result.len(), 2);
}
#[test]
fn test_cse_empty_graph() {
let nodes: Vec<Node> = vec![];
let result = eliminate_common_subexpressions(nodes);
assert!(result.is_empty());
}
#[test]
fn test_cse_with_attrs_match() {
let mut node1 = make_node(OpKind::Gemm, "gemm1", vec!["x", "w", "b"], vec!["out1"]);
node1.attrs.floats.insert("alpha".to_string(), 1.0);
node1.attrs.ints.insert("transB".to_string(), 1);
let mut node2 = make_node(OpKind::Gemm, "gemm2", vec!["x", "w", "b"], vec!["out2"]);
node2.attrs.floats.insert("alpha".to_string(), 1.0);
node2.attrs.ints.insert("transB".to_string(), 1);
let nodes = vec![node1, node2];
let result = eliminate_common_subexpressions(nodes);
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "gemm1");
}
}