1use crate::internal::*;
2use crate::model::*;
3use crate::ops;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6use crate::plan::{FrozenSimpleState, SimplePlan, SimpleState};
7use crate::transform::ModelTransform;
8use tract_data::TooEarly;
9use tract_linalg::block_quant::BlockQuantValue;
10use tract_num_traits::Zero;
11
12pub type TypedModel = Graph<TypedFact, Box<dyn TypedOp>>;
14pub type TypedNode = Node<TypedFact, Box<dyn TypedOp>>;
16pub type TypedModelPatch = ModelPatch<TypedFact, Box<dyn TypedOp>>;
18pub type TypedSimplePlan<M> = SimplePlan<TypedFact, Box<dyn TypedOp>, M>;
20pub type TypedRunnableModel<M> = RunnableModel<TypedFact, Box<dyn TypedOp>, M>;
22pub type TypedSimpleState<M, P> = SimpleState<TypedFact, Box<dyn TypedOp>, M, P>;
24pub type TypedFrozenSimpleState<M, P> = FrozenSimpleState<TypedFact, Box<dyn TypedOp>, M, P>;
26
27pub type RunnableModel<F, O, M> = SimplePlan<F, O, M>;
29
30impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
31 fn is_source(op: &Box<dyn TypedOp>) -> bool {
32 op.as_op().downcast_ref::<ops::source::TypedSource>().is_some()
33 }
34
35 fn create_dummy(&self) -> Box<dyn TypedOp> {
36 Box::new(crate::ops::dummy::Dummy::new())
37 }
38
39 fn create_source(&self, fact: TypedFact) -> Box<dyn TypedOp> {
40 Box::new(crate::ops::source::TypedSource::new(fact))
41 }
42
43 fn wire_node(
44 &mut self,
45 name: impl Into<String>,
46 op: impl Into<Box<dyn TypedOp>>,
47 inputs: &[OutletId],
48 ) -> TractResult<TVec<OutletId>> {
49 let op = op.into();
50 let name = name.into();
51 if let Some(konst) = op.downcast_ref::<Const>() {
52 for node in &self.nodes {
53 if node.op_as::<Const>().is_some_and(|other| other == konst) {
54 return Ok(tvec!(node.id.into()));
55 }
56 }
57 }
58 if self.nodes.iter().any(|n| n.name == name) {
59 bail!("Duplicate node name: {name}");
60 }
61 {
62 let input_facts = inputs
63 .iter()
64 .map(|o| self.outlet_fact(*o).cloned())
65 .collect::<TractResult<TVec<_>>>()?;
66
67 let input_facts: TVec<_> = input_facts.iter().collect();
68 let mut output_facts = op
69 .output_facts(&input_facts)
70 .with_context(|| format!("in output_facts invocation for {name}: {}", op.name()))?;
71
72 #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
73 for o in &output_facts {
74 o.consistent()?;
75 }
76
77 if op.is_stateless() && input_facts.len() > 0 {
78 if let Some(tensors) = input_facts
79 .iter()
80 .map(|f| {
81 f.konst
82 .as_ref()
83 .filter(|k| k.volume() < 16 && !k.datum_type().is_opaque())
84 .cloned()
85 .map(|t| t.into_tvalue())
86 })
87 .collect::<Option<TVec<_>>>()
88 {
89 if let Ok(outputs) = op.eval_with_session(usize::MAX, &SessionState::default(), tensors) {
90 return outputs
91 .into_iter()
92 .enumerate()
93 .map(|(ix, o)| {
94 let name =
95 if ix == 0 { name.clone() } else { format!("{name}.{ix}") };
96 self.wire_node(
97 name.clone(),
98 Const::new_with_opt_opaque_fact(
99 o.into_tensor().into(),
100 output_facts[ix].opaque_fact.clone(),
101 )?,
102 &[],
103 )
104 .with_context(|| format!("Eager const-folding {name}"))
105 .map(|vec| vec[0])
106 })
107 .collect::<TractResult<TVec<OutletId>>>();
108 }
109 }
110 }
111
112 for fact in &mut output_facts {
113 if fact.konst.is_none() && fact.shape.is_concrete() && fact.shape.volume().is_zero()
114 {
115 let tensor =
116 Tensor::zero_dt(fact.datum_type, fact.shape.as_concrete().unwrap())?;
117 fact.konst = Some(tensor.into_arc_tensor());
118 }
119 }
120 let id = self.add_node(&name, &op, output_facts)?;
121 inputs
122 .iter()
123 .enumerate()
124 .try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
125 TractResult::Ok(
126 self.node(id)
127 .outputs
128 .iter()
129 .enumerate()
130 .map(|(ix, _)| OutletId::new(id, ix))
131 .collect(),
132 )
133 }
134 .with_context(|| format!("Wiring node \"{name}\", {op:?}"))
135 }
136
137 fn add_const(
138 &mut self,
139 name: impl Into<String>,
140 v: impl IntoArcTensor,
141 ) -> TractResult<OutletId> {
142 let v = v.into_arc_tensor();
143 for node in &self.nodes {
144 if node.op_is::<Const>() && node.outputs[0].fact.konst.as_ref() == Some(&v) {
145 return Ok(node.id.into());
146 }
147 }
148 let mut fact = TypedFact::from(v.clone());
149 let name = name.into();
150 if v.datum_type().is_opaque() && v.volume() == 1 {
152 if let Some(bqv) = v.as_slice::<Opaque>()?[0].downcast_ref::<BlockQuantValue>() {
153 let opaque = Box::new(bqv.fact.clone());
154 fact.opaque_fact = Some(opaque.clone());
155 return self
156 .add_node(
157 name,
158 crate::ops::konst::Const::new_with_opaque_fact(v, opaque)?,
159 tvec!(fact),
160 )
161 .map(|id| id.into());
162 }
163 }
164 self.add_node(name, crate::ops::konst::Const::new(v)?, tvec!(fact)).map(|id| id.into())
165 }
166}
167
168impl TypedModel {
169 pub fn into_optimized(mut self) -> TractResult<TypedModel> {
170 self.declutter()?;
171 self.optimize()?;
172 Ok(self)
173 }
174 #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
175 #[inline]
176 pub fn check_consistency(&self) -> TractResult<()> {
177 Ok(())
178 }
179
180 #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
181 pub fn check_consistency(&self) -> TractResult<()> {
182 self.check_edges()?;
183 for node_id in &self.eval_order()? {
184 let input_facts = self.node_input_facts(*node_id)?;
185 let node = &self.nodes[*node_id];
186 if node.id != *node_id {
187 bail!("Node at position {} has id {}", node_id, node.id);
188 }
189 let output_facts = node.op.output_facts(&input_facts)?;
190 if node.outputs.len() != output_facts.len() {
191 bail!(
192 "Inconsistent model, node output count mismatch. Op says {}, node says {}. {}",
193 output_facts.len(),
194 node.outputs.len(),
195 node
196 );
197 }
198 if node
199 .outputs
200 .iter()
201 .map(|o| &o.fact)
202 .zip(output_facts.iter())
203 .any(|(a, b)| a.datum_type != b.datum_type || a.shape != b.shape)
204 {
205 bail!(
206 "Inconsistent model, output types mismatch. Op says: {:?}, node says: {:?}. {} with inputs {:?}. {}",
207 output_facts, node.outputs.iter().map(|o| &o.fact).collect::<Vec<_>>(), node, input_facts, node)
208 }
209 }
218 for node in &self.nodes {
219 for (ix, output) in node.outputs.iter().enumerate() {
220 output.fact.consistent().with_context(|| {
221 format!("Inconsistent fact {:?}: {:?}", OutletId::new(node.id, ix), output.fact)
222 })?
223 }
224 }
225 self.axes_mapping().context("Checking model axes mapping")?;
226 Ok(())
227 }
228
229 pub fn into_decluttered(mut self) -> TractResult<TypedModel> {
230 self.declutter()?;
231 Ok(self)
232 }
233
234 pub fn transform(&mut self, transform: &dyn ModelTransform) -> TractResult<()> {
236 transform.transform(self)
237 }
238
239 pub fn declutter(&mut self) -> TractResult<()> {
241 crate::optim::Optimizer::declutter().session().optimize(self)
242 }
243
244 pub fn optimize_with_session(&mut self, session: &mut OptimizerSession) -> TractResult<()> {
246 session.optimize(self)?;
247 self.properties.insert("tract_stage".to_string(), rctensor0("optimized".to_string()));
248 Ok(())
249 }
250
251 pub fn concretize_dims(&self, values: &SymbolValues) -> TractResult<TypedModel> {
252 values.translate_model(self)
253 }
254
255 pub fn prop_consts(&mut self) -> TractResult<()> {
256 crate::optim::Optimizer::prop_consts().optimize(self)
257 }
258
259 pub fn optimize(&mut self) -> TractResult<()> {
261 crate::optim::Optimizer::codegen().optimize(self)
262 }
263
264 pub fn node_axes_mapping(&self, id: usize) -> TractResult<AxesMapping> {
265 let (inputs, outputs) = self.node_facts(id)?;
266 self.nodes[id].op.axes_mapping(&inputs, &outputs)
267 }
268
269 pub fn axes_mapping(&self) -> TractResult<AxesMapping> {
270 crate::axes::for_model(self)
271 }
272
273 pub fn compute_const_facts(&mut self) -> TractResult<()> {
274 for n in self.eval_order()? {
275 let node = self.node(n);
276 let (inputs, outputs) = self.node_facts(n)?;
277 if node.op.is_stateless()
278 && inputs.iter().all(|i| i.konst.is_some())
279 && outputs.iter().any(|o| o.konst.is_none())
280 {
281 let inputs_ref =
282 inputs.iter().map(|f| f.konst.clone().unwrap().into_tvalue()).collect();
283 match node.op.eval_with_session(node.id, &SessionState::default(), inputs_ref) {
284 Ok(res) => {
285 drop(inputs);
286 drop(outputs);
287 for (ix, output) in res.into_iter().enumerate() {
288 self.nodes[n].outputs[ix].fact.konst = Some(output.into_arc_tensor());
289 }
290 }
291 Err(e) => {
292 if !e.root_cause().is::<TooEarly>() {
293 Err(e).with_context(|| {
294 format!("Eager eval {} during const fact computation", self.node(n))
295 })?;
296 }
297 }
298 }
299 }
300 }
301 Ok(())
302 }
303}
304
305use crate::model::translator::Translate;
306impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for SymbolValues {
307 fn translate_node(
308 &self,
309 source: &TypedModel,
310 node: &TypedNode,
311 target: &mut TypedModel,
312 mapping: &HashMap<OutletId, OutletId>,
313 ) -> TractResult<TVec<OutletId>> {
314 target.check_consistency()?;
315 let outlets = node.op.concretize_dims(source, node, target, mapping, self)?;
316 for &outlet in &outlets {
317 let fact = &mut target.nodes[outlet.node].outputs[outlet.slot].fact;
318 if fact.shape.volume().is_zero() {
319 if let Some(shape) = fact.shape.as_concrete() {
320 let tensor = Tensor::zero_dt(fact.datum_type, shape)?;
321 fact.konst = Some(tensor.into_arc_tensor());
322 }
323 }
324 fact.consistent()?;
325 }
326 Ok(outlets)
327 }
328}
329
330#[cfg(test)]
331mod test {
332 use super::*;
333
334 #[test]
335 fn test() {
336 fn is_sync<T: Sync>() {}
337 is_sync::<TypedModel>();
338 }
339}