Skip to main content

tract_api/
lib.rs

1use anyhow::{Result, ensure};
2use boow::Bow;
3use std::fmt::{Debug, Display};
4use std::path::Path;
5
6#[macro_use]
7pub mod macros;
8
9/// an implementation of tract's NNEF framework object
10///
11/// Entry point for NNEF model manipulation: loading from file, dumping to file.
12pub trait NnefInterface: Sized {
13    type Model: ModelInterface;
14    /// Load a NNEF model from the path into a tract-core model.
15    ///
16    /// * `path` can point to a directory, a `tar` file or a `tar.gz` file.
17    fn load(&self, path: impl AsRef<Path>) -> Result<Self::Model>;
18
19    /// Load a NNEF model from a buffer into a tract-core model.
20    ///
21    /// data is the content of a NNEF model, as a `tar` file or a `tar.gz` file.
22    fn load_buffer(&self, data: &[u8]) -> Result<Self::Model>;
23
24    /// Allow the framework to use tract_core extensions instead of a stricter NNEF definition.
25    fn enable_tract_core(&mut self) -> Result<()>;
26
27    /// Allow the framework to use tract_extra extensions.
28    fn enable_tract_extra(&mut self) -> Result<()>;
29
30    /// Allow the framework to use tract_transformers extensions to support common transformer operators.
31    fn enable_tract_transformers(&mut self) -> Result<()>;
32
33    /// Allow the framework to use tract_onnx extensions to support operators in ONNX that are
34    /// absent from NNEF.
35    fn enable_onnx(&mut self) -> Result<()>;
36
37    /// Allow the framework to use tract_pulse extensions to support stateful streaming operation.
38    fn enable_pulse(&mut self) -> Result<()>;
39
40    /// Allow the framework to use a tract-proprietary extension that can support special characters
41    /// in node names. If disable, tract will replace everything by underscore '_' to keep
42    /// compatibility with NNEF. If enabled, the extended syntax will be used, allowing to maintain
43    /// the node names in serialized form.
44    fn enable_extended_identifier_syntax(&mut self) -> Result<()>;
45
46    /// Convenience function, similar with enable_tract_core but allowing method chaining.
47    fn with_tract_core(mut self) -> Result<Self> {
48        self.enable_tract_core()?;
49        Ok(self)
50    }
51
52    /// Convenience function, similar with enable_tract_core but allowing method chaining.
53    fn with_tract_extra(mut self) -> Result<Self> {
54        self.enable_tract_extra()?;
55        Ok(self)
56    }
57
58    /// Convenience function, similar with enable_tract_transformers but allowing method chaining.
59    fn with_tract_transformers(mut self) -> Result<Self> {
60        self.enable_tract_transformers()?;
61        Ok(self)
62    }
63
64    /// Convenience function, similar with enable_onnx but allowing method chaining.
65    fn with_onnx(mut self) -> Result<Self> {
66        self.enable_onnx()?;
67        Ok(self)
68    }
69
70    /// Convenience function, similar with enable_pulse but allowing method chaining.
71    fn with_pulse(mut self) -> Result<Self> {
72        self.enable_pulse()?;
73        Ok(self)
74    }
75
76    /// Convenience function, similar with enable_extended_identifier_syntax but allowing method chaining.
77    fn with_extended_identifier_syntax(mut self) -> Result<Self> {
78        self.enable_extended_identifier_syntax()?;
79        Ok(self)
80    }
81
82    /// Dump a TypedModel as a NNEF directory.
83    ///
84    /// `path` is the directory name to dump to
85    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
86
87    /// Dump a TypedModel as a NNEF tar file.
88    ///
89    /// This function creates a plain, non-compressed, archive.
90    ///
91    /// `path` is the archive name
92    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
93    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
94}
95
96pub trait OnnxInterface {
97    type InferenceModel: InferenceModelInterface;
98    fn load(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel>;
99    /// Load a ONNX model from a buffer into an InferenceModel.
100    fn load_buffer(&self, data: &[u8]) -> Result<Self::InferenceModel>;
101}
102
103pub trait InferenceModelInterface: Sized {
104    type Model: ModelInterface;
105    type InferenceFact: InferenceFactInterface;
106    fn set_output_names(
107        &mut self,
108        outputs: impl IntoIterator<Item = impl AsRef<str>>,
109    ) -> Result<()>;
110    fn input_count(&self) -> Result<usize>;
111    fn output_count(&self) -> Result<usize>;
112    fn input_name(&self, id: usize) -> Result<String>;
113    fn output_name(&self, id: usize) -> Result<String>;
114
115    fn input_fact(&self, id: usize) -> Result<Self::InferenceFact>;
116
117    fn set_input_fact(
118        &mut self,
119        id: usize,
120        fact: impl AsFact<Self, Self::InferenceFact>,
121    ) -> Result<()>;
122
123    fn output_fact(&self, id: usize) -> Result<Self::InferenceFact>;
124
125    fn set_output_fact(
126        &mut self,
127        id: usize,
128        fact: impl AsFact<Self, Self::InferenceFact>,
129    ) -> Result<()>;
130
131    fn analyse(&mut self) -> Result<()>;
132
133    fn into_tract(self) -> Result<Self::Model>;
134}
135
136pub trait ModelInterface: Sized {
137    type Fact: FactInterface;
138    type Runnable: RunnableInterface;
139    type Value: ValueInterface;
140    fn input_count(&self) -> Result<usize>;
141
142    fn output_count(&self) -> Result<usize>;
143
144    fn input_name(&self, id: usize) -> Result<String>;
145
146    fn output_name(&self, id: usize) -> Result<String>;
147
148    fn set_output_names(
149        &mut self,
150        outputs: impl IntoIterator<Item = impl AsRef<str>>,
151    ) -> Result<()>;
152
153    fn input_fact(&self, id: usize) -> Result<Self::Fact>;
154
155    fn output_fact(&self, id: usize) -> Result<Self::Fact>;
156
157    fn into_runnable(self) -> Result<Self::Runnable>;
158
159    fn concretize_symbols(
160        &mut self,
161        values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
162    ) -> Result<()>;
163
164    fn transform(&mut self, transform: &str) -> Result<()>;
165
166    fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()>;
167
168    fn property_keys(&self) -> Result<Vec<String>>;
169
170    fn property(&self, name: impl AsRef<str>) -> Result<Self::Value>;
171
172    fn parse_fact(&self, spec: &str) -> Result<Self::Fact>;
173
174    fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
175        Ok((0..self.input_count()?)
176            .map(|ix| self.input_fact(ix))
177            .collect::<Result<Vec<_>>>()?
178            .into_iter())
179    }
180
181    fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
182        Ok((0..self.output_count()?)
183            .map(|ix| self.output_fact(ix))
184            .collect::<Result<Vec<_>>>()?
185            .into_iter())
186    }
187}
188
189pub trait RuntimeInterface {
190    type Runnable: RunnableInterface;
191    type Model: ModelInterface;
192    fn prepare(&self, model: Self::Model) -> Result<Self::Runnable>;
193}
194
195pub trait RunnableInterface: Send + Sync {
196    type Value: ValueInterface;
197    type Fact: FactInterface;
198    type State: StateInterface<Value = Self::Value>;
199    fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Self::Value>>
200    where
201        I: IntoIterator<Item = V>,
202        V: TryInto<Self::Value, Error = E>,
203        E: Into<anyhow::Error>,
204    {
205        self.spawn_state()?.run(inputs)
206    }
207
208    fn input_count(&self) -> Result<usize>;
209    fn output_count(&self) -> Result<usize>;
210    fn input_fact(&self, id: usize) -> Result<Self::Fact>;
211
212    fn output_fact(&self, id: usize) -> Result<Self::Fact>;
213
214    fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
215        Ok((0..self.input_count()?)
216            .map(|ix| self.input_fact(ix))
217            .collect::<Result<Vec<_>>>()?
218            .into_iter())
219    }
220
221    fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
222        Ok((0..self.output_count()?)
223            .map(|ix| self.output_fact(ix))
224            .collect::<Result<Vec<_>>>()?
225            .into_iter())
226    }
227
228    fn property_keys(&self) -> Result<Vec<String>>;
229    fn property(&self, name: impl AsRef<str>) -> Result<Self::Value>;
230
231    fn spawn_state(&self) -> Result<Self::State>;
232
233    fn cost_json(&self) -> Result<String>;
234
235    fn profile_json<I, IV, IE, S, SV, SE>(
236        &self,
237        inputs: Option<I>,
238        state_initializers: Option<S>,
239    ) -> Result<String>
240    where
241        I: IntoIterator<Item = IV>,
242        IV: TryInto<Self::Value, Error = IE>,
243        IE: Into<anyhow::Error> + Debug,
244        S: IntoIterator<Item = SV>,
245        SV: TryInto<Self::Value, Error = SE>,
246        SE: Into<anyhow::Error> + Debug;
247}
248
249pub trait StateInterface {
250    type Fact: FactInterface;
251    type Value: ValueInterface;
252
253    fn input_count(&self) -> Result<usize>;
254    fn output_count(&self) -> Result<usize>;
255
256    fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Self::Value>>
257    where
258        I: IntoIterator<Item = V>,
259        V: TryInto<Self::Value, Error = E>,
260        E: Into<anyhow::Error>;
261
262    #[doc(hidden)]
263    #[deprecated]
264    fn initializable_states_count(&self) -> Result<usize>;
265
266    #[doc(hidden)]
267    #[deprecated]
268    fn get_states_facts(&self) -> Result<Vec<Self::Fact>>;
269
270    #[doc(hidden)]
271    #[deprecated]
272    fn set_states<I, V, E>(&mut self, state_initializers: I) -> Result<()>
273    where
274        I: IntoIterator<Item = V>,
275        V: TryInto<Self::Value, Error = E>,
276        E: Into<anyhow::Error> + Debug;
277
278    #[doc(hidden)]
279    #[deprecated]
280    fn get_states(&self) -> Result<Vec<Self::Value>>;
281}
282
283pub trait ValueInterface: Debug + Sized + Clone + PartialEq + Send + Sync {
284    fn datum_type(&self) -> Result<DatumType>;
285    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self>;
286    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])>;
287
288    fn from_slice<T: Datum>(shape: &[usize], data: &[T]) -> Result<Self> {
289        let data = unsafe {
290            std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
291        };
292        Self::from_bytes(T::datum_type(), shape, data)
293    }
294
295    fn as_slice<T: Datum>(&self) -> Result<&[T]> {
296        let (dt, _shape, data) = self.as_bytes()?;
297        ensure!(T::datum_type() == dt);
298        let data = unsafe {
299            std::slice::from_raw_parts(
300                data.as_ptr() as *const T,
301                data.len() / std::mem::size_of::<T>(),
302            )
303        };
304        Ok(data)
305    }
306
307    fn as_shape_and_slice<T: Datum>(&self) -> Result<(&[usize], &[T])> {
308        let (_, shape, _) = self.as_bytes()?;
309        let data = self.as_slice()?;
310        Ok((shape, data))
311    }
312
313    fn shape(&self) -> Result<&[usize]> {
314        let (_, shape, _) = self.as_bytes()?;
315        Ok(shape)
316    }
317
318    fn view<T: Datum>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
319        let (shape, data) = self.as_shape_and_slice()?;
320        Ok(unsafe { ndarray::ArrayViewD::from_shape_ptr(shape, data.as_ptr()) })
321    }
322
323    fn convert_to(&self, to: DatumType) -> Result<Self>;
324}
325
326pub trait FactInterface: Debug + Display + Clone {
327    type Dim: DimInterface;
328    fn datum_type(&self) -> Result<DatumType>;
329    fn rank(&self) -> Result<usize>;
330    fn dim(&self, axis: usize) -> Result<Self::Dim>;
331
332    fn dims(&self) -> Result<impl Iterator<Item = Self::Dim>> {
333        Ok((0..self.rank()?).map(|axis| self.dim(axis)).collect::<Result<Vec<_>>>()?.into_iter())
334    }
335}
336
337pub trait DimInterface: Debug + Display + Clone {
338    fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self>;
339    fn to_int64(&self) -> Result<i64>;
340}
341
342pub trait InferenceFactInterface: Debug + Display + Default + Clone {
343    fn empty() -> Result<Self>;
344}
345
346pub trait AsFact<M, F> {
347    fn as_fact(&self, model: &M) -> Result<Bow<'_, F>>;
348}
349
350#[repr(C)]
351#[allow(non_camel_case_types)]
352#[derive(Debug, PartialEq, Eq, Copy, Clone)]
353pub enum DatumType {
354    TRACT_DATUM_TYPE_BOOL = 0x01,
355    TRACT_DATUM_TYPE_U8 = 0x11,
356    TRACT_DATUM_TYPE_U16 = 0x12,
357    TRACT_DATUM_TYPE_U32 = 0x14,
358    TRACT_DATUM_TYPE_U64 = 0x18,
359    TRACT_DATUM_TYPE_I8 = 0x21,
360    TRACT_DATUM_TYPE_I16 = 0x22,
361    TRACT_DATUM_TYPE_I32 = 0x24,
362    TRACT_DATUM_TYPE_I64 = 0x28,
363    TRACT_DATUM_TYPE_F16 = 0x32,
364    TRACT_DATUM_TYPE_F32 = 0x34,
365    TRACT_DATUM_TYPE_F64 = 0x38,
366    #[cfg(feature = "complex")]
367    TRACT_DATUM_TYPE_COMPLEX_I16 = 0x42,
368    #[cfg(feature = "complex")]
369    TRACT_DATUM_TYPE_COMPLEX_I32 = 0x44,
370    #[cfg(feature = "complex")]
371    TRACT_DATUM_TYPE_COMPLEX_I64 = 0x48,
372    #[cfg(feature = "complex")]
373    TRACT_DATUM_TYPE_COMPLEX_F16 = 0x52,
374    #[cfg(feature = "complex")]
375    TRACT_DATUM_TYPE_COMPLEX_F32 = 0x54,
376    #[cfg(feature = "complex")]
377    TRACT_DATUM_TYPE_COMPLEX_F64 = 0x58,
378}
379
380impl DatumType {
381    pub fn size_of(&self) -> usize {
382        use DatumType::*;
383        match &self {
384            TRACT_DATUM_TYPE_BOOL | TRACT_DATUM_TYPE_U8 | TRACT_DATUM_TYPE_I8 => 1,
385            TRACT_DATUM_TYPE_U16 | TRACT_DATUM_TYPE_I16 | TRACT_DATUM_TYPE_F16 => 2,
386            TRACT_DATUM_TYPE_U32 | TRACT_DATUM_TYPE_I32 | TRACT_DATUM_TYPE_F32 => 4,
387            TRACT_DATUM_TYPE_U64 | TRACT_DATUM_TYPE_I64 | TRACT_DATUM_TYPE_F64 => 8,
388            #[cfg(feature = "complex")]
389            TRACT_DATUM_TYPE_COMPLEX_I16 | TRACT_DATUM_TYPE_F16 => 4,
390            #[cfg(feature = "complex")]
391            TRACT_DATUM_TYPE_COMPLEX_I32 | TRACT_DATUM_TYPE_F32 => 8,
392            #[cfg(feature = "complex")]
393            TRACT_DATUM_TYPE_COMPLEX_I64 | TRACT_DATUM_TYPE_F64 => 16,
394        }
395    }
396
397    pub fn is_bool(&self) -> bool {
398        use DatumType::*;
399        *self == TRACT_DATUM_TYPE_BOOL
400    }
401
402    pub fn is_number(&self) -> bool {
403        use DatumType::*;
404        *self != TRACT_DATUM_TYPE_BOOL
405    }
406
407    pub fn is_unsigned(&self) -> bool {
408        use DatumType::*;
409        *self == TRACT_DATUM_TYPE_U8
410            || *self == TRACT_DATUM_TYPE_U16
411            || *self == TRACT_DATUM_TYPE_U32
412            || *self == TRACT_DATUM_TYPE_U64
413    }
414
415    pub fn is_signed(&self) -> bool {
416        use DatumType::*;
417        *self == TRACT_DATUM_TYPE_I8
418            || *self == TRACT_DATUM_TYPE_I16
419            || *self == TRACT_DATUM_TYPE_I32
420            || *self == TRACT_DATUM_TYPE_I64
421    }
422
423    pub fn is_float(&self) -> bool {
424        use DatumType::*;
425        *self == TRACT_DATUM_TYPE_F16
426            || *self == TRACT_DATUM_TYPE_F32
427            || *self == TRACT_DATUM_TYPE_F64
428    }
429}
430
431pub trait Datum {
432    fn datum_type() -> DatumType;
433}
434
435macro_rules! impl_datum_type {
436    ($ty:ty, $c_repr:expr) => {
437        impl Datum for $ty {
438            fn datum_type() -> DatumType {
439                $c_repr
440            }
441        }
442    };
443}
444
445impl_datum_type!(bool, DatumType::TRACT_DATUM_TYPE_BOOL);
446impl_datum_type!(u8, DatumType::TRACT_DATUM_TYPE_U8);
447impl_datum_type!(u16, DatumType::TRACT_DATUM_TYPE_U16);
448impl_datum_type!(u32, DatumType::TRACT_DATUM_TYPE_U32);
449impl_datum_type!(u64, DatumType::TRACT_DATUM_TYPE_U64);
450impl_datum_type!(i8, DatumType::TRACT_DATUM_TYPE_I8);
451impl_datum_type!(i16, DatumType::TRACT_DATUM_TYPE_I16);
452impl_datum_type!(i32, DatumType::TRACT_DATUM_TYPE_I32);
453impl_datum_type!(i64, DatumType::TRACT_DATUM_TYPE_I64);
454impl_datum_type!(half::f16, DatumType::TRACT_DATUM_TYPE_F16);
455impl_datum_type!(f32, DatumType::TRACT_DATUM_TYPE_F32);
456impl_datum_type!(f64, DatumType::TRACT_DATUM_TYPE_F64);