1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use crate::model::TypedModel;
use crate::TractResult;
use std::fmt::Debug;

mod prop_const;
mod push_split_down;

use self::prop_const::PropConst;
use self::push_split_down::PushSplitDown;

pub trait DeclutterPass: Debug + Send + Sync {
    fn pass(&self, model: &mut TypedModel) -> TractResult<bool>;
}

pub trait CodegenPass: Debug + Send + Sync {
    fn pass(&self, model: &mut TypedModel) -> TractResult<bool>;
}

pub fn declutter() -> Vec<Box<DeclutterPass>> {
    vec![Box::new(PropConst) as _, Box::new(NormalizeOps)]
}

pub fn codegen() -> Vec<Box<CodegenPass>> {
    vec![Box::new(CodegenOps), Box::new(PushSplitDown)]
}

#[derive(Debug)]
pub struct NormalizeOps;

impl DeclutterPass for NormalizeOps {
    fn pass(&self, model: &mut TypedModel) -> TractResult<bool> {
        let mut done_something = false;
        loop {
            let mut done_something_this_time = false;
            for id in model.eval_order()? {
                let reduced = {
                    let node = &model.nodes()[id];
                    debug!("Decluttering {}", node);
                    node.op
                        .declutter(model, node)
                        .map_err(|e| format!("{:?} node {}, {:?}", self, node, e))?
                };
                if let Some(red) = reduced {
                    {
                        let node = &model.nodes()[id];
                        debug!("Apply a model patch for {:?}: {}", self, node);
                    }
                    red.apply(model)?;
                    if cfg!(debug_assertions) {
                        model.check_edges()?;
                    }
                    done_something_this_time = true
                }
            }
            done_something = done_something || done_something_this_time;
            if !done_something_this_time {
                break;
            }
        }
        Ok(done_something)
    }
}

#[derive(Debug)]
pub struct CodegenOps;

impl CodegenPass for CodegenOps {
    fn pass(&self, model: &mut TypedModel) -> TractResult<bool> {
        let mut done_something = false;
        loop {
            let mut done_something_this_time = false;
            for id in model.eval_order()? {
                let reduced = {
                    let node = &model.nodes()[id];
                    debug!("Codegen {}", node);
                    node.op
                        .codegen(model, node)
                        .map_err(|e| format!("{:?} node {}, {:?}", self, node, e))?
                };
                if let Some(red) = reduced {
                    {
                        let node = &model.nodes()[id];
                        debug!("Apply a model patch for {:?} {}", self, node);
                    }
                    red.apply(model)?;
                    if cfg!(debug_assertions) {
                        model.check_edges()?;
                    }
                    done_something_this_time = true
                }
            }
            done_something = done_something || done_something_this_time;
            if !done_something_this_time {
                break;
            }
        }
        Ok(done_something)
    }
}