onnx_helpers/
lib.rs

1//! ONNX model construction helpers.
2
3pub mod builder;
4pub mod nodes;
5
6pub mod prelude {
7    pub use crate::builder;
8    pub use crate::nodes::ops::*;
9    pub use crate::nodes::*;
10}
11
12#[cfg(test)]
13mod tests {
14    use super::*;
15
16    use onnx_pb::{open_model, tensor_proto::DataType};
17
18    #[test]
19    fn compare_with_prev_output() {
20        let prev_output = open_model("tests/mean-reverse.onnx").unwrap();
21        let mut graph = builder::Graph::new("reverse");
22        let x = graph.input("X").typed(DataType::Float).dim(1).dim(6).node();
23        let two = graph.constant("two", 2.0f32);
24        let graph = graph.outputs_typed(-(&x - x.mean(1, true)) * two + x, DataType::Float);
25        let model = graph.model().build();
26        assert_eq!(model, prev_output);
27    }
28}