1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
extern crate libc;
use libc::size_t;
#[link(name = "onnx", kind = "static")]
extern "C" {
fn onnx_proto_shape_inference(buffer: *const u8, size: size_t, out: *mut u8) -> size_t;
}
const OUTPUT_SIZE_MULTIPLIER: usize = 10;
pub fn infer_shapes_proto(body: &[u8]) -> Vec<u8> {
let capacity = body.len() * OUTPUT_SIZE_MULTIPLIER;
let mut output = Vec::with_capacity(capacity);
unsafe {
output.set_len(capacity);
let out_size = onnx_proto_shape_inference(body.as_ptr(), body.len(), output.as_mut_ptr());
output.truncate(out_size);
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_proto() {
let buffer = read_buf("tests/model.onnx");
let inferred = read_buf("tests/model-inferred.onnx");
let output = infer_shapes_proto(buffer.as_slice());
assert_eq!(output, inferred);
}
fn read_buf<P: AsRef<std::path::Path>>(path: P) -> Vec<u8> {
use std::io::Read;
let mut file = std::fs::File::open(path).unwrap();
let mut buffer = Vec::new();
file.read_to_end(&mut buffer).unwrap();
buffer
}
}