tract_nnef/
ast.rs

1use crate::internal::*;
2use tract_itertools::Itertools;
3
4pub mod dump;
5pub mod dump_doc;
6pub mod parse;
7pub mod quant;
8
9#[derive(Clone, Debug)]
10pub struct ProtoModel {
11    pub doc: Document,
12    pub tensors: HashMap<Identifier, Arc<Tensor>>,
13    pub quantization: Option<HashMap<Identifier, QuantFormat>>,
14    pub resources: HashMap<String, Arc<dyn Resource>>,
15}
16
17impl ProtoModel {
18    pub fn validate(&self) -> TractResult<()> {
19        self.doc.validate()
20    }
21}
22
23#[derive(Clone, Debug, PartialEq, Eq)]
24pub enum QuantFormat {
25    Linear { params: QParams, bits: i8, signed: bool },
26}
27
28impl QuantFormat {
29    pub fn from_dt(datum_type: DatumType) -> Option<QuantFormat> {
30        if let Some(params) = datum_type.qparams() {
31            let quant_format = QuantFormat::Linear {
32                params,
33                bits: 8 * datum_type.size_of() as i8,
34                signed: datum_type.is_signed(),
35            };
36            Some(quant_format)
37        } else {
38            None
39        }
40    }
41
42    pub fn datum_type(&self) -> DatumType {
43        match self {
44            QuantFormat::Linear { params, bits, signed } => match (bits, signed) {
45                (8, true) => DatumType::QI8(*params),
46                (8, false) => DatumType::QU8(*params),
47                (32, true) => DatumType::QI32(*params),
48                (32, false) => DatumType::U32,
49                _ => todo!(),
50            },
51        }
52    }
53}
54
55#[derive(Clone, Debug, PartialEq, Eq)]
56pub struct Document {
57    pub version: NumericLiteral,
58    pub extension: Vec<(Identifier, String)>,
59    pub fragments: Vec<FragmentDef>,
60    pub graph_def: GraphDef,
61}
62
63impl Document {
64    pub fn validate(&self) -> TractResult<()> {
65        for frag in &self.fragments {
66            frag.validate()?;
67        }
68        Ok(())
69    }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq)]
73pub enum TypeSpec {
74    Single(TypeName),
75    Tensor(TypeName),
76    Array(Box<TypeSpec>),
77    Tuple(Vec<TypeSpec>),
78}
79
80impl TypeSpec {
81    pub fn array(self) -> TypeSpec {
82        TypeSpec::Array(Box::new(self))
83    }
84    pub fn named(self, s: impl AsRef<str>) -> Parameter {
85        Parameter { id: s.as_ref().into(), spec: self, lit: None, doc: None }
86    }
87}
88
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum TypeName {
91    Integer,
92    Scalar,
93    #[cfg(feature = "complex")]
94    Complex,
95    Logical,
96    String,
97    Any,
98}
99
100impl TypeName {
101    pub fn tensor(self) -> TypeSpec {
102        TypeSpec::Tensor(self)
103    }
104    pub fn spec(self) -> TypeSpec {
105        TypeSpec::Single(self)
106    }
107    pub fn array(self) -> TypeSpec {
108        self.spec().array()
109    }
110    pub fn named(self, s: impl AsRef<str>) -> Parameter {
111        self.spec().named(s)
112    }
113}
114
115#[derive(Clone, Debug, PartialEq, Eq)]
116pub struct GraphDef {
117    pub id: Identifier,
118    pub parameters: Vec<Identifier>,
119    pub results: Vec<Identifier>,
120    pub body: Vec<Assignment>,
121}
122
123#[derive(Clone, Debug, PartialEq, Eq)]
124pub struct FragmentDef {
125    pub decl: FragmentDecl,
126    pub body: Option<Vec<Assignment>>,
127}
128
129impl FragmentDef {
130    pub fn validate(&self) -> TractResult<()> {
131        self.decl.validate().with_context(|| format!("Invalid fragment {:?}", self.decl.id))
132    }
133}
134
135#[derive(Clone, Debug, PartialEq, Eq)]
136pub struct FragmentDecl {
137    pub id: Identifier,
138    pub generic_decl: Option<Option<TypeName>>,
139    pub parameters: Vec<Parameter>,
140    pub results: Vec<Result_>,
141}
142
143impl FragmentDecl {
144    pub fn validate(&self) -> TractResult<()> {
145        if let Some(dup) = self
146            .parameters
147            .iter()
148            .map(|p| &p.id)
149            .sorted()
150            .group_by(|x| x.to_owned())
151            .into_iter()
152            .find_map(|(key, values)| if values.count() > 1 { Some(key) } else { None })
153        {
154            bail!("Duplicate parameter name found {:?}", dup);
155        }
156        if let Some(dup) = self
157            .results
158            .iter()
159            .map(|p| &p.id)
160            .sorted()
161            .group_by(|x| x.to_owned())
162            .into_iter()
163            .find_map(|(key, values)| if values.count() > 1 { Some(key) } else { None })
164        {
165            bail!("Duplicate result name found {:?}", dup);
166        }
167        if let Some(dup) = self
168            .parameters
169            .iter()
170            .map(|p| &p.id)
171            .chain(self.results.iter().map(|p| &p.id))
172            .sorted()
173            .group_by(|x| x.to_owned())
174            .into_iter()
175            .find_map(|(key, values)| if values.count() > 1 { Some(key) } else { None })
176        {
177            bail!("Same name used as parameter and result {:?}", dup);
178        }
179        Ok(())
180    }
181}
182
183#[derive(Clone, Debug, PartialEq, Eq)]
184pub struct Parameter {
185    pub id: Identifier,
186    pub spec: TypeSpec,
187    pub lit: Option<Literal>,
188    pub doc: Option<String>,
189}
190
191impl Parameter {
192    pub fn default(self, lit: impl Into<Literal>) -> Parameter {
193        Parameter { lit: Some(lit.into()), ..self }
194    }
195
196    pub fn doc(mut self, s: impl Into<String>) -> Parameter {
197        self.doc = Some(s.into());
198        self
199    }
200}
201
202pub fn param(s: impl AsRef<str>, spec: TypeSpec) -> Parameter {
203    Parameter { id: s.as_ref().into(), spec, lit: None, doc: None }
204}
205
206#[derive(Clone, Debug, PartialEq, Eq)]
207pub struct Result_ {
208    pub id: Identifier,
209    pub spec: TypeSpec,
210}
211
212impl<S: Into<String>> From<(S, TypeSpec)> for Result_ {
213    fn from(v: (S, TypeSpec)) -> Result_ {
214        Result_ { id: Identifier(v.0.into()), spec: v.1 }
215    }
216}
217
218#[derive(Clone, Debug, PartialEq, Eq, Default, Ord, PartialOrd, Hash)]
219pub struct Identifier(pub String);
220
221impl From<&str> for Identifier {
222    fn from(value: &str) -> Self {
223        Identifier(value.to_string())
224    }
225}
226
227impl AsRef<str> for Identifier {
228    fn as_ref(&self) -> &str {
229        &self.0
230    }
231}
232
233#[derive(Clone, Debug, PartialEq, Eq)]
234pub struct Assignment {
235    pub left: LValue,
236    pub right: RValue,
237}
238
239#[derive(Clone, Debug, PartialEq, Eq)]
240pub enum LValue {
241    Identifier(Identifier),
242    Array(Vec<LValue>),
243    Tuple(Vec<LValue>),
244}
245
246#[derive(Clone, Debug, PartialEq, Eq)]
247pub struct Invocation {
248    pub id: Identifier,
249    pub generic_type_name: Option<TypeName>,
250    pub arguments: Vec<Argument>,
251}
252
253#[derive(Clone, Debug, PartialEq, Eq)]
254pub struct Argument {
255    pub id: Option<Identifier>,
256    pub rvalue: RValue,
257}
258
259#[derive(Clone, Debug, PartialEq, Eq)]
260pub enum RValue {
261    Identifier(Identifier),
262    Literal(Literal),
263    Binary(Box<RValue>, String, Box<RValue>),
264    Unary(String, Box<RValue>),
265    Tuple(Vec<RValue>),
266    Array(Vec<RValue>),
267    Subscript(Box<RValue>, Box<Subscript>),
268    Comprehension(Box<Comprehension>),
269    IfThenElse(Box<IfThenElse>),
270    Invocation(Invocation),
271}
272
273impl RValue {
274    pub fn boxed(self) -> Box<RValue> {
275        Box::new(self)
276    }
277}
278
279#[derive(Clone, Debug, PartialEq, Eq)]
280pub struct Comprehension {
281    pub loop_iters: Vec<(Identifier, RValue)>,
282    pub filter: Option<RValue>,
283    pub yields: RValue,
284}
285
286#[derive(Clone, Debug, PartialEq, Eq)]
287pub enum Subscript {
288    Single(RValue),
289    Range(Option<RValue>, Option<RValue>),
290}
291
292#[derive(Clone, Debug, PartialEq, Eq)]
293pub struct IfThenElse {
294    pub cond: RValue,
295    pub then: RValue,
296    pub otherwise: RValue,
297}
298
299#[derive(Clone, Debug, PartialEq, Eq)]
300pub enum Literal {
301    Numeric(NumericLiteral),
302    String(StringLiteral),
303    Logical(LogicalLiteral),
304    Array(Vec<Literal>),
305    Tuple(Vec<Literal>),
306}
307
308impl From<bool> for Literal {
309    fn from(b: bool) -> Literal {
310        Literal::Logical(b)
311    }
312}
313
314impl From<i64> for Literal {
315    fn from(i: i64) -> Literal {
316        Literal::Numeric(i.to_string())
317    }
318}
319
320impl From<f32> for Literal {
321    fn from(f: f32) -> Literal {
322        Literal::Numeric(format!("{f:?}"))
323    }
324}
325
326impl<'a> From<&'a str> for Literal {
327    fn from(s: &'a str) -> Literal {
328        Literal::String(s.to_string())
329    }
330}
331
332pub type NumericLiteral = String;
333pub type StringLiteral = String;
334pub type LogicalLiteral = bool;