cervo_nnef/
lib.rs

1/*! Contains utilities for using cervo with NNEF.
2
3If you're going to defer loading NNEF files to runtime, consider
4running [`init`] ahead of time to remove some overhead from the first
5load call.
6
7
8## Loading an inference model
9```no_run
10# fn load_bytes(s: &str) -> std::io::Cursor<Vec<u8>> { std::io::Cursor::new(vec![]) }
11use cervo_core::prelude::InfererExt;
12
13let model_data = load_bytes("model.nnef");
14let model = cervo_nnef::builder(model_data)
15    .build_fixed(&[2])?
16    .with_default_epsilon("epsilon");
17# Ok::<(), Box<dyn std::error::Error>>(())
18```
19
20*/
21
22use anyhow::Result;
23use cervo_core::prelude::{
24    BasicInferer, DynamicInferer, FixedBatchInferer, InfererBuilder, InfererProvider,
25    MemoizingDynamicInferer,
26};
27use std::{
28    ffi::OsStr,
29    io::Read,
30    path::{Path, PathBuf},
31};
32use tract_nnef::{framework::Nnef, prelude::*};
33
34lazy_static::lazy_static! {
35    static ref NNEF: Nnef  = {
36        tract_nnef::nnef().with_tract_core()
37    };
38}
39
40/// Initialize the global NNEF instance.
41///
42/// To ensure fast loading cervo uses a shared instance of the
43/// tract NNEF framework. If you don't want to pay for initialization
44/// on first-time load you can call this earlier to ensure it's set up
45/// ahead of time.
46pub fn init() {
47    use lazy_static::LazyStatic;
48    NNEF::initialize(&NNEF);
49}
50
51/// Utility function to check if a file name is `.nnef.tar`.
52pub fn is_nnef_tar(path: &Path) -> bool {
53    if let Some(ext) = path.extension().and_then(OsStr::to_str) {
54        if ext != "tar" {
55            return false;
56        }
57
58        let stem = match path.file_stem().and_then(OsStr::to_str).map(PathBuf::from) {
59            Some(p) => p,
60            None => return false,
61        };
62
63        if let Some(ext) = stem.extension().and_then(OsStr::to_str) {
64            return ext == "nnef";
65        }
66    }
67
68    false
69}
70
71fn model_for_reader(reader: &mut dyn Read) -> Result<TypedModel> {
72    NNEF.model_for_read(reader)
73}
74
75/// A reader for providing NNEF data.
76pub struct NnefData<T: Read>(pub T);
77
78impl<T> NnefData<T>
79where
80    T: Read,
81{
82    fn load(&mut self) -> Result<TypedModel> {
83        model_for_reader(&mut self.0)
84    }
85}
86
87impl<T> InfererProvider for NnefData<T>
88where
89    T: Read,
90{
91    /// Build a [`BasicInferer`].
92    fn build_basic(mut self) -> Result<BasicInferer> {
93        let model = self.load()?;
94        BasicInferer::from_typed(model)
95    }
96
97    /// Build a [`BasicInferer`].
98    fn build_fixed(mut self, sizes: &[usize]) -> Result<FixedBatchInferer> {
99        let model = self.load()?;
100        FixedBatchInferer::from_typed(model, sizes)
101    }
102
103    /// Build a [`MemoizingDynamicInferer`].
104    fn build_memoizing(mut self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer> {
105        let model = self.load()?;
106        MemoizingDynamicInferer::from_typed(model, preload_sizes)
107    }
108
109    /// Build a [`DynamicInferer`].
110    fn build_dynamic(mut self) -> Result<DynamicInferer> {
111        let model = self.load()?;
112        DynamicInferer::from_typed(model)
113    }
114}
115
116/// Utility function for creating an [`InfererBuilder`] for [`NnefData`].
117pub fn builder<T: Read>(read: T) -> InfererBuilder<NnefData<T>> {
118    InfererBuilder::new(NnefData(read))
119}