onnx_shape_inference/
lib.rs1extern crate libc;
7
8#[cfg(feature = "proto")]
9extern crate onnx_pb;
10#[cfg(feature = "proto")]
11extern crate prost;
12
13use libc::size_t;
14
15#[cfg(feature = "proto")]
16mod proto;
17#[cfg(feature = "proto")]
18pub use self::proto::*;
19
20#[link(name = "onnx", kind = "static")]
21extern "C" {
22 fn onnx_proto_shape_inference(buffer: *const u8, size: size_t, out: *mut u8) -> size_t;
23}
24
25const OUTPUT_SIZE_MULTIPLIER: usize = 10;
26
27pub fn shape_inference_proto(body: &[u8]) -> Vec<u8> {
29 let capacity = body.len() * OUTPUT_SIZE_MULTIPLIER;
30 let mut output = Vec::with_capacity(capacity);
31 unsafe {
32 output.set_len(capacity);
33 let out_size = onnx_proto_shape_inference(body.as_ptr(), body.len(), output.as_mut_ptr());
34 output.truncate(out_size);
35 }
36 output
37}
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42
43 #[cfg(feature = "proto")]
44 #[test]
45 fn inference() {
46 fn open_model<P: AsRef<std::path::Path>>(path: P) -> onnx_pb::ModelProto {
47 use prost::Message;
48 let body = read_buf(path);
49 onnx_pb::ModelProto::decode(body.as_slice()).unwrap()
50 }
51 let buffer = open_model("tests/model.onnx");
52 let inferred = open_model("tests/model-inferred.onnx");
53 let output = shape_inference(&buffer).unwrap();
54 assert_eq!(output, inferred);
55
56 let buffer = open_model("tests/mean-reverse.onnx");
57 let inferred = open_model("tests/mean-reverse-inferred.onnx");
58 let output = shape_inference(&buffer).unwrap();
59 assert_eq!(output, inferred);
60 }
61
62 #[test]
63 fn inference_proto() {
64 let buffer = read_buf("tests/model.onnx");
65 let inferred = read_buf("tests/model-inferred.onnx");
66 let output = shape_inference_proto(buffer.as_slice());
67 assert_eq!(output, inferred);
68
69 let buffer = read_buf("tests/mean-reverse.onnx");
70 let inferred = read_buf("tests/mean-reverse-inferred.onnx");
71 let output = shape_inference_proto(buffer.as_slice());
72 assert_eq!(output, inferred);
73 }
74
75 fn read_buf<P: AsRef<std::path::Path>>(path: P) -> Vec<u8> {
76 use std::io::Read;
77 let mut file = std::fs::File::open(path).unwrap();
78 let mut buffer = Vec::new();
79 file.read_to_end(&mut buffer).unwrap();
81 buffer
82 }
83}