tract_core/optim/
concat_then_einsum.rs

1use crate::internal::*;
2
3use crate::ops::array::{Slice, TypedConcat};
4use crate::ops::einsum::EinSum;
5use crate::ops::math::add;
6use crate::optim::OptimizerSession;
7use tract_itertools::Itertools;
8
9#[derive(Clone, Debug, Default)]
10pub struct ConcatThenEinsum(Option<InletId>);
11
12impl super::TypedPass for ConcatThenEinsum {
13    fn reset(&mut self) -> TractResult<()> {
14        self.0 = Default::default();
15        Ok(())
16    }
17
18    #[allow(clippy::comparison_chain)]
19    fn next(
20        &mut self,
21        _session: &mut OptimizerSession,
22        model: &TypedModel,
23    ) -> TractResult<Option<TypedModelPatch>> {
24        'outer: loop {
25            self.0 = if let Some(previous) = self.0 {
26                if let Some(next) = next_inlet(model, &previous) {
27                    Some(next)
28                } else {
29                    return Ok(None);
30                }
31            } else if let Some(first) =
32                model.nodes.iter().find(|n| n.inputs.len() > 0).map(|n| InletId::new(n.id, 0))
33            {
34                Some(first)
35            } else {
36                return Ok(None);
37            };
38            let inlet = self.0.unwrap();
39            let outlet = model.nodes[inlet.node].inputs[inlet.slot];
40            let concat_node = model.node(outlet.node);
41            if model.outputs.contains(&concat_node.id.into()) {
42                continue;
43            }
44            let einsum_node = &model.nodes[inlet.node];
45            if einsum_node.inputs.len() != 2 {
46                // should we try and apply this on quantized einsums ?
47                continue;
48            }
49            if let (Some(concat), Some(einsum)) =
50                (concat_node.op_as::<TypedConcat>(), einsum_node.op_as::<EinSum>())
51            {
52                let offsets = concat.offsets(&model.node_input_facts(concat_node.id)?)?;
53                let axis_info = einsum.axes.axis((InOut::In(inlet.slot), concat.axis))?;
54                // only split if axis is a summing axis
55                if axis_info.outputs[0].len() > 0 {
56                    continue;
57                }
58                let mut patch = TypedModelPatch::new(format!(
59                    "Split Einsum for concat on axis {}",
60                    axis_info.repr
61                ));
62                // inputs[einsum_input_slot][concated_slice]. concated_slice = 0 for broadcast
63                let mut inputs: TVec<TVec<OutletId>> = tvec!();
64                for (slot, input) in einsum_node.inputs.iter().enumerate() {
65                    let tap = patch.tap_model(model, *input)?;
66                    if axis_info.inputs[slot].len() > 1 {
67                        continue 'outer;
68                    } else if axis_info.inputs[slot].len() == 1 {
69                        let mut slices = tvec!();
70                        for (start, end) in offsets.iter().cloned().tuple_windows() {
71                            let wire = patch.wire_node(
72                                format!(
73                                    "{}.concat-einsum-slice-{}.{}.{}..{}",
74                                    einsum_node.name, axis_info.repr, slot, start, end
75                                ),
76                                Slice { axis: axis_info.inputs[slot][0], start, end },
77                                &[tap],
78                            )?;
79                            slices.push(wire[0]);
80                        }
81                        inputs.push(slices);
82                    } else {
83                        inputs.push(tvec!(tap)); // broadcast
84                    };
85                }
86                let mut einsums = tvec!();
87                for (ix, (start, end)) in offsets.iter().tuple_windows().enumerate() {
88                    let mut einsum_inputs = tvec!();
89                    for input_ix in 0..einsum_node.inputs.len() {
90                        einsum_inputs
91                            .push(inputs[input_ix].get(ix).cloned().unwrap_or(inputs[input_ix][0]));
92                    }
93                    let einsum = patch.wire_node(
94                        format!(
95                            "{}.concat-einsum-{}.{}..{}",
96                            einsum_node.name, axis_info.repr, start, end
97                        ),
98                        einsum.clone(),
99                        &einsum_inputs,
100                    )?[0];
101                    einsums.push(einsum);
102                }
103                let wire = if let Some(axis) = axis_info.outputs[0].first().cloned() {
104                    patch.wire_node(
105                        format!("{}.concat-einsum-{}.concat", einsum_node.name, axis_info.repr),
106                        TypedConcat { axis },
107                        &einsums,
108                    )?[0]
109                } else {
110                    let mut wire = einsums[0];
111                    for ix in 1..einsums.len() {
112                        wire = patch.wire_node(
113                            format!(
114                                "{}.concat-einsum-{}.add-{}",
115                                einsum_node.name, axis_info.repr, ix
116                            ),
117                            add(),
118                            &[wire, einsums[ix]],
119                        )?[0]
120                    }
121                    wire
122                };
123                patch.shunt_outside(model, einsum_node.id.into(), wire)?;
124                return Ok(Some(patch));
125            }
126        }
127    }
128}
129
130fn next_inlet(model: &TypedModel, inlet: &InletId) -> Option<InletId> {
131    if inlet.slot + 1 < model.nodes[inlet.node].inputs.len() {
132        Some(InletId::new(inlet.node, inlet.slot + 1))
133    } else {
134        model.nodes[inlet.node + 1..]
135            .iter()
136            .find(|n| n.inputs.len() > 0)
137            .map(|n| InletId::new(n.id, 0))
138    }
139}