1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
pub mod attrs;
pub mod builder;
pub mod nodes;
pub mod prelude {
pub use crate::attrs::*;
pub use crate::builder;
pub use crate::nodes::ops::*;
pub use crate::nodes::*;
}
#[cfg(test)]
mod tests {
use super::*;
use prost::Message;
use onnx_pb::{tensor_proto::DataType, ModelProto};
#[test]
fn compare_with_py_output() {
let from_python = ModelProto::decode(&read_buf("tests/model.onnx")).unwrap();
let x_input = builder::Value::new("X")
.typed(DataType::Float)
.shape(vec![1, 10]);
let mean_reduce = builder::Node::new("ReduceMean")
.input("X")
.output("Z")
.attribute("axes", vec![1i64]);
let graph = builder::Graph::new("reduce-mean")
.node(mean_reduce)
.input(x_input);
let model = builder::Model::new(graph).producer_name("reducer").build();
assert_eq!(model, from_python);
}
fn read_buf<P: AsRef<std::path::Path>>(path: P) -> Vec<u8> {
use std::io::Read;
let mut file = std::fs::File::open(path).unwrap();
let mut buffer = Vec::new();
file.read_to_end(&mut buffer).unwrap();
buffer
}
}