use std::error::Error;
use std::sync::Arc;
use rten_base::byte_cast::cast_pod_slice;
use rten_tensor::{NdTensor, Tensor};
use rten_testing::TestCases;
use super::{GraphOptimizer, OptimizeError, OptimizeOptions};
use crate::Dimension;
use crate::constant_storage::{ArcSlice, ArcTensorView, ConstantStorage};
use crate::graph::builder::{Expr, OutputMeta, dims};
use crate::graph::{
CaptureEnv, Constant, Graph, Node, NodeId, OperatorNode, PlanOptions, TypedConstant,
};
use crate::ops::{
Add, Cast, ComputeShape, DimSpec, DynamicQuantizeLinear, Erf, Expand, FusedMatMul, Gather,
Gelu, GroupedQueryAttentionMatMul, Identity, IsNaN, LayerNormalization, MatMul, MatMulInteger,
Neg, Pow, ReduceMean, RepeatInterleave, Reshape, RmsNormalization, Shape, Sigmoid, Slice,
Softmax, Sqrt, Swish, Tanh, Transpose, Unsqueeze, Where,
};
use crate::value::{DataType, Value, ValueType};
fn optimize_graph(graph: Graph) -> Result<Graph, OptimizeError> {
let optimizer = GraphOptimizer::new();
optimizer.optimize(graph, None, OptimizeOptions::default())
}
fn arc_tensor_view(val: f32) -> ArcTensorView<f32> {
let const_data = Vec::from(val.to_le_bytes());
let const_storage = Arc::new(ConstantStorage::Buffer(const_data));
let slice = ArcSlice::new(
const_storage.clone(),
cast_pod_slice(const_storage.data()).unwrap(),
)
.unwrap();
ArcTensorView::from_data(&[], slice)
}
trait OpExprs {
fn cast(&self, to: DataType) -> Expr;
fn erf(&self) -> Expr;
fn identity(&self) -> Expr;
fn is_nan(&self) -> Expr;
fn permute(&self, perm: &[usize]) -> Expr;
fn pow(&self, rhs: Expr) -> Expr;
fn matmul(&self, rhs: Expr) -> Expr;
fn mean(&self) -> Expr;
fn mean_axes(&self, axes: Expr) -> Expr;
fn shape(&self) -> Expr;
fn sigmoid(&self) -> Expr;
fn slice(&self, starts: Expr, ends: Expr) -> Expr;
fn square(&self) -> Expr;
fn sqrt(&self) -> Expr;
fn softmax(&self, axis: isize) -> Expr;
fn tanh(&self) -> Expr;
fn transpose(&self) -> Expr;
fn where_(&self, if_true: Expr, if_false: Expr) -> Expr;
}
impl OpExprs for Expr {
fn cast(&self, to: DataType) -> Expr {
self.unary(Cast { to })
}
fn erf(&self) -> Expr {
self.unary(Erf {})
}
fn identity(&self) -> Expr {
self.unary(Identity {})
}
fn is_nan(&self) -> Expr {
self.unary(IsNaN {})
}
fn matmul(&self, rhs: Expr) -> Expr {
self.binary(MatMul {}, rhs)
}
fn mean(&self) -> Expr {
self.unary(ReduceMean {
axes: Some(vec![-1]),
keep_dims: false,
noop_with_empty_axes: false,
})
}
fn mean_axes(&self, axes: Expr) -> Expr {
self.binary(
ReduceMean {
axes: None,
keep_dims: false,
noop_with_empty_axes: false,
},
axes,
)
}
fn permute(&self, perm: &[usize]) -> Expr {
self.unary(Transpose {
perm: Some(perm.to_vec()),
})
}
fn pow(&self, rhs: Expr) -> Expr {
self.binary(Pow {}, rhs)
}
fn shape(&self) -> Expr {
self.unary(Shape {
start: None,
end: None,
})
}
fn slice(&self, starts: Expr, ends: Expr) -> Expr {
self.apply(Slice {}, &[starts, ends], &[OutputMeta::NoMeta])
}
fn sigmoid(&self) -> Expr {
self.unary(Sigmoid {})
}
fn square(&self) -> Expr {
self.binary(Pow {}, Expr::constant(2.0))
}
fn sqrt(&self) -> Expr {
self.unary(Sqrt {})
}
fn softmax(&self, axis: isize) -> Expr {
self.unary(Softmax {
axis,
flush_nans_to_zero: false,
})
}
fn tanh(&self) -> Expr {
self.unary(Tanh {})
}
fn transpose(&self) -> Expr {
self.unary(Transpose { perm: None })
}
fn where_(&self, if_true: Expr, if_false: Expr) -> Expr {
self.apply(Where {}, &[if_true, if_false], &[OutputMeta::NoMeta])
}
}
trait GetConsumingOp {
fn get_consuming_op(&self, value: NodeId) -> Option<&OperatorNode>;
}
impl GetConsumingOp for Graph {
fn get_consuming_op(&self, value: NodeId) -> Option<&OperatorNode> {
self.get_consumers(value)
.and_then(|c| c.first())
.and_then(|consumer_id| self.get_node(*consumer_id))
.and_then(|node| node.as_operator())
}
}
#[test]
fn test_convert_captured_values_to_constants() -> Result<(), Box<dyn Error>> {
let mut graph = Graph::new();
let const_tensor = arc_tensor_view(42.);
graph.add_constant(Some("const_a"), const_tensor);
let mut subgraph = Graph::new();
let sg_val = subgraph.add_value(Some("const_a"), None, None);
subgraph.set_captures(&[sg_val]);
subgraph.set_output_ids(&[sg_val]);
let optimizer = GraphOptimizer::new();
let capture_env = CaptureEnv::top_level_static(&graph);
let optimized_subgraph =
optimizer.optimize(subgraph, Some(&capture_env), OptimizeOptions::default())?;
let outputs = optimized_subgraph.output_ids();
assert!(optimized_subgraph.captures().is_empty());
assert_eq!(outputs.len(), 1);
let node = optimized_subgraph.get_node(outputs[0]).unwrap();
assert_eq!(node.name(), Some("const_a"));
assert!(matches!(node, Node::Constant(_)));
Ok(())
}
#[test]
fn test_constant_propagation() -> Result<(), Box<dyn Error>> {
let mut graph = Graph::new();
let const_a = graph.add_constant(Some("const_a"), Tensor::from([1, 2, 3]).into_arc());
let const_b = graph.add_constant(Some("const_b"), Tensor::from([4, 5, 6]).into_arc());
let (_, add_out) = graph.add_simple_op("add_1", Add {}, &[const_a, const_b]);
let input = graph.add_value(Some("input"), None, None);
let (add_op_2, add_2_out) = graph.add_simple_op("add_2", Add {}, &[add_out, input]);
graph.set_input_ids(&[input]);
graph.set_output_ids(&[add_out, add_2_out]);
let optimizer = GraphOptimizer::new();
let optimized_graph = optimizer.optimize(graph, None, OptimizeOptions::default())?;
assert_eq!(optimized_graph.input_ids(), &[input]);
assert_ne!(optimized_graph.output_ids()[0], add_out);
assert_eq!(optimized_graph.output_ids()[1], add_2_out);
let replaced_node = optimized_graph
.get_node(optimized_graph.output_ids()[0])
.and_then(|n| match &n {
Node::Constant(c) => Some(c),
_ => None,
})
.unwrap();
let Constant::Int32(const_int) = replaced_node else {
return Err("constant not an int".into());
};
assert_eq!(const_int.view(), Tensor::from([5, 7, 9]));
let op = optimized_graph
.get_node(add_op_2)
.and_then(|n| match &n {
Node::Operator(op) => Some(op),
_ => None,
})
.unwrap();
let input_ids: Vec<_> = op.input_ids().iter().map(|id| id.unwrap()).collect();
assert_eq!(input_ids.len(), 2);
assert_ne!(input_ids[0], add_out);
assert_eq!(input_ids[0], optimized_graph.output_ids()[0]);
assert_eq!(input_ids[1], input);
Ok(())
}
#[test]
fn test_fuse_op_with_duplicate_inputs() {
let graph = {
let x = Expr::value("x");
let bias = Tensor::from([1., 2., 3.]);
let expr = x.matmul(x.clone()) + bias;
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "FusedMatMul");
}
#[test]
fn test_fuse_op_with_captured_input() {
let mut subgraph = {
let x = Expr::value("x");
let expr = x.clone() * x.sigmoid();
expr.build_graph([])
};
let x_id = subgraph.get_node_id("x").unwrap();
subgraph.set_captures(&[x_id]);
let graph = Graph::new();
let capture_env = CaptureEnv::top_level_static(&graph);
let optimized_subgraph = GraphOptimizer::new()
.optimize(subgraph, Some(&capture_env), OptimizeOptions::default())
.unwrap();
let (_, op) = optimized_subgraph
.get_source_node(optimized_subgraph.output_ids()[0])
.unwrap();
assert_eq!(op.operator().name(), "Silu");
}
#[test]
fn test_fuse_transpose() {
let graph = {
let x = Expr::value("x");
let y = Expr::value("y");
x.transpose().matmul(y.transpose()).build_graph(["x", "y"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "TransformInputs(MatMul)");
assert_eq!(
op.input_ids(),
graph
.input_ids()
.iter()
.copied()
.map(Some)
.collect::<Vec<_>>()
);
}
#[test]
fn test_fuse_silu() {
let graph = {
let x = Expr::value("x");
let expr = x.clone() * x.sigmoid();
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "Silu");
}
#[test]
fn test_fuse_swish() {
let graph = {
let x = Expr::value("x");
let beta = 1.7;
let expr = x.clone() * (x.clone() * beta).sigmoid();
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
let swish_op = op.operator().downcast_ref::<Swish>().unwrap();
assert_eq!(swish_op.beta, 1.7);
}
#[test]
fn test_fuse_matmul_add() {
let graph = {
let a = Expr::value("a");
let b = Expr::value("b");
let bias = Tensor::from([1., 2., 3.]);
let expr = a.matmul(b) + bias;
expr.build_graph(["a", "b"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "FusedMatMul");
}
#[test]
fn test_fuse_matmul_scaled() {
let graph = {
let a = Expr::value("a");
let b = Expr::value("b");
let expr = (a * 0.5).matmul(b * 0.3);
expr.build_graph(["a", "b"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "FusedMatMul");
let fused_matmul_op = op.operator().downcast_ref::<FusedMatMul>().unwrap();
assert_eq!(fused_matmul_op.alpha, Some(0.5 * 0.3));
let graph = {
let a = Expr::value("a");
let b = Expr::value("b");
let expr = a.matmul(b) / 0.5;
expr.build_graph(["a", "b"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "FusedMatMul");
let fused_matmul_op = op.operator().downcast_ref::<FusedMatMul>().unwrap();
assert_eq!(fused_matmul_op.alpha, Some(1. / 0.5));
}
#[test]
fn test_chained_fused_ops() {
let graph = {
let x = Expr::value("x");
let y = x.clone() * x.sigmoid();
let z = y.clone() * y.sigmoid();
z.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, fused_op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(fused_op.operator().name(), "Silu");
let (_, fused_op_2) = graph
.get_source_node(fused_op.input_ids()[0].unwrap())
.unwrap();
assert_eq!(fused_op_2.operator().name(), "Silu");
}
#[test]
fn test_fuse_gelu() {
let graph = {
let x = Expr::value("x");
let sqrt_2 = (2.0f32).sqrt();
let expr = x.clone() * ((x / sqrt_2).erf() + 1.0) * 0.5;
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
let gelu = op.operator().downcast_ref::<Gelu>().unwrap();
assert_eq!(gelu.approximate, false);
}
#[test]
fn test_fuse_approx_gelu() {
let graph = {
let x = Expr::value("x");
let sqrt_2_pi = Expr::constant((2.0f32 / std::f32::consts::PI).sqrt());
let expr = x.clone()
* 0.5
* (Expr::constant(1.)
+ (sqrt_2_pi * (x.clone() + x.pow(Expr::constant(3.0)) * 0.044715)).tanh());
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
let gelu = op.operator().downcast_ref::<Gelu>().unwrap();
assert_eq!(gelu.approximate, true);
}
fn layer_norm_graph(with_bias: bool) -> Graph {
let epsilon = 1e-6;
let x = Expr::value("x");
let x_mean = x.mean();
let x_sub_mean = x.clone() - x_mean;
let normalized = x_sub_mean.clone() / (x_sub_mean.square().mean() + epsilon).sqrt();
let scale = Tensor::from([3., 4., 5.]);
let expr = if with_bias {
let bias = Tensor::from([1., 2., 3.]);
normalized * scale + bias
} else {
normalized * scale
};
expr.build_graph(["x"])
}
#[test]
fn test_fuse_layer_norm() {
#[derive(Debug)]
struct Case {
with_bias: bool,
}
let cases = [Case { with_bias: true }, Case { with_bias: false }];
cases.test_each(|&Case { with_bias }| {
let graph = layer_norm_graph(with_bias);
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "LayerNormalization");
let layer_norm = op.operator().downcast_ref::<LayerNormalization>().unwrap();
assert_eq!(layer_norm.epsilon, Some(1e-6));
let bias_input = op.input_ids().get(2).copied().flatten();
assert_eq!(bias_input.is_some(), with_bias);
})
}
#[test]
fn test_fuse_rms_norm() {
let graph = {
let x = Expr::value("x");
let epsilon = 1e-6;
let rms = (x.square().mean() + epsilon).sqrt();
let scale = Tensor::from([3., 4., 5.]);
let expr = x * (Expr::constant(1.) / rms) * scale;
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
let rms_norm = op.operator().downcast_ref::<RmsNormalization>().unwrap();
assert_eq!(rms_norm.epsilon, Some(1e-6));
}
#[test]
fn test_fuse_rms_norm_with_positive_axes() {
let graph = {
let dims = &[
Dimension::Symbolic("batch".to_string()),
Dimension::Symbolic("seq".to_string()),
Dimension::Fixed(16),
];
let x = Expr::value_with_info("x", ValueType::Tensor(DataType::Float), dims);
let axes = Expr::constant(Tensor::from([2i32]));
let epsilon = 1e-6;
let x_square = x.apply(
Pow {},
&[Expr::constant(2.0)],
&[OutputMeta::Meta((DataType::Float, dims.to_vec()))],
);
let rms = (x_square.mean_axes(axes) + epsilon).sqrt();
let scale = Tensor::from([3., 4., 5.]);
let expr = x * (Expr::constant(1.) / rms) * scale;
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
let rms_norm = op.operator().downcast_ref::<RmsNormalization>().unwrap();
assert_eq!(rms_norm.epsilon, Some(1e-6));
}
#[test]
fn test_fuse_add_softmax() {
let graph = {
let qk = Expr::value("qk");
let m = Expr::value("m");
let expr = (qk + m).softmax(-1);
expr.build_graph(["qk", "m"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "AddSoftmax");
}
#[test]
fn test_fuse_add_softmax_positive_axes() {
let graph = {
let dims = [
Dimension::Symbolic("batch".to_string()),
Dimension::Fixed(768),
];
let qk = Expr::value_with_info("qk", ValueType::Tensor(DataType::Float), &dims);
let m = Expr::value("m");
let expr = qk
.apply(
Add {},
&[m],
&[OutputMeta::Meta((DataType::Float, dims.to_vec()))],
)
.softmax(1);
expr.build_graph(["qk", "m"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "AddSoftmax");
}
#[test]
fn test_fuse_safe_softmax() {
let graph = {
let x = Expr::value("x");
let y = x.softmax(-1);
let expr = y.is_nan().where_(Expr::constant(0.), y);
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
let softmax = op.operator().downcast_ref::<Softmax>().unwrap();
assert_eq!(softmax.axis, -1);
assert_eq!(softmax.flush_nans_to_zero, true);
}
#[test]
fn test_fuse_reciprocal() {
let graph = {
let x = Expr::value("x");
let expr = Expr::constant(1.) / x;
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "Reciprocal");
}
#[test]
fn test_fuse_reduce_mean_axes() {
let graph = {
let x = Expr::value("x");
let axes = Expr::constant(Tensor::from([-1i32]));
x.mean_axes(axes).build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
let mean_op = op.operator().downcast_ref::<ReduceMean>().unwrap();
assert_eq!(mean_op.axes.as_deref(), Some([-1].as_slice()));
}
#[test]
fn test_fuse_identity_op() {
struct Case {
expr: Expr,
}
let cases = [
Case {
expr: (Expr::value("x") + 0.),
},
Case {
expr: (Expr::value("x") - 0.),
},
Case {
expr: (Expr::value("x") * 1.),
},
Case {
expr: (Expr::value("x") / 1.),
},
];
for case in cases {
let graph = optimize_graph(case.expr.build_graph(["x"])).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "Identity");
}
}
#[test]
fn test_eliminate_binary_identity_pattern() {
let graph = {
let x = Expr::value("x");
let expr = x * 1. + 2.;
expr.build_graph(["x"])
};
let input_id = graph.input_ids()[0];
let output_id = graph.output_ids()[0];
let mul_op = graph.get_consumers(input_id).unwrap()[0];
assert_eq!(
graph
.get_node(mul_op)
.unwrap()
.as_operator()
.unwrap()
.operator()
.name(),
"Mul"
);
let graph = optimize_graph(graph).unwrap();
assert_eq!(graph.input_ids(), [input_id]);
assert_eq!(graph.output_ids(), [output_id]);
let (_, op) = graph.get_source_node(output_id).unwrap();
assert_eq!(op.operator().name(), "Add");
assert_eq!(op.input_ids()[0], Some(input_id));
assert!(graph.get_node(mul_op).is_none());
}
#[test]
fn test_eliminate_unary_identity_pattern() {
let graph = {
let x = Expr::value("x");
x.identity().sqrt().build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let input_id = graph.input_ids()[0];
assert_eq!(
graph.get_consuming_op(input_id).unwrap().operator().name(),
"Sqrt"
);
}
#[test]
fn test_eliminate_noop_cast() {
let graph = {
let x = Expr::value_with_info(
"x",
ValueType::Tensor(DataType::Float),
&[Dimension::Symbolic("x".to_string())],
);
x.cast(DataType::Float).erf().build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let input_id = graph.input_ids()[0];
assert_eq!(
graph.get_consuming_op(input_id).unwrap().operator().name(),
"Erf"
);
}
#[test]
fn test_fuse_matmulinteger_cast_scale() {
let graph = {
let x = Expr::value("x");
let weights = Expr::constant(Tensor::<i8>::zeros(&[4, 4]));
let weights_zero = Expr::constant(Tensor::<i8>::zeros(&[4]));
let quant = x.apply(
DynamicQuantizeLinear {},
&[],
&[OutputMeta::NoMeta, OutputMeta::NoMeta, OutputMeta::NoMeta],
);
let quant_x = quant.output(0);
let quant_scale = quant.output(1);
let quant_zero = quant.output(2);
let const_scale = Expr::constant(Tensor::from([0.1, 0.2, 0.3]));
let expr = quant_x
.apply(
MatMulInteger {},
&[weights, quant_zero, weights_zero],
&[OutputMeta::NoMeta],
)
.unary(Cast {
to: DataType::Float,
})
* (quant_scale * const_scale);
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "MatMulIntegerToFloat");
}
#[test]
fn test_slice_shape_to_constant() {
let graph = {
let x = Expr::value_with_info(
"x",
ValueType::Tensor(DataType::Float),
&[Dimension::Symbolic("batch".into()), Dimension::Fixed(64)],
);
let starts = Expr::constant(Tensor::from([-1i32]));
let ends = Expr::constant(Tensor::from([i32::MAX]));
let expr = x.shape().slice(starts, ends);
expr.build_graph(["x"])
};
let graph = optimize_graph(graph).unwrap();
let id_input = graph.output_ids()[0];
let const_node = graph
.get_node(id_input)
.and_then(|n| n.as_constant())
.unwrap();
assert_eq!(const_node.as_scalar(), Some(64i32));
}
#[test]
fn test_optimize_preserves_input_output_nodes() {
let graph = {
let x = Expr::value("x");
let y = Expr::value("y");
x.transpose().matmul(y).build_graph(["x", "y"])
};
let orig_input_ids = graph.input_ids().to_vec();
let orig_output_ids = graph.output_ids().to_vec();
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "TransformInputs(MatMul)");
assert_eq!(graph.input_ids(), orig_input_ids);
assert_eq!(graph.output_ids(), orig_output_ids);
}
#[test]
fn test_optimize_error() {
let mut graph = Graph::new();
let optimizer = GraphOptimizer::new();
let invalid_id = NodeId::from_u32(123);
graph.set_input_ids(&[invalid_id]);
graph.set_output_ids(&[invalid_id]);
let result = optimizer.optimize(graph, None, OptimizeOptions::default());
assert!(matches!(result, Err(OptimizeError::RunError(_))));
}
#[test]
fn test_optimize_removes_unfused_ops() {
let graph = {
let x = Expr::value("x");
let sqrt_2 = (2.0f32).sqrt();
let expr = x.clone() * ((x / sqrt_2).erf() + 1.0) * 0.5;
expr.build_graph(["x"])
};
let ops = ["Mul", "Erf", "Div", "Add"];
for op in ops {
assert!(graph.get_node_id(op).is_some());
}
let optimized = optimize_graph(graph).unwrap();
for op in ops {
assert!(optimized.get_node_id(op).is_none());
}
let fused_op = optimized
.get_node_id("Mul_1")
.and_then(|id| optimized.get_node(id))
.and_then(|n| n.as_operator())
.unwrap();
assert_eq!(fused_op.operator().name(), "Gelu");
}
#[test]
fn test_optimize_does_not_fuse_if_intermediate_outputs_reused() {
let mut graph = {
let x = Expr::value("x");
let sqrt_2 = (2.0f32).sqrt();
let expr = x.clone() * ((x / sqrt_2).erf() + 1.0) * 0.5;
expr.build_graph(["x"])
};
let erf_out = graph.get_node_id("Erf_out").unwrap();
graph.add_simple_op("neg", Neg {}, &[erf_out]);
let optimized = optimize_graph(graph).unwrap();
let fused_op = optimized
.get_node_id("Mul_1")
.and_then(|id| optimized.get_node(id))
.and_then(|n| n.as_operator())
.unwrap();
assert_eq!(fused_op.operator().name(), "Mul");
}
#[test]
fn test_fuse_transpose_matmul_scaled() {
let x = Expr::value("x");
let y = Expr::value("y");
let xy = x.transpose().matmul(y.transpose());
let xy_scaled = xy / 8.;
let graph = xy_scaled.build_graph(["x", "y"]);
let input_ids = graph.input_ids().to_vec();
let output_ids = graph.output_ids().to_vec();
let optimized = optimize_graph(graph).unwrap();
let plan = optimized
.execution_plan(&input_ids, &output_ids, PlanOptions::default())
.unwrap();
let op_name = |node_id| {
optimized
.get_node(node_id)
.and_then(|n| n.as_operator())
.map(|op| op.operator().name())
};
assert_eq!(plan.len(), 1);
assert_eq!(op_name(plan[0]), Some("TransformInputs(FusedMatMul)"));
}
#[test]
fn test_optimize_does_not_fuse_if_intermediate_output_is_graph_output() {
let mut graph = {
let x = Expr::value("x");
let sqrt_2 = (2.0f32).sqrt();
let expr = x.clone() * ((x / sqrt_2).erf() + 1.0) * 0.5;
expr.build_graph(["x"])
};
let erf_out = graph.get_node_id("Erf_out").unwrap();
let mut output_ids = graph.output_ids().to_vec();
output_ids.push(erf_out);
graph.set_output_ids(&output_ids);
let optimized = optimize_graph(graph).unwrap();
let fused_op = optimized
.get_node_id("Mul_1")
.and_then(|id| optimized.get_node(id))
.and_then(|n| n.as_operator())
.unwrap();
assert_eq!(fused_op.operator().name(), "Mul");
}
#[test]
fn test_fuse_repeat_interleave() {
let batch_dim = Dimension::Symbolic("batch".to_string());
let kv_heads = 8;
let embed_dim = Dimension::Fixed(16);
let n_repeats = 3;
let repeat_axis = 1;
let input_shape = [
batch_dim.clone(),
Dimension::Fixed(kv_heads),
embed_dim.clone(),
];
let x = Expr::value_with_info("x", ValueType::Tensor(DataType::Float), &input_shape);
let unsqueeze_axes = Expr::constant(Value::from(NdTensor::from([repeat_axis as i32 + 1])));
let expand_shape = Expr::value_with_info(
"expand_shape",
ValueType::Tensor(DataType::Float),
&[
batch_dim.clone(),
Dimension::Fixed(kv_heads),
Dimension::Fixed(n_repeats),
embed_dim.clone(),
],
);
let out_shape = [
batch_dim.clone(),
Dimension::Fixed(kv_heads * n_repeats),
embed_dim.clone(),
];
let reshape_shape = Expr::value_with_info(
"reshape_shape",
ValueType::Tensor(DataType::Float),
&out_shape,
);
let t1 = x.binary(Unsqueeze {}, unsqueeze_axes);
let t2 = t1.binary(Expand {}, expand_shape);
let output_meta = OutputMeta::Meta((DataType::Float, out_shape.to_vec()));
let y = t2.apply(
Reshape { allow_zero: false },
&[reshape_shape],
&[output_meta],
);
let graph = y.build_graph([
"x",
"expand_shape",
"reshape_shape",
]);
let optimized = optimize_graph(graph).unwrap();
let op = optimized
.get_node_id("Reshape")
.and_then(|id| optimized.get_node(id))
.and_then(|n| n.as_operator())
.unwrap();
let repeat_interleave_op = op.operator().downcast_ref::<RepeatInterleave>().unwrap();
assert_eq!(repeat_interleave_op.axis, repeat_axis);
assert_eq!(repeat_interleave_op.repeats, n_repeats);
}
#[test]
fn test_fuse_compute_shape() {
let in_shape = vec![
Dimension::Symbolic("batch".to_string()),
Dimension::Fixed(3),
Dimension::Fixed(224),
Dimension::Fixed(224),
];
let x = Expr::value_with_info("pixels", ValueType::Tensor(DataType::Float), &in_shape);
let eps = Expr::constant(Value::from(NdTensor::from(0.)));
let y = x.apply(
Add {},
&[eps],
&[OutputMeta::Meta((DataType::Float, in_shape))],
);
let shape = y.unary(Shape {
start: None,
end: None,
});
let graph = Expr::make_graph([x], [y, shape]);
let optimized = optimize_graph(graph).unwrap();
let op = optimized
.get_node_id("Shape")
.and_then(|id| optimized.get_node(id))
.and_then(|n| n.as_operator())
.unwrap();
let op = op.operator().downcast_ref::<ComputeShape>().unwrap();
assert_eq!(
op.shape,
[
DimSpec::Dynamic { input: 0, dim: 0 },
DimSpec::Static(3),
DimSpec::Static(224),
DimSpec::Static(224),
]
);
}
#[test]
fn test_fuse_grouped_query_attention_matmul() {
let batch_dim = Dimension::Symbolic("batch".to_string());
let seq_dim = Dimension::Symbolic("seq".to_string());
let kv_heads = 8;
let n_repeats = 3;
let query_heads = kv_heads * n_repeats;
let d_model = 64;
let q_shape = [
batch_dim.clone(),
Dimension::Fixed(query_heads),
seq_dim.clone(),
Dimension::Fixed(d_model),
];
let q = Expr::value_with_info("q", ValueType::Tensor(DataType::Float), &q_shape);
let kv_shape = [
batch_dim.clone(),
Dimension::Fixed(kv_heads),
seq_dim.clone(),
Dimension::Fixed(d_model),
];
let kv = Expr::value_with_info("kv", ValueType::Tensor(DataType::Float), &kv_shape);
let repeat_axis = 1;
let unsqueeze_axes = Expr::constant(Value::from(NdTensor::from([repeat_axis as i32 + 1])));
let expand_shape = Expr::value_with_info(
"expand_shape",
ValueType::Tensor(DataType::Float),
&[
batch_dim.clone(),
Dimension::Fixed(kv_heads),
Dimension::Fixed(n_repeats),
seq_dim.clone(),
Dimension::Fixed(d_model),
],
);
let out_shape = [
batch_dim.clone(),
Dimension::Fixed(query_heads),
seq_dim.clone(),
Dimension::Fixed(d_model),
];
let reshape_shape = Expr::value_with_info(
"reshape_shape",
ValueType::Tensor(DataType::Float),
&out_shape,
);
let t1 = kv.binary(Unsqueeze {}, unsqueeze_axes);
let t2 = t1.binary(Expand {}, expand_shape);
let output_meta = OutputMeta::Meta((DataType::Float, out_shape.to_vec()));
let kv_repeated = t2.apply(
Reshape { allow_zero: false },
&[reshape_shape],
&[output_meta],
);
let expr = q.matmul(kv_repeated);
let graph = expr.build_graph(["q", "kv", "expand_shape", "reshape_shape"]);
let optimized = optimize_graph(graph).unwrap();
let (_, op) = optimized
.get_source_node(optimized.output_ids()[0])
.unwrap();
let qkv_matmul = op
.operator()
.downcast_ref::<GroupedQueryAttentionMatMul>()
.unwrap();
assert_eq!(qkv_matmul.repeats, n_repeats);
assert_eq!(qkv_matmul.alpha, None);
assert_eq!(qkv_matmul.transpose_rhs, false);
}
#[test]
fn test_fuse_grouped_query_attention_matmul_with_transpose_and_scale() {
let batch_dim = Dimension::Symbolic("batch".to_string());
let seq_dim = Dimension::Symbolic("seq".to_string());
let kv_heads = 8;
let n_repeats = 3;
let query_heads = kv_heads * n_repeats;
let d_model = 64;
let scale = (d_model as f32).sqrt();
let q_shape = [
batch_dim.clone(),
Dimension::Fixed(query_heads),
seq_dim.clone(),
Dimension::Fixed(d_model),
];
let q = Expr::value_with_info("q", ValueType::Tensor(DataType::Float), &q_shape);
let k_shape = [
batch_dim.clone(),
Dimension::Fixed(kv_heads),
seq_dim.clone(),
Dimension::Fixed(d_model),
];
let k = Expr::value_with_info("k", ValueType::Tensor(DataType::Float), &k_shape);
let repeat_axis = 1;
let unsqueeze_axes = Expr::constant(Value::from(NdTensor::from([repeat_axis as i32 + 1])));
let expand_shape = Expr::value_with_info(
"expand_shape",
ValueType::Tensor(DataType::Float),
&[
batch_dim.clone(),
Dimension::Fixed(kv_heads),
Dimension::Fixed(n_repeats),
seq_dim.clone(),
Dimension::Fixed(d_model),
],
);
let repeated_shape = [
batch_dim.clone(),
Dimension::Fixed(query_heads),
seq_dim.clone(),
Dimension::Fixed(d_model),
];
let reshape_shape = Expr::value_with_info(
"reshape_shape",
ValueType::Tensor(DataType::Float),
&repeated_shape,
);
let t1 = k.binary(Unsqueeze {}, unsqueeze_axes);
let t2 = t1.binary(Expand {}, expand_shape);
let output_meta = OutputMeta::Meta((DataType::Float, repeated_shape.to_vec()));
let k_repeated = t2.apply(
Reshape { allow_zero: false },
&[reshape_shape],
&[output_meta],
);
let k_repeated_transposed = k_repeated.permute(&[0, 1, 3, 2]);
let expr = q.matmul(k_repeated_transposed) / scale;
let graph = expr.build_graph(["q", "k", "expand_shape", "reshape_shape"]);
let optimized = optimize_graph(graph).unwrap();
let (_, op) = optimized
.get_source_node(optimized.output_ids()[0])
.unwrap();
let qkv_matmul = op
.operator()
.downcast_ref::<GroupedQueryAttentionMatMul>()
.unwrap();
assert_eq!(qkv_matmul.repeats, n_repeats);
assert_eq!(qkv_matmul.alpha, Some(1.0 / scale));
assert_eq!(qkv_matmul.transpose_rhs, true);
}
#[test]
fn test_infer_shapes() {
let graph = {
let x = Expr::value_with_info(
"data",
ValueType::Tensor(DataType::Float),
&dims!("batch", 64),
);
let w = Expr::constant(NdTensor::<f32, _>::zeros([64, 12]));
let out = x.apply(MatMul {}, &[w], &[OutputMeta::NoMeta]);
out.build_graph(&["data"])
};
let optimizer = GraphOptimizer::new();
let graph = optimizer
.optimize(graph, None, OptimizeOptions { infer_shapes: true })
.unwrap();
let output = graph.get_node(graph.output_ids()[0]).unwrap();
assert_eq!(
output.shape().as_deref(),
Some(dims!("batch", 12).as_slice())
);
assert_eq!(output.dtype(), Some(ValueType::Tensor(DataType::Float)));
}
#[test]
fn test_shape_inference_replaces_values_with_constants() {
let graph = {
let x = Expr::value_with_info(
"data",
ValueType::Tensor(DataType::Float),
&dims!("batch", 64),
);
let indices = Expr::constant(1);
let out = x
.shape()
.apply(Gather { axis: 0 }, &[indices], &[OutputMeta::NoMeta]);
out.build_graph(&["data"])
};
let optimizer = GraphOptimizer::new();
let graph = optimizer
.optimize(graph, None, OptimizeOptions { infer_shapes: true })
.unwrap();
let output = graph.get_node(graph.output_ids()[0]).unwrap();
assert_eq!(output.as_constant().and_then(|c| c.as_scalar()), Some(64));
}