1use crate::internal::*;
2use tract_itertools::Itertools;
3
4pub mod dump;
5pub mod dump_doc;
6pub mod parse;
7pub mod quant;
8
9#[derive(Clone, Debug)]
10pub struct ProtoModel {
11 pub doc: Document,
12 pub tensors: HashMap<Identifier, Arc<Tensor>>,
13 pub quantization: Option<HashMap<Identifier, QuantFormat>>,
14 pub resources: HashMap<String, Arc<dyn Resource>>,
15}
16
17impl ProtoModel {
18 pub fn validate(&self) -> TractResult<()> {
19 self.doc.validate()
20 }
21}
22
23#[derive(Clone, Debug, PartialEq, Eq)]
24pub enum QuantFormat {
25 Linear { params: QParams, bits: i8, signed: bool },
26}
27
28impl QuantFormat {
29 pub fn from_dt(datum_type: DatumType) -> Option<QuantFormat> {
30 if let Some(params) = datum_type.qparams() {
31 let quant_format = QuantFormat::Linear {
32 params,
33 bits: 8 * datum_type.size_of() as i8,
34 signed: datum_type.is_signed(),
35 };
36 Some(quant_format)
37 } else {
38 None
39 }
40 }
41
42 pub fn datum_type(&self) -> DatumType {
43 match self {
44 QuantFormat::Linear { params, bits, signed } => match (bits, signed) {
45 (8, true) => DatumType::QI8(*params),
46 (8, false) => DatumType::QU8(*params),
47 (32, true) => DatumType::QI32(*params),
48 (32, false) => DatumType::U32,
49 _ => todo!(),
50 },
51 }
52 }
53}
54
55#[derive(Clone, Debug, PartialEq, Eq)]
56pub struct Document {
57 pub version: NumericLiteral,
58 pub extension: Vec<(Identifier, String)>,
59 pub fragments: Vec<FragmentDef>,
60 pub graph_def: GraphDef,
61}
62
63impl Document {
64 pub fn validate(&self) -> TractResult<()> {
65 for frag in &self.fragments {
66 frag.validate()?;
67 }
68 Ok(())
69 }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq)]
73pub enum TypeSpec {
74 Single(TypeName),
75 Tensor(TypeName),
76 Array(Box<TypeSpec>),
77 Tuple(Vec<TypeSpec>),
78}
79
80impl TypeSpec {
81 pub fn array(self) -> TypeSpec {
82 TypeSpec::Array(Box::new(self))
83 }
84 pub fn named(self, s: impl AsRef<str>) -> Parameter {
85 Parameter { id: s.as_ref().into(), spec: self, lit: None, doc: None }
86 }
87}
88
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum TypeName {
91 Integer,
92 Scalar,
93 #[cfg(feature = "complex")]
94 Complex,
95 Logical,
96 String,
97 Any,
98}
99
100impl TypeName {
101 pub fn tensor(self) -> TypeSpec {
102 TypeSpec::Tensor(self)
103 }
104 pub fn spec(self) -> TypeSpec {
105 TypeSpec::Single(self)
106 }
107 pub fn array(self) -> TypeSpec {
108 self.spec().array()
109 }
110 pub fn named(self, s: impl AsRef<str>) -> Parameter {
111 self.spec().named(s)
112 }
113}
114
115#[derive(Clone, Debug, PartialEq, Eq)]
116pub struct GraphDef {
117 pub id: Identifier,
118 pub parameters: Vec<Identifier>,
119 pub results: Vec<Identifier>,
120 pub body: Vec<Assignment>,
121}
122
123#[derive(Clone, Debug, PartialEq, Eq)]
124pub struct FragmentDef {
125 pub decl: FragmentDecl,
126 pub body: Option<Vec<Assignment>>,
127}
128
129impl FragmentDef {
130 pub fn validate(&self) -> TractResult<()> {
131 self.decl.validate().with_context(|| format!("Invalid fragment {:?}", self.decl.id))
132 }
133}
134
135#[derive(Clone, Debug, PartialEq, Eq)]
136pub struct FragmentDecl {
137 pub id: Identifier,
138 pub generic_decl: Option<Option<TypeName>>,
139 pub parameters: Vec<Parameter>,
140 pub results: Vec<Result_>,
141}
142
143impl FragmentDecl {
144 pub fn validate(&self) -> TractResult<()> {
145 if let Some(dup) = self
146 .parameters
147 .iter()
148 .map(|p| &p.id)
149 .sorted()
150 .group_by(|x| x.to_owned())
151 .into_iter()
152 .find_map(|(key, values)| if values.count() > 1 { Some(key) } else { None })
153 {
154 bail!("Duplicate parameter name found {:?}", dup);
155 }
156 if let Some(dup) = self
157 .results
158 .iter()
159 .map(|p| &p.id)
160 .sorted()
161 .group_by(|x| x.to_owned())
162 .into_iter()
163 .find_map(|(key, values)| if values.count() > 1 { Some(key) } else { None })
164 {
165 bail!("Duplicate result name found {:?}", dup);
166 }
167 if let Some(dup) = self
168 .parameters
169 .iter()
170 .map(|p| &p.id)
171 .chain(self.results.iter().map(|p| &p.id))
172 .sorted()
173 .group_by(|x| x.to_owned())
174 .into_iter()
175 .find_map(|(key, values)| if values.count() > 1 { Some(key) } else { None })
176 {
177 bail!("Same name used as parameter and result {:?}", dup);
178 }
179 Ok(())
180 }
181}
182
183#[derive(Clone, Debug, PartialEq, Eq)]
184pub struct Parameter {
185 pub id: Identifier,
186 pub spec: TypeSpec,
187 pub lit: Option<Literal>,
188 pub doc: Option<String>,
189}
190
191impl Parameter {
192 pub fn default(self, lit: impl Into<Literal>) -> Parameter {
193 Parameter { lit: Some(lit.into()), ..self }
194 }
195
196 pub fn doc(mut self, s: impl Into<String>) -> Parameter {
197 self.doc = Some(s.into());
198 self
199 }
200}
201
202pub fn param(s: impl AsRef<str>, spec: TypeSpec) -> Parameter {
203 Parameter { id: s.as_ref().into(), spec, lit: None, doc: None }
204}
205
206#[derive(Clone, Debug, PartialEq, Eq)]
207pub struct Result_ {
208 pub id: Identifier,
209 pub spec: TypeSpec,
210}
211
212impl<S: Into<String>> From<(S, TypeSpec)> for Result_ {
213 fn from(v: (S, TypeSpec)) -> Result_ {
214 Result_ { id: Identifier(v.0.into()), spec: v.1 }
215 }
216}
217
218#[derive(Clone, Debug, PartialEq, Eq, Default, Ord, PartialOrd, Hash)]
219pub struct Identifier(pub String);
220
221impl From<&str> for Identifier {
222 fn from(value: &str) -> Self {
223 Identifier(value.to_string())
224 }
225}
226
227impl AsRef<str> for Identifier {
228 fn as_ref(&self) -> &str {
229 &self.0
230 }
231}
232
233#[derive(Clone, Debug, PartialEq, Eq)]
234pub struct Assignment {
235 pub left: LValue,
236 pub right: RValue,
237}
238
239#[derive(Clone, Debug, PartialEq, Eq)]
240pub enum LValue {
241 Identifier(Identifier),
242 Array(Vec<LValue>),
243 Tuple(Vec<LValue>),
244}
245
246#[derive(Clone, Debug, PartialEq, Eq)]
247pub struct Invocation {
248 pub id: Identifier,
249 pub generic_type_name: Option<TypeName>,
250 pub arguments: Vec<Argument>,
251}
252
253#[derive(Clone, Debug, PartialEq, Eq)]
254pub struct Argument {
255 pub id: Option<Identifier>,
256 pub rvalue: RValue,
257}
258
259#[derive(Clone, Debug, PartialEq, Eq)]
260pub enum RValue {
261 Identifier(Identifier),
262 Literal(Literal),
263 Binary(Box<RValue>, String, Box<RValue>),
264 Unary(String, Box<RValue>),
265 Tuple(Vec<RValue>),
266 Array(Vec<RValue>),
267 Subscript(Box<RValue>, Box<Subscript>),
268 Comprehension(Box<Comprehension>),
269 IfThenElse(Box<IfThenElse>),
270 Invocation(Invocation),
271}
272
273impl RValue {
274 pub fn boxed(self) -> Box<RValue> {
275 Box::new(self)
276 }
277}
278
279#[derive(Clone, Debug, PartialEq, Eq)]
280pub struct Comprehension {
281 pub loop_iters: Vec<(Identifier, RValue)>,
282 pub filter: Option<RValue>,
283 pub yields: RValue,
284}
285
286#[derive(Clone, Debug, PartialEq, Eq)]
287pub enum Subscript {
288 Single(RValue),
289 Range(Option<RValue>, Option<RValue>),
290}
291
292#[derive(Clone, Debug, PartialEq, Eq)]
293pub struct IfThenElse {
294 pub cond: RValue,
295 pub then: RValue,
296 pub otherwise: RValue,
297}
298
299#[derive(Clone, Debug, PartialEq, Eq)]
300pub enum Literal {
301 Numeric(NumericLiteral),
302 String(StringLiteral),
303 Logical(LogicalLiteral),
304 Array(Vec<Literal>),
305 Tuple(Vec<Literal>),
306}
307
308impl From<bool> for Literal {
309 fn from(b: bool) -> Literal {
310 Literal::Logical(b)
311 }
312}
313
314impl From<i64> for Literal {
315 fn from(i: i64) -> Literal {
316 Literal::Numeric(i.to_string())
317 }
318}
319
320impl From<f32> for Literal {
321 fn from(f: f32) -> Literal {
322 Literal::Numeric(format!("{f:?}"))
323 }
324}
325
326impl<'a> From<&'a str> for Literal {
327 fn from(s: &'a str) -> Literal {
328 Literal::String(s.to_string())
329 }
330}
331
332pub type NumericLiteral = String;
333pub type StringLiteral = String;
334pub type LogicalLiteral = bool;