use crate::ops::MetalFusedAxisOp;
use tract_core::internal::*;
use tract_core::tract_data::itertools::Itertools;
use tract_gpu::fact::DeviceTypedFactExt;
use tract_gpu::ops::change_axes::GpuAxisOp;
use tract_gpu::rule_ensure;
fn is_supported_axis_op(op: &GpuAxisOp) -> bool {
matches!(op.inner, AxisOp::Add(_) | AxisOp::Rm(_) | AxisOp::Reshape(..))
}
fn can_fuse_move(model: &TypedModel, axis_node: &TypedNode) -> bool {
model.single_succ(axis_node.id).unwrap().is_some_and(|node| {
node.op_is::<tract_gpu::ops::concat::GpuConcat>()
|| node.op_is::<tract_gpu::ops::apply_rope::GpuApplyRope>()
|| node.op_is::<tract_gpu::ops::scaled_masked_softmax::GpuScaledMaskedSoftmax>()
|| node.op_is::<tract_gpu::ops::slice::GpuSlice>()
|| node.op_is::<tract_gpu::ops::broadcast::GpuMultiBroadcastTo>()
|| node.op_is::<tract_gpu::ops::dyn_kv_cache::GpuDynKVCache>()
})
}
pub fn collect_chain_of_axis_ops<'a>(
model: &'a TypedModel,
mut cursor: &'a TypedNode,
) -> TractResult<Option<(TVec<GpuAxisOp>, &'a TypedNode)>> {
let mut acc_axis_ops = tvec![];
let mut head_of_chain = cursor;
while let Some(axis_op) = cursor.op_as::<GpuAxisOp>().filter(|o| {
is_supported_axis_op(o)
|| (matches!(o.inner, AxisOp::Move(..)) && can_fuse_move(model, cursor))
}) {
acc_axis_ops.push(axis_op.clone());
head_of_chain = cursor;
if let Some(prev) = model.single_prec(cursor.id)? {
cursor = prev;
} else {
break;
}
}
Ok(if acc_axis_ops.is_empty() {
None
} else {
Some((acc_axis_ops.into_iter().rev().collect(), head_of_chain))
})
}
fn split_succs(
model: &TypedModel,
axis_node: &TypedNode,
axis_node_name: &str,
axis_op: &GpuAxisOp,
) -> TractResult<Option<TypedModelPatch>> {
let succs = model.all_succ(axis_node.id)?.context("Expected node with successors")?;
let mut patch = TypedModelPatch::default();
let input = patch.tap_model(model, axis_node.inputs[0])?;
for (i, succ) in succs.iter().enumerate() {
let axis_out =
patch.wire_node(format!("{axis_node_name}.{i}"), axis_op.clone(), &[input])?[0];
let mut op_ins = patch.taps(model, &succ.inputs)?;
let (idx, _) = succ
.inputs
.iter()
.enumerate()
.find(|(_, inlet)| inlet.node == axis_node.id)
.context("Axis node not found in its successor inputs")?;
op_ins[idx] = axis_out;
let op_outs = patch.wire_node(succ.name.clone(), succ.op.clone(), &op_ins)?;
for out in op_outs {
patch.shunt_outside(model, succ.id.into(), out)?;
}
}
Ok(Some(patch))
}
pub fn fuse_axis_op(
_ctx: &(),
model: &TypedModel,
axis_node: &TypedNode,
axis_node_name: &str,
axis_op: &GpuAxisOp,
) -> TractResult<Option<TypedModelPatch>> {
rule_ensure!(is_supported_axis_op(axis_op) || matches!(axis_op.inner, AxisOp::Move(..)));
let Some(node) = model.single_succ(axis_node.id)? else {
return split_succs(model, axis_node, axis_node_name, axis_op);
};
let is_axis_like = node.op_is::<GpuAxisOp>() || node.op_is::<MetalFusedAxisOp>();
let is_allowed_move =
node.op_as::<GpuAxisOp>().is_some_and(|op| matches!(op.inner, AxisOp::Move(..)));
rule_ensure!(!is_axis_like || is_allowed_move);
let node_name = &node.name;
let Some(in_nodes) = model.all_prec(node.id)? else {
return Ok(None);
};
let mut grouped_axis_ops: TVec<TVec<GpuAxisOp>> = tvec![];
let mut tap_inputs = tvec![];
let mut patch = TypedModelPatch::default();
for (in_idx, in_node) in in_nodes.into_iter().enumerate() {
match collect_chain_of_axis_ops(model, in_node)? {
Some((acc_axis_ops, head_of_chain)) => {
grouped_axis_ops.push(acc_axis_ops);
tap_inputs.push(patch.tap_model(model, head_of_chain.inputs[0])?);
}
None => {
grouped_axis_ops.push(tvec![]);
tap_inputs.push(patch.tap_model(model, node.inputs[in_idx])?);
}
}
}
if let Some(op) = node.op_as::<GpuAxisOp>() {
if matches!(op.inner, AxisOp::Move(..)) {
let should_defer_move = !grouped_axis_ops[0].is_empty() && !can_fuse_move(model, node);
if should_defer_move {
let out = patch.wire_node(
format!("{node_name}.fused_axis_op"),
MetalFusedAxisOp { grouped_axis_ops, op: Box::new(op.clone()) },
&tap_inputs,
)?;
patch.shunt_outside(model, node.id.into(), out[0])?;
return Ok(Some(patch));
} else {
return Ok(None);
}
}
}
let out = patch.wire_node(
format!("{node_name}.fused_axis_op"),
MetalFusedAxisOp { grouped_axis_ops, op: node.op.clone() },
&tap_inputs,
)?;
patch.shunt_outside(model, node.id.into(), out[0])?;
Ok(Some(patch))
}
pub fn fuse_move_axis(
_ctx: &(),
model: &TypedModel,
axis_node: &TypedNode,
axis_node_name: &str,
axis_op: &GpuAxisOp,
) -> TractResult<Option<TypedModelPatch>> {
rule_ensure!(matches!(axis_op.inner, AxisOp::Move(..)));
let in_fact = model.node_input_facts(axis_node.id)?[0];
let in_shape =
in_fact.as_device_fact().map(|mf| mf.shape.clone()).unwrap_or(in_fact.shape.clone());
let out_fact = model.node_output_facts(axis_node.id)?[0];
let out_shape =
out_fact.as_device_fact().map(|mf| mf.shape.clone()).unwrap_or(out_fact.shape.clone());
if in_shape == out_shape {
if let (Some(in_strides), AxisOp::Move(from, to)) =
(in_shape.as_concrete().map(Tensor::natural_strides), axis_op.inner.clone())
{
let mut out_strides = in_strides.clone();
let remove_stride = out_strides.remove(from);
out_strides.insert(to, remove_stride);
if in_strides == out_strides {
return TypedModelPatch::shunt_one_op(model, axis_node);
}
}
}
let simpl_op = GpuAxisOp::simplify_axis_op(axis_op.inner.clone(), in_shape.dims());
if simpl_op != *axis_op {
return Ok(Some(TypedModelPatch::replace_single_op(
model,
axis_node,
&[axis_node.inputs[0]],
simpl_op,
)?));
}
let Some(cursor) = model.single_succ(axis_node.id)? else { return Ok(None) };
if let (AxisOp::Move(from_1, to_1), AxisOp::Move(from_2, to_2)) = (
axis_op.inner.clone(),
cursor.op_as::<GpuAxisOp>().map(|ax_op| ax_op.inner.clone()).unwrap_or(AxisOp::Add(0)),
) {
let max_rank = [from_1, from_2, to_1, to_2].iter().max().unwrap() + 1;
let mut perm: TVec<usize> = (0..max_rank).collect_vec().into();
AxisOp::Move(from_1, to_1).change_shape_array(&mut perm, false)?;
AxisOp::Move(from_2, to_2).change_shape_array(&mut perm, false)?;
let new_axis_ops = perm_to_ops(&perm);
if new_axis_ops.len() == 1 {
let mut patch = TypedModelPatch::default();
let inputs = patch.taps(model, &axis_node.inputs)?;
let out = patch.wire_node(
format!("{axis_node_name}.fused_move_axis"),
GpuAxisOp::new(new_axis_ops[0].clone()),
&inputs,
)?;
patch.shunt_outside(model, cursor.id.into(), out[0])?;
return Ok(Some(patch));
}
}
let Some(cursor) = model.single_prec(axis_node.id)? else { return Ok(None) };
if let (AxisOp::Move(from_1, to_1), AxisOp::Add(ax)) = (
axis_op.inner.clone(),
cursor.op_as::<GpuAxisOp>().map(|ax_op| ax_op.inner.clone()).unwrap_or(AxisOp::Rm(0)),
) {
if ax == from_1 {
let mut patch = TypedModelPatch::default();
let inputs = patch.taps(model, &cursor.inputs)?;
let out =
patch.wire_node(cursor.name.clone(), GpuAxisOp::new(AxisOp::Add(to_1)), &inputs)?;
patch.shunt_outside(model, axis_node.id.into(), out[0])?;
return Ok(Some(patch));
}
}
Ok(None)
}