1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
/*! Contains utilities for using cervo with NNEF.
If you're going to defer loading NNEF files to runtime, consider
running [`init`] ahead of time to remove some overhead from the first
load call.
## Loading an inference model
```no_run
# fn load_bytes(s: &str) -> std::io::Cursor<Vec<u8>> { std::io::Cursor::new(vec![]) }
use cervo_core::prelude::InfererExt;
let model_data = load_bytes("model.nnef");
let model = cervo_nnef::builder(model_data)
.build_fixed(&[2])?
.with_default_epsilon("epsilon");
# Ok::<(), Box<dyn std::error::Error>>(())
```
*/
use anyhow::Result;
use cervo_core::prelude::{
BasicInferer, DynamicInferer, FixedBatchInferer, InfererBuilder, InfererProvider,
MemoizingDynamicInferer,
};
use std::{
ffi::OsStr,
io::Read,
path::{Path, PathBuf},
};
use tract_nnef::{framework::Nnef, prelude::*};
lazy_static::lazy_static! {
static ref NNEF: Nnef = {
tract_nnef::nnef().with_tract_core()
};
}
/// Initialize the global NNEF instance.
///
/// To ensure fast loading cervo uses a shared instance of the
/// tract NNEF framework. If you don't want to pay for initialization
/// on first-time load you can call this earlier to ensure it's set up
/// ahead of time.
pub fn init() {
use lazy_static::LazyStatic;
NNEF::initialize(&NNEF);
}
/// Utility function to check if a file name is `.nnef.tar`.
pub fn is_nnef_tar(path: &Path) -> bool {
if let Some(ext) = path.extension().and_then(OsStr::to_str) {
if ext != "tar" {
return false;
}
let stem = match path.file_stem().and_then(OsStr::to_str).map(PathBuf::from) {
Some(p) => p,
None => return false,
};
if let Some(ext) = stem.extension().and_then(OsStr::to_str) {
return ext == "nnef";
}
}
false
}
fn model_for_reader(reader: &mut dyn Read) -> Result<TypedModel> {
NNEF.model_for_read(reader)
}
/// A reader for providing NNEF data.
pub struct NnefData<T: Read>(pub T);
impl<T> NnefData<T>
where
T: Read,
{
fn load(&mut self) -> Result<TypedModel> {
model_for_reader(&mut self.0)
}
}
impl<T> InfererProvider for NnefData<T>
where
T: Read,
{
/// Build a [`BasicInferer`].
fn build_basic(mut self) -> Result<BasicInferer> {
let model = self.load()?;
BasicInferer::from_typed(model)
}
/// Build a [`BasicInferer`].
fn build_fixed(mut self, sizes: &[usize]) -> Result<FixedBatchInferer> {
let model = self.load()?;
FixedBatchInferer::from_typed(model, sizes)
}
/// Build a [`MemoizingDynamicInferer`].
fn build_memoizing(mut self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer> {
let model = self.load()?;
MemoizingDynamicInferer::from_typed(model, preload_sizes)
}
/// Build a [`DynamicInferer`].
fn build_dynamic(mut self) -> Result<DynamicInferer> {
let model = self.load()?;
DynamicInferer::from_typed(model)
}
}
/// Utility function for creating an [`InfererBuilder`] for [`NnefData`].
pub fn builder<T: Read>(read: T) -> InfererBuilder<NnefData<T>> {
InfererBuilder::new(NnefData(read))
}