tract_api/
lib.rs

1use anyhow::{ensure, Result};
2use boow::Bow;
3use std::fmt::{Debug, Display};
4use std::path::Path;
5
6#[macro_use]
7pub mod macros;
8
9/// an implementation of tract's NNEF framework object
10///
11/// Entry point for NNEF model manipulation: loading from file, dumping to file.
12pub trait NnefInterface: Sized {
13    type Model: ModelInterface;
14    /// Load a NNEF model from the path into a tract-core model.
15    ///
16    /// * `path` can point to a directory, a `tar` file or a `tar.gz` file.
17    fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Self::Model>;
18
19    /// Transform model according to transform spec
20    fn transform_model(&self, model: &mut Self::Model, transform_spec: &str) -> Result<()>;
21
22    /// Allow the framework to use tract_core extensions instead of a stricter NNEF definition.
23    fn enable_tract_core(&mut self) -> Result<()>;
24
25    /// Allow the framework to use tract_extra extensions.
26    fn enable_tract_extra(&mut self) -> Result<()>;
27
28    /// Allow the framework to use tract_transformers extensions to support common transformer operators.
29    fn enable_tract_transformers(&mut self) -> Result<()>;
30
31    /// Allow the framework to use tract_onnx extensions to support operators in ONNX that are
32    /// absent from NNEF.
33    fn enable_onnx(&mut self) -> Result<()>;
34
35    /// Allow the framework to use tract_pulse extensions to support stateful streaming operation.
36    fn enable_pulse(&mut self) -> Result<()>;
37
38    /// Allow the framework to use a tract-proprietary extension that can support special characters
39    /// in node names. If disable, tract will replace everything by underscore '_' to keep
40    /// compatibility with NNEF. If enabled, the extended syntax will be used, allowing to maintain
41    /// the node names in serialized form.
42    fn enable_extended_identifier_syntax(&mut self) -> Result<()>;
43
44    /// Convenience function, similar with enable_tract_core but allowing method chaining.
45    fn with_tract_core(mut self) -> Result<Self> {
46        self.enable_tract_core()?;
47        Ok(self)
48    }
49
50    /// Convenience function, similar with enable_tract_core but allowing method chaining.
51    fn with_tract_extra(mut self) -> Result<Self> {
52        self.enable_tract_extra()?;
53        Ok(self)
54    }
55
56    /// Convenience function, similar with enable_tract_transformers but allowing method chaining.
57    fn with_tract_transformers(mut self) -> Result<Self> {
58        self.enable_tract_transformers()?;
59        Ok(self)
60    }
61
62    /// Convenience function, similar with enable_onnx but allowing method chaining.
63    fn with_onnx(mut self) -> Result<Self> {
64        self.enable_onnx()?;
65        Ok(self)
66    }
67
68    /// Convenience function, similar with enable_pulse but allowing method chaining.
69    fn with_pulse(mut self) -> Result<Self> {
70        self.enable_pulse()?;
71        Ok(self)
72    }
73
74    /// Convenience function, similar with enable_extended_identifier_syntax but allowing method chaining.
75    fn with_extended_identifier_syntax(mut self) -> Result<Self> {
76        self.enable_extended_identifier_syntax()?;
77        Ok(self)
78    }
79
80    /// Dump a TypedModel as a NNEF directory.
81    ///
82    /// `path` is the directory name to dump to
83    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
84
85    /// Dump a TypedModel as a NNEF tar file.
86    ///
87    /// This function creates a plain, non-compressed, archive.
88    ///
89    /// `path` is the archive name
90    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
91    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
92}
93
94pub trait OnnxInterface {
95    type InferenceModel: InferenceModelInterface;
96    fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel>;
97}
98
99pub trait InferenceModelInterface: Sized {
100    type Model: ModelInterface;
101    type InferenceFact: InferenceFactInterface;
102    fn set_output_names(
103        &mut self,
104        outputs: impl IntoIterator<Item = impl AsRef<str>>,
105    ) -> Result<()>;
106    fn input_count(&self) -> Result<usize>;
107    fn output_count(&self) -> Result<usize>;
108    fn input_name(&self, id: usize) -> Result<String>;
109    fn output_name(&self, id: usize) -> Result<String>;
110
111    fn input_fact(&self, id: usize) -> Result<Self::InferenceFact>;
112
113    fn set_input_fact(
114        &mut self,
115        id: usize,
116        fact: impl AsFact<Self, Self::InferenceFact>,
117    ) -> Result<()>;
118
119    fn output_fact(&self, id: usize) -> Result<Self::InferenceFact>;
120
121    fn set_output_fact(
122        &mut self,
123        id: usize,
124        fact: impl AsFact<Self, Self::InferenceFact>,
125    ) -> Result<()>;
126
127    fn analyse(&mut self) -> Result<()>;
128
129    fn into_typed(self) -> Result<Self::Model>;
130
131    fn into_optimized(self) -> Result<Self::Model>;
132}
133
134pub trait ModelInterface: Sized {
135    type Fact: FactInterface;
136    type Runnable: RunnableInterface;
137    type Value: ValueInterface;
138    fn input_count(&self) -> Result<usize>;
139
140    fn output_count(&self) -> Result<usize>;
141
142    fn input_name(&self, id: usize) -> Result<String>;
143
144    fn output_name(&self, id: usize) -> Result<String>;
145
146    fn set_output_names(
147        &mut self,
148        outputs: impl IntoIterator<Item = impl AsRef<str>>,
149    ) -> Result<()>;
150
151    fn input_fact(&self, id: usize) -> Result<Self::Fact>;
152
153    fn output_fact(&self, id: usize) -> Result<Self::Fact>;
154
155    fn declutter(&mut self) -> Result<()>;
156
157    fn optimize(&mut self) -> Result<()>;
158
159    fn into_decluttered(self) -> Result<Self>;
160
161    fn into_optimized(self) -> Result<Self>;
162
163    fn into_runnable(self) -> Result<Self::Runnable>;
164
165    fn concretize_symbols(
166        &mut self,
167        values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
168    ) -> Result<()>;
169
170    fn transform(&mut self, transform: &str) -> Result<()>;
171
172    fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()>;
173
174    fn cost_json(&self) -> Result<String>;
175
176    fn profile_json<I, IV, IE, S, SV, SE>(
177        &self,
178        inputs: Option<I>,
179        state_initializers: Option<S>,
180    ) -> Result<String>
181    where
182        I: IntoIterator<Item = IV>,
183        IV: TryInto<Self::Value, Error = IE>,
184        IE: Into<anyhow::Error> + Debug,
185        S: IntoIterator<Item = SV>,
186        SV: TryInto<Self::Value, Error = SE>,
187        SE: Into<anyhow::Error> + Debug;
188
189    fn property_keys(&self) -> Result<Vec<String>>;
190
191    fn property(&self, name: impl AsRef<str>) -> Result<Self::Value>;
192}
193
194pub trait RunnableInterface {
195    type Value: ValueInterface;
196    type State: StateInterface<Value = Self::Value>;
197    fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Self::Value>>
198    where
199        I: IntoIterator<Item = V>,
200        V: TryInto<Self::Value, Error = E>,
201        E: Into<anyhow::Error>,
202    {
203        self.spawn_state()?.run(inputs)
204    }
205
206    fn input_count(&self) -> Result<usize>;
207    fn output_count(&self) -> Result<usize>;
208
209    fn spawn_state(&self) -> Result<Self::State>;
210}
211
212pub trait StateInterface {
213    type Fact: FactInterface;
214    type Value: ValueInterface;
215
216    fn input_count(&self) -> Result<usize>;
217    fn output_count(&self) -> Result<usize>;
218
219    fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Self::Value>>
220    where
221        I: IntoIterator<Item = V>,
222        V: TryInto<Self::Value, Error = E>,
223        E: Into<anyhow::Error>;
224
225    fn initializable_states_count(&self) -> Result<usize>;
226
227    fn get_states_facts(&self) -> Result<Vec<Self::Fact>>;
228
229    fn set_states<I, V, E>(&mut self, state_initializers: I) -> Result<()>
230    where
231        I: IntoIterator<Item = V>,
232        V: TryInto<Self::Value, Error = E>,
233        E: Into<anyhow::Error> + Debug;
234
235    fn get_states(&self) -> Result<Vec<Self::Value>>;
236}
237
238pub trait ValueInterface: Sized + Clone {
239    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self>;
240    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])>;
241
242    fn from_slice<T: Datum>(shape: &[usize], data: &[T]) -> Result<Self> {
243        let data = unsafe {
244            std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
245        };
246        Self::from_bytes(T::datum_type(), shape, data)
247    }
248
249    fn as_slice<T: Datum>(&self) -> Result<(&[usize], &[T])> {
250        let (dt, shape, data) = self.as_bytes()?;
251        ensure!(T::datum_type() == dt);
252        let data = unsafe {
253            std::slice::from_raw_parts(
254                data.as_ptr() as *const T,
255                data.len() / std::mem::size_of::<T>(),
256            )
257        };
258        Ok((shape, data))
259    }
260
261    fn view<T: Datum>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
262        let (shape, data) = self.as_slice()?;
263        Ok(unsafe { ndarray::ArrayViewD::from_shape_ptr(shape, data.as_ptr()) })
264    }
265}
266
267pub trait FactInterface: Debug + Display + Clone {}
268pub trait InferenceFactInterface: Debug + Display + Default + Clone {
269    fn empty() -> Result<Self>;
270}
271
272pub trait AsFact<M, F> {
273    fn as_fact(&self, model: &mut M) -> Result<Bow<'_, F>>;
274}
275
276#[repr(C)]
277#[allow(non_camel_case_types)]
278#[derive(Debug, PartialEq, Eq, Copy, Clone)]
279pub enum DatumType {
280    TRACT_DATUM_TYPE_BOOL = 0x01,
281    TRACT_DATUM_TYPE_U8 = 0x11,
282    TRACT_DATUM_TYPE_U16 = 0x12,
283    TRACT_DATUM_TYPE_U32 = 0x14,
284    TRACT_DATUM_TYPE_U64 = 0x18,
285    TRACT_DATUM_TYPE_I8 = 0x21,
286    TRACT_DATUM_TYPE_I16 = 0x22,
287    TRACT_DATUM_TYPE_I32 = 0x24,
288    TRACT_DATUM_TYPE_I64 = 0x28,
289    TRACT_DATUM_TYPE_F16 = 0x32,
290    TRACT_DATUM_TYPE_F32 = 0x34,
291    TRACT_DATUM_TYPE_F64 = 0x38,
292    #[cfg(feature = "complex")]
293    TRACT_DATUM_TYPE_COMPLEX_I16 = 0x42,
294    #[cfg(feature = "complex")]
295    TRACT_DATUM_TYPE_COMPLEX_I32 = 0x44,
296    #[cfg(feature = "complex")]
297    TRACT_DATUM_TYPE_COMPLEX_I64 = 0x48,
298    #[cfg(feature = "complex")]
299    TRACT_DATUM_TYPE_COMPLEX_F16 = 0x52,
300    #[cfg(feature = "complex")]
301    TRACT_DATUM_TYPE_COMPLEX_F32 = 0x54,
302    #[cfg(feature = "complex")]
303    TRACT_DATUM_TYPE_COMPLEX_F64 = 0x58,
304}
305
306impl DatumType {
307    pub fn size_of(&self) -> usize {
308        use DatumType::*;
309        match &self {
310            TRACT_DATUM_TYPE_BOOL | TRACT_DATUM_TYPE_U8 | TRACT_DATUM_TYPE_I8 => 1,
311            TRACT_DATUM_TYPE_U16 | TRACT_DATUM_TYPE_I16 | TRACT_DATUM_TYPE_F16 => 2,
312            TRACT_DATUM_TYPE_U32 | TRACT_DATUM_TYPE_I32 | TRACT_DATUM_TYPE_F32 => 4,
313            TRACT_DATUM_TYPE_U64 | TRACT_DATUM_TYPE_I64 | TRACT_DATUM_TYPE_F64 => 8,
314            #[cfg(feature = "complex")]
315            TRACT_DATUM_TYPE_COMPLEX_I16 | TRACT_DATUM_TYPE_F16 => 4,
316            #[cfg(feature = "complex")]
317            TRACT_DATUM_TYPE_COMPLEX_I32 | TRACT_DATUM_TYPE_F32 => 8,
318            #[cfg(feature = "complex")]
319            TRACT_DATUM_TYPE_COMPLEX_I64 | TRACT_DATUM_TYPE_F64 => 16,
320        }
321    }
322}
323
324pub trait Datum {
325    fn datum_type() -> DatumType;
326}
327
328macro_rules! impl_datum_type {
329    ($ty:ty, $c_repr:expr) => {
330        impl Datum for $ty {
331            fn datum_type() -> DatumType {
332                $c_repr
333            }
334        }
335    };
336}
337
338impl_datum_type!(bool, DatumType::TRACT_DATUM_TYPE_BOOL);
339impl_datum_type!(u8, DatumType::TRACT_DATUM_TYPE_U8);
340impl_datum_type!(u16, DatumType::TRACT_DATUM_TYPE_U16);
341impl_datum_type!(u32, DatumType::TRACT_DATUM_TYPE_U32);
342impl_datum_type!(u64, DatumType::TRACT_DATUM_TYPE_U64);
343impl_datum_type!(i8, DatumType::TRACT_DATUM_TYPE_I8);
344impl_datum_type!(i16, DatumType::TRACT_DATUM_TYPE_I16);
345impl_datum_type!(i32, DatumType::TRACT_DATUM_TYPE_I32);
346impl_datum_type!(i64, DatumType::TRACT_DATUM_TYPE_I64);
347impl_datum_type!(half::f16, DatumType::TRACT_DATUM_TYPE_F16);
348impl_datum_type!(f32, DatumType::TRACT_DATUM_TYPE_F32);
349impl_datum_type!(f64, DatumType::TRACT_DATUM_TYPE_F64);