1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
extern crate libc;
#[cfg(feature = "proto")]
extern crate onnx_pb;
#[cfg(feature = "proto")]
extern crate prost;
use libc::size_t;
#[cfg(feature = "proto")]
mod proto;
#[cfg(feature = "proto")]
pub use self::proto::*;
#[link(name = "onnx", kind = "static")]
extern "C" {
fn onnx_proto_shape_inference(buffer: *const u8, size: size_t, out: *mut u8) -> size_t;
}
const OUTPUT_SIZE_MULTIPLIER: usize = 10;
pub fn shape_inference_proto(body: &[u8]) -> Vec<u8> {
let capacity = body.len() * OUTPUT_SIZE_MULTIPLIER;
let mut output = Vec::with_capacity(capacity);
unsafe {
output.set_len(capacity);
let out_size = onnx_proto_shape_inference(body.as_ptr(), body.len(), output.as_mut_ptr());
output.truncate(out_size);
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "proto")]
#[test]
fn inference() {
fn open_model<P: AsRef<std::path::Path>>(path: P) -> onnx_pb::ModelProto {
use prost::Message;
let body = read_buf(path);
onnx_pb::ModelProto::decode(body.as_slice()).unwrap()
}
let buffer = open_model("tests/model.onnx");
let inferred = open_model("tests/model-inferred.onnx");
let output = shape_inference(&buffer).unwrap();
assert_eq!(output, inferred);
let buffer = open_model("tests/mean-reverse.onnx");
let inferred = open_model("tests/mean-reverse-inferred.onnx");
let output = shape_inference(&buffer).unwrap();
assert_eq!(output, inferred);
}
#[test]
fn inference_proto() {
let buffer = read_buf("tests/model.onnx");
let inferred = read_buf("tests/model-inferred.onnx");
let output = shape_inference_proto(buffer.as_slice());
assert_eq!(output, inferred);
let buffer = read_buf("tests/mean-reverse.onnx");
let inferred = read_buf("tests/mean-reverse-inferred.onnx");
let output = shape_inference_proto(buffer.as_slice());
assert_eq!(output, inferred);
}
fn read_buf<P: AsRef<std::path::Path>>(path: P) -> Vec<u8> {
use std::io::Read;
let mut file = std::fs::File::open(path).unwrap();
let mut buffer = Vec::new();
file.read_to_end(&mut buffer).unwrap();
buffer
}
}