1use anyhow::Result;
27
28use cervo_core::prelude::{
29 BasicInferer, DynamicInferer, FixedBatchInferer, MemoizingDynamicInferer,
30 {InfererBuilder, InfererProvider},
31};
32use std::io::Read;
33use tract_onnx::{prelude::*, tract_hir::infer::Factoid};
34
35#[doc(hidden)]
36pub use tract_onnx;
37
38fn model_for_reader(reader: &mut dyn Read) -> Result<InferenceModel> {
39 let onnx = tract_onnx::onnx().with_ignore_output_shapes(false);
40 onnx.model_for_read(reader)
41}
42
43pub struct OnnxData<T: Read>(pub T);
45
46impl<T> OnnxData<T>
47where
48 T: Read,
49{
50 fn load(&mut self) -> Result<InferenceModel> {
51 model_for_reader(&mut self.0)
52 }
53}
54
55impl<T> InfererProvider for OnnxData<T>
56where
57 T: Read,
58{
59 fn build_basic(mut self) -> Result<BasicInferer> {
61 let model = self.load()?;
62 BasicInferer::from_model(model)
63 }
64
65 fn build_fixed(mut self, sizes: &[usize]) -> Result<FixedBatchInferer> {
67 let model = self.load()?;
68 FixedBatchInferer::from_model(model, sizes)
69 }
70
71 fn build_memoizing(mut self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer> {
73 let model = self.load()?;
74 MemoizingDynamicInferer::from_model(model, preload_sizes)
75 }
76
77 fn build_dynamic(mut self) -> Result<DynamicInferer> {
79 let model = self.load()?;
80 DynamicInferer::from_model(model)
81 }
82}
83
84pub fn builder<T: Read>(read: T) -> InfererBuilder<OnnxData<T>> {
86 InfererBuilder::new(OnnxData(read))
87}
88
89pub fn to_nnef(reader: &mut dyn Read, batch_size: Option<usize>) -> Result<Vec<u8>> {
91 let mut model = model_for_reader(reader)?;
92 let outlets = model.output_outlets().unwrap().len();
93 for output in 0..outlets {
94 model.set_output_fact(output, Default::default())?;
95 }
96
97 let symbol = model.symbols.sym("N");
98 let input_outlets = model.input_outlets()?.to_vec();
99
100 let batch = batch_size
101 .map(|b| b.to_dim())
102 .unwrap_or(symbol.clone().to_dim());
103
104 for input_outlet in input_outlets {
105 let input_shape = &model.input_fact(input_outlet.node)?.shape;
106
107 let mut shape: Vec<_> = input_shape
108 .dims()
109 .skip(1)
110 .map(|fact| fact.concretize().unwrap())
111 .collect();
112
113 shape.insert(0, batch.clone());
114
115 model.set_input_fact(
116 input_outlet.node,
117 InferenceFact::dt_shape(DatumType::F32, &shape),
118 )?;
119 }
120
121 let model = model.into_typed()?.into_decluttered()?;
122 let mut output = vec![];
123 let nnef = tract_nnef::nnef().with_tract_core().with_onnx();
124
125 nnef.write_to_tar_with_config(&model, &mut output, false, true)?;
127 Ok(output)
128}