mod common;
use std::collections::HashMap;
use oxionnx::{Attributes, Graph, OpKind, OptLevel, Session, Tensor};
use common::{assert_close, assert_shape, make_node_with_attrs, run_op};
#[test]
fn conformance_reshape() {
let x = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
let shape = Tensor::new(vec![2.0, 3.0], vec![2]);
let out = run_op(
OpKind::Reshape,
vec!["x", "shape"],
vec!["out"],
vec!["x"],
vec![("x", x)],
vec![("shape", shape)],
Attributes::default(),
);
let t = out.get("out").unwrap();
assert_shape(t, &[2, 3], "reshape");
assert_close(&t.data, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 1e-5, "reshape");
}
#[test]
fn conformance_transpose() {
let mut attrs = Attributes::default();
attrs.int_lists.insert("perm".to_string(), vec![1, 0]);
let x = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let out = run_op(
OpKind::Transpose,
vec!["x"],
vec!["out"],
vec!["x"],
vec![("x", x)],
vec![],
attrs,
);
let t = out.get("out").unwrap();
assert_shape(t, &[3, 2], "transpose");
assert_close(&t.data, &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 1e-5, "transpose");
}
#[test]
fn conformance_concat_axis1() {
let mut attrs = Attributes::default();
attrs.ints.insert("axis".to_string(), 1);
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let b = Tensor::new(vec![5.0, 6.0], vec![2, 1]);
let node = make_node_with_attrs(OpKind::Concat, "op0", &["a", "b"], &["out"], attrs);
let graph = Graph {
nodes: vec![node],
input_names: vec!["a".to_string(), "b".to_string()],
output_names: vec!["out".to_string()],
..Default::default()
};
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, HashMap::new())
.expect("build session");
let mut feed: HashMap<&str, Tensor> = HashMap::new();
feed.insert("a", a);
feed.insert("b", b);
let out = session.run(&feed).expect("run");
let t = out.get("out").unwrap();
assert_shape(t, &[2, 3], "concat_axis1");
assert_close(
&t.data,
&[1.0, 2.0, 5.0, 3.0, 4.0, 6.0],
1e-5,
"concat_axis1",
);
}
#[test]
fn conformance_squeeze() {
let mut attrs = Attributes::default();
attrs.int_lists.insert("axes".to_string(), vec![0, 2]);
let x = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3, 1]);
let out = run_op(
OpKind::Squeeze,
vec!["x"],
vec!["out"],
vec!["x"],
vec![("x", x)],
vec![],
attrs,
);
let t = out.get("out").unwrap();
assert_shape(t, &[3], "squeeze");
assert_close(&t.data, &[1.0, 2.0, 3.0], 1e-5, "squeeze");
}
#[test]
fn conformance_unsqueeze() {
let mut attrs = Attributes::default();
attrs.int_lists.insert("axes".to_string(), vec![0, 2]);
let x = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]);
let out = run_op(
OpKind::Unsqueeze,
vec!["x"],
vec!["out"],
vec!["x"],
vec![("x", x)],
vec![],
attrs,
);
let t = out.get("out").unwrap();
assert_shape(t, &[1, 3, 1], "unsqueeze");
assert_close(&t.data, &[1.0, 2.0, 3.0], 1e-5, "unsqueeze");
}
#[test]
fn conformance_flatten() {
let mut attrs = Attributes::default();
attrs.ints.insert("axis".to_string(), 1);
let data: Vec<f32> = (0..24).map(|v| v as f32).collect();
let x = Tensor::new(data.clone(), vec![2, 3, 4]);
let out = run_op(
OpKind::Flatten,
vec!["x"],
vec!["out"],
vec!["x"],
vec![("x", x)],
vec![],
attrs,
);
let t = out.get("out").unwrap();
assert_shape(t, &[2, 12], "flatten");
assert_close(&t.data, &data, 1e-5, "flatten");
}
#[test]
fn conformance_slice() {
let x = Tensor::new(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], vec![8]);
let starts = Tensor::new(vec![1.0], vec![1]);
let ends = Tensor::new(vec![7.0], vec![1]);
let axes = Tensor::new(vec![0.0], vec![1]);
let steps = Tensor::new(vec![2.0], vec![1]);
let out = run_op(
OpKind::Slice,
vec!["x", "starts", "ends", "axes", "steps"],
vec!["out"],
vec!["x"],
vec![("x", x)],
vec![
("starts", starts),
("ends", ends),
("axes", axes),
("steps", steps),
],
Attributes::default(),
);
let t = out.get("out").unwrap();
assert_close(&t.data, &[1.0, 3.0, 5.0], 1e-5, "slice");
}