1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10mod push_split_down;
11mod slice;
12
13use self::change_axes::ChangeAxes;
14use self::prop_const::PropConst;
15use self::push_split_down::PushSplitDown;
16use self::slice::PushSliceUp;
17use op_optim::OpOptim;
18
19pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
20 fn reset(&mut self) -> TractResult<()>;
21 fn next(
22 &mut self,
23 session: &mut OptimizerSession,
24 model: &TypedModel,
25 ) -> TractResult<Option<TypedModelPatch>>;
26}
27
28dyn_clone::clone_trait_object!(TypedPass);
29
30#[derive(Debug)]
31pub struct Optimizer {
32 pub passes: Vec<Box<dyn TypedPass>>,
33 pub steps: Option<usize>,
34}
35
36impl Optimizer {
37 fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
38 Optimizer { passes, steps: None }
39 }
40
41 pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
42 let num_pass = self.passes.len();
43 if idx > num_pass {
44 log::warn!("Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass.");
45 self.passes.push(pass);
46 } else {
47 self.passes.insert(idx, pass);
48 }
49 }
50
51 pub fn stopping_at(self, steps: usize) -> Optimizer {
52 Optimizer { steps: Some(steps), ..self }
53 }
54
55 pub fn prop_consts() -> Optimizer {
56 Optimizer::passes(vec![Box::<PropConst>::default()])
57 }
58
59 pub fn declutter() -> Optimizer {
60 Optimizer::passes(vec![
61 Box::<PropConst>::default(),
62 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
63 Box::new(PushSliceUp),
64 Box::new(PushSplitDown),
65 Box::<concat_then_einsum::ConcatThenEinsum>::default(),
66 Box::<ChangeAxes>::default(),
67 ])
68 }
69
70 pub fn codegen() -> Optimizer {
71 Optimizer::passes(vec![
72 Box::<PropConst>::default(),
73 Box::new(OpOptim(
74 "codegen",
75 |op, _session, model, node| TypedOp::codegen(op, model, node),
76 0,
77 )),
78 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
79 Box::new(PushSplitDown),
80 Box::new(OpOptim(
81 "fuse",
82 |op, _session, model, node| TypedOp::fuse(op, model, node),
83 0,
84 )),
85 ])
86 }
87
88 pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
89 self.session().optimize(model)
90 }
91
92 pub fn session(&self) -> OptimizerSession<'_> {
93 OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
94 }
95}
96
97#[derive(Debug)]
98pub struct OptimizerSession<'o> {
99 optimizer: &'o Optimizer,
100 counter: usize,
101 seen: HashSet<String>,
102}
103
104impl OptimizerSession<'_> {
105 pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
106 model.check_consistency().context("during optimizer preflight check")?;
107 model.compact().context("during optimizer preflight compaction")?;
108 model.check_names().context("after optimizer preflight compaction")?;
109 for i in 0.. {
110 let old = self.counter;
111 self.run_all_passes(i, model)?;
112 if old == self.counter {
113 return Ok(());
114 }
115 model.compact()?;
116 }
117 unreachable!()
118 }
119
120 pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
121 let mut passes = self.optimizer.passes.clone();
122 for p in passes.iter_mut() {
123 self.run_one_pass_outer(i, p.as_mut(), model)
124 .with_context(|| format!("running pass {p:?}"))?;
125 model.compact()?;
126 model
127 .check_consistency()
128 .with_context(|| format!("consistency check after pass {p:?}"))?;
129 }
130 Ok(())
131 }
132
133 pub fn run_one_pass_outer(
134 &mut self,
135 i: usize,
136 p: &mut dyn TypedPass,
137 model: &mut TypedModel,
138 ) -> TractResult<()> {
139 loop {
140 let old_counter = self.counter;
141 self.run_one_pass_inner(i, p, model)?;
142 if self.counter == old_counter {
143 return Ok(());
144 }
145 model.compact().with_context(|| format!("after pass {p:?}"))?;
146 }
147 }
148
149 pub fn run_one_pass_inner(
150 &mut self,
151 i: usize,
152 p: &mut dyn TypedPass,
153 model: &mut TypedModel,
154 ) -> TractResult<()> {
155 p.reset()?;
156 if let Some(steps) = self.optimizer.steps {
157 if self.counter >= steps {
158 return Ok(());
159 }
160 }
161 while let Some(mut patch) = p.next(self, model)? {
162 patch.push_context(format!("{p:?}/{i}"));
163 patch.model.check_consistency().context("checking patch internal consistency")?;
164 model
165 .check_consistency()
166 .context("Checking target model consistency before patching")?;
167 if let Some(watchdog) = patch.dont_apply_twice.take() {
168 if self.seen.contains(&watchdog) {
169 debug!("Loop detected: {watchdog} seen before");
170 continue;
171 } else {
172 self.seen.insert(watchdog);
173 }
174 }
175 let patch_name = patch.context.iter().rev().join(" >> ");
176 debug!("applying patch #{}: {patch_name}", self.counter);
177 patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
178 model
179 .check_consistency()
180 .context("Checking target model consistency after patching")?;
181 self.counter += 1;
182 if let Some(steps) = self.optimizer.steps {
183 if self.counter >= steps {
184 return Ok(());
185 }
186 }
187 }
188 model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
189 Ok(())
190 }
191}