cervo_onnx/
lib.rs

1/*! Contains utilities for using cervo with ONNX.
2
3## Loading an inference model
4```no_run
5# fn load_bytes(s: &str) -> std::io::Cursor<Vec<u8>> { std::io::Cursor::new(vec![]) }
6use cervo_core::prelude::InfererExt;
7
8let model_data = load_bytes("model.onnx");
9let model = cervo_onnx::builder(model_data)
10    .build_memoizing(&[])?
11    .with_default_epsilon("epsilon");
12# Ok::<(), Box<dyn std::error::Error>>(())
13```
14
15## Converting to NNEF
16```no_run
17# fn load_bytes(s: &str) -> std::io::Cursor<Vec<u8>> { std::io::Cursor::new(vec![]) }
18use cervo_core::prelude::InfererExt;
19
20let mut onnx_data = load_bytes("model.onnx");
21let nnef_data = cervo_onnx::to_nnef(&mut onnx_data, None);
22# Ok::<(), Box<dyn std::error::Error>>(())
23```
24 */
25
26use anyhow::Result;
27
28use cervo_core::prelude::{
29    BasicInferer, DynamicInferer, FixedBatchInferer, MemoizingDynamicInferer,
30    {InfererBuilder, InfererProvider},
31};
32use std::io::Read;
33use tract_onnx::{prelude::*, tract_hir::infer::Factoid};
34
35#[doc(hidden)]
36pub use tract_onnx;
37
38fn model_for_reader(reader: &mut dyn Read) -> Result<InferenceModel> {
39    let onnx = tract_onnx::onnx().with_ignore_output_shapes(false);
40    onnx.model_for_read(reader)
41}
42
43/// Wrapper for a reader providing ONNX data.
44pub struct OnnxData<T: Read>(pub T);
45
46impl<T> OnnxData<T>
47where
48    T: Read,
49{
50    fn load(&mut self) -> Result<InferenceModel> {
51        model_for_reader(&mut self.0)
52    }
53}
54
55impl<T> InfererProvider for OnnxData<T>
56where
57    T: Read,
58{
59    /// Build a [`BasicInferer`].
60    fn build_basic(mut self) -> Result<BasicInferer> {
61        let model = self.load()?;
62        BasicInferer::from_model(model)
63    }
64
65    /// Build a [`BasicInferer`].
66    fn build_fixed(mut self, sizes: &[usize]) -> Result<FixedBatchInferer> {
67        let model = self.load()?;
68        FixedBatchInferer::from_model(model, sizes)
69    }
70
71    /// Build a [`MemoizingDynamicInferer`].
72    fn build_memoizing(mut self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer> {
73        let model = self.load()?;
74        MemoizingDynamicInferer::from_model(model, preload_sizes)
75    }
76
77    /// Build a [`DynamicInferer`].
78    fn build_dynamic(mut self) -> Result<DynamicInferer> {
79        let model = self.load()?;
80        DynamicInferer::from_model(model)
81    }
82}
83
84/// Utility function for creating an [`InfererBuilder`] for [`OnnxData`].
85pub fn builder<T: Read>(read: T) -> InfererBuilder<OnnxData<T>> {
86    InfererBuilder::new(OnnxData(read))
87}
88
89/// Convert an ONNX model to a NNEF model.
90pub fn to_nnef(reader: &mut dyn Read, batch_size: Option<usize>) -> Result<Vec<u8>> {
91    let mut model = model_for_reader(reader)?;
92    let outlets = model.output_outlets().unwrap().len();
93    for output in 0..outlets {
94        model.set_output_fact(output, Default::default())?;
95    }
96
97    let symbol = model.symbols.sym("N");
98    let input_outlets = model.input_outlets()?.to_vec();
99
100    let batch = batch_size
101        .map(|b| b.to_dim())
102        .unwrap_or(symbol.clone().to_dim());
103
104    for input_outlet in input_outlets {
105        let input_shape = &model.input_fact(input_outlet.node)?.shape;
106
107        let mut shape: Vec<_> = input_shape
108            .dims()
109            .skip(1)
110            .map(|fact| fact.concretize().unwrap())
111            .collect();
112
113        shape.insert(0, batch.clone());
114
115        model.set_input_fact(
116            input_outlet.node,
117            InferenceFact::dt_shape(DatumType::F32, &shape),
118        )?;
119    }
120
121    let model = model.into_typed()?.into_decluttered()?;
122    let mut output = vec![];
123    let nnef = tract_nnef::nnef().with_tract_core().with_onnx();
124
125    // bools => compress, deterministic
126    nnef.write_to_tar_with_config(&model, &mut output, false, true)?;
127    Ok(output)
128}