tract_core/optim/
concat_then_einsum.rs1use crate::internal::*;
2
3use crate::ops::array::{Slice, TypedConcat};
4use crate::ops::einsum::EinSum;
5use crate::ops::math::add;
6use crate::optim::OptimizerSession;
7use tract_itertools::Itertools;
8
9#[derive(Clone, Debug, Default)]
10pub struct ConcatThenEinsum(Option<InletId>);
11
12impl super::TypedPass for ConcatThenEinsum {
13 fn reset(&mut self) -> TractResult<()> {
14 self.0 = Default::default();
15 Ok(())
16 }
17
18 #[allow(clippy::comparison_chain)]
19 fn next(
20 &mut self,
21 _session: &mut OptimizerSession,
22 model: &TypedModel,
23 ) -> TractResult<Option<TypedModelPatch>> {
24 'outer: loop {
25 self.0 = if let Some(previous) = self.0 {
26 if let Some(next) = next_inlet(model, &previous) {
27 Some(next)
28 } else {
29 return Ok(None);
30 }
31 } else if let Some(first) =
32 model.nodes.iter().find(|n| n.inputs.len() > 0).map(|n| InletId::new(n.id, 0))
33 {
34 Some(first)
35 } else {
36 return Ok(None);
37 };
38 let inlet = self.0.unwrap();
39 let outlet = model.nodes[inlet.node].inputs[inlet.slot];
40 let concat_node = model.node(outlet.node);
41 if model.outputs.contains(&concat_node.id.into()) {
42 continue;
43 }
44 let einsum_node = &model.nodes[inlet.node];
45 if einsum_node.inputs.len() != 2 {
46 continue;
48 }
49 if let (Some(concat), Some(einsum)) =
50 (concat_node.op_as::<TypedConcat>(), einsum_node.op_as::<EinSum>())
51 {
52 let offsets = concat.offsets(&model.node_input_facts(concat_node.id)?)?;
53 let axis_info = einsum.axes.axis((InOut::In(inlet.slot), concat.axis))?;
54 if axis_info.outputs[0].len() > 0 {
56 continue;
57 }
58 let mut patch = TypedModelPatch::new(format!(
59 "Split Einsum for concat on axis {}",
60 axis_info.repr
61 ));
62 let mut inputs: TVec<TVec<OutletId>> = tvec!();
64 for (slot, input) in einsum_node.inputs.iter().enumerate() {
65 let tap = patch.tap_model(model, *input)?;
66 if axis_info.inputs[slot].len() > 1 {
67 continue 'outer;
68 } else if axis_info.inputs[slot].len() == 1 {
69 let mut slices = tvec!();
70 for (start, end) in offsets.iter().cloned().tuple_windows() {
71 let wire = patch.wire_node(
72 format!(
73 "{}.concat-einsum-slice-{}.{}.{}..{}",
74 einsum_node.name, axis_info.repr, slot, start, end
75 ),
76 Slice { axis: axis_info.inputs[slot][0], start, end },
77 &[tap],
78 )?;
79 slices.push(wire[0]);
80 }
81 inputs.push(slices);
82 } else {
83 inputs.push(tvec!(tap)); };
85 }
86 let mut einsums = tvec!();
87 for (ix, (start, end)) in offsets.iter().tuple_windows().enumerate() {
88 let mut einsum_inputs = tvec!();
89 for input_ix in 0..einsum_node.inputs.len() {
90 einsum_inputs
91 .push(inputs[input_ix].get(ix).cloned().unwrap_or(inputs[input_ix][0]));
92 }
93 let einsum = patch.wire_node(
94 format!(
95 "{}.concat-einsum-{}.{}..{}",
96 einsum_node.name, axis_info.repr, start, end
97 ),
98 einsum.clone(),
99 &einsum_inputs,
100 )?[0];
101 einsums.push(einsum);
102 }
103 let wire = if let Some(axis) = axis_info.outputs[0].first().cloned() {
104 patch.wire_node(
105 format!("{}.concat-einsum-{}.concat", einsum_node.name, axis_info.repr),
106 TypedConcat { axis },
107 &einsums,
108 )?[0]
109 } else {
110 let mut wire = einsums[0];
111 for ix in 1..einsums.len() {
112 wire = patch.wire_node(
113 format!(
114 "{}.concat-einsum-{}.add-{}",
115 einsum_node.name, axis_info.repr, ix
116 ),
117 add(),
118 &[wire, einsums[ix]],
119 )?[0]
120 }
121 wire
122 };
123 patch.shunt_outside(model, einsum_node.id.into(), wire)?;
124 return Ok(Some(patch));
125 }
126 }
127 }
128}
129
130fn next_inlet(model: &TypedModel, inlet: &InletId) -> Option<InletId> {
131 if inlet.slot + 1 < model.nodes[inlet.node].inputs.len() {
132 Some(InletId::new(inlet.node, inlet.slot + 1))
133 } else {
134 model.nodes[inlet.node + 1..]
135 .iter()
136 .find(|n| n.inputs.len() > 0)
137 .map(|n| InletId::new(n.id, 0))
138 }
139}