1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
use crate::internal::*;
use std::collections::HashSet;
use std::fmt::Debug;
use tract_itertools::Itertools;

pub mod change_axes;
mod op_optim;
mod prop_const;
mod push_split_down;
mod slice;

use self::change_axes::ChangeAxes;
use self::prop_const::PropConst;
use self::push_split_down::PushSplitDown;
use self::slice::PushSliceUp;
use op_optim::OpOptim;

pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
    fn reset(&mut self) -> TractResult<()>;
    fn next(
        &mut self,
        session: &mut OptimizerSession,
        model: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>>;
}

dyn_clone::clone_trait_object!(TypedPass);

#[derive(Debug)]
pub struct Optimizer {
    pub passes: Vec<Box<dyn TypedPass>>,
    pub steps: Option<usize>,
}

impl Optimizer {
    fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
        Optimizer { passes, steps: None }
    }

    pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
        let num_pass = self.passes.len();
        if idx > num_pass {
            log::warn!("Cannot add new pass {:?} at index {}. Optimizer currently as {} passes, pass will be added as the last pass.", pass, idx, num_pass);
            self.passes.push(pass);
        } else {
            self.passes.insert(idx, pass);
        }
    }

    pub fn stopping_at(self, steps: usize) -> Optimizer {
        Optimizer { steps: Some(steps), ..self }
    }

    pub fn declutter() -> Optimizer {
        Optimizer::passes(vec![
            Box::new(PropConst),
            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
            Box::new(PushSliceUp),
            Box::new(PushSplitDown),
            Box::<ChangeAxes>::default(),
        ])
    }

    pub fn codegen() -> Optimizer {
        Optimizer::passes(vec![
            Box::new(PropConst),
            Box::new(OpOptim(
                "codegen",
                |op, _session, model, node| TypedOp::codegen(op, model, node),
                0,
            )),
            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
            Box::new(PushSplitDown),
            Box::new(OpOptim(
                "fuse",
                |op, _session, model, node| TypedOp::fuse(op, model, node),
                0,
            )),
        ])
    }

    pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
        self.session().optimize(model)
    }

    pub fn session(&self) -> OptimizerSession {
        OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
    }
}

#[derive(Debug)]
pub struct OptimizerSession<'o> {
    optimizer: &'o Optimizer,
    counter: usize,
    seen: HashSet<String>,
}

impl<'o> OptimizerSession<'o> {
    pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
        model.check_consistency().context("during optimizer preflight check")?;
        model.compact().context("during optimizer preflight compaction")?;
        model.check_names().context("after optimizer preflight compaction")?;
        for i in 0.. {
            let old = self.counter;
            self.run_all_passes(i, model)?;
            if old == self.counter {
                return Ok(());
            }
            model.compact()?;
        }
        unreachable!()
    }

    pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
        let mut passes = self.optimizer.passes.clone();
        for p in passes.iter_mut() {
            self.run_one_pass_outer(i, p.as_mut(), model)
                .with_context(|| format!("running pass {p:?}"))?;
            model.compact()?;
            model
                .check_consistency()
                .with_context(|| format!("consistency check after pass {p:?}"))?;
        }
        Ok(())
    }

    pub fn run_one_pass_outer(
        &mut self,
        i: usize,
        p: &mut dyn TypedPass,
        model: &mut TypedModel,
    ) -> TractResult<()> {
        loop {
            let old_counter = self.counter;
            self.run_one_pass_inner(i, p, model)?;
            if self.counter == old_counter {
                return Ok(());
            }
            model.compact().with_context(|| format!("after pass {p:?}"))?;
        }
    }

    pub fn run_one_pass_inner(
        &mut self,
        i: usize,
        p: &mut dyn TypedPass,
        model: &mut TypedModel,
    ) -> TractResult<()> {
        p.reset()?;
        if let Some(steps) = self.optimizer.steps {
            if self.counter >= steps {
                return Ok(());
            }
        }
        while let Some(mut patch) = p.next(self, model)? {
            patch.push_context(format!("{p:?}/{i}"));
            patch.model.check_consistency().context("checking patch internal consistency")?;
            model
                .check_consistency()
                .context("Checking target model consistency before patching")?;
            if let Some(watchdog) = patch.dont_apply_twice.take() {
                if self.seen.contains(&watchdog) {
                    debug!("Loop detected: {} seen before", watchdog);
                    continue;
                } else {
                    self.seen.insert(watchdog);
                }
            }
            debug!("applying patch #{}: {}", self.counter, patch.context.iter().rev().join(" >> "),);
            patch.apply(model)?;
            model
                .check_consistency()
                .context("Checking target model consistency after patching")?;
            self.counter += 1;
            if let Some(steps) = self.optimizer.steps {
                if self.counter >= steps {
                    return Ok(());
                }
            }
        }
        model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
        Ok(())
    }
}