use crate::graph::{Attributes, Node, OpKind};
use crate::tensor::Tensor;
use std::collections::{HashMap, HashSet};
pub fn cancel_consecutive_transpose(nodes: Vec<Node>) -> Vec<Node> {
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut consumer_count: HashMap<String, usize> = HashMap::new();
for node in &nodes {
for inp in &node.inputs {
if !inp.is_empty() {
*consumer_count.entry(inp.clone()).or_insert(0) += 1;
}
}
}
let mut skip: HashSet<usize> = HashSet::new();
let mut replacements: HashMap<usize, Node> = HashMap::new();
let mut redirects: HashMap<String, String> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if skip.contains(&i) {
continue;
}
if !matches!(node.op, OpKind::Transpose) {
continue;
}
let input_name = match node.inputs.first() {
Some(name) if !name.is_empty() => name,
_ => continue,
};
let prev_idx = match producer.get(input_name) {
Some(&idx) => idx,
None => continue,
};
if skip.contains(&prev_idx) {
continue;
}
if !matches!(nodes[prev_idx].op, OpKind::Transpose) {
continue;
}
if consumer_count.get(input_name).copied().unwrap_or(0) != 1 {
continue;
}
let perm1 = match nodes[prev_idx].attrs.int_lists.get("perm") {
Some(p) => p.clone(),
None => continue,
};
let perm2 = match node.attrs.int_lists.get("perm") {
Some(p) => p.clone(),
None => continue,
};
if perm1.len() != perm2.len() {
continue;
}
let composed: Vec<i64> = perm2
.iter()
.map(|&j| {
let j_usize = j as usize;
if j_usize < perm1.len() {
perm1[j_usize]
} else {
j
}
})
.collect();
let is_identity = composed.iter().enumerate().all(|(idx, &v)| v == idx as i64);
if is_identity {
let original_input = match nodes[prev_idx].inputs.first() {
Some(name) => name.clone(),
None => continue,
};
skip.insert(prev_idx);
skip.insert(i);
if let Some(out_name) = node.outputs.first() {
redirects.insert(out_name.clone(), original_input);
}
} else {
let mut new_attrs = Attributes::default();
new_attrs.int_lists.insert("perm".to_string(), composed);
let original_input = match nodes[prev_idx].inputs.first() {
Some(name) => name.clone(),
None => continue,
};
let collapsed = Node {
op: OpKind::Transpose,
name: format!("{}_collapsed_transpose", nodes[prev_idx].name),
inputs: vec![original_input],
outputs: node.outputs.clone(),
attrs: new_attrs,
};
skip.insert(prev_idx);
replacements.insert(i, collapsed);
}
}
nodes
.into_iter()
.enumerate()
.filter(|(i, _)| !skip.contains(i))
.map(|(i, mut n)| {
if let Some(replacement) = replacements.remove(&i) {
replacement
} else {
for inp in &mut n.inputs {
if let Some(redirect) = redirects.get(inp) {
*inp = redirect.clone();
}
}
n
}
})
.collect()
}
pub fn cancel_consecutive_reshape(nodes: Vec<Node>) -> Vec<Node> {
if nodes.len() < 2 {
return nodes;
}
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut consumer_count: HashMap<String, usize> = HashMap::new();
for node in &nodes {
for inp in &node.inputs {
if !inp.is_empty() {
*consumer_count.entry(inp.clone()).or_insert(0) += 1;
}
}
}
let mut skip: HashSet<usize> = HashSet::new();
let mut replacements: HashMap<usize, Node> = HashMap::new();
let mut redirects: HashMap<String, String> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if skip.contains(&i) {
continue;
}
if !matches!(node.op, OpKind::Reshape) {
continue;
}
if node.inputs.is_empty() {
continue;
}
let prev_out = &node.inputs[0];
if consumer_count.get(prev_out).copied().unwrap_or(0) != 1 {
continue;
}
let prev_idx = match producer.get(prev_out) {
Some(&idx) => idx,
None => continue,
};
if skip.contains(&prev_idx) {
continue;
}
if !matches!(nodes[prev_idx].op, OpKind::Reshape) {
continue;
}
if nodes[prev_idx].inputs.is_empty() {
continue;
}
let original_input = nodes[prev_idx].inputs[0].clone();
let mut new_inputs = vec![original_input.clone()];
if node.inputs.len() > 1 {
new_inputs.push(node.inputs[1].clone());
}
let shapes_match = if nodes[prev_idx].inputs.len() > 1 && node.inputs.len() > 1 {
nodes[prev_idx].inputs[1] == node.inputs[1]
} else {
false
};
if shapes_match {
skip.insert(prev_idx);
skip.insert(i);
if let Some(out_name) = node.outputs.first() {
redirects.insert(out_name.clone(), original_input);
}
} else {
let collapsed = Node {
op: OpKind::Reshape,
name: format!("{}_collapsed_reshape", nodes[prev_idx].name),
inputs: new_inputs,
outputs: node.outputs.clone(),
attrs: node.attrs.clone(),
};
skip.insert(prev_idx);
replacements.insert(i, collapsed);
}
}
nodes
.into_iter()
.enumerate()
.filter(|(i, _)| !skip.contains(i))
.map(|(i, mut n)| {
if let Some(replacement) = replacements.remove(&i) {
replacement
} else {
for inp in &mut n.inputs {
if let Some(redirect) = redirects.get(inp) {
*inp = redirect.clone();
}
}
n
}
})
.collect()
}
pub fn fuse_mul_sigmoid_to_silu(nodes: Vec<Node>) -> Vec<Node> {
if nodes.len() < 2 {
return nodes;
}
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut consumer_count: HashMap<String, usize> = HashMap::new();
for node in &nodes {
for inp in &node.inputs {
if !inp.is_empty() {
*consumer_count.entry(inp.clone()).or_insert(0) += 1;
}
}
}
let mut skip: HashSet<usize> = HashSet::new();
let mut replacements: HashMap<usize, Node> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if skip.contains(&i) {
continue;
}
if !matches!(node.op, OpKind::Mul) {
continue;
}
if node.inputs.len() < 2 {
continue;
}
let (sigmoid_out, x_name, sigmoid_idx) = {
let inp0 = &node.inputs[0];
let inp1 = &node.inputs[1];
let try_order =
|sig_candidate: &str, x_candidate: &str| -> Option<(String, String, usize)> {
let sig_idx = match producer.get(sig_candidate) {
Some(&idx) => idx,
None => return None,
};
if skip.contains(&sig_idx) {
return None;
}
if !matches!(nodes[sig_idx].op, OpKind::Sigmoid) {
return None;
}
if consumer_count.get(sig_candidate).copied().unwrap_or(0) != 1 {
return None;
}
if nodes[sig_idx].inputs.is_empty() {
return None;
}
if nodes[sig_idx].inputs[0] != *x_candidate {
return None;
}
Some((sig_candidate.to_string(), x_candidate.to_string(), sig_idx))
};
match try_order(inp1, inp0).or_else(|| try_order(inp0, inp1)) {
Some(result) => result,
None => continue,
}
};
let _ = &sigmoid_out;
let fused = Node {
op: OpKind::SiLU,
name: format!("{}_fused_silu", node.name),
inputs: vec![x_name],
outputs: node.outputs.clone(),
attrs: Attributes::default(),
};
replacements.insert(i, fused);
skip.insert(sigmoid_idx);
}
nodes
.into_iter()
.enumerate()
.filter(|(i, _)| !skip.contains(i))
.map(|(i, n)| replacements.remove(&i).unwrap_or(n))
.collect()
}
pub fn fuse_div_sqrt_to_rsqrt(nodes: Vec<Node>, weights: &HashMap<String, Tensor>) -> Vec<Node> {
if nodes.len() < 2 {
return nodes;
}
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut consumer_count: HashMap<String, usize> = HashMap::new();
for node in &nodes {
for inp in &node.inputs {
if !inp.is_empty() {
*consumer_count.entry(inp.clone()).or_insert(0) += 1;
}
}
}
let mut replacements: HashMap<usize, Node> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if !matches!(node.op, OpKind::Div) {
continue;
}
if node.inputs.len() < 2 {
continue;
}
let numerator_name = &node.inputs[0];
let denominator_name = &node.inputs[1];
let is_const_one = match weights.get(numerator_name) {
Some(t) => t.numel() == 1 && (t.data[0] - 1.0).abs() < 1e-7,
None => false,
};
if !is_const_one {
continue;
}
let sqrt_idx = match producer.get(denominator_name) {
Some(&idx) => idx,
None => continue,
};
if !matches!(nodes[sqrt_idx].op, OpKind::Sqrt) {
continue;
}
if consumer_count.get(denominator_name).copied().unwrap_or(0) != 1 {
continue;
}
let fused = Node {
op: OpKind::Reciprocal,
name: format!("{}_fused_rsqrt", node.name),
inputs: vec![denominator_name.clone()],
outputs: node.outputs.clone(),
attrs: Attributes::default(),
};
replacements.insert(i, fused);
}
nodes
.into_iter()
.enumerate()
.map(|(i, n)| replacements.remove(&i).unwrap_or(n))
.collect()
}
pub fn fuse_gather_composition(nodes: Vec<Node>, weights: &HashMap<String, Tensor>) -> Vec<Node> {
if nodes.len() < 2 {
return nodes;
}
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut consumer_count: HashMap<String, usize> = HashMap::new();
for node in &nodes {
for inp in &node.inputs {
if !inp.is_empty() {
*consumer_count.entry(inp.clone()).or_insert(0) += 1;
}
}
}
let skip: HashSet<usize> = HashSet::new();
let mut replacements: HashMap<usize, Node> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if skip.contains(&i) {
continue;
}
if !matches!(node.op, OpKind::Gather) {
continue;
}
if node.inputs.len() < 2 {
continue;
}
let outer_axis = node.attrs.i("axis", 0);
let inner_result_name = &node.inputs[0];
let outer_indices_name = &node.inputs[1];
if consumer_count.get(inner_result_name).copied().unwrap_or(0) != 1 {
continue;
}
let inner_idx = match producer.get(inner_result_name) {
Some(&idx) => idx,
None => continue,
};
if skip.contains(&inner_idx) {
continue;
}
if !matches!(nodes[inner_idx].op, OpKind::Gather) {
continue;
}
if nodes[inner_idx].inputs.len() < 2 {
continue;
}
let inner_axis = nodes[inner_idx].attrs.i("axis", 0);
if outer_axis != inner_axis {
continue;
}
if outer_axis != 0 && outer_axis != inner_axis {
continue;
}
let orig_data_name = &nodes[inner_idx].inputs[0];
let inner_indices_name = &nodes[inner_idx].inputs[1];
let inner_indices = match weights.get(inner_indices_name) {
Some(t) => t,
None => continue,
};
let outer_indices = match weights.get(outer_indices_name) {
Some(t) => t,
None => continue,
};
let mut composed_data = Vec::with_capacity(outer_indices.data.len());
let mut valid = true;
for &oi in &outer_indices.data {
let idx = oi as usize;
if idx >= inner_indices.data.len() {
valid = false;
break;
}
composed_data.push(inner_indices.data[idx]);
}
if !valid {
continue;
}
let composed_name = format!("{}_composed_indices", node.name);
let composed_shape = outer_indices.shape.clone();
let mut const_attrs = Attributes::default();
const_attrs.tensors.insert(
"value".to_string(),
Tensor::new(composed_data, composed_shape),
);
let const_node = Node {
op: OpKind::Constant,
name: format!("{}_const", composed_name),
inputs: vec![],
outputs: vec![composed_name.clone()],
attrs: const_attrs,
};
let mut fused_attrs = Attributes::default();
fused_attrs.ints.insert("axis".to_string(), inner_axis);
let fused_gather = Node {
op: OpKind::Gather,
name: format!("{}_fused_gather", nodes[inner_idx].name),
inputs: vec![orig_data_name.clone(), composed_name],
outputs: node.outputs.clone(),
attrs: fused_attrs,
};
replacements.insert(inner_idx, const_node);
replacements.insert(i, fused_gather);
}
nodes
.into_iter()
.enumerate()
.map(|(i, n)| {
if skip.contains(&i) {
n
} else {
replacements.remove(&i).unwrap_or(n)
}
})
.collect()
}
pub fn eliminate_dropout_inference(nodes: Vec<Node>) -> Vec<Node> {
if nodes.is_empty() {
return nodes;
}
let mut skip: HashSet<usize> = HashSet::new();
let mut redirects: HashMap<String, String> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if !matches!(node.op, OpKind::Dropout) {
continue;
}
if node.inputs.is_empty() {
continue;
}
let training_mode = node.attrs.i("training_mode", 0);
if training_mode != 0 {
continue;
}
let data_input = &node.inputs[0];
if let Some(out_name) = node.outputs.first() {
if !out_name.is_empty() {
redirects.insert(out_name.clone(), data_input.clone());
}
}
let mask_used = node.outputs.get(1).is_some_and(|mask_name| {
!mask_name.is_empty()
&& nodes
.iter()
.any(|n| n.inputs.iter().any(|inp| inp == mask_name))
});
if mask_used {
redirects.remove(node.outputs.first().unwrap_or(&String::new()));
continue;
}
skip.insert(i);
}
nodes
.into_iter()
.enumerate()
.filter(|(i, _)| !skip.contains(i))
.map(|(_, mut n)| {
for inp in &mut n.inputs {
if let Some(redirect) = redirects.get(inp) {
*inp = redirect.clone();
}
}
n
})
.collect()
}
pub fn simplify_transpose_reshape(
nodes: Vec<Node>,
weights: &HashMap<String, Tensor>,
) -> Vec<Node> {
if nodes.len() < 2 {
return nodes;
}
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut consumer_count: HashMap<String, usize> = HashMap::new();
for node in &nodes {
for inp in &node.inputs {
if !inp.is_empty() {
*consumer_count.entry(inp.clone()).or_insert(0) += 1;
}
}
}
let mut skip: HashSet<usize> = HashSet::new();
let mut replacements: HashMap<usize, Node> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if skip.contains(&i) {
continue;
}
if !matches!(node.op, OpKind::Reshape) {
continue;
}
if node.inputs.is_empty() {
continue;
}
let reshape_input = &node.inputs[0];
if consumer_count.get(reshape_input).copied().unwrap_or(0) != 1 {
continue;
}
let transpose_idx = match producer.get(reshape_input) {
Some(&idx) => idx,
None => continue,
};
if skip.contains(&transpose_idx) {
continue;
}
if !matches!(nodes[transpose_idx].op, OpKind::Transpose) {
continue;
}
let perm = match nodes[transpose_idx].attrs.int_lists.get("perm") {
Some(p) => p.clone(),
None => continue,
};
let is_identity = perm.iter().enumerate().all(|(idx, &v)| v == idx as i64);
let is_contiguous_transpose = if !is_identity {
if let Some(shape_name) = node.inputs.get(1) {
if let Some(shape_tensor) = weights.get(shape_name) {
let rank = perm.len();
if rank == 0 {
false
} else {
let target_dims: Vec<i64> =
shape_tensor.data.iter().map(|&v| v as i64).collect();
target_dims.len() == 1
|| (target_dims.len() < rank && target_dims.iter().all(|&d| d >= 0))
}
} else {
false
}
} else {
false
}
} else {
true
};
if !is_identity && !is_contiguous_transpose {
continue;
}
let original_input = match nodes[transpose_idx].inputs.first() {
Some(name) => name.clone(),
None => continue,
};
let mut new_inputs = vec![original_input];
for inp in node.inputs.iter().skip(1) {
new_inputs.push(inp.clone());
}
let simplified = Node {
op: OpKind::Reshape,
name: format!("{}_simplified", node.name),
inputs: new_inputs,
outputs: node.outputs.clone(),
attrs: node.attrs.clone(),
};
skip.insert(transpose_idx);
replacements.insert(i, simplified);
}
nodes
.into_iter()
.enumerate()
.filter(|(i, _)| !skip.contains(i))
.map(|(i, n)| replacements.remove(&i).unwrap_or(n))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizer::test_utils::make_node;
#[test]
fn test_cancel_consecutive_transpose_identity() {
let mut node1 = make_node(OpKind::Transpose, "t1", vec!["x"], vec!["t1_out"]);
node1.attrs.int_lists.insert("perm".to_string(), vec![1, 0]);
let mut node2 = make_node(OpKind::Transpose, "t2", vec!["t1_out"], vec!["t2_out"]);
node2.attrs.int_lists.insert("perm".to_string(), vec![1, 0]);
let nodes = vec![node1, node2];
let result = cancel_consecutive_transpose(nodes);
assert_eq!(result.len(), 0);
}
#[test]
fn test_cancel_consecutive_transpose_non_identity() {
let mut node1 = make_node(OpKind::Transpose, "t1", vec!["x"], vec!["t1_out"]);
node1
.attrs
.int_lists
.insert("perm".to_string(), vec![2, 0, 1]);
let mut node2 = make_node(OpKind::Transpose, "t2", vec!["t1_out"], vec!["t2_out"]);
node2
.attrs
.int_lists
.insert("perm".to_string(), vec![1, 2, 0]);
let nodes = vec![node1, node2];
let result = cancel_consecutive_transpose(nodes);
assert_eq!(result.len(), 0);
}
#[test]
fn test_cancel_consecutive_transpose_compose() {
let mut node1 = make_node(OpKind::Transpose, "t1", vec!["x"], vec!["t1_out"]);
node1
.attrs
.int_lists
.insert("perm".to_string(), vec![1, 2, 0]);
let mut node2 = make_node(OpKind::Transpose, "t2", vec!["t1_out"], vec!["t2_out"]);
node2
.attrs
.int_lists
.insert("perm".to_string(), vec![1, 2, 0]);
let nodes = vec![node1, node2];
let result = cancel_consecutive_transpose(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::Transpose));
let perm = result[0].attrs.int_lists.get("perm").expect("perm attr");
assert_eq!(perm, &vec![2, 0, 1]);
}
#[test]
fn test_cancel_consecutive_transpose_redirect() {
let mut node1 = make_node(OpKind::Transpose, "t1", vec!["x"], vec!["t1_out"]);
node1.attrs.int_lists.insert("perm".to_string(), vec![1, 0]);
let mut node2 = make_node(OpKind::Transpose, "t2", vec!["t1_out"], vec!["t2_out"]);
node2.attrs.int_lists.insert("perm".to_string(), vec![1, 0]);
let relu = make_node(OpKind::Relu, "relu", vec!["t2_out"], vec!["out"]);
let nodes = vec![node1, node2, relu];
let result = cancel_consecutive_transpose(nodes);
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "relu");
assert_eq!(result[0].inputs[0], "x");
}
#[test]
fn test_cancel_single_transpose() {
let mut node = make_node(OpKind::Transpose, "t1", vec!["x"], vec!["t1_out"]);
node.attrs.int_lists.insert("perm".to_string(), vec![1, 0]);
let nodes = vec![node];
let result = cancel_consecutive_transpose(nodes);
assert_eq!(result.len(), 1);
}
#[test]
fn test_cancel_consecutive_reshape_collapse() {
let r1 = make_node(OpKind::Reshape, "r1", vec!["x", "shape1"], vec!["r1_out"]);
let r2 = make_node(
OpKind::Reshape,
"r2",
vec!["r1_out", "shape2"],
vec!["r2_out"],
);
let nodes = vec![r1, r2];
let result = cancel_consecutive_reshape(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::Reshape));
assert_eq!(result[0].inputs[0], "x");
assert_eq!(result[0].inputs[1], "shape2");
assert_eq!(result[0].outputs[0], "r2_out");
}
#[test]
fn test_cancel_consecutive_reshape_same_shape_eliminates_both() {
let r1 = make_node(OpKind::Reshape, "r1", vec!["x", "shape_a"], vec!["r1_out"]);
let r2 = make_node(
OpKind::Reshape,
"r2",
vec!["r1_out", "shape_a"],
vec!["r2_out"],
);
let relu = make_node(OpKind::Relu, "relu", vec!["r2_out"], vec!["out"]);
let nodes = vec![r1, r2, relu];
let result = cancel_consecutive_reshape(nodes);
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "relu");
assert_eq!(result[0].inputs[0], "x");
}
#[test]
fn test_cancel_consecutive_reshape_no_cancel_multiple_consumers() {
let r1 = make_node(OpKind::Reshape, "r1", vec!["x", "shape1"], vec!["r1_out"]);
let r2 = make_node(
OpKind::Reshape,
"r2",
vec!["r1_out", "shape2"],
vec!["r2_out"],
);
let relu = make_node(OpKind::Relu, "relu", vec!["r1_out"], vec!["relu_out"]);
let nodes = vec![r1, r2, relu];
let result = cancel_consecutive_reshape(nodes);
assert_eq!(result.len(), 3);
}
#[test]
fn test_cancel_consecutive_reshape_single_node() {
let r1 = make_node(OpKind::Reshape, "r1", vec!["x", "shape1"], vec!["r1_out"]);
let nodes = vec![r1];
let result = cancel_consecutive_reshape(nodes);
assert_eq!(result.len(), 1);
}
#[test]
fn test_cancel_consecutive_reshape_three_reshapes() {
let r1 = make_node(OpKind::Reshape, "r1", vec!["x", "s1"], vec!["r1_out"]);
let r2 = make_node(OpKind::Reshape, "r2", vec!["r1_out", "s2"], vec!["r2_out"]);
let r3 = make_node(OpKind::Reshape, "r3", vec!["r2_out", "s3"], vec!["r3_out"]);
let nodes = vec![r1, r2, r3];
let result = cancel_consecutive_reshape(nodes);
assert!(result.len() <= 2);
let last = result.last().expect("should have at least one node");
assert_eq!(last.outputs[0], "r3_out");
}
#[test]
fn test_fuse_mul_sigmoid_to_silu_basic() {
let sigmoid = make_node(OpKind::Sigmoid, "sigmoid", vec!["x"], vec!["sig_out"]);
let mul = make_node(OpKind::Mul, "mul", vec!["x", "sig_out"], vec!["mul_out"]);
let nodes = vec![sigmoid, mul];
let result = fuse_mul_sigmoid_to_silu(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::SiLU));
assert_eq!(result[0].inputs, vec!["x"]);
assert_eq!(result[0].outputs, vec!["mul_out"]);
}
#[test]
fn test_fuse_mul_sigmoid_to_silu_reversed_mul_inputs() {
let sigmoid = make_node(OpKind::Sigmoid, "sigmoid", vec!["x"], vec!["sig_out"]);
let mul = make_node(OpKind::Mul, "mul", vec!["sig_out", "x"], vec!["mul_out"]);
let nodes = vec![sigmoid, mul];
let result = fuse_mul_sigmoid_to_silu(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::SiLU));
assert_eq!(result[0].inputs, vec!["x"]);
}
#[test]
fn test_fuse_mul_sigmoid_to_silu_no_fusion_multiple_consumers() {
let sigmoid = make_node(OpKind::Sigmoid, "sigmoid", vec!["x"], vec!["sig_out"]);
let mul = make_node(OpKind::Mul, "mul", vec!["x", "sig_out"], vec!["mul_out"]);
let relu = make_node(OpKind::Relu, "relu", vec!["sig_out"], vec!["relu_out"]);
let nodes = vec![sigmoid, mul, relu];
let result = fuse_mul_sigmoid_to_silu(nodes);
assert_eq!(result.len(), 3);
}
#[test]
fn test_fuse_mul_sigmoid_to_silu_no_fusion_different_input() {
let sigmoid = make_node(OpKind::Sigmoid, "sigmoid", vec!["y"], vec!["sig_out"]);
let mul = make_node(OpKind::Mul, "mul", vec!["x", "sig_out"], vec!["mul_out"]);
let nodes = vec![sigmoid, mul];
let result = fuse_mul_sigmoid_to_silu(nodes);
assert_eq!(result.len(), 2);
}
#[test]
fn test_fuse_mul_sigmoid_to_silu_preserves_downstream() {
let sigmoid = make_node(OpKind::Sigmoid, "sigmoid", vec!["x"], vec!["sig_out"]);
let mul = make_node(OpKind::Mul, "mul", vec!["x", "sig_out"], vec!["mul_out"]);
let relu = make_node(OpKind::Relu, "relu", vec!["mul_out"], vec!["relu_out"]);
let nodes = vec![sigmoid, mul, relu];
let result = fuse_mul_sigmoid_to_silu(nodes);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::SiLU));
assert_eq!(result[0].outputs, vec!["mul_out"]);
assert_eq!(result[1].inputs, vec!["mul_out"]);
}
#[test]
fn test_fuse_div_sqrt_to_rsqrt_basic() {
let sqrt = make_node(OpKind::Sqrt, "sqrt", vec!["x"], vec!["sqrt_out"]);
let div = make_node(OpKind::Div, "div", vec!["one", "sqrt_out"], vec!["div_out"]);
let nodes = vec![sqrt, div];
let mut weights = HashMap::new();
weights.insert("one".to_string(), Tensor::new(vec![1.0], vec![1]));
let result = fuse_div_sqrt_to_rsqrt(nodes, &weights);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::Sqrt));
assert!(matches!(result[1].op, OpKind::Reciprocal));
assert_eq!(result[1].inputs, vec!["sqrt_out"]);
assert_eq!(result[1].outputs, vec!["div_out"]);
}
#[test]
fn test_fuse_div_sqrt_to_rsqrt_not_const_one() {
let sqrt = make_node(OpKind::Sqrt, "sqrt", vec!["x"], vec!["sqrt_out"]);
let div = make_node(OpKind::Div, "div", vec!["two", "sqrt_out"], vec!["div_out"]);
let nodes = vec![sqrt, div];
let mut weights = HashMap::new();
weights.insert("two".to_string(), Tensor::new(vec![2.0], vec![1]));
let result = fuse_div_sqrt_to_rsqrt(nodes, &weights);
assert_eq!(result.len(), 2);
assert!(matches!(result[1].op, OpKind::Div));
}
#[test]
fn test_fuse_div_sqrt_to_rsqrt_not_sqrt() {
let relu = make_node(OpKind::Relu, "relu", vec!["x"], vec!["relu_out"]);
let div = make_node(OpKind::Div, "div", vec!["one", "relu_out"], vec!["div_out"]);
let nodes = vec![relu, div];
let mut weights = HashMap::new();
weights.insert("one".to_string(), Tensor::new(vec![1.0], vec![1]));
let result = fuse_div_sqrt_to_rsqrt(nodes, &weights);
assert_eq!(result.len(), 2);
assert!(matches!(result[1].op, OpKind::Div));
}
#[test]
fn test_fuse_div_sqrt_to_rsqrt_sqrt_multiple_consumers() {
let sqrt = make_node(OpKind::Sqrt, "sqrt", vec!["x"], vec!["sqrt_out"]);
let div = make_node(OpKind::Div, "div", vec!["one", "sqrt_out"], vec!["div_out"]);
let relu = make_node(OpKind::Relu, "relu", vec!["sqrt_out"], vec!["relu_out"]);
let nodes = vec![sqrt, div, relu];
let mut weights = HashMap::new();
weights.insert("one".to_string(), Tensor::new(vec![1.0], vec![1]));
let result = fuse_div_sqrt_to_rsqrt(nodes, &weights);
assert_eq!(result.len(), 3);
assert!(matches!(result[1].op, OpKind::Div));
}
#[test]
fn test_fuse_gather_composition_basic() {
let mut gather1 = make_node(OpKind::Gather, "g1", vec!["data", "idx1"], vec!["g1_out"]);
gather1.attrs.ints.insert("axis".to_string(), 0);
let mut gather2 = make_node(OpKind::Gather, "g2", vec!["g1_out", "idx2"], vec!["g2_out"]);
gather2.attrs.ints.insert("axis".to_string(), 0);
let nodes = vec![gather1, gather2];
let mut weights = HashMap::new();
weights.insert(
"idx1".to_string(),
Tensor::new(vec![2.0, 0.0, 1.0], vec![3]),
);
weights.insert("idx2".to_string(), Tensor::new(vec![1.0, 2.0], vec![2]));
let result = fuse_gather_composition(nodes, &weights);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::Constant));
assert!(matches!(result[1].op, OpKind::Gather));
assert_eq!(result[1].inputs[0], "data");
assert_eq!(result[1].outputs[0], "g2_out");
let composed = result[0]
.attrs
.tensors
.get("value")
.expect("composed tensor");
assert_eq!(composed.data, vec![0.0, 1.0]);
}
#[test]
fn test_fuse_gather_composition_no_fusion_different_axis() {
let mut gather1 = make_node(OpKind::Gather, "g1", vec!["data", "idx1"], vec!["g1_out"]);
gather1.attrs.ints.insert("axis".to_string(), 0);
let mut gather2 = make_node(OpKind::Gather, "g2", vec!["g1_out", "idx2"], vec!["g2_out"]);
gather2.attrs.ints.insert("axis".to_string(), 1);
let nodes = vec![gather1, gather2];
let mut weights = HashMap::new();
weights.insert("idx1".to_string(), Tensor::new(vec![0.0, 1.0], vec![2]));
weights.insert("idx2".to_string(), Tensor::new(vec![0.0], vec![1]));
let result = fuse_gather_composition(nodes, &weights);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::Gather));
assert!(matches!(result[1].op, OpKind::Gather));
}
#[test]
fn test_fuse_gather_composition_no_fusion_multiple_consumers() {
let mut gather1 = make_node(OpKind::Gather, "g1", vec!["data", "idx1"], vec!["g1_out"]);
gather1.attrs.ints.insert("axis".to_string(), 0);
let mut gather2 = make_node(OpKind::Gather, "g2", vec!["g1_out", "idx2"], vec!["g2_out"]);
gather2.attrs.ints.insert("axis".to_string(), 0);
let relu = make_node(OpKind::Relu, "relu", vec!["g1_out"], vec!["relu_out"]);
let nodes = vec![gather1, gather2, relu];
let mut weights = HashMap::new();
weights.insert("idx1".to_string(), Tensor::new(vec![0.0, 1.0], vec![2]));
weights.insert("idx2".to_string(), Tensor::new(vec![0.0], vec![1]));
let result = fuse_gather_composition(nodes, &weights);
assert_eq!(result.len(), 3);
}
#[test]
fn test_fuse_gather_composition_no_fusion_non_constant_indices() {
let mut gather1 = make_node(
OpKind::Gather,
"g1",
vec!["data", "dynamic_idx1"],
vec!["g1_out"],
);
gather1.attrs.ints.insert("axis".to_string(), 0);
let mut gather2 = make_node(
OpKind::Gather,
"g2",
vec!["g1_out", "dynamic_idx2"],
vec!["g2_out"],
);
gather2.attrs.ints.insert("axis".to_string(), 0);
let nodes = vec![gather1, gather2];
let weights = HashMap::new();
let result = fuse_gather_composition(nodes, &weights);
assert_eq!(result.len(), 2);
}
#[test]
fn test_eliminate_dropout_inference_basic() {
let dropout = make_node(OpKind::Dropout, "dropout", vec!["x"], vec!["dropout_out"]);
let relu = make_node(OpKind::Relu, "relu", vec!["dropout_out"], vec!["out"]);
let nodes = vec![dropout, relu];
let result = eliminate_dropout_inference(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::Relu));
assert_eq!(result[0].inputs[0], "x");
}
#[test]
fn test_eliminate_dropout_after_softmax() {
let softmax = make_node(OpKind::Softmax, "softmax", vec!["x"], vec!["sm_out"]);
let dropout = make_node(
OpKind::Dropout,
"dropout",
vec!["sm_out"],
vec!["dropout_out"],
);
let matmul = make_node(
OpKind::MatMul,
"matmul",
vec!["dropout_out", "v"],
vec!["out"],
);
let nodes = vec![softmax, dropout, matmul];
let result = eliminate_dropout_inference(nodes);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::Softmax));
assert!(matches!(result[1].op, OpKind::MatMul));
assert_eq!(result[1].inputs[0], "sm_out");
}
#[test]
fn test_eliminate_dropout_training_mode_not_eliminated() {
let mut dropout = make_node(OpKind::Dropout, "dropout", vec!["x"], vec!["dropout_out"]);
dropout.attrs.ints.insert("training_mode".to_string(), 1);
let relu = make_node(OpKind::Relu, "relu", vec!["dropout_out"], vec!["out"]);
let nodes = vec![dropout, relu];
let result = eliminate_dropout_inference(nodes);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::Dropout));
}
#[test]
fn test_eliminate_dropout_mask_output_used() {
let dropout = make_node(
OpKind::Dropout,
"dropout",
vec!["x"],
vec!["dropout_out", "dropout_mask"],
);
let relu = make_node(OpKind::Relu, "relu", vec!["dropout_out"], vec!["out1"]);
let mask_user = make_node(
OpKind::Identity,
"mask_user",
vec!["dropout_mask"],
vec!["out2"],
);
let nodes = vec![dropout, relu, mask_user];
let result = eliminate_dropout_inference(nodes);
assert_eq!(result.len(), 3);
assert!(matches!(result[0].op, OpKind::Dropout));
}
#[test]
fn test_simplify_transpose_reshape_identity_perm() {
let mut transpose = make_node(OpKind::Transpose, "transpose", vec!["x"], vec!["t_out"]);
transpose
.attrs
.int_lists
.insert("perm".to_string(), vec![0, 1, 2]);
let reshape = make_node(
OpKind::Reshape,
"reshape",
vec!["t_out", "target_shape"],
vec!["out"],
);
let nodes = vec![transpose, reshape];
let weights = HashMap::new();
let result = simplify_transpose_reshape(nodes, &weights);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::Reshape));
assert_eq!(result[0].inputs[0], "x");
assert_eq!(result[0].inputs[1], "target_shape");
assert_eq!(result[0].outputs[0], "out");
}
#[test]
fn test_simplify_transpose_reshape_flatten_after_transpose() {
let mut transpose = make_node(OpKind::Transpose, "transpose", vec!["x"], vec!["t_out"]);
transpose
.attrs
.int_lists
.insert("perm".to_string(), vec![0, 2, 1]);
let reshape = make_node(
OpKind::Reshape,
"reshape",
vec!["t_out", "flat_shape"],
vec!["out"],
);
let nodes = vec![transpose, reshape];
let mut weights = HashMap::new();
weights.insert("flat_shape".to_string(), Tensor::new(vec![-1.0], vec![1]));
let result = simplify_transpose_reshape(nodes, &weights);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::Reshape));
assert_eq!(result[0].inputs[0], "x");
}
#[test]
fn test_simplify_transpose_reshape_no_simplification_non_trivial() {
let mut transpose = make_node(OpKind::Transpose, "transpose", vec!["x"], vec!["t_out"]);
transpose
.attrs
.int_lists
.insert("perm".to_string(), vec![1, 0]);
let reshape = make_node(
OpKind::Reshape,
"reshape",
vec!["t_out", "shape_2d"],
vec!["out"],
);
let nodes = vec![transpose, reshape];
let mut weights = HashMap::new();
weights.insert("shape_2d".to_string(), Tensor::new(vec![3.0, 4.0], vec![2]));
let result = simplify_transpose_reshape(nodes, &weights);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::Transpose));
assert!(matches!(result[1].op, OpKind::Reshape));
}
#[test]
fn test_simplify_transpose_reshape_no_simplification_multiple_consumers() {
let mut transpose = make_node(OpKind::Transpose, "transpose", vec!["x"], vec!["t_out"]);
transpose
.attrs
.int_lists
.insert("perm".to_string(), vec![0, 1, 2]);
let reshape = make_node(
OpKind::Reshape,
"reshape",
vec!["t_out", "shape"],
vec!["out1"],
);
let relu = make_node(OpKind::Relu, "relu", vec!["t_out"], vec!["out2"]);
let nodes = vec![transpose, reshape, relu];
let weights = HashMap::new();
let result = simplify_transpose_reshape(nodes, &weights);
assert_eq!(result.len(), 3);
}
}