mod common;
use oxionnx::{Attributes, OpKind, Tensor};
use common::{assert_close, assert_shape, run_op};
#[test]
fn conformance_conv2d_1x1() {
let mut attrs = Attributes::default();
attrs.int_lists.insert("strides".to_string(), vec![1, 1]);
attrs.int_lists.insert("pads".to_string(), vec![0, 0, 0, 0]);
attrs.int_lists.insert("dilations".to_string(), vec![1, 1]);
attrs.ints.insert("group".to_string(), 1);
let input = Tensor::new(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
vec![1, 2, 2, 2],
);
let kernel = Tensor::new(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], vec![3, 2, 1, 1]);
let out = run_op(
OpKind::Conv,
vec!["input", "kernel"],
vec!["out"],
vec!["input"],
vec![("input", input)],
vec![("kernel", kernel)],
attrs,
);
let t = out.get("out").unwrap();
assert_shape(t, &[1, 3, 2, 2], "conv2d_1x1");
assert_close(
&t.data,
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 6.0, 8.0, 10.0, 12.0],
1e-5,
"conv2d_1x1",
);
}
#[test]
fn conformance_conv2d_3x3_pad1() {
let mut attrs = Attributes::default();
attrs.int_lists.insert("strides".to_string(), vec![1, 1]);
attrs.int_lists.insert("pads".to_string(), vec![1, 1, 1, 1]);
attrs.int_lists.insert("dilations".to_string(), vec![1, 1]);
attrs.ints.insert("group".to_string(), 1);
let input = Tensor::new(vec![1.0; 9], vec![1, 1, 3, 3]);
let kernel = Tensor::new(vec![1.0; 9], vec![1, 1, 3, 3]);
let out = run_op(
OpKind::Conv,
vec!["input", "kernel"],
vec!["out"],
vec!["input"],
vec![("input", input)],
vec![("kernel", kernel)],
attrs,
);
let t = out.get("out").unwrap();
assert_shape(t, &[1, 1, 3, 3], "conv2d_3x3_pad1");
assert_close(
&t.data,
&[4.0, 6.0, 4.0, 6.0, 9.0, 6.0, 4.0, 6.0, 4.0],
1e-5,
"conv2d_3x3_pad1",
);
}
#[test]
fn conformance_maxpool_2x2() {
let mut attrs = Attributes::default();
attrs
.int_lists
.insert("kernel_shape".to_string(), vec![2, 2]);
attrs.int_lists.insert("strides".to_string(), vec![2, 2]);
attrs.int_lists.insert("pads".to_string(), vec![0, 0, 0, 0]);
let input_data: Vec<f32> = (1..=16).map(|v| v as f32).collect();
let input = Tensor::new(input_data, vec![1, 1, 4, 4]);
let out = run_op(
OpKind::MaxPool,
vec!["input"],
vec!["out"],
vec!["input"],
vec![("input", input)],
vec![],
attrs,
);
let t = out.get("out").unwrap();
assert_shape(t, &[1, 1, 2, 2], "maxpool_2x2");
assert_close(&t.data, &[6.0, 8.0, 14.0, 16.0], 1e-5, "maxpool_2x2");
}