tract_core/optim/
mod.rs

1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10mod push_split_down;
11mod slice;
12
13use self::change_axes::ChangeAxes;
14use self::prop_const::PropConst;
15use self::push_split_down::PushSplitDown;
16use self::slice::PushSliceUp;
17use op_optim::OpOptim;
18
19pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
20    fn reset(&mut self) -> TractResult<()>;
21    fn next(
22        &mut self,
23        session: &mut OptimizerSession,
24        model: &TypedModel,
25    ) -> TractResult<Option<TypedModelPatch>>;
26}
27
28dyn_clone::clone_trait_object!(TypedPass);
29
30#[derive(Debug)]
31pub struct Optimizer {
32    pub passes: Vec<Box<dyn TypedPass>>,
33    pub steps: Option<usize>,
34}
35
36impl Optimizer {
37    fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
38        Optimizer { passes, steps: None }
39    }
40
41    pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
42        let num_pass = self.passes.len();
43        if idx > num_pass {
44            log::warn!("Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass.");
45            self.passes.push(pass);
46        } else {
47            self.passes.insert(idx, pass);
48        }
49    }
50
51    pub fn stopping_at(self, steps: usize) -> Optimizer {
52        Optimizer { steps: Some(steps), ..self }
53    }
54
55    pub fn prop_consts() -> Optimizer {
56        Optimizer::passes(vec![Box::<PropConst>::default()])
57    }
58
59    pub fn declutter() -> Optimizer {
60        Optimizer::passes(vec![
61            Box::<PropConst>::default(),
62            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
63            Box::new(PushSliceUp),
64            Box::new(PushSplitDown),
65            Box::<concat_then_einsum::ConcatThenEinsum>::default(),
66            Box::<ChangeAxes>::default(),
67        ])
68    }
69
70    pub fn codegen() -> Optimizer {
71        Optimizer::passes(vec![
72            Box::<PropConst>::default(),
73            Box::new(OpOptim(
74                "codegen",
75                |op, _session, model, node| TypedOp::codegen(op, model, node),
76                0,
77            )),
78            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
79            Box::new(PushSplitDown),
80            Box::new(OpOptim(
81                "fuse",
82                |op, _session, model, node| TypedOp::fuse(op, model, node),
83                0,
84            )),
85        ])
86    }
87
88    pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
89        self.session().optimize(model)
90    }
91
92    pub fn session(&self) -> OptimizerSession<'_> {
93        OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
94    }
95}
96
97#[derive(Debug)]
98pub struct OptimizerSession<'o> {
99    optimizer: &'o Optimizer,
100    counter: usize,
101    seen: HashSet<String>,
102}
103
104impl OptimizerSession<'_> {
105    pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
106        model.check_consistency().context("during optimizer preflight check")?;
107        model.compact().context("during optimizer preflight compaction")?;
108        model.check_names().context("after optimizer preflight compaction")?;
109        for i in 0.. {
110            let old = self.counter;
111            self.run_all_passes(i, model)?;
112            if old == self.counter {
113                return Ok(());
114            }
115            model.compact()?;
116        }
117        unreachable!()
118    }
119
120    pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
121        let mut passes = self.optimizer.passes.clone();
122        for p in passes.iter_mut() {
123            self.run_one_pass_outer(i, p.as_mut(), model)
124                .with_context(|| format!("running pass {p:?}"))?;
125            model.compact()?;
126            model
127                .check_consistency()
128                .with_context(|| format!("consistency check after pass {p:?}"))?;
129        }
130        Ok(())
131    }
132
133    pub fn run_one_pass_outer(
134        &mut self,
135        i: usize,
136        p: &mut dyn TypedPass,
137        model: &mut TypedModel,
138    ) -> TractResult<()> {
139        loop {
140            let old_counter = self.counter;
141            self.run_one_pass_inner(i, p, model)?;
142            if self.counter == old_counter {
143                return Ok(());
144            }
145            model.compact().with_context(|| format!("after pass {p:?}"))?;
146        }
147    }
148
149    pub fn run_one_pass_inner(
150        &mut self,
151        i: usize,
152        p: &mut dyn TypedPass,
153        model: &mut TypedModel,
154    ) -> TractResult<()> {
155        p.reset()?;
156        if let Some(steps) = self.optimizer.steps {
157            if self.counter >= steps {
158                return Ok(());
159            }
160        }
161        while let Some(mut patch) = p.next(self, model)? {
162            patch.push_context(format!("{p:?}/{i}"));
163            patch.model.check_consistency().context("checking patch internal consistency")?;
164            model
165                .check_consistency()
166                .context("Checking target model consistency before patching")?;
167            if let Some(watchdog) = patch.dont_apply_twice.take() {
168                if self.seen.contains(&watchdog) {
169                    debug!("Loop detected: {watchdog} seen before");
170                    continue;
171                } else {
172                    self.seen.insert(watchdog);
173                }
174            }
175            let patch_name = patch.context.iter().rev().join(" >> ");
176            debug!("applying patch #{}: {patch_name}", self.counter);
177            patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
178            model
179                .check_consistency()
180                .context("Checking target model consistency after patching")?;
181            self.counter += 1;
182            if let Some(steps) = self.optimizer.steps {
183                if self.counter >= steps {
184                    return Ok(());
185                }
186            }
187        }
188        model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
189        Ok(())
190    }
191}