Skip to main content

tract_hir/infer/
ops.rs

1use super::Factoid;
2use crate::infer::*;
3use std::fmt;
4use tract_data::TooEarly;
5
6tract_core::dyn_clone::clone_trait_object!(InferenceOp);
7
8/// An operation with tensor type inference
9pub trait InferenceOp: Op {
10    /// Infers properties about the input and output tensors.
11    ///
12    /// The `inputs` and `outputs` arguments correspond to properties about
13    /// the input and output tensors that are already known.
14    ///
15    /// The default implementation will call the private infer_facts method,
16    /// which is usually implemented using the InferenceRulesOp trait. It will
17    /// also try to eval() the op if its a EvalOp and if the inputs are
18    /// fully determined.
19    ///
20    /// Returns Err in case of an unrecoverable error during the inference,
21    /// and the refined properties about the inputs and outputs otherwise.
22    fn infer(
23        &mut self,
24        inputs: TVec<&InferenceFact>,
25        outputs: TVec<&InferenceFact>,
26        observed: TVec<&InferenceFact>,
27    ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
28        let (infered_inputs, infered_outputs, observed) =
29            self.infer_facts(inputs, outputs, observed).context("Infering facts")?;
30
31        if self.is_stateless() && infered_inputs.iter().all(|i| i.value.is_concrete()) {
32            let input_values = infered_inputs
33                .iter()
34                .map(|i| i.value.concretize().unwrap().into_tvalue())
35                .collect(); // checked
36            match self.eval(input_values) {
37                Ok(values) => {
38                    let output_values = values
39                        .into_iter()
40                        .map(|t| t.into_arc_tensor().try_into())
41                        .collect::<TractResult<TVec<_>>>()?;
42                    return Ok((infered_inputs, output_values, observed));
43                }
44                Err(e) if e.root_cause().downcast_ref::<TooEarly>().is_some() => (),
45                Err(e) => return Err(e).context("Eager eval during inference"),
46            }
47        }
48
49        Ok((infered_inputs, infered_outputs, observed))
50    }
51
52    /// Allow an op to specify a supplementary list of outlets facts that
53    /// will trigger inference again.
54    fn observe_outlets(
55        &self,
56        _model: &InferenceModel,
57        _node: &InferenceNode,
58    ) -> TractResult<Vec<OutletId>> {
59        Ok(vec![])
60    }
61
62    /// Infer properties about inputs and output tensors. This method does not
63    /// need to deal with the "trivial" stateless op with fully determined
64    /// inputs cases.
65    ///
66    /// Most of the time, it is implemented using InferenceRulesOp.
67    fn infer_facts(
68        &mut self,
69        inputs: TVec<&InferenceFact>,
70        outputs: TVec<&InferenceFact>,
71        observed: TVec<&InferenceFact>,
72    ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)>;
73
74    /// Early pass on inference model, after analyse, but before translation to
75    /// typed network. Meant to deal with some framework idiosyncrasies that
76    /// manifest with temporaries nodes that can run some form of inference but
77    /// require refactoring the network before it can be evaluated.
78    ///
79    /// Called after succesful analyse, but before translating to typed model.
80    #[allow(unused_variables)]
81    fn incorporate(
82        &self,
83        model: &InferenceModel,
84        node: &InferenceNode,
85    ) -> TractResult<Option<InferenceModelPatch>> {
86        Ok(None)
87    }
88
89    fn nboutputs(&self) -> TractResult<usize> {
90        Ok(1)
91    }
92
93    /// Reinterpret the InferenceOp as an Op.
94    fn as_op(&self) -> &dyn Op;
95
96    /// Reinterpret the InferenceOp as an Op, mutably.
97    fn as_op_mut(&mut self) -> &mut dyn Op;
98
99    /// Called during translation to TypedModel.
100    #[allow(unused_variables)]
101    fn to_typed(
102        &self,
103        source: &InferenceModel,
104        node: &InferenceNode,
105        target: &mut TypedModel,
106        mapping: &HashMap<OutletId, OutletId>,
107    ) -> TractResult<TVec<OutletId>> {
108        bail!("Operator can not be made a TypedOp.")
109    }
110}
111
112impl std::fmt::Display for Box<dyn InferenceOp> {
113    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
114        write!(fmt, "{}", self.name())
115    }
116}
117
118impl<O: InferenceOp> From<O> for Box<dyn InferenceOp> {
119    fn from(it: O) -> Box<dyn InferenceOp> {
120        Box::new(it)
121    }
122}
123
124impl AsRef<dyn Op> for dyn InferenceOp {
125    fn as_ref(&self) -> &dyn Op {
126        self.as_op()
127    }
128}
129
130impl AsRef<dyn Op> for Box<dyn InferenceOp> {
131    fn as_ref(&self) -> &dyn Op {
132        self.as_op()
133    }
134}
135
136impl AsMut<dyn Op> for dyn InferenceOp {
137    fn as_mut(&mut self) -> &mut dyn Op {
138        self.as_op_mut()
139    }
140}
141
142impl AsMut<dyn Op> for Box<dyn InferenceOp> {
143    fn as_mut(&mut self) -> &mut dyn Op {
144        self.as_op_mut()
145    }
146}