1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
use crate::model::TypedModel; use crate::TractResult; use std::fmt::Debug; mod prop_const; mod push_split_down; use self::prop_const::PropConst; use self::push_split_down::PushSplitDown; pub trait DeclutterPass: Debug + Send + Sync { fn pass(&self, model: &mut TypedModel) -> TractResult<bool>; } pub trait CodegenPass: Debug + Send + Sync { fn pass(&self, model: &mut TypedModel) -> TractResult<bool>; } pub fn declutter() -> Vec<Box<DeclutterPass>> { vec![Box::new(PropConst) as _, Box::new(NormalizeOps)] } pub fn codegen() -> Vec<Box<CodegenPass>> { vec![Box::new(CodegenOps), Box::new(PushSplitDown)] } #[derive(Debug)] pub struct NormalizeOps; impl DeclutterPass for NormalizeOps { fn pass(&self, model: &mut TypedModel) -> TractResult<bool> { let mut done_something = false; loop { let mut done_something_this_time = false; for id in model.eval_order()? { let reduced = { let node = &model.nodes()[id]; debug!("Decluttering {}", node); node.op .declutter(model, node) .map_err(|e| format!("{:?} node {}, {:?}", self, node, e))? }; if let Some(red) = reduced { { let node = &model.nodes()[id]; debug!("Apply a model patch for {:?}: {}", self, node); } red.apply(model)?; if cfg!(debug_assertions) { model.check_edges()?; } done_something_this_time = true } } done_something = done_something || done_something_this_time; if !done_something_this_time { break; } } Ok(done_something) } } #[derive(Debug)] pub struct CodegenOps; impl CodegenPass for CodegenOps { fn pass(&self, model: &mut TypedModel) -> TractResult<bool> { let mut done_something = false; loop { let mut done_something_this_time = false; for id in model.eval_order()? { let reduced = { let node = &model.nodes()[id]; debug!("Codegen {}", node); node.op .codegen(model, node) .map_err(|e| format!("{:?} node {}, {:?}", self, node, e))? }; if let Some(red) = reduced { { let node = &model.nodes()[id]; debug!("Apply a model patch for {:?} {}", self, node); } red.apply(model)?; if cfg!(debug_assertions) { model.check_edges()?; } done_something_this_time = true } } done_something = done_something || done_something_this_time; if !done_something_this_time { break; } } Ok(done_something) } }