tract_libcli/
model.rs

1use tract_core::internal::*;
2use tract_core::{downcast_rs, dyn_clone};
3
4/// Common methods for all variants of model.
5pub trait Model:
6    downcast_rs::Downcast + std::fmt::Debug + dyn_clone::DynClone + Send + Sync
7{
8    /// Lookup node id by name
9    fn node_id_by_name(&self, name: &str) -> TractResult<usize>;
10
11    /// Node name by id
12    fn node_name(&self, id: usize) -> &str;
13
14    /// Node op by id
15    fn node_op(&self, id: usize) -> &dyn Op;
16
17    /// Node is const
18    fn node_const(&self, id: usize) -> bool;
19
20    /// Node op by id
21    fn node_op_name(&self, id: usize) -> StaticName;
22
23    /// Node inputs by id
24    fn node_inputs(&self, id: usize) -> &[OutletId];
25
26    /// Number of outputs for a node, by id.
27    fn node_output_count(&self, id: usize) -> usize;
28
29    /// Number nodes
30    fn nodes_len(&self) -> usize;
31
32    /// Formatted node label
33    fn node_display(&self, id: usize) -> String;
34
35    /// Formatted node label
36    fn node_debug(&self, id: usize) -> String;
37
38    /// Eval order for the model
39    fn eval_order(&self) -> TractResult<Vec<usize>>;
40
41    /// Eval order for the model
42    fn eval_order_opt_ram(&self) -> TractResult<Vec<usize>>;
43
44    /// Inputs of the model
45    fn input_outlets(&self) -> &[OutletId];
46
47    fn set_input_names(&mut self, names: &[&str]) -> TractResult<()>;
48    fn set_output_names(&mut self, names: &[&str]) -> TractResult<()>;
49
50    /// Outputs of the model
51    fn output_outlets(&self) -> &[OutletId];
52
53    /// Tensorfact for an outlet
54    fn outlet_typedfact(&self, outlet: OutletId) -> TractResult<TypedFact>;
55
56    /// Short outlet formatter (id plus fact)
57    fn outlet_fact_format(&self, outlet: OutletId) -> String;
58
59    /// Labels for an outlet
60    fn outlet_label(&self, id: OutletId) -> Option<&str>;
61
62    /// List consumers of an outlet
63    fn outlet_successors(&self, outlet: OutletId) -> &[InletId];
64
65    /// Subnets of a node
66    fn nested_models(&self, id: usize) -> Vec<(String, &dyn Model)> {
67        if let Some(submodel) =
68            self.node_op(id).downcast_ref::<tract_core::ops::submodel::SubmodelOp>()
69        {
70            return vec![("submodel".into(), submodel.model())];
71        }
72        if let Some(lir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::OptScan>() {
73            return vec![("loop".into(), lir.plan.model())];
74        }
75        if let Some(mir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::Scan>() {
76            return vec![("loop".into(), &mir.body)];
77        }
78        if let Some(mir) = self.node_op(id).downcast_ref::<tract_core::ops::logic::IfThenElse>() {
79            return vec![("then".into(), &mir.then_body), ("else".into(), &mir.else_body)];
80        }
81        #[cfg(feature = "hir")]
82        if let Some(hir) = self.node_op(id).downcast_ref::<tract_hir::ops::scan::InferenceScan>() {
83            return vec![("loop".into(), &hir.body)];
84        }
85        #[cfg(feature = "onnx")]
86        if let Some(hir) = self.node_op(id).downcast_ref::<tract_onnx::ops::logic::If>() {
87            return vec![("then".into(), &hir.then_body), ("else".into(), &hir.else_body)];
88        }
89        vec![]
90    }
91
92    /// Subnets of a node
93    fn nested_models_iters(&self, id: usize, input: &[&TypedFact]) -> Option<TDim> {
94        if let Some(submodel) =
95            self.node_op(id).downcast_ref::<tract_core::ops::submodel::SubmodelOp>()
96        {
97            submodel.iteration_count(input)
98        } else if let Some(lir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::OptScan>()
99        {
100            lir.iteration_count(input)
101        } else if let Some(mir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::Scan>() {
102            mir.iteration_count(input)
103        } else {
104            None
105        }
106    }
107
108    fn auto_outputs(&mut self) -> TractResult<()>;
109
110    fn properties(&self) -> &HashMap<String, Arc<Tensor>>;
111
112    fn symbols(&self) -> &SymbolScope;
113
114    fn get_or_intern_symbol(&self, name: &str) -> Symbol;
115
116    fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()>;
117}
118
119downcast_rs::impl_downcast!(Model);
120dyn_clone::clone_trait_object!(Model);
121
122impl<F, O> Model for Graph<F, O>
123where
124    F: Fact + Hash + Clone + 'static,
125    O: std::fmt::Debug
126        + std::fmt::Display
127        + AsRef<dyn Op>
128        + AsMut<dyn Op>
129        + Clone
130        + 'static
131        + Send
132        + Sync,
133    Graph<F, O>: Send + Sync + 'static,
134{
135    fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
136        self.nodes
137            .iter()
138            .find(|n| n.name == name)
139            .map(|n| n.id)
140            .with_context(|| format!("No node found for name: \"{name}\""))
141    }
142
143    fn node_name(&self, id: usize) -> &str {
144        &self.nodes[id].name
145    }
146
147    fn node_op_name(&self, id: usize) -> StaticName {
148        self.node(id).op().name()
149    }
150
151    fn node_const(&self, id: usize) -> bool {
152        self.node_op_name(id) == "Const"
153    }
154
155    fn node_inputs(&self, id: usize) -> &[OutletId] {
156        &self.nodes[id].inputs
157    }
158
159    fn node_output_count(&self, id: usize) -> usize {
160        self.nodes[id].outputs.len()
161    }
162
163    fn nodes_len(&self) -> usize {
164        self.nodes.len()
165    }
166
167    fn node_display(&self, id: usize) -> String {
168        format!("{}", self.nodes[id])
169    }
170
171    fn node_debug(&self, id: usize) -> String {
172        format!("{:?}", self.nodes[id])
173    }
174
175    fn eval_order(&self) -> TractResult<Vec<usize>> {
176        tract_core::model::order::eval_order(self)
177    }
178
179    fn eval_order_opt_ram(&self) -> TractResult<Vec<usize>> {
180        tract_core::model::order::eval_order_opt_ram(self)
181    }
182
183    fn input_outlets(&self) -> &[OutletId] {
184        &self.inputs
185    }
186
187    fn set_input_names(&mut self, names: &[&str]) -> TractResult<()> {
188        self.set_input_names(names.iter())
189    }
190
191    fn set_output_names(&mut self, names: &[&str]) -> TractResult<()> {
192        self.set_output_names(names)
193    }
194
195    fn output_outlets(&self) -> &[OutletId] {
196        &self.outputs
197    }
198
199    fn node_op(&self, id: usize) -> &dyn Op {
200        self.nodes[id].op.as_ref()
201    }
202
203    fn outlet_typedfact(&self, outlet: OutletId) -> TractResult<TypedFact> {
204        Ok(self.outlet_fact(outlet)?.to_typed_fact()?.into_owned())
205    }
206
207    fn outlet_fact_format(&self, outlet: OutletId) -> String {
208        format!("{:?}", self.outlet_fact(outlet).unwrap())
209    }
210
211    fn outlet_label(&self, id: OutletId) -> Option<&str> {
212        self.outlet_label(id)
213    }
214
215    fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
216        &self.nodes[outlet.node].outputs[outlet.slot].successors
217    }
218
219    fn auto_outputs(&mut self) -> TractResult<()> {
220        self.auto_outputs()
221    }
222
223    fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
224        &self.properties
225    }
226
227    fn symbols(&self) -> &SymbolScope {
228        &self.symbols
229    }
230    fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
231        self.rename_node(id, name)
232    }
233
234    fn get_or_intern_symbol(&self, name: &str) -> Symbol {
235        self.symbols.sym(name)
236    }
237}