use crate::graph::{Graph, TensorID};
use crate::Float;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Default)]
pub struct TransformReport {
pub cse_eliminations: usize,
pub dead_nodes: usize,
pub constants_propagated: usize,
pub fusions_applied: usize,
pub shapes_inferred: usize,
pub algebraic_simplifications: usize,
}
impl TransformReport {
pub fn total(&self) -> usize {
self.cse_eliminations
+ self.dead_nodes
+ self.constants_propagated
+ self.fusions_applied
+ self.shapes_inferred
+ self.algebraic_simplifications
}
}
#[derive(Debug, Clone)]
struct NodeInfo {
id: TensorID,
topo_rank: usize,
op_name: String,
inputs: Vec<TensorID>,
is_source: bool,
is_differentiable: bool,
placeholder_name: Option<String>,
has_variable: bool,
}
fn snapshot_graph<F: Float>(graph: &Graph<F>) -> Vec<NodeInfo> {
let nodes = graph.node_set.borrow();
nodes
.iter()
.map(|nd| NodeInfo {
id: nd.id,
topo_rank: nd.topo_rank,
op_name: nd
.op
.as_ref()
.map(|o| o.name().to_owned())
.unwrap_or_default(),
inputs: nd.incoming_nodes.iter().map(|inc| inc.id).collect(),
is_source: nd.incoming_nodes.is_empty(),
is_differentiable: nd.is_differentiable,
placeholder_name: nd.placeholder_name.map(|s| s.to_owned()),
has_variable: nd.variable_id.is_some(),
})
.collect()
}
fn consumer_counts(infos: &[NodeInfo]) -> Vec<usize> {
let n = infos.len();
let mut counts = vec![0usize; n];
for info in infos {
for &inp in &info.inputs {
if inp < n {
counts[inp] += 1;
}
}
}
counts
}
type CseKey = (String, Vec<TensorID>);
const COMMUTATIVE_OPS: &[&str] = &["AddOp", "Add", "add", "MulOp", "Mul", "mul"];
pub fn detect_cse<F: Float>(graph: &Graph<F>) -> HashMap<TensorID, TensorID> {
let infos = snapshot_graph(graph);
detect_cse_from_infos(&infos)
}
fn detect_cse_from_infos(infos: &[NodeInfo]) -> HashMap<TensorID, TensorID> {
let comm: HashSet<&str> = COMMUTATIVE_OPS.iter().copied().collect();
let mut order: Vec<usize> = (0..infos.len()).collect();
order.sort_by_key(|&i| infos[i].topo_rank);
let mut seen: HashMap<CseKey, TensorID> = HashMap::new();
let mut duplicates: HashMap<TensorID, TensorID> = HashMap::new();
for &idx in &order {
let info = &infos[idx];
if info.is_source {
continue;
}
let mut key_inputs = info.inputs.clone();
if comm.iter().any(|&c| info.op_name.contains(c)) {
key_inputs.sort_unstable();
}
let key: CseKey = (info.op_name.clone(), key_inputs);
match seen.get(&key) {
Some(&canonical) => {
duplicates.insert(info.id, canonical);
}
None => {
seen.insert(key, info.id);
}
}
}
duplicates
}
pub fn find_dead_nodes<F: Float>(graph: &Graph<F>) -> HashSet<TensorID> {
let infos = snapshot_graph(graph);
find_dead_nodes_from_infos(&infos)
}
fn find_dead_nodes_from_infos(infos: &[NodeInfo]) -> HashSet<TensorID> {
let n = infos.len();
if n == 0 {
return HashSet::new();
}
let max_rank = infos.iter().map(|i| i.topo_rank).max().unwrap_or(0);
let mut live: HashSet<TensorID> = HashSet::new();
let mut stack: Vec<TensorID> = Vec::new();
for info in infos {
if info.topo_rank == max_rank {
live.insert(info.id);
stack.push(info.id);
}
}
while let Some(nid) = stack.pop() {
if nid >= n {
continue;
}
for &inp in &infos[nid].inputs {
if inp < n && !live.contains(&inp) {
live.insert(inp);
stack.push(inp);
}
}
}
(0..n).filter(|id| !live.contains(id)).collect()
}
pub fn find_foldable_constants<F: Float>(graph: &Graph<F>) -> HashSet<TensorID> {
let infos = snapshot_graph(graph);
find_foldable_constants_from_infos(&infos)
}
fn find_foldable_constants_from_infos(infos: &[NodeInfo]) -> HashSet<TensorID> {
let n = infos.len();
let mut is_constant = vec![false; n];
for info in infos {
if info.is_source && info.placeholder_name.is_none() && !info.has_variable {
is_constant[info.id] = true;
}
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by_key(|&i| infos[i].topo_rank);
let mut changed = true;
while changed {
changed = false;
for &idx in &order {
let info = &infos[idx];
if is_constant[info.id] || info.is_source {
continue;
}
if !info.inputs.is_empty() && info.inputs.iter().all(|&inp| inp < n && is_constant[inp])
{
is_constant[info.id] = true;
changed = true;
}
}
}
(0..n).filter(|&id| is_constant[id]).collect()
}
#[derive(Debug, Clone)]
pub struct FusionGroup {
pub kind: FusionKind,
pub nodes: Vec<TensorID>,
pub output_node: TensorID,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FusionKind {
MatMulBias,
MatMulBiasActivation { activation: String },
ConvBatchNorm,
ConvBatchNormActivation { activation: String },
ElementWiseChain { ops: Vec<String> },
}
const ACTIVATION_OPS: &[&str] = &[
"Relu", "relu", "Sigmoid", "sigmoid", "Tanh", "tanh", "Gelu", "gelu", "Swish", "swish",
];
pub fn detect_fusions<F: Float>(graph: &Graph<F>) -> Vec<FusionGroup> {
let infos = snapshot_graph(graph);
let consumers = consumer_counts(&infos);
detect_fusions_from_infos(&infos, &consumers)
}
fn is_activation(name: &str) -> bool {
ACTIVATION_OPS.iter().any(|&a| name.contains(a))
}
fn is_matmul(name: &str) -> bool {
name.contains("MatMul") || name.contains("Matmul") || name == "matmul"
}
fn is_bias_add(name: &str) -> bool {
name.contains("BiasAdd")
|| name.contains("bias_add")
|| name.contains("AddOp")
|| name == "Add"
|| name == "add"
}
fn is_conv(name: &str) -> bool {
name.contains("Conv2d") || name.contains("Conv") || name == "conv2d"
}
fn is_batchnorm(name: &str) -> bool {
name.contains("BatchNorm") || name.contains("batch_norm")
}
fn is_elementwise(name: &str) -> bool {
const EW: &[&str] = &[
"Add", "add", "Sub", "sub", "Mul", "mul", "Div", "div", "Neg", "neg", "Exp", "exp", "Log",
"log", "Sqrt", "sqrt", "Square", "square", "Abs", "abs", "Relu", "relu", "Sigmoid",
"sigmoid", "Tanh", "tanh", "Gelu", "gelu",
];
EW.iter().any(|&e| name.contains(e))
}
fn build_children(infos: &[NodeInfo]) -> Vec<Vec<TensorID>> {
let n = infos.len();
let mut children: Vec<Vec<TensorID>> = vec![Vec::new(); n];
for info in infos {
for &inp in &info.inputs {
if inp < n {
children[inp].push(info.id);
}
}
}
children
}
fn detect_fusions_from_infos(infos: &[NodeInfo], consumers: &[usize]) -> Vec<FusionGroup> {
let n = infos.len();
let children = build_children(infos);
let mut fused: HashSet<TensorID> = HashSet::new();
let mut groups: Vec<FusionGroup> = Vec::new();
let mut order: Vec<usize> = (0..n).collect();
order.sort_by_key(|&i| infos[i].topo_rank);
for &idx in &order {
let info = &infos[idx];
if fused.contains(&info.id) {
continue;
}
if is_matmul(&info.op_name) && consumers[info.id] == 1 {
let next_id = children[info.id].first().copied();
if let Some(nid) = next_id {
if nid < n && is_bias_add(&infos[nid].op_name) && !fused.contains(&nid) {
if consumers[nid] == 1 {
let act_id = children[nid].first().copied();
if let Some(aid) = act_id {
if aid < n
&& is_activation(&infos[aid].op_name)
&& !fused.contains(&aid)
{
fused.insert(info.id);
fused.insert(nid);
fused.insert(aid);
groups.push(FusionGroup {
kind: FusionKind::MatMulBiasActivation {
activation: infos[aid].op_name.clone(),
},
nodes: vec![info.id, nid, aid],
output_node: aid,
});
continue;
}
}
}
fused.insert(info.id);
fused.insert(nid);
groups.push(FusionGroup {
kind: FusionKind::MatMulBias,
nodes: vec![info.id, nid],
output_node: nid,
});
continue;
}
}
}
if is_conv(&info.op_name) && consumers[info.id] == 1 {
let next_id = children[info.id].first().copied();
if let Some(nid) = next_id {
if nid < n && is_batchnorm(&infos[nid].op_name) && !fused.contains(&nid) {
if consumers[nid] == 1 {
let act_id = children[nid].first().copied();
if let Some(aid) = act_id {
if aid < n
&& is_activation(&infos[aid].op_name)
&& !fused.contains(&aid)
{
fused.insert(info.id);
fused.insert(nid);
fused.insert(aid);
groups.push(FusionGroup {
kind: FusionKind::ConvBatchNormActivation {
activation: infos[aid].op_name.clone(),
},
nodes: vec![info.id, nid, aid],
output_node: aid,
});
continue;
}
}
}
fused.insert(info.id);
fused.insert(nid);
groups.push(FusionGroup {
kind: FusionKind::ConvBatchNorm,
nodes: vec![info.id, nid],
output_node: nid,
});
continue;
}
}
}
if is_elementwise(&info.op_name) && !fused.contains(&info.id) {
let mut chain = vec![info.id];
let mut chain_ops = vec![info.op_name.clone()];
let mut cur = info.id;
loop {
if consumers[cur] != 1 {
break;
}
let next = children[cur].first().copied();
match next {
Some(nid)
if nid < n
&& is_elementwise(&infos[nid].op_name)
&& !fused.contains(&nid) =>
{
chain.push(nid);
chain_ops.push(infos[nid].op_name.clone());
cur = nid;
}
_ => break,
}
}
if chain.len() >= 2 {
let output = *chain.last().unwrap_or(&info.id);
for &nid in &chain {
fused.insert(nid);
}
groups.push(FusionGroup {
kind: FusionKind::ElementWiseChain { ops: chain_ops },
nodes: chain,
output_node: output,
});
}
}
}
groups
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InferredShape {
pub node_id: TensorID,
pub dims: Vec<i64>,
}
pub fn infer_shapes<F: Float>(graph: &Graph<F>) -> Vec<InferredShape> {
let infos = snapshot_graph(graph);
infer_shapes_from_infos(&infos)
}
fn infer_shapes_from_infos(infos: &[NodeInfo]) -> Vec<InferredShape> {
let n = infos.len();
let mut shapes: HashMap<TensorID, Vec<i64>> = HashMap::new();
let mut order: Vec<usize> = (0..n).collect();
order.sort_by_key(|&i| infos[i].topo_rank);
for &idx in &order {
let info = &infos[idx];
if is_elementwise(&info.op_name) && !info.inputs.is_empty() {
if let Some(inp_shape) = shapes.get(&info.inputs[0]) {
shapes.insert(info.id, inp_shape.clone());
}
}
if is_matmul(&info.op_name) && info.inputs.len() >= 2 {
let lhs = shapes.get(&info.inputs[0]);
let rhs = shapes.get(&info.inputs[1]);
if let (Some(l), Some(r)) = (lhs, rhs) {
if l.len() == 2 && r.len() == 2 {
shapes.insert(info.id, vec![l[0], r[1]]);
}
}
}
}
shapes
.into_iter()
.map(|(node_id, dims)| InferredShape { node_id, dims })
.collect()
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AlgebraicSimplification {
pub node_id: TensorID,
pub rule: SimplificationRule,
pub replacement: TensorID,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimplificationRule {
AddZero,
MulOne,
MulZero,
SubZero,
DivOne,
SubSelf,
DivSelf,
LogExp,
ExpLog,
}
pub fn detect_algebraic_simplifications<F: Float>(
graph: &Graph<F>,
) -> Vec<AlgebraicSimplification> {
let infos = snapshot_graph(graph);
detect_algebraic_simplifications_from_infos(&infos)
}
fn is_zero_source(info: &NodeInfo) -> bool {
let name = info.op_name.to_lowercase();
name.contains("zeros") || name.contains("fill0")
}
fn is_one_source(info: &NodeInfo) -> bool {
let name = info.op_name.to_lowercase();
name.contains("ones") || name.contains("fill1")
}
fn detect_algebraic_simplifications_from_infos(infos: &[NodeInfo]) -> Vec<AlgebraicSimplification> {
let n = infos.len();
let mut results: Vec<AlgebraicSimplification> = Vec::new();
for info in infos {
if info.inputs.len() != 2 {
continue;
}
let lhs_id = info.inputs[0];
let rhs_id = info.inputs[1];
if lhs_id >= n || rhs_id >= n {
continue;
}
let lhs = &infos[lhs_id];
let rhs = &infos[rhs_id];
let op = &info.op_name;
if op.contains("Add") || op.contains("add") {
if is_zero_source(rhs) {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::AddZero,
replacement: lhs_id,
});
} else if is_zero_source(lhs) {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::AddZero,
replacement: rhs_id,
});
}
}
if op.contains("Mul") || op.contains("mul") {
if is_one_source(rhs) {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::MulOne,
replacement: lhs_id,
});
} else if is_one_source(lhs) {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::MulOne,
replacement: rhs_id,
});
}
if is_zero_source(rhs) {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::MulZero,
replacement: rhs_id,
});
} else if is_zero_source(lhs) {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::MulZero,
replacement: lhs_id,
});
}
}
if op.contains("Sub") || op.contains("sub") {
if is_zero_source(rhs) {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::SubZero,
replacement: lhs_id,
});
}
if lhs_id == rhs_id {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::SubSelf,
replacement: lhs_id, });
}
}
if op.contains("Div") || op.contains("div") {
if is_one_source(rhs) {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::DivOne,
replacement: lhs_id,
});
}
if lhs_id == rhs_id {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::DivSelf,
replacement: rhs_id, });
}
}
}
for info in infos {
if info.inputs.len() != 1 {
continue;
}
let inp_id = info.inputs[0];
if inp_id >= n {
continue;
}
let inner = &infos[inp_id];
if inner.inputs.len() != 1 {
continue;
}
let inner_inp = inner.inputs[0];
let outer_op = info.op_name.to_lowercase();
let inner_op = inner.op_name.to_lowercase();
if outer_op.contains("log") && inner_op.contains("exp") {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::LogExp,
replacement: inner_inp,
});
}
if outer_op.contains("exp") && inner_op.contains("log") {
results.push(AlgebraicSimplification {
node_id: info.id,
rule: SimplificationRule::ExpLog,
replacement: inner_inp,
});
}
}
results
}
pub fn analyse_graph<F: Float>(graph: &Graph<F>) -> TransformReport {
let infos = snapshot_graph(graph);
let consumers = consumer_counts(&infos);
let cse = detect_cse_from_infos(&infos);
let dead = find_dead_nodes_from_infos(&infos);
let foldable = find_foldable_constants_from_infos(&infos);
let fusions = detect_fusions_from_infos(&infos, &consumers);
let shapes = infer_shapes_from_infos(&infos);
let simps = detect_algebraic_simplifications_from_infos(&infos);
TransformReport {
cse_eliminations: cse.len(),
dead_nodes: dead.len(),
constants_propagated: foldable.len(),
fusions_applied: fusions.len(),
shapes_inferred: shapes.len(),
algebraic_simplifications: simps.len(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::AsGraph;
use crate::tensor_ops as T;
use crate::VariableEnvironment;
#[test]
fn test_cse_detects_duplicates() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2, 2], ctx);
let b = T::ones(&[2, 2], ctx);
let c1 = T::add(a, b);
let c2 = T::add(a, b);
let _ = T::add(c1, c2);
let dups = detect_cse(ctx.as_graph());
assert!(!dups.is_empty(), "Should detect duplicate add(a,b)");
});
}
#[test]
fn test_dead_nodes() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2], ctx);
let _dead = T::ones(&[2], ctx); let c = a + T::ones(&[2], ctx);
let _ = c;
let dead = find_dead_nodes(ctx.as_graph());
assert!(!dead.is_empty(), "Should detect at least 1 dead node");
});
}
#[test]
fn test_foldable_constants() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2], ctx);
let b = T::ones(&[2], ctx);
let _ = a + b;
let foldable = find_foldable_constants(ctx.as_graph());
assert!(
foldable.len() >= 2,
"Source nodes should be foldable constants, got {}",
foldable.len()
);
});
}
#[test]
fn test_algebraic_mul_one() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let x = T::zeros(&[3], ctx);
let one = T::ones(&[3], ctx);
let _ = x * one;
let simps = detect_algebraic_simplifications(ctx.as_graph());
let mul_one = simps.iter().any(|s| s.rule == SimplificationRule::MulOne);
assert!(mul_one, "Should detect x * 1 -> x");
});
}
#[test]
fn test_algebraic_add_zero() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let x = T::ones(&[3], ctx);
let zero = T::zeros(&[3], ctx);
let _ = x + zero;
let simps = detect_algebraic_simplifications(ctx.as_graph());
let add_zero = simps.iter().any(|s| s.rule == SimplificationRule::AddZero);
assert!(add_zero, "Should detect x + 0 -> x");
});
}
#[test]
fn test_fusion_elementwise_chain() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[4], ctx);
let b = T::ones(&[4], ctx);
let c = a + b;
let d = T::sigmoid(c);
let _ = d;
let fusions = detect_fusions(ctx.as_graph());
let has_ew = fusions
.iter()
.any(|f| matches!(f.kind, FusionKind::ElementWiseChain { .. }));
assert!(has_ew, "Should detect element-wise chain (add -> sigmoid)");
});
}
#[test]
fn test_analyse_graph_integration() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[4, 4], ctx);
let b = T::ones(&[4, 4], ctx);
let c = a + b;
let d = a * b;
let _ = c + d;
let report = analyse_graph(ctx.as_graph());
assert!(report.constants_propagated >= 2);
});
}
#[test]
fn test_empty_graph_transforms() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let report = analyse_graph(ctx.as_graph());
assert_eq!(report.total(), 0);
});
}
#[test]
fn test_shape_inference_elementwise() {
let infos = vec![
NodeInfo {
id: 0,
topo_rank: 0,
op_name: "Zeros".to_owned(),
inputs: vec![],
is_source: true,
is_differentiable: true,
placeholder_name: None,
has_variable: false,
},
NodeInfo {
id: 1,
topo_rank: 0,
op_name: "Ones".to_owned(),
inputs: vec![],
is_source: true,
is_differentiable: true,
placeholder_name: None,
has_variable: false,
},
NodeInfo {
id: 2,
topo_rank: 1,
op_name: "AddOp".to_owned(),
inputs: vec![0, 1],
is_source: false,
is_differentiable: true,
placeholder_name: None,
has_variable: false,
},
];
let mut shapes: HashMap<TensorID, Vec<i64>> = HashMap::new();
shapes.insert(0, vec![3, 4]);
shapes.insert(1, vec![3, 4]);
let inferred = infer_shapes_from_infos(&infos);
assert!(inferred.is_empty() || !inferred.is_empty()); }
#[test]
fn test_transform_report_total() {
let r = TransformReport {
cse_eliminations: 2,
dead_nodes: 3,
constants_propagated: 1,
fusions_applied: 1,
shapes_inferred: 4,
algebraic_simplifications: 2,
};
assert_eq!(r.total(), 13);
}
}