Skip to main content

tract_api/
lib.rs

1use anyhow::{Result, ensure};
2use boow::Bow;
3use std::fmt::{Debug, Display};
4use std::path::Path;
5
6#[macro_use]
7pub mod macros;
8pub mod transform;
9
10pub use transform::{FloatPrecision, Pulse, SetSymbols, TransformConfig, TransformSpec};
11
12/// an implementation of tract's NNEF framework object
13///
14/// Entry point for NNEF model manipulation: loading from file, dumping to file.
15pub trait NnefInterface: Debug + Sized {
16    type Model: ModelInterface;
17    /// Load a NNEF model from the path into a tract-core model.
18    ///
19    /// * `path` can point to a directory, a `tar` file or a `tar.gz` file.
20    fn load(&self, path: impl AsRef<Path>) -> Result<Self::Model>;
21
22    /// Load a NNEF model from a buffer into a tract-core model.
23    ///
24    /// data is the content of a NNEF model, as a `tar` file or a `tar.gz` file.
25    fn load_buffer(&self, data: &[u8]) -> Result<Self::Model>;
26
27    /// Force the framework to emit strict NNEF instead of using the tract_core extension.
28    /// The tract_core extension is enabled by default; call this to opt out.
29    fn disable_tract_core(&mut self) -> Result<()>;
30
31    /// Allow the framework to use tract_extra extensions.
32    fn enable_tract_extra(&mut self) -> Result<()>;
33
34    /// Allow the framework to use tract_transformers extensions to support common transformer operators.
35    fn enable_tract_transformers(&mut self) -> Result<()>;
36
37    /// Allow the framework to use tract_onnx extensions to support operators in ONNX that are
38    /// absent from NNEF.
39    fn enable_onnx(&mut self) -> Result<()>;
40
41    /// Allow the framework to use tract_pulse extensions to support stateful streaming operation.
42    fn enable_pulse(&mut self) -> Result<()>;
43
44    /// Allow the framework to use a tract-proprietary extension that can support special characters
45    /// in node names. If disable, tract will replace everything by underscore '_' to keep
46    /// compatibility with NNEF. If enabled, the extended syntax will be used, allowing to maintain
47    /// the node names in serialized form.
48    fn enable_extended_identifier_syntax(&mut self) -> Result<()>;
49
50    /// Convenience function, similar to disable_tract_core but allowing method chaining.
51    fn without_tract_core(mut self) -> Result<Self> {
52        self.disable_tract_core()?;
53        Ok(self)
54    }
55
56    /// Convenience function, similar with enable_tract_extra but allowing method chaining.
57    fn with_tract_extra(mut self) -> Result<Self> {
58        self.enable_tract_extra()?;
59        Ok(self)
60    }
61
62    /// Convenience function, similar with enable_tract_transformers but allowing method chaining.
63    fn with_tract_transformers(mut self) -> Result<Self> {
64        self.enable_tract_transformers()?;
65        Ok(self)
66    }
67
68    /// Convenience function, similar with enable_onnx but allowing method chaining.
69    fn with_onnx(mut self) -> Result<Self> {
70        self.enable_onnx()?;
71        Ok(self)
72    }
73
74    /// Convenience function, similar with enable_pulse but allowing method chaining.
75    fn with_pulse(mut self) -> Result<Self> {
76        self.enable_pulse()?;
77        Ok(self)
78    }
79
80    /// Convenience function, similar with enable_extended_identifier_syntax but allowing method chaining.
81    fn with_extended_identifier_syntax(mut self) -> Result<Self> {
82        self.enable_extended_identifier_syntax()?;
83        Ok(self)
84    }
85
86    /// Dump a TypedModel as a NNEF directory.
87    ///
88    /// `path` is the directory name to dump to
89    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
90
91    /// Dump a TypedModel as a NNEF tar file.
92    ///
93    /// This function creates a plain, non-compressed, archive.
94    ///
95    /// `path` is the archive name
96    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
97    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
98}
99
100pub trait OnnxInterface: Debug {
101    type InferenceModel: InferenceModelInterface;
102    fn load(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel>;
103    /// Load a ONNX model from a buffer into an InferenceModel.
104    fn load_buffer(&self, data: &[u8]) -> Result<Self::InferenceModel>;
105}
106
107pub trait InferenceModelInterface: Debug + Sized {
108    type Model: ModelInterface;
109    type InferenceFact: InferenceFactInterface;
110    fn input_count(&self) -> Result<usize>;
111    fn output_count(&self) -> Result<usize>;
112    fn input_name(&self, id: usize) -> Result<String>;
113    fn output_name(&self, id: usize) -> Result<String>;
114
115    fn input_fact(&self, id: usize) -> Result<Self::InferenceFact>;
116
117    fn set_input_fact(
118        &mut self,
119        id: usize,
120        fact: impl AsFact<Self, Self::InferenceFact>,
121    ) -> Result<()>;
122
123    fn output_fact(&self, id: usize) -> Result<Self::InferenceFact>;
124
125    fn set_output_fact(
126        &mut self,
127        id: usize,
128        fact: impl AsFact<Self, Self::InferenceFact>,
129    ) -> Result<()>;
130
131    fn analyse(&mut self) -> Result<()>;
132
133    fn into_model(self) -> Result<Self::Model>;
134}
135
136pub trait ModelInterface: Debug + Sized {
137    type Fact: FactInterface;
138    type Runnable: RunnableInterface;
139    type Tensor: TensorInterface;
140    fn input_count(&self) -> Result<usize>;
141
142    fn output_count(&self) -> Result<usize>;
143
144    fn input_name(&self, id: usize) -> Result<String>;
145
146    fn output_name(&self, id: usize) -> Result<String>;
147
148    fn input_fact(&self, id: usize) -> Result<Self::Fact>;
149
150    fn output_fact(&self, id: usize) -> Result<Self::Fact>;
151
152    fn into_runnable(self) -> Result<Self::Runnable>;
153
154    fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()>;
155
156    fn property_keys(&self) -> Result<Vec<String>>;
157
158    fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
159
160    fn parse_fact(&self, spec: &str) -> Result<Self::Fact>;
161
162    fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
163        Ok((0..self.input_count()?)
164            .map(|ix| self.input_fact(ix))
165            .collect::<Result<Vec<_>>>()?
166            .into_iter())
167    }
168
169    fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
170        Ok((0..self.output_count()?)
171            .map(|ix| self.output_fact(ix))
172            .collect::<Result<Vec<_>>>()?
173            .into_iter())
174    }
175}
176
177pub trait RuntimeInterface: Debug {
178    type Runnable: RunnableInterface;
179    type Model: ModelInterface;
180    fn name(&self) -> Result<String>;
181    fn prepare(&self, model: Self::Model) -> Result<Self::Runnable>;
182}
183
184pub trait RunnableInterface: Debug + Send + Sync {
185    type Tensor: TensorInterface;
186    type Fact: FactInterface;
187    type State: StateInterface<Tensor = Self::Tensor>;
188    fn run(&self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>> {
189        self.spawn_state()?.run(inputs.into_inputs()?)
190    }
191
192    fn input_count(&self) -> Result<usize>;
193    fn output_count(&self) -> Result<usize>;
194    fn input_fact(&self, id: usize) -> Result<Self::Fact>;
195
196    fn output_fact(&self, id: usize) -> Result<Self::Fact>;
197
198    fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
199        Ok((0..self.input_count()?)
200            .map(|ix| self.input_fact(ix))
201            .collect::<Result<Vec<_>>>()?
202            .into_iter())
203    }
204
205    fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
206        Ok((0..self.output_count()?)
207            .map(|ix| self.output_fact(ix))
208            .collect::<Result<Vec<_>>>()?
209            .into_iter())
210    }
211
212    fn property_keys(&self) -> Result<Vec<String>>;
213    fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
214
215    fn spawn_state(&self) -> Result<Self::State>;
216
217    fn cost_json(&self) -> Result<String>;
218
219    fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
220    where
221        I: IntoIterator<Item = IV>,
222        IV: TryInto<Self::Tensor, Error = IE>,
223        IE: Into<anyhow::Error> + Debug;
224}
225
226pub trait StateInterface: Debug + Clone + Send {
227    type Fact: FactInterface;
228    type Tensor: TensorInterface;
229
230    fn input_count(&self) -> Result<usize>;
231    fn output_count(&self) -> Result<usize>;
232
233    fn run(&mut self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>>;
234}
235
236pub trait TensorInterface: Debug + Sized + Clone + PartialEq + Send + Sync {
237    fn datum_type(&self) -> Result<DatumType>;
238    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self>;
239    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])>;
240
241    fn from_slice<T: Datum>(shape: &[usize], data: &[T]) -> Result<Self> {
242        let data = unsafe {
243            std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
244        };
245        Self::from_bytes(T::datum_type(), shape, data)
246    }
247
248    fn as_slice<T: Datum>(&self) -> Result<&[T]> {
249        let (dt, _shape, data) = self.as_bytes()?;
250        ensure!(T::datum_type() == dt);
251        let data = unsafe {
252            std::slice::from_raw_parts(
253                data.as_ptr() as *const T,
254                data.len() / std::mem::size_of::<T>(),
255            )
256        };
257        Ok(data)
258    }
259
260    fn as_shape_and_slice<T: Datum>(&self) -> Result<(&[usize], &[T])> {
261        let (_, shape, _) = self.as_bytes()?;
262        let data = self.as_slice()?;
263        Ok((shape, data))
264    }
265
266    fn shape(&self) -> Result<&[usize]> {
267        let (_, shape, _) = self.as_bytes()?;
268        Ok(shape)
269    }
270
271    fn convert_to(&self, to: DatumType) -> Result<Self>;
272}
273
274pub trait FactInterface: Debug + Display + Clone {
275    type Dim: DimInterface;
276    fn datum_type(&self) -> Result<DatumType>;
277    fn rank(&self) -> Result<usize>;
278    fn dim(&self, axis: usize) -> Result<Self::Dim>;
279
280    fn dims(&self) -> Result<impl Iterator<Item = Self::Dim>> {
281        Ok((0..self.rank()?).map(|axis| self.dim(axis)).collect::<Result<Vec<_>>>()?.into_iter())
282    }
283}
284
285pub trait DimInterface: Debug + Display + Clone {
286    fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self>;
287    fn to_int64(&self) -> Result<i64>;
288}
289
290pub trait InferenceFactInterface: Debug + Display + Default + Clone {
291    fn empty() -> Result<Self>;
292}
293
294pub trait AsFact<M, F>: Debug {
295    fn as_fact(&self, model: &M) -> Result<Bow<'_, F>>;
296}
297
298#[repr(C)]
299#[derive(Debug, PartialEq, Eq, Copy, Clone)]
300pub enum DatumType {
301    Bool = 0x01,
302    U8 = 0x11,
303    U16 = 0x12,
304    U32 = 0x14,
305    U64 = 0x18,
306    I8 = 0x21,
307    I16 = 0x22,
308    I32 = 0x24,
309    I64 = 0x28,
310    F16 = 0x32,
311    F32 = 0x34,
312    F64 = 0x38,
313    #[cfg(feature = "complex")]
314    ComplexI16 = 0x42,
315    #[cfg(feature = "complex")]
316    ComplexI32 = 0x44,
317    #[cfg(feature = "complex")]
318    ComplexI64 = 0x48,
319    #[cfg(feature = "complex")]
320    ComplexF16 = 0x52,
321    #[cfg(feature = "complex")]
322    ComplexF32 = 0x54,
323    #[cfg(feature = "complex")]
324    ComplexF64 = 0x58,
325}
326
327impl DatumType {
328    pub fn size_of(&self) -> usize {
329        use DatumType::*;
330        match &self {
331            Bool | U8 | I8 => 1,
332            U16 | I16 | F16 => 2,
333            U32 | I32 | F32 => 4,
334            U64 | I64 | F64 => 8,
335            #[cfg(feature = "complex")]
336            ComplexI16 | ComplexF16 => 4,
337            #[cfg(feature = "complex")]
338            ComplexI32 | ComplexF32 => 8,
339            #[cfg(feature = "complex")]
340            ComplexI64 | ComplexF64 => 16,
341        }
342    }
343
344    pub fn is_bool(&self) -> bool {
345        *self == DatumType::Bool
346    }
347
348    pub fn is_number(&self) -> bool {
349        *self != DatumType::Bool
350    }
351
352    pub fn is_unsigned(&self) -> bool {
353        use DatumType::*;
354        *self == U8 || *self == U16 || *self == U32 || *self == U64
355    }
356
357    pub fn is_signed(&self) -> bool {
358        use DatumType::*;
359        *self == I8 || *self == I16 || *self == I32 || *self == I64
360    }
361
362    pub fn is_float(&self) -> bool {
363        use DatumType::*;
364        *self == F16 || *self == F32 || *self == F64
365    }
366}
367
368pub trait Datum {
369    fn datum_type() -> DatumType;
370}
371
372// IntoInputs trait — ergonomic input conversion for run()
373pub trait IntoInputs<V: TensorInterface> {
374    fn into_inputs(self) -> Result<Vec<V>>;
375}
376
377// Arrays of anything convertible to Tensor
378impl<V, T, E, const N: usize> IntoInputs<V> for [T; N]
379where
380    V: TensorInterface,
381    T: TryInto<V, Error = E>,
382    E: Into<anyhow::Error>,
383{
384    fn into_inputs(self) -> Result<Vec<V>> {
385        self.into_iter().map(|v| v.try_into().map_err(|e| e.into())).collect()
386    }
387}
388
389// Vec<V> passthrough
390impl<V: TensorInterface> IntoInputs<V> for Vec<V> {
391    fn into_inputs(self) -> Result<Vec<V>> {
392        Ok(self)
393    }
394}
395
396// Tuples — each element converts independently
397macro_rules! impl_into_inputs_tuple {
398    ($($idx:tt : $T:ident),+) => {
399        impl<V, $($T),+> IntoInputs<V> for ($($T,)+)
400        where
401            V: TensorInterface,
402            $($T: TryInto<V>,
403              <$T as TryInto<V>>::Error: Into<anyhow::Error>,)+
404        {
405            fn into_inputs(self) -> Result<Vec<V>> {
406                Ok(vec![$(self.$idx.try_into().map_err(|e| e.into())?),+])
407            }
408        }
409    };
410}
411
412impl_into_inputs_tuple!(0: A);
413impl_into_inputs_tuple!(0: A, 1: B);
414impl_into_inputs_tuple!(0: A, 1: B, 2: C);
415impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D);
416impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_);
417impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F);
418impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G);
419impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G, 7: H);
420
421/// Convert any compatible input into a `V: TensorInterface`.
422pub fn tensor<V, T, E>(v: T) -> Result<V>
423where
424    V: TensorInterface,
425    T: TryInto<V, Error = E>,
426    E: Into<anyhow::Error>,
427{
428    v.try_into().map_err(|e| e.into())
429}
430
431macro_rules! impl_datum_type {
432    ($ty:ty, $c_repr:expr) => {
433        impl Datum for $ty {
434            fn datum_type() -> DatumType {
435                $c_repr
436            }
437        }
438    };
439}
440
441impl_datum_type!(bool, DatumType::Bool);
442impl_datum_type!(u8, DatumType::U8);
443impl_datum_type!(u16, DatumType::U16);
444impl_datum_type!(u32, DatumType::U32);
445impl_datum_type!(u64, DatumType::U64);
446impl_datum_type!(i8, DatumType::I8);
447impl_datum_type!(i16, DatumType::I16);
448impl_datum_type!(i32, DatumType::I32);
449impl_datum_type!(i64, DatumType::I64);
450impl_datum_type!(half::f16, DatumType::F16);
451impl_datum_type!(f32, DatumType::F32);
452impl_datum_type!(f64, DatumType::F64);