ensemble/
ensemble.rs

1use onnx_helpers::prelude::*;
2use onnx_pb::{save_model, tensor_proto::DataType};
3
4fn main() {
5    let mut graph = builder::Graph::new("stddev");
6    let x = graph.input("X").typed(DataType::Float).dim(1).dim(6).node();
7    let std = stddev(&mut graph, &x);
8    let mrev = mean_reverse(&mut graph, &x);
9    let graph = graph
10        .outputs_typed(std.with_name("stddev"), DataType::Float)
11        .outputs_typed(mrev.with_name("mean_reverse"), DataType::Float);
12    let model = graph.model().build();
13    save_model("ensemble.onnx", &model).unwrap();
14}
15
16fn stddev(graph: &mut builder::Graph, x: &Node) -> Node {
17    let two = graph.constant("two", 2.0f32);
18    (x - x.mean(1, true)).abs().pow(two).mean(1, true).sqrt()
19}
20
21fn mean_reverse(graph: &mut builder::Graph, x: &Node) -> Node {
22    let two = graph.constant("two", 2.0f32);
23    -(x - x.mean(1, true)) * two + x
24}