tract_core/optim/
push_split_down.rs

1use crate::internal::*;
2
3use crate::optim::OptimizerSession;
4use tract_itertools::Itertools;
5
6#[derive(Clone, Debug)]
7pub struct PushSplitDown;
8
9impl super::TypedPass for PushSplitDown {
10    fn reset(&mut self) -> TractResult<()> {
11        Ok(())
12    }
13    fn next(
14        &mut self,
15        _session: &mut OptimizerSession,
16        model: &TypedModel,
17    ) -> TractResult<Option<TypedModelPatch>> {
18        let mut patch = TypedModelPatch::default();
19        for node in model.eval_order()? {
20            for output in &model.node(node).outputs {
21                for (a, b) in output.successors.iter().tuple_combinations() {
22                    if a.node == b.node {
23                        // found where a square is implemented using a mul with duplicate input
24                        continue;
25                    }
26                    if patch.obliterate.contains(&b.node) {
27                        continue;
28                    }
29                    // dont merge outputs.
30                    if model.outputs.contains(&a.node.into())
31                        && model.outputs.contains(&b.node.into())
32                    {
33                        continue;
34                    }
35                    let a = model.node(a.node);
36                    let b = model.node(b.node);
37                    if a.same_as(b) {
38                        for slot in 0..b.outputs.len() {
39                            let tap = patch.tap_model(model, OutletId::new(a.id, slot))?;
40                            patch.shunt_outside(model, OutletId::new(b.id, slot), tap)?;
41                            patch.obliterate(b.id)?;
42                        }
43                    }
44                }
45            }
46        }
47        Ok(Some(patch).filter(|p| !p.is_empty()))
48    }
49}