tract_nnef/
ser.rs

1use crate::ast::*;
2use crate::internal::*;
3use tract_core::ndarray::ArrayViewD;
4use tract_core::ndarray::Axis;
5use tract_itertools::Itertools;
6use tract_linalg::block_quant::BlockQuantValue;
7
8pub fn rewrite_model(model: &mut TypedModel) -> TractResult<()> {
9    model.prop_consts()?;
10    tract_core::ops::einsum::prefix_matmul::rewrite_einsum_to_prefix_matmul(model)?;
11    Rewriter::default()
12        .with_rule_for(
13            "rewrite_block_quant_const_to_scalar",
14            crate::ops::nnef::ser::rewrite_block_quant_const_to_scalar,
15        )
16        .with_rule_for(
17            "rewrite_matmul_to_same_rank",
18            crate::ops::nnef::ser::rewrite_matmul_to_same_rank,
19        )
20        .with_rule_for("rewrite_conv_with_n_axis", tract_core::ops::cnn::rewrite_conv_with_n_axis)
21        .with_rule_for(
22            "rewrite_deconv_with_n_axis",
23            tract_core::ops::cnn::rewrite_deconv_with_n_axis,
24        )
25        .with_rule_for(
26            "rewrite_kernel_conv_in_oihw",
27            crate::ops::nnef::ser::rewrite_kernel_conv_in_oihw,
28        )
29        .with_rule_for(
30            "rewrite_kernel_deconv_in_oihw",
31            crate::ops::nnef::ser::rewrite_kernel_deconv_in_oihw,
32        )
33        .with_rule_for(
34            "rewrite_consistent_quantized_conv",
35            crate::ops::nnef::ser::rewrite_consistent_quantized_conv,
36        )
37        .with_rule_for("expand_mean_of_square", tract_core::ops::nn::expand_mean_of_squares)
38        .rewrite(&(), model)
39}
40
41pub fn to_proto_model(framework: &Nnef, model: &TypedModel) -> TractResult<ProtoModel> {
42    let mut fixed_model = model.clone();
43    rewrite_model(&mut fixed_model)?;
44    let mut into_ast = IntoAst::new(framework, &fixed_model);
45    into_ast.translate().context("Translating model to AST")?;
46    into_ast.into_proto_model().context("Translating AST to proto model")
47}
48
49pub fn to_fragment_def(
50    parent: &IntoAst,
51    model: &TypedModel,
52) -> TractResult<(FragmentDef, Vec<RequiredTensorParameter>)> {
53    let mut into_ast = IntoAst::new(parent.framework, model);
54    into_ast.parent = Some(parent);
55    into_ast.translate()?;
56    into_ast.into_fragment()
57}
58
59pub struct IntoAst<'a> {
60    pub framework: &'a Nnef,
61    pub parent: Option<&'a IntoAst<'a>>,
62    pub registries: Vec<Identifier>,
63    pub model: &'a TypedModel,
64    pub parameters: Vec<Identifier>,
65    pub results: Vec<Identifier>,
66    pub mapping: HashMap<OutletId, Arc<RValue>>,
67    pub tensors: HashMap<Identifier, Arc<Tensor>>,
68    pub quantization: HashMap<Identifier, QuantFormat>,
69    pub resources: HashMap<String, Arc<dyn Resource>>,
70    pub fragments: HashMap<Identifier, FragmentDef>,
71    pub body: Vec<Assignment>,
72}
73
74pub struct RequiredTensorParameter {
75    pub parameter_id: Identifier,
76    pub label: Identifier,
77    pub value: Arc<Tensor>,
78}
79
80impl<'a> IntoAst<'a> {
81    pub fn new(framework: &'a Nnef, model: &'a TypedModel) -> IntoAst<'a> {
82        IntoAst {
83            framework,
84            registries: Default::default(),
85            model,
86            parameters: Default::default(),
87            results: Default::default(),
88            mapping: Default::default(),
89            tensors: Default::default(),
90            quantization: Default::default(),
91            resources: Default::default(),
92            fragments: Default::default(),
93            body: Default::default(),
94            parent: None,
95        }
96    }
97
98    fn ensure_registry(&mut self, id: &Identifier) -> TractResult<()> {
99        if !self.framework.registries.iter().any(|r| &r.id == id) {
100            bail!("Registry {} required, consider allowing it on the NNEF framework.", id.0);
101        }
102        if !self.registries.iter().any(|r| r == id) {
103            self.registries.push(id.clone());
104        }
105        Ok(())
106    }
107
108    fn translate(&mut self) -> TractResult<()> {
109        for input in self.model.input_outlets()? {
110            let left = self.scoped_id(&self.model.node(input.node).name);
111            self.parameters.push(left.clone());
112            self.node(self.model.node(input.node))?;
113            self.mapping.insert(*input, RValue::Identifier(left).into());
114        }
115        for node in self.model.eval_order()? {
116            if self.model.input_outlets()?.iter().any(|io| io.node == node) {
117                continue;
118            }
119            self.node(self.model.node(node))
120                .with_context(|| format!("translating node {}", self.model.node(node)))?;
121        }
122        let outlets: Vec<OutletId> = self.model.output_outlets()?.to_vec();
123        for (ix, o) in outlets.into_iter().enumerate() {
124            let rv = if let Some(label) = self.model.outlet_label(o) {
125                self.force_variable_and_name(label, &self.mapping[&o].clone())
126            } else {
127                self.force_variable(format!("output_{ix}"), &self.mapping[&o].clone())
128            };
129            if let RValue::Identifier(name) = rv.as_ref() {
130                self.results.push(name.clone());
131            } else {
132                unreachable!()
133            };
134        }
135        Ok(())
136    }
137
138    pub fn into_fragment(self) -> TractResult<(FragmentDef, Vec<RequiredTensorParameter>)> {
139        let mut tensor_params = vec![];
140        for (name, t) in &self.tensors {
141            tensor_params.push(RequiredTensorParameter {
142                parameter_id: self.scoped_id(name),
143                label: name.clone(),
144                value: t.clone(),
145            })
146        }
147        let IntoAst { body, mut parameters, results, .. } = self;
148        parameters.extend(tensor_params.iter().map(|rtp| rtp.parameter_id.clone()).sorted());
149        let body = body
150            .into_iter()
151            .filter(|assign| match &assign.left {
152                LValue::Identifier(id) => !parameters.contains(id),
153                _ => true,
154            })
155            .collect();
156        Ok((
157            FragmentDef {
158                decl: FragmentDecl {
159                    id: Identifier("network".into()),
160                    generic_decl: None,
161                    parameters: parameters
162                        .into_iter()
163                        .map(|s| TypeName::Scalar.tensor().named(s))
164                        .collect(),
165                    results: results
166                        .into_iter()
167                        .map(|s| Result_ { id: s, spec: TypeName::Scalar.tensor() })
168                        .collect(),
169                },
170                body: Some(body),
171            },
172            tensor_params,
173        ))
174    }
175
176    pub fn into_proto_model(mut self) -> TractResult<ProtoModel> {
177        let mut properties = self
178            .model
179            .properties
180            .iter()
181            .sorted_by_key(|(k, _v)| k.to_owned())
182            .map(|(k, v)| Ok(tuple_2(string(k), self.konst(k, v)?.as_ref().clone())))
183            .collect::<TractResult<Vec<_>>>()?;
184        let version = env!("CARGO_PKG_VERSION");
185        properties.push(tuple_2(
186            string("tract_nnef_ser_version"),
187            self.konst("tract_nnef_ser_version", &rctensor0(version.to_string()))?.as_ref().clone(),
188        ));
189        properties.push(tuple_2(
190            string("tract_nnef_format_version"),
191            self.konst("tract_nnef_format_version", &rctensor0("beta1".to_string()))?
192                .as_ref()
193                .clone(),
194        ));
195        let properties: Assignment = assignment("properties", Arc::new(array(properties)));
196        let IntoAst { mut fragments, body, tensors, parameters, results, .. } = self;
197        let mut extension = vec![];
198        self.registries.sort();
199        for reg in self.registries {
200            if reg.0 != "tract_nnef" {
201                extension.push(("tract_registry".into(), reg.0));
202            }
203        }
204        for sym in self.model.symbols.all_symbols() {
205            extension.push(("tract_symbol".into(), sym.to_string()));
206        }
207        let locked = self.model.symbols.0.lock();
208        for assert in locked.borrow().all_assertions() {
209            extension.push(("tract_assert".into(), assert.to_string()));
210        }
211        for scenario in locked.borrow().scenarios() {
212            for assert in locked.borrow().scenario(scenario) {
213                extension.push(("tract_assert".into(), format!("{scenario}: {assert}")));
214            }
215        }
216        let properties = FragmentDef {
217            decl: FragmentDecl {
218                id: Identifier("tract_core_properties".to_string()),
219                generic_decl: None,
220                parameters: vec![],
221                results: vec![Result_ {
222                    id: Identifier("properties".to_string()),
223                    spec: TypeSpec::Tuple(vec![TypeName::String.spec(), TypeName::Scalar.tensor()])
224                        .array(),
225                }],
226            },
227            body: Some(vec![properties]),
228        };
229        fragments.insert(properties.decl.id.clone(), properties);
230        let doc = Document {
231            version: "1.0".into(),
232            extension,
233            fragments: fragments.into_values().collect(),
234            graph_def: GraphDef { id: Identifier("network".into()), parameters, results, body },
235        };
236        let quantization = if self.quantization.len() > 0 { Some(self.quantization) } else { None };
237        Ok(ProtoModel { doc, tensors, quantization, resources: self.resources })
238    }
239
240    fn node(&mut self, node: &TypedNode) -> TractResult<TVec<Arc<RValue>>> {
241        let mut required_registries = Vec::new();
242        for reg in &self.framework.registries {
243            if let Some(outputs) = reg.serialize(self, node).context("Serializing op")? {
244                if self.ensure_registry(&reg.id).is_err() {
245                    required_registries.push(&reg.id);
246                    continue;
247                };
248                let scoped = self.scoped_id(&node.name);
249                let names: Vec<_> = (0..node.outputs.len())
250                    .map(|ix| {
251                        if ix > 0 {
252                            Identifier(format!("{}_{}", scoped.0, ix))
253                        } else {
254                            scoped.clone()
255                        }
256                    })
257                    .collect();
258                if node.outputs.len() > 1 {
259                    self.body.push(Assignment {
260                        left: LValue::Tuple(
261                            names.iter().map(|n| LValue::Identifier(n.clone())).collect(),
262                        ),
263                        right: outputs.as_ref().clone(),
264                    });
265                } else {
266                    self.assignment(names[0].clone(), outputs);
267                };
268
269                for (outlet, name) in node.outputs.iter().zip(names.iter()) {
270                    if let Some(qf) = QuantFormat::from_dt(outlet.fact.datum_type) {
271                        self.quantization.insert(name.clone(), qf);
272                    }
273                }
274
275                let mut outputs = tvec!();
276                for (ix, o) in names.into_iter().enumerate() {
277                    let rv = Arc::new(ident(o));
278                    self.mapping.insert((node.id, ix).into(), rv.clone());
279                    outputs.push(rv);
280                }
281
282                return Ok(outputs);
283            }
284        }
285        if required_registries.is_empty() {
286            bail!("No serializer found for node {}", node);
287        } else if required_registries.len() == 1 {
288            bail!(
289                "Registry {} required, consider allowing it on the NNEF framework.",
290                required_registries[0].0
291            );
292        } else {
293            bail!("One of the following registries is required: {:?}, consider allowing one on the NNEF framework.", required_registries);
294        }
295    }
296
297    pub fn scoped_id(&self, name: impl AsRef<str>) -> Identifier {
298        let name = name.as_ref().to_string();
299        Identifier(name)
300    }
301
302    pub fn force_variable(&mut self, name: impl AsRef<str>, exp: &Arc<RValue>) -> Arc<RValue> {
303        if let RValue::Identifier(_) = exp.as_ref() {
304            exp.clone()
305        } else {
306            let name = self.scoped_id(name);
307            self.assignment(name.clone(), exp.clone());
308            ident(name).into()
309        }
310    }
311
312    pub fn force_variable_and_name(
313        &mut self,
314        name: impl Into<String>,
315        exp: &Arc<RValue>,
316    ) -> Arc<RValue> {
317        let name = name.into();
318        if let RValue::Identifier(id) = exp.as_ref() {
319            if name == id.0 {
320                return exp.clone();
321            }
322        }
323        let name = self.scoped_id(name);
324        self.assignment(name.clone(), exp.clone());
325        ident(name).into()
326    }
327
328    pub fn konst(
329        &mut self,
330        name: impl AsRef<str>,
331        tensor: &Arc<Tensor>,
332    ) -> TractResult<Arc<RValue>> {
333        self.do_konst(name, tensor, false)
334    }
335
336    pub fn konst_variable(
337        &mut self,
338        name: impl AsRef<str>,
339        tensor: &Arc<Tensor>,
340    ) -> TractResult<Arc<RValue>> {
341        self.do_konst(name, tensor, true)
342    }
343
344    fn dump_rec_tensor<T: Datum>(
345        t: &ArrayViewD<T>,
346        el: impl for<'t> Fn(&'t T) -> RValue + Copy,
347    ) -> RValue {
348        if t.ndim() == 0 {
349            el(&t.as_slice().unwrap()[0])
350        } else {
351            let values: TVec<RValue> = (0..t.shape()[0])
352                .map(|i| Self::dump_rec_tensor(&t.index_axis(Axis(0), i), el))
353                .collect();
354            array(values)
355        }
356    }
357
358    fn do_konst(
359        &mut self,
360        name: impl AsRef<str>,
361        tensor: &Arc<Tensor>,
362        force_variable: bool,
363    ) -> TractResult<Arc<RValue>> {
364        let mut name: Identifier = name.as_ref().into();
365        let have_tract_core = self.ensure_registry(&"tract_core".into()).is_ok();
366        if tensor.datum_type() == TDim::datum_type() {
367            return Ok(Self::dump_rec_tensor(&tensor.to_array_view::<TDim>()?, tdim).into());
368        }
369        if !force_variable && tensor.len() <= 8 {
370            if tensor.datum_type() == String::datum_type() {
371                return Ok(Self::dump_rec_tensor(&tensor.to_array_view::<String>()?, |f| {
372                    string(f)
373                })
374                .into());
375            } else if tensor.datum_type() == DatumType::F32 {
376                return Ok(
377                    Self::dump_rec_tensor(&tensor.to_array_view::<f32>()?, |f| numeric(f)).into()
378                );
379            } else if have_tract_core && tensor.datum_type() == DatumType::F16 {
380                let array =
381                    Self::dump_rec_tensor(&tensor.to_array_view::<f16>()?, |f| numeric(f)).into();
382                return Ok(invocation("tract_core_cast", &[array], &[("to", string("f16"))]));
383            } else if have_tract_core && tensor.datum_type().is_integer() {
384                if let Ok(value) = tensor.cast_to::<i64>() {
385                    let value =
386                        Self::dump_rec_tensor(&value.to_array_view::<i64>().unwrap(), |i| {
387                            numeric(i)
388                        });
389                    let to = string(format!("{:?}", tensor.datum_type()).to_lowercase());
390                    return Ok(invocation("tract_core_cast", &[value.into()], &[("to", to)]));
391                }
392            };
393        }
394
395        if self.tensors.contains_key(&name) {
396            name = (0..)
397                .map(|it| Identifier::from(&*format!("{}_{}", name.0, it)))
398                .find(|it| !self.tensors.contains_key(it))
399                .unwrap();
400        }
401
402        self.tensors.insert(name.clone(), tensor.clone());
403        let id = self.scoped_id(&name);
404        let shape = if tensor.datum_type().is_opaque() {
405            if let Some(bqv) = tensor.to_scalar::<Opaque>()?.downcast_ref::<BlockQuantValue>() {
406                bqv.fact.shape()
407            } else {
408                bail!("Unexpected opaque tensor in serialization {tensor:?}");
409            }
410        } else {
411            tensor.shape()
412        };
413        self.assignment(
414            id.clone(),
415            RValue::Invocation(Invocation {
416                id: "variable".into(),
417                generic_type_name: Some(TypeName::Scalar),
418                arguments: vec![
419                    named_arg("label", string(name.0)),
420                    named_arg("shape", ints(shape)),
421                ],
422            })
423            .into(),
424        );
425        if let Some(qp) = QuantFormat::from_dt(tensor.datum_type()) {
426            self.quantization.insert(id.clone(), qp);
427        }
428        Ok(ident(id).into())
429    }
430
431    fn assignment(&mut self, name: impl AsRef<str>, right: Arc<RValue>) {
432        let name = name.as_ref();
433        if *right == ident(name) {
434            return;
435        }
436        self.body.push(assignment(name, right))
437    }
438}
439
440pub fn assignment(name: impl AsRef<str>, right: Arc<RValue>) -> Assignment {
441    Assignment { left: LValue::Identifier(name.as_ref().into()), right: right.as_ref().to_owned() }
442}
443
444pub fn ints(shape: &[usize]) -> RValue {
445    RValue::Array(shape.iter().map(|s| RValue::Literal(Literal::Numeric(s.to_string()))).collect())
446}
447
448pub fn tdims(shape: &[TDim]) -> RValue {
449    RValue::Array(shape.iter().map(tdim).collect())
450}
451
452pub fn tdim(dim: &TDim) -> RValue {
453    match dim {
454        TDim::Val(x) => numeric(x),
455        TDim::Sym(s) => ident(s.to_string()),
456        TDim::Add(terms) => terms
457            .iter()
458            .map(tdim)
459            .reduce(|x, y| RValue::Binary(x.boxed(), "+".to_string(), y.boxed()))
460            .unwrap(),
461        TDim::Mul(terms) => terms
462            .iter()
463            .map(tdim)
464            .reduce(|x, y| RValue::Binary(x.boxed(), "*".to_string(), y.boxed()))
465            .unwrap(),
466        TDim::MulInt(x, y) => RValue::Binary(numeric(x).boxed(), "*".to_string(), tdim(y).boxed()),
467        TDim::Div(x, y) => RValue::Binary(tdim(x).boxed(), "/".to_string(), numeric(y).boxed()),
468        TDim::Broadcast(_) => todo!(),
469        TDim::Min(_) | TDim::Max(_) => todo!(),
470    }
471}
472
473pub fn string(s: impl AsRef<str>) -> RValue {
474    RValue::Literal(Literal::String(s.as_ref().into()))
475}
476
477pub fn datum_type(dt: DatumType) -> RValue {
478    string(format!("{:?}", dt.unquantized()).to_lowercase())
479}
480
481pub fn logical(b: bool) -> RValue {
482    RValue::Literal(Literal::Logical(b))
483}
484
485pub fn lident(s: impl AsRef<str>) -> LValue {
486    LValue::Identifier(s.as_ref().into())
487}
488
489pub fn ident(s: impl AsRef<str>) -> RValue {
490    RValue::Identifier(s.as_ref().into())
491}
492
493pub fn array(items: impl AsRef<[RValue]>) -> RValue {
494    RValue::Array(items.as_ref().to_vec())
495}
496
497pub fn tuple_2(a: RValue, b: RValue) -> RValue {
498    RValue::Tuple(vec![a, b])
499}
500
501pub fn tuple_3(a: RValue, b: RValue, c: RValue) -> RValue {
502    RValue::Tuple(vec![a, b, c])
503}
504
505pub fn tuple_4(a: RValue, b: RValue, c: RValue, d: RValue) -> RValue {
506    RValue::Tuple(vec![a, b, c, d])
507}
508
509pub fn numeric<D: std::fmt::Debug>(num: D) -> RValue {
510    RValue::Literal(Literal::Numeric(format!("{num:?}")))
511}
512
513pub fn named_arg(id: &str, rv: RValue) -> Argument {
514    Argument { id: Some(id.into()), rvalue: rv }
515}
516
517pub fn invocation(
518    id: impl AsRef<str>,
519    positional: &[Arc<RValue>],
520    named: &[(&str, RValue)],
521) -> Arc<RValue> {
522    let arguments = positional
523        .iter()
524        .map(|rv| Argument { id: None, rvalue: rv.as_ref().clone() })
525        .chain(named.iter().map(|(n, v)| named_arg(n, v.clone())))
526        .collect();
527    RValue::Invocation(Invocation { id: id.as_ref().into(), generic_type_name: None, arguments })
528        .into()
529}