tract_core/model/
typed.rs

1use crate::internal::*;
2use crate::model::*;
3use crate::ops;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6use crate::plan::{FrozenSimpleState, SimplePlan, SimpleState};
7use crate::transform::ModelTransform;
8use tract_data::TooEarly;
9use tract_linalg::block_quant::BlockQuantValue;
10use tract_num_traits::Zero;
11
12/// A model with completely determined types and shapes.
13pub type TypedModel = Graph<TypedFact, Box<dyn TypedOp>>;
14/// Node for TypedModel graph
15pub type TypedNode = Node<TypedFact, Box<dyn TypedOp>>;
16/// A ModelPatch for TypedModel.
17pub type TypedModelPatch = ModelPatch<TypedFact, Box<dyn TypedOp>>;
18/// An execution plan for TypedModel.
19pub type TypedSimplePlan<M> = SimplePlan<TypedFact, Box<dyn TypedOp>, M>;
20/// A runnable TypedModel (new name for SimplePlan).
21pub type TypedRunnableModel<M> = RunnableModel<TypedFact, Box<dyn TypedOp>, M>;
22/// An execution state for TypedModel.
23pub type TypedSimpleState<M, P> = SimpleState<TypedFact, Box<dyn TypedOp>, M, P>;
24/// An execution state for TypedModel, frozen (and Send).
25pub type TypedFrozenSimpleState<M, P> = FrozenSimpleState<TypedFact, Box<dyn TypedOp>, M, P>;
26
27/// A runnable model with fixed inputs and outputs.
28pub type RunnableModel<F, O, M> = SimplePlan<F, O, M>;
29
30impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
31    fn is_source(op: &Box<dyn TypedOp>) -> bool {
32        op.as_op().downcast_ref::<ops::source::TypedSource>().is_some()
33    }
34
35    fn create_dummy(&self) -> Box<dyn TypedOp> {
36        Box::new(crate::ops::dummy::Dummy::new())
37    }
38
39    fn create_source(&self, fact: TypedFact) -> Box<dyn TypedOp> {
40        Box::new(crate::ops::source::TypedSource::new(fact))
41    }
42
43    fn wire_node(
44        &mut self,
45        name: impl Into<String>,
46        op: impl Into<Box<dyn TypedOp>>,
47        inputs: &[OutletId],
48    ) -> TractResult<TVec<OutletId>> {
49        let op = op.into();
50        let name = name.into();
51        if let Some(konst) = op.downcast_ref::<Const>() {
52            for node in &self.nodes {
53                if node.op_as::<Const>().is_some_and(|other| other == konst) {
54                    return Ok(tvec!(node.id.into()));
55                }
56            }
57        }
58        if self.nodes.iter().any(|n| n.name == name) {
59            bail!("Duplicate node name: {name}");
60        }
61        {
62            let input_facts = inputs
63                .iter()
64                .map(|o| self.outlet_fact(*o).cloned())
65                .collect::<TractResult<TVec<_>>>()?;
66
67            let input_facts: TVec<_> = input_facts.iter().collect();
68            let mut output_facts = op
69                .output_facts(&input_facts)
70                .with_context(|| format!("in output_facts invocation for {name}: {}", op.name()))?;
71
72            #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
73            for o in &output_facts {
74                o.consistent()?;
75            }
76
77            if op.is_stateless() && input_facts.len() > 0 {
78                if let Some(tensors) = input_facts
79                    .iter()
80                    .map(|f| {
81                        f.konst
82                            .as_ref()
83                            .filter(|k| k.volume() < 16 && !k.datum_type().is_opaque())
84                            .cloned()
85                            .map(|t| t.into_tvalue())
86                    })
87                    .collect::<Option<TVec<_>>>()
88                {
89                    if let Ok(outputs) = op.eval_with_session(usize::MAX, &SessionState::default(), tensors) {
90                        return outputs
91                            .into_iter()
92                            .enumerate()
93                            .map(|(ix, o)| {
94                                let name =
95                                    if ix == 0 { name.clone() } else { format!("{name}.{ix}") };
96                                self.wire_node(
97                                    name.clone(),
98                                    Const::new_with_opt_opaque_fact(
99                                        o.into_tensor().into(),
100                                        output_facts[ix].opaque_fact.clone(),
101                                    )?,
102                                    &[],
103                                )
104                                .with_context(|| format!("Eager const-folding {name}"))
105                                .map(|vec| vec[0])
106                            })
107                            .collect::<TractResult<TVec<OutletId>>>();
108                    }
109                }
110            }
111
112            for fact in &mut output_facts {
113                if fact.konst.is_none() && fact.shape.is_concrete() && fact.shape.volume().is_zero()
114                {
115                    let tensor =
116                        Tensor::zero_dt(fact.datum_type, fact.shape.as_concrete().unwrap())?;
117                    fact.konst = Some(tensor.into_arc_tensor());
118                }
119            }
120            let id = self.add_node(&name, &op, output_facts)?;
121            inputs
122                .iter()
123                .enumerate()
124                .try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
125            TractResult::Ok(
126                self.node(id)
127                    .outputs
128                    .iter()
129                    .enumerate()
130                    .map(|(ix, _)| OutletId::new(id, ix))
131                    .collect(),
132            )
133        }
134        .with_context(|| format!("Wiring node \"{name}\", {op:?}"))
135    }
136
137    fn add_const(
138        &mut self,
139        name: impl Into<String>,
140        v: impl IntoArcTensor,
141    ) -> TractResult<OutletId> {
142        let v = v.into_arc_tensor();
143        for node in &self.nodes {
144            if node.op_is::<Const>() && node.outputs[0].fact.konst.as_ref() == Some(&v) {
145                return Ok(node.id.into());
146            }
147        }
148        let mut fact = TypedFact::from(v.clone());
149        let name = name.into();
150        // this feel incredibly hackish and dirty...
151        if v.datum_type().is_opaque() && v.volume() == 1 {
152            if let Some(bqv) = v.as_slice::<Opaque>()?[0].downcast_ref::<BlockQuantValue>() {
153                let opaque = Box::new(bqv.fact.clone());
154                fact.opaque_fact = Some(opaque.clone());
155                return self
156                    .add_node(
157                        name,
158                        crate::ops::konst::Const::new_with_opaque_fact(v, opaque)?,
159                        tvec!(fact),
160                    )
161                    .map(|id| id.into());
162            }
163        }
164        self.add_node(name, crate::ops::konst::Const::new(v)?, tvec!(fact)).map(|id| id.into())
165    }
166}
167
168impl TypedModel {
169    pub fn into_optimized(mut self) -> TractResult<TypedModel> {
170        self.declutter()?;
171        self.optimize()?;
172        Ok(self)
173    }
174    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
175    #[inline]
176    pub fn check_consistency(&self) -> TractResult<()> {
177        Ok(())
178    }
179
180    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
181    pub fn check_consistency(&self) -> TractResult<()> {
182        self.check_edges()?;
183        for node_id in &self.eval_order()? {
184            let input_facts = self.node_input_facts(*node_id)?;
185            let node = &self.nodes[*node_id];
186            if node.id != *node_id {
187                bail!("Node at position {} has id {}", node_id, node.id);
188            }
189            let output_facts = node.op.output_facts(&input_facts)?;
190            if node.outputs.len() != output_facts.len() {
191                bail!(
192                    "Inconsistent model, node output count mismatch. Op says {}, node says {}. {}",
193                    output_facts.len(),
194                    node.outputs.len(),
195                    node
196                );
197            }
198            if node
199                .outputs
200                .iter()
201                .map(|o| &o.fact)
202                .zip(output_facts.iter())
203                .any(|(a, b)| a.datum_type != b.datum_type || a.shape != b.shape)
204            {
205                bail!(
206                    "Inconsistent model, output types mismatch. Op says: {:?}, node says: {:?}. {} with inputs {:?}. {}",
207                    output_facts, node.outputs.iter().map(|o| &o.fact).collect::<Vec<_>>(), node, input_facts, node)
208            }
209            /* this is not true for regularly packed values
210            if let Some(k) = node.op_as::<Const>() {
211                ensure!(
212                    !k.0.datum_type().is_opaque() || k.1.is_some(),
213                    "Node {node} is missing an opaque fact"
214                );
215            }
216            */
217        }
218        for node in &self.nodes {
219            for (ix, output) in node.outputs.iter().enumerate() {
220                output.fact.consistent().with_context(|| {
221                    format!("Inconsistent fact {:?}: {:?}", OutletId::new(node.id, ix), output.fact)
222                })?
223            }
224        }
225        self.axes_mapping().context("Checking model axes mapping")?;
226        Ok(())
227    }
228
229    pub fn into_decluttered(mut self) -> TractResult<TypedModel> {
230        self.declutter()?;
231        Ok(self)
232    }
233
234    /// Perform declutter passes on the network.
235    pub fn transform(&mut self, transform: &dyn ModelTransform) -> TractResult<()> {
236        transform.transform(self)
237    }
238
239    /// Perform declutter passes on the network.
240    pub fn declutter(&mut self) -> TractResult<()> {
241        crate::optim::Optimizer::declutter().session().optimize(self)
242    }
243
244    /// Perform optimization passes on the model, using a given optimizer session.
245    pub fn optimize_with_session(&mut self, session: &mut OptimizerSession) -> TractResult<()> {
246        session.optimize(self)?;
247        self.properties.insert("tract_stage".to_string(), rctensor0("optimized".to_string()));
248        Ok(())
249    }
250
251    pub fn concretize_dims(&self, values: &SymbolValues) -> TractResult<TypedModel> {
252        values.translate_model(self)
253    }
254
255    pub fn prop_consts(&mut self) -> TractResult<()> {
256        crate::optim::Optimizer::prop_consts().optimize(self)
257    }
258
259    /// Translate the graph to locally optimized operators (LIR or MIR ops).
260    pub fn optimize(&mut self) -> TractResult<()> {
261        crate::optim::Optimizer::codegen().optimize(self)
262    }
263
264    pub fn node_axes_mapping(&self, id: usize) -> TractResult<AxesMapping> {
265        let (inputs, outputs) = self.node_facts(id)?;
266        self.nodes[id].op.axes_mapping(&inputs, &outputs)
267    }
268
269    pub fn axes_mapping(&self) -> TractResult<AxesMapping> {
270        crate::axes::for_model(self)
271    }
272
273    pub fn compute_const_facts(&mut self) -> TractResult<()> {
274        for n in self.eval_order()? {
275            let node = self.node(n);
276            let (inputs, outputs) = self.node_facts(n)?;
277            if node.op.is_stateless()
278                && inputs.iter().all(|i| i.konst.is_some())
279                && outputs.iter().any(|o| o.konst.is_none())
280            {
281                let inputs_ref =
282                    inputs.iter().map(|f| f.konst.clone().unwrap().into_tvalue()).collect();
283                match node.op.eval_with_session(node.id, &SessionState::default(), inputs_ref) {
284                    Ok(res) => {
285                        drop(inputs);
286                        drop(outputs);
287                        for (ix, output) in res.into_iter().enumerate() {
288                            self.nodes[n].outputs[ix].fact.konst = Some(output.into_arc_tensor());
289                        }
290                    }
291                    Err(e) => {
292                        if !e.root_cause().is::<TooEarly>() {
293                            Err(e).with_context(|| {
294                                format!("Eager eval {} during const fact computation", self.node(n))
295                            })?;
296                        }
297                    }
298                }
299            }
300        }
301        Ok(())
302    }
303}
304
305use crate::model::translator::Translate;
306impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for SymbolValues {
307    fn translate_node(
308        &self,
309        source: &TypedModel,
310        node: &TypedNode,
311        target: &mut TypedModel,
312        mapping: &HashMap<OutletId, OutletId>,
313    ) -> TractResult<TVec<OutletId>> {
314        target.check_consistency()?;
315        let outlets = node.op.concretize_dims(source, node, target, mapping, self)?;
316        for &outlet in &outlets {
317            let fact = &mut target.nodes[outlet.node].outputs[outlet.slot].fact;
318            if fact.shape.volume().is_zero() {
319                if let Some(shape) = fact.shape.as_concrete() {
320                    let tensor = Tensor::zero_dt(fact.datum_type, shape)?;
321                    fact.konst = Some(tensor.into_arc_tensor());
322                }
323            }
324            fact.consistent()?;
325        }
326        Ok(outlets)
327    }
328}
329
330#[cfg(test)]
331mod test {
332    use super::*;
333
334    #[test]
335    fn test() {
336        fn is_sync<T: Sync>() {}
337        is_sync::<TypedModel>();
338    }
339}