use crate::graph::OpKind;
use crate::optimizer::test_utils::make_node;
use crate::tensor::Tensor;
use std::collections::HashMap;
use super::{fold_batch_norm_inference, fuse_conv_add_relu, fuse_conv_batchnorm};
use super::{fuse_conv_clip_to_conv_relu6, fuse_conv_relu};
#[test]
fn test_fuse_conv_batchnorm() {
let conv = make_node(
OpKind::Conv,
"conv",
vec!["x", "conv_w", "conv_b"],
vec!["conv_out"],
);
let mut bn = make_node(
OpKind::BatchNorm,
"bn",
vec!["conv_out", "bn_scale", "bn_bias", "bn_mean", "bn_var"],
vec!["bn_out"],
);
bn.attrs.floats.insert("epsilon".to_string(), 1e-5);
let nodes = vec![conv, bn];
let mut weights = HashMap::new();
weights.insert(
"conv_w".to_string(),
Tensor::new(vec![1.0], vec![1, 1, 1, 1]),
);
weights.insert("conv_b".to_string(), Tensor::new(vec![0.0], vec![1]));
weights.insert("bn_scale".to_string(), Tensor::new(vec![1.0], vec![1]));
weights.insert("bn_bias".to_string(), Tensor::new(vec![0.0], vec![1]));
weights.insert("bn_mean".to_string(), Tensor::new(vec![0.0], vec![1]));
weights.insert("bn_var".to_string(), Tensor::new(vec![1.0], vec![1]));
let result = fuse_conv_batchnorm(nodes, &mut weights);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::Conv));
assert_eq!(result[0].outputs[0], "bn_out");
assert!(weights.contains_key("conv_fused_weight"));
assert!(weights.contains_key("conv_fused_bias"));
}
#[test]
fn test_fuse_conv_batchnorm_no_conv_bias() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "conv_w"], vec!["conv_out"]);
let mut bn = make_node(
OpKind::BatchNorm,
"bn",
vec!["conv_out", "bn_scale", "bn_bias", "bn_mean", "bn_var"],
vec!["bn_out"],
);
bn.attrs.floats.insert("epsilon".to_string(), 1e-5);
let nodes = vec![conv, bn];
let mut weights = HashMap::new();
weights.insert(
"conv_w".to_string(),
Tensor::new(vec![2.0], vec![1, 1, 1, 1]),
);
weights.insert("bn_scale".to_string(), Tensor::new(vec![3.0], vec![1]));
weights.insert("bn_bias".to_string(), Tensor::new(vec![0.5], vec![1]));
weights.insert("bn_mean".to_string(), Tensor::new(vec![1.0], vec![1]));
weights.insert("bn_var".to_string(), Tensor::new(vec![4.0], vec![1]));
let result = fuse_conv_batchnorm(nodes, &mut weights);
assert_eq!(result.len(), 1);
let fused_w = weights.get("conv_fused_weight").expect("fused weight");
let inv_std = 1.0 / (4.0f32 + 1e-5).sqrt();
let expected_w = 2.0 * 3.0 * inv_std;
assert!((fused_w.data[0] - expected_w).abs() < 1e-5);
let fused_b = weights.get("conv_fused_bias").expect("fused bias");
let expected_b = (0.0 - 1.0) * 3.0 * inv_std + 0.5;
assert!((fused_b.data[0] - expected_b).abs() < 1e-5);
}
#[test]
fn test_fuse_conv_batchnorm_multiple_consumers() {
let conv = make_node(
OpKind::Conv,
"conv",
vec!["x", "conv_w", "conv_b"],
vec!["conv_out"],
);
let mut bn = make_node(
OpKind::BatchNorm,
"bn",
vec!["conv_out", "bn_scale", "bn_bias", "bn_mean", "bn_var"],
vec!["bn_out"],
);
bn.attrs.floats.insert("epsilon".to_string(), 1e-5);
let relu = make_node(OpKind::Relu, "relu", vec!["conv_out"], vec!["relu_out"]);
let nodes = vec![conv, bn, relu];
let mut weights = HashMap::new();
weights.insert(
"conv_w".to_string(),
Tensor::new(vec![1.0], vec![1, 1, 1, 1]),
);
weights.insert("conv_b".to_string(), Tensor::new(vec![0.0], vec![1]));
weights.insert("bn_scale".to_string(), Tensor::new(vec![1.0], vec![1]));
weights.insert("bn_bias".to_string(), Tensor::new(vec![0.0], vec![1]));
weights.insert("bn_mean".to_string(), Tensor::new(vec![0.0], vec![1]));
weights.insert("bn_var".to_string(), Tensor::new(vec![1.0], vec![1]));
let result = fuse_conv_batchnorm(nodes, &mut weights);
assert_eq!(result.len(), 3);
}
#[test]
fn test_fuse_conv_relu() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w", "b"], vec!["conv_out"]);
let relu = make_node(OpKind::Relu, "relu", vec!["conv_out"], vec!["relu_out"]);
let nodes = vec![conv, relu];
let result = fuse_conv_relu(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::Conv));
assert_eq!(result[0].outputs[0], "relu_out");
assert_eq!(result[0].attrs.s("activation"), "relu");
}
#[test]
fn test_fuse_conv_clip_as_relu() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w"], vec!["conv_out"]);
let mut clip = make_node(OpKind::Clip, "clip", vec!["conv_out"], vec!["clip_out"]);
clip.attrs.floats.insert("min".to_string(), 0.0);
clip.attrs.floats.insert("max".to_string(), f32::INFINITY);
let nodes = vec![conv, clip];
let result = fuse_conv_relu(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::Conv));
assert_eq!(result[0].outputs[0], "clip_out");
assert_eq!(result[0].attrs.s("activation"), "relu");
}
#[test]
fn test_fuse_conv_clip_general() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w"], vec!["conv_out"]);
let mut clip = make_node(OpKind::Clip, "clip", vec!["conv_out"], vec!["clip_out"]);
clip.attrs.floats.insert("min".to_string(), 0.0);
clip.attrs.floats.insert("max".to_string(), 6.0);
let nodes = vec![conv, clip];
let result = fuse_conv_relu(nodes);
assert_eq!(result.len(), 1);
assert_eq!(result[0].attrs.s("activation"), "clip");
assert_eq!(result[0].attrs.f("activation_min", -1.0), 0.0);
assert_eq!(result[0].attrs.f("activation_max", -1.0), 6.0);
}
#[test]
fn test_fuse_conv_relu_no_fusion_multiple_consumers() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w"], vec!["conv_out"]);
let relu = make_node(OpKind::Relu, "relu", vec!["conv_out"], vec!["relu_out"]);
let add = make_node(
OpKind::Add,
"add",
vec!["conv_out", "other"],
vec!["add_out"],
);
let nodes = vec![conv, relu, add];
let result = fuse_conv_relu(nodes);
assert_eq!(result.len(), 3);
}
#[test]
fn test_fuse_conv_clip_to_conv_relu6_basic() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w"], vec!["conv_out"]);
let mut clip = make_node(OpKind::Clip, "clip", vec!["conv_out"], vec!["clip_out"]);
clip.attrs.floats.insert("min".to_string(), 0.0);
clip.attrs.floats.insert("max".to_string(), 6.0);
let nodes = vec![conv, clip];
let result = fuse_conv_clip_to_conv_relu6(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::Conv));
assert_eq!(result[0].attrs.s("activation"), "relu6");
assert_eq!(result[0].attrs.f("activation_min", -1.0), 0.0);
assert_eq!(result[0].attrs.f("activation_max", -1.0), 6.0);
assert_eq!(result[0].outputs, vec!["clip_out"]);
}
#[test]
fn test_fuse_conv_clip_to_conv_relu6_wrong_range() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w"], vec!["conv_out"]);
let mut clip = make_node(OpKind::Clip, "clip", vec!["conv_out"], vec!["clip_out"]);
clip.attrs.floats.insert("min".to_string(), 0.0);
clip.attrs.floats.insert("max".to_string(), 1.0);
let nodes = vec![conv, clip];
let result = fuse_conv_clip_to_conv_relu6(nodes);
assert_eq!(result.len(), 2);
}
#[test]
fn test_fuse_conv_clip_to_conv_relu6_not_conv() {
let relu = make_node(OpKind::Relu, "relu", vec!["x"], vec!["relu_out"]);
let mut clip = make_node(OpKind::Clip, "clip", vec!["relu_out"], vec!["clip_out"]);
clip.attrs.floats.insert("min".to_string(), 0.0);
clip.attrs.floats.insert("max".to_string(), 6.0);
let nodes = vec![relu, clip];
let result = fuse_conv_clip_to_conv_relu6(nodes);
assert_eq!(result.len(), 2);
}
#[test]
fn test_fuse_conv_clip_to_conv_relu6_multiple_consumers() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w"], vec!["conv_out"]);
let mut clip = make_node(OpKind::Clip, "clip", vec!["conv_out"], vec!["clip_out"]);
clip.attrs.floats.insert("min".to_string(), 0.0);
clip.attrs.floats.insert("max".to_string(), 6.0);
let add = make_node(
OpKind::Add,
"add",
vec!["conv_out", "other"],
vec!["add_out"],
);
let nodes = vec![conv, clip, add];
let result = fuse_conv_clip_to_conv_relu6(nodes);
assert_eq!(result.len(), 3);
}
#[test]
fn test_fold_batch_norm_inference_basic() {
let mut bn = make_node(
OpKind::BatchNorm,
"bn",
vec!["x", "scale", "bias", "mean", "var"],
vec!["bn_out"],
);
bn.attrs.floats.insert("epsilon".to_string(), 1e-5);
let nodes = vec![bn];
let mut weights = HashMap::new();
weights.insert("scale".to_string(), Tensor::new(vec![2.0], vec![1]));
weights.insert("bias".to_string(), Tensor::new(vec![0.5], vec![1]));
weights.insert("mean".to_string(), Tensor::new(vec![1.0], vec![1]));
weights.insert("var".to_string(), Tensor::new(vec![4.0], vec![1]));
let result = fold_batch_norm_inference(nodes, &mut weights);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::Mul));
assert!(matches!(result[1].op, OpKind::Add));
assert_eq!(result[1].outputs, vec!["bn_out"]);
assert_eq!(result[0].inputs[0], "x");
let inv_std = 1.0 / (4.0f32 + 1e-5).sqrt();
let expected_factor = 2.0 * inv_std;
let expected_shift = 0.5 - 1.0 * expected_factor;
let factor = weights.get("bn_bn_factor").expect("factor weight");
assert!((factor.data[0] - expected_factor).abs() < 1e-5);
let shift = weights.get("bn_bn_shift").expect("shift weight");
assert!((shift.data[0] - expected_shift).abs() < 1e-5);
}
#[test]
fn test_fold_batch_norm_inference_skips_conv_preceded() {
let conv = make_node(OpKind::Conv, "conv", vec!["inp", "w"], vec!["conv_out"]);
let mut bn = make_node(
OpKind::BatchNorm,
"bn",
vec!["conv_out", "scale", "bias", "mean", "var"],
vec!["bn_out"],
);
bn.attrs.floats.insert("epsilon".to_string(), 1e-5);
let nodes = vec![conv, bn];
let mut weights = HashMap::new();
weights.insert("scale".to_string(), Tensor::new(vec![1.0], vec![1]));
weights.insert("bias".to_string(), Tensor::new(vec![0.0], vec![1]));
weights.insert("mean".to_string(), Tensor::new(vec![0.0], vec![1]));
weights.insert("var".to_string(), Tensor::new(vec![1.0], vec![1]));
let result = fold_batch_norm_inference(nodes, &mut weights);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::Conv));
assert!(matches!(result[1].op, OpKind::BatchNorm));
}
#[test]
fn test_fold_batch_norm_inference_missing_weights() {
let mut bn = make_node(
OpKind::BatchNorm,
"bn",
vec!["x", "scale", "bias", "mean", "var"],
vec!["bn_out"],
);
bn.attrs.floats.insert("epsilon".to_string(), 1e-5);
let nodes = vec![bn];
let mut weights = HashMap::new();
weights.insert("scale".to_string(), Tensor::new(vec![1.0], vec![1]));
let result = fold_batch_norm_inference(nodes, &mut weights);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::BatchNorm));
}
#[test]
fn test_fold_batch_norm_inference_multi_channel() {
let mut bn = make_node(
OpKind::BatchNorm,
"bn",
vec!["x", "scale", "bias", "mean", "var"],
vec!["bn_out"],
);
bn.attrs.floats.insert("epsilon".to_string(), 0.001);
let nodes = vec![bn];
let mut weights = HashMap::new();
weights.insert(
"scale".to_string(),
Tensor::new(vec![1.0, 2.0, 3.0], vec![3]),
);
weights.insert(
"bias".to_string(),
Tensor::new(vec![0.1, 0.2, 0.3], vec![3]),
);
weights.insert(
"mean".to_string(),
Tensor::new(vec![0.5, 1.0, 1.5], vec![3]),
);
weights.insert("var".to_string(), Tensor::new(vec![1.0, 2.0, 4.0], vec![3]));
let result = fold_batch_norm_inference(nodes, &mut weights);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, OpKind::Mul));
assert!(matches!(result[1].op, OpKind::Add));
let factor = weights.get("bn_bn_factor").expect("factor");
assert_eq!(factor.shape, vec![3]);
let shift = weights.get("bn_bn_shift").expect("shift");
assert_eq!(shift.shape, vec![3]);
let inv_std_0 = 1.0 / (1.0f32 + 0.001).sqrt();
let expected_f0 = 1.0 * inv_std_0;
assert!((factor.data[0] - expected_f0).abs() < 1e-5);
}
#[test]
fn test_fold_batch_norm_inference_shape_mismatch() {
let mut bn = make_node(
OpKind::BatchNorm,
"bn",
vec!["x", "scale", "bias", "mean", "var"],
vec!["bn_out"],
);
bn.attrs.floats.insert("epsilon".to_string(), 1e-5);
let nodes = vec![bn];
let mut weights = HashMap::new();
weights.insert("scale".to_string(), Tensor::new(vec![1.0, 2.0], vec![2]));
weights.insert("bias".to_string(), Tensor::new(vec![0.0], vec![1])); weights.insert("mean".to_string(), Tensor::new(vec![0.0, 0.0], vec![2]));
weights.insert("var".to_string(), Tensor::new(vec![1.0, 1.0], vec![2]));
let result = fold_batch_norm_inference(nodes, &mut weights);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::BatchNorm));
}
#[test]
fn test_fuse_conv_add_relu_basic() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w", "b"], vec!["conv_out"]);
let add = make_node(
OpKind::Add,
"add",
vec!["conv_out", "residual"],
vec!["add_out"],
);
let relu = make_node(OpKind::Relu, "relu", vec!["add_out"], vec!["relu_out"]);
let nodes = vec![conv, add, relu];
let result = fuse_conv_add_relu(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::ConvAddRelu));
assert_eq!(result[0].inputs, vec!["x", "w", "b", "residual"]);
assert_eq!(result[0].outputs, vec!["relu_out"]);
}
#[test]
fn test_fuse_conv_add_relu_reversed_add_inputs() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w", "b"], vec!["conv_out"]);
let add = make_node(
OpKind::Add,
"add",
vec!["residual", "conv_out"],
vec!["add_out"],
);
let relu = make_node(OpKind::Relu, "relu", vec!["add_out"], vec!["relu_out"]);
let nodes = vec![conv, add, relu];
let result = fuse_conv_add_relu(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::ConvAddRelu));
assert_eq!(result[0].inputs, vec!["x", "w", "b", "residual"]);
}
#[test]
fn test_fuse_conv_add_relu_no_bias() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w"], vec!["conv_out"]);
let add = make_node(
OpKind::Add,
"add",
vec!["conv_out", "residual"],
vec!["add_out"],
);
let relu = make_node(OpKind::Relu, "relu", vec!["add_out"], vec!["relu_out"]);
let nodes = vec![conv, add, relu];
let result = fuse_conv_add_relu(nodes);
assert_eq!(result.len(), 1);
assert!(matches!(result[0].op, OpKind::ConvAddRelu));
assert_eq!(result[0].inputs[2], "");
assert_eq!(result[0].inputs[3], "residual");
}
#[test]
fn test_fuse_conv_add_relu_no_fusion_conv_multiple_consumers() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w", "b"], vec!["conv_out"]);
let add = make_node(
OpKind::Add,
"add",
vec!["conv_out", "residual"],
vec!["add_out"],
);
let relu = make_node(OpKind::Relu, "relu", vec!["add_out"], vec!["relu_out"]);
let extra = make_node(OpKind::Relu, "extra", vec!["conv_out"], vec!["extra_out"]);
let nodes = vec![conv, add, relu, extra];
let result = fuse_conv_add_relu(nodes);
assert_eq!(result.len(), 4);
assert!(matches!(result[0].op, OpKind::Conv));
}
#[test]
fn test_fuse_conv_add_relu_no_fusion_add_multiple_consumers() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w", "b"], vec!["conv_out"]);
let add = make_node(
OpKind::Add,
"add",
vec!["conv_out", "residual"],
vec!["add_out"],
);
let relu = make_node(OpKind::Relu, "relu", vec!["add_out"], vec!["relu_out"]);
let extra = make_node(OpKind::Sigmoid, "extra", vec!["add_out"], vec!["extra_out"]);
let nodes = vec![conv, add, relu, extra];
let result = fuse_conv_add_relu(nodes);
assert_eq!(result.len(), 4);
assert!(matches!(result[0].op, OpKind::Conv));
}
#[test]
fn test_fuse_conv_add_relu_no_fusion_not_relu() {
let conv = make_node(OpKind::Conv, "conv", vec!["x", "w", "b"], vec!["conv_out"]);
let add = make_node(
OpKind::Add,
"add",
vec!["conv_out", "residual"],
vec!["add_out"],
);
let sigmoid = make_node(OpKind::Sigmoid, "sigmoid", vec!["add_out"], vec!["sig_out"]);
let nodes = vec![conv, add, sigmoid];
let result = fuse_conv_add_relu(nodes);
assert_eq!(result.len(), 3);
}