1use anyhow::Result;
23use cervo_core::prelude::{
24 BasicInferer, DynamicInferer, FixedBatchInferer, InfererBuilder, InfererProvider,
25 MemoizingDynamicInferer,
26};
27use std::{
28 ffi::OsStr,
29 io::Read,
30 path::{Path, PathBuf},
31};
32use tract_nnef::{framework::Nnef, prelude::*};
33
34lazy_static::lazy_static! {
35 static ref NNEF: Nnef = {
36 tract_nnef::nnef().with_tract_core()
37 };
38}
39
40pub fn init() {
47 use lazy_static::LazyStatic;
48 NNEF::initialize(&NNEF);
49}
50
51pub fn is_nnef_tar(path: &Path) -> bool {
53 if let Some(ext) = path.extension().and_then(OsStr::to_str) {
54 if ext != "tar" {
55 return false;
56 }
57
58 let stem = match path.file_stem().and_then(OsStr::to_str).map(PathBuf::from) {
59 Some(p) => p,
60 None => return false,
61 };
62
63 if let Some(ext) = stem.extension().and_then(OsStr::to_str) {
64 return ext == "nnef";
65 }
66 }
67
68 false
69}
70
71fn model_for_reader(reader: &mut dyn Read) -> Result<TypedModel> {
72 NNEF.model_for_read(reader)
73}
74
75pub struct NnefData<T: Read>(pub T);
77
78impl<T> NnefData<T>
79where
80 T: Read,
81{
82 fn load(&mut self) -> Result<TypedModel> {
83 model_for_reader(&mut self.0)
84 }
85}
86
87impl<T> InfererProvider for NnefData<T>
88where
89 T: Read,
90{
91 fn build_basic(mut self) -> Result<BasicInferer> {
93 let model = self.load()?;
94 BasicInferer::from_typed(model)
95 }
96
97 fn build_fixed(mut self, sizes: &[usize]) -> Result<FixedBatchInferer> {
99 let model = self.load()?;
100 FixedBatchInferer::from_typed(model, sizes)
101 }
102
103 fn build_memoizing(mut self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer> {
105 let model = self.load()?;
106 MemoizingDynamicInferer::from_typed(model, preload_sizes)
107 }
108
109 fn build_dynamic(mut self) -> Result<DynamicInferer> {
111 let model = self.load()?;
112 DynamicInferer::from_typed(model)
113 }
114}
115
116pub fn builder<T: Read>(read: T) -> InfererBuilder<NnefData<T>> {
118 InfererBuilder::new(NnefData(read))
119}