onnx_shape_inference/
lib.rs

1//! ONNX Shape inference helper binding.
2//!
3//! Resources used to implement:
4//!  * https://github.com/onnx/onnx/blob/master/onnx/cpp2py_export.cc#L295
5
6extern crate libc;
7
8#[cfg(feature = "proto")]
9extern crate onnx_pb;
10#[cfg(feature = "proto")]
11extern crate prost;
12
13use libc::size_t;
14
15#[cfg(feature = "proto")]
16mod proto;
17#[cfg(feature = "proto")]
18pub use self::proto::*;
19
20#[link(name = "onnx", kind = "static")]
21extern "C" {
22    fn onnx_proto_shape_inference(buffer: *const u8, size: size_t, out: *mut u8) -> size_t;
23}
24
25const OUTPUT_SIZE_MULTIPLIER: usize = 10;
26
27/// Infers model shapes accepting and returning protocol buffers model.
28pub fn shape_inference_proto(body: &[u8]) -> Vec<u8> {
29    let capacity = body.len() * OUTPUT_SIZE_MULTIPLIER;
30    let mut output = Vec::with_capacity(capacity);
31    unsafe {
32        output.set_len(capacity);
33        let out_size = onnx_proto_shape_inference(body.as_ptr(), body.len(), output.as_mut_ptr());
34        output.truncate(out_size);
35    }
36    output
37}
38
39#[cfg(test)]
40mod tests {
41    use super::*;
42
43    #[cfg(feature = "proto")]
44    #[test]
45    fn inference() {
46        fn open_model<P: AsRef<std::path::Path>>(path: P) -> onnx_pb::ModelProto {
47            use prost::Message;
48            let body = read_buf(path);
49            onnx_pb::ModelProto::decode(body.as_slice()).unwrap()
50        }
51        let buffer = open_model("tests/model.onnx");
52        let inferred = open_model("tests/model-inferred.onnx");
53        let output = shape_inference(&buffer).unwrap();
54        assert_eq!(output, inferred);
55
56        let buffer = open_model("tests/mean-reverse.onnx");
57        let inferred = open_model("tests/mean-reverse-inferred.onnx");
58        let output = shape_inference(&buffer).unwrap();
59        assert_eq!(output, inferred);
60    }
61
62    #[test]
63    fn inference_proto() {
64        let buffer = read_buf("tests/model.onnx");
65        let inferred = read_buf("tests/model-inferred.onnx");
66        let output = shape_inference_proto(buffer.as_slice());
67        assert_eq!(output, inferred);
68
69        let buffer = read_buf("tests/mean-reverse.onnx");
70        let inferred = read_buf("tests/mean-reverse-inferred.onnx");
71        let output = shape_inference_proto(buffer.as_slice());
72        assert_eq!(output, inferred);
73    }
74
75    fn read_buf<P: AsRef<std::path::Path>>(path: P) -> Vec<u8> {
76        use std::io::Read;
77        let mut file = std::fs::File::open(path).unwrap();
78        let mut buffer = Vec::new();
79        // read the whole file
80        file.read_to_end(&mut buffer).unwrap();
81        buffer
82    }
83}