use super::*;
use crate::optimizer::shape_inference::{infer_shapes, infer_shapes_with_diagnostics};
use crate::optimizer::test_utils::make_node;
fn shapes_map(pairs: &[(&str, Vec<usize>)]) -> HashMap<String, Vec<usize>> {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect()
}
fn weights_map(pairs: &[(&str, Vec<f32>, Vec<usize>)]) -> HashMap<String, Tensor> {
pairs
.iter()
.map(|(k, data, shape)| (k.to_string(), Tensor::new(data.clone(), shape.clone())))
.collect()
}
#[test]
fn test_shape_inference_global_avg_pool() {
let node = make_node(OpKind::GlobalAveragePool, "gap", vec!["x"], vec!["y"]);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![1, 3, 8, 8])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![1, 3, 1, 1]));
}
#[test]
fn test_shape_inference_depth_to_space() {
let mut node = make_node(OpKind::DepthToSpace, "d2s", vec!["x"], vec!["y"]);
node.attrs.ints.insert("blocksize".to_string(), 2);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![1, 12, 2, 3])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![1, 3, 4, 6]));
}
#[test]
fn test_shape_inference_comparison_ops() {
let node = make_node(OpKind::Equal, "eq", vec!["a", "b"], vec!["c"]);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("a", vec![2, 3]), ("b", vec![3])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("c"), Some(&vec![2, 3]));
}
#[test]
fn test_shape_inference_dropout() {
let node = make_node(OpKind::Dropout, "drop", vec!["x"], vec!["y", "mask"]);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![4, 5, 6])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![4, 5, 6]));
assert_eq!(result.get("mask"), Some(&vec![4, 5, 6]));
}
#[test]
fn test_shape_inference_lstm() {
let node = make_node(
OpKind::LSTM,
"lstm",
vec!["x", "w", "r"],
vec!["y", "y_h", "y_c"],
);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[
("x", vec![10, 2, 8]),
("w", vec![1, 16, 8]),
("r", vec![1, 16, 4]),
]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![10, 1, 2, 4]));
assert_eq!(result.get("y_h"), Some(&vec![1, 2, 4]));
assert_eq!(result.get("y_c"), Some(&vec![1, 2, 4]));
}
#[test]
fn test_shape_inference_reduce_mean() {
let mut node = make_node(OpKind::ReduceMean, "rm", vec!["x"], vec!["y"]);
node.attrs.int_lists.insert("axes".to_string(), vec![1, 2]);
node.attrs.ints.insert("keepdims".to_string(), 1);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![2, 3, 4, 5])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![2, 1, 1, 5]));
}
#[test]
fn test_shape_inference_reduce_no_keepdims() {
let mut node = make_node(OpKind::ReduceSum, "rs", vec!["x"], vec!["y"]);
node.attrs.int_lists.insert("axes".to_string(), vec![1]);
node.attrs.ints.insert("keepdims".to_string(), 0);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![2, 3, 4])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![2, 4]));
}
#[test]
fn test_shape_inference_pool() {
let mut node = make_node(OpKind::MaxPool, "mp", vec!["x"], vec!["y"]);
node.attrs
.int_lists
.insert("kernel_shape".to_string(), vec![2, 2]);
node.attrs
.int_lists
.insert("strides".to_string(), vec![2, 2]);
node.attrs
.int_lists
.insert("pads".to_string(), vec![0, 0, 0, 0]);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![1, 3, 8, 8])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![1, 3, 4, 4]));
}
#[test]
fn test_shape_inference_space_to_depth() {
let mut node = make_node(OpKind::SpaceToDepth, "s2d", vec!["x"], vec!["y"]);
node.attrs.ints.insert("blocksize".to_string(), 2);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![1, 3, 4, 6])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![1, 12, 2, 3]));
}
#[test]
fn test_shape_inference_variadic_broadcast() {
let node = make_node(OpKind::VariadicMax, "vmax", vec!["a", "b", "c"], vec!["y"]);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[
("a", vec![2, 1, 4]),
("b", vec![1, 3, 1]),
("c", vec![2, 1, 1]),
]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![2, 3, 4]));
}
#[test]
fn test_shape_inference_einsum() {
let mut node = make_node(OpKind::Einsum, "ein", vec!["a", "b"], vec!["y"]);
node.attrs
.strings
.insert("equation".to_string(), "ij,jk->ik".to_string());
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("a", vec![2, 3]), ("b", vec![3, 4])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![2, 4]));
}
#[test]
fn test_shape_inference_conv_transpose() {
let mut node = make_node(OpKind::ConvTranspose, "ct", vec!["x", "w"], vec!["y"]);
node.attrs
.int_lists
.insert("kernel_shape".to_string(), vec![3, 3]);
node.attrs
.int_lists
.insert("strides".to_string(), vec![2, 2]);
node.attrs
.int_lists
.insert("pads".to_string(), vec![1, 1, 1, 1]);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![1, 16, 4, 4]), ("w", vec![16, 3, 3, 3])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![1, 3, 7, 7]));
}
#[test]
fn test_shape_inference_quantize_dequantize() {
let node = make_node(
OpKind::QuantizeLinear,
"ql",
vec!["x", "scale", "zp"],
vec!["y"],
);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![2, 3, 4]), ("scale", vec![1]), ("zp", vec![1])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![2, 3, 4]));
}
#[test]
fn test_shape_inference_diagnostics_ext() {
let node = make_node(
OpKind::GlobalAveragePool,
"gap",
vec!["missing_input"],
vec!["y"],
);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = HashMap::new();
let (_shapes, diagnostics) = infer_shapes_with_diagnostics(&nodes, &weights, &input_shapes);
assert!(!diagnostics.is_empty());
let diag = &diagnostics[0];
assert_eq!(diag.node_name, "gap");
assert!(diag.message.contains("missing_input"));
}
#[test]
fn test_shape_inference_gather_nd() {
let node = make_node(OpKind::GatherND, "gnd", vec!["data", "indices"], vec!["y"]);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("data", vec![2, 3, 4]), ("indices", vec![2, 2])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![2, 4]));
}
#[test]
fn test_range_shape_inference() {
let node = make_node(
OpKind::Range,
"range",
vec!["start", "limit", "delta"],
vec!["y"],
);
let nodes = vec![node];
let weights = weights_map(&[
("start", vec![0.0], vec![1]),
("limit", vec![5.0], vec![1]),
("delta", vec![1.0], vec![1]),
]);
let input_shapes = HashMap::new();
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![5]));
}
#[test]
fn test_hann_window_shape_inference() {
let node = make_node(OpKind::HannWindow, "hann", vec!["size"], vec!["y"]);
let nodes = vec![node];
let weights = weights_map(&[("size", vec![8.0], vec![1])]);
let input_shapes = HashMap::new();
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![8]));
}
#[test]
fn test_stft_shape_inference() {
let mut node = make_node(
OpKind::STFT,
"stft",
vec!["signal", "frame_step", "", "frame_length"],
vec!["y"],
);
node.attrs.ints.insert("onesided".to_string(), 1);
let nodes = vec![node];
let weights = weights_map(&[
("frame_step", vec![4.0], vec![1]),
("frame_length", vec![8.0], vec![1]),
]);
let input_shapes = shapes_map(&[("signal", vec![1, 16])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![1, 3, 5, 2]));
}
#[test]
fn test_mel_weight_matrix_shape_inference() {
let node = make_node(
OpKind::MelWeightMatrix,
"mel",
vec!["num_mel", "dft_len"],
vec!["y"],
);
let nodes = vec![node];
let weights = weights_map(&[
("num_mel", vec![40.0], vec![1]),
("dft_len", vec![512.0], vec![1]),
]);
let input_shapes = HashMap::new();
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![257, 40]));
}
#[test]
fn test_reduce_l1_shape_inference() {
let mut node = make_node(OpKind::ReduceL1, "rl1", vec!["x"], vec!["y"]);
node.attrs.int_lists.insert("axes".to_string(), vec![1]);
node.attrs.ints.insert("keepdims".to_string(), 0);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[("x", vec![2, 3])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![2]));
}
#[test]
fn test_shape_inference_onehot() {
let mut node = make_node(
OpKind::OneHot,
"oh",
vec!["indices", "depth", "values"],
vec!["y"],
);
node.attrs.ints.insert("axis".to_string(), -1);
let nodes = vec![node];
let weights = weights_map(&[("depth", vec![5.0], vec![1])]);
let input_shapes = shapes_map(&[("indices", vec![3, 4]), ("values", vec![2])]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![3, 4, 5]));
}
#[test]
fn test_shape_inference_gru() {
let node = make_node(OpKind::GRU, "gru", vec!["x", "w", "r"], vec!["y", "y_h"]);
let nodes = vec![node];
let weights = HashMap::new();
let input_shapes = shapes_map(&[
("x", vec![8, 2, 6]),
("w", vec![1, 12, 6]),
("r", vec![1, 12, 4]),
]);
let result = infer_shapes(&nodes, &weights, &input_shapes);
assert_eq!(result.get("y"), Some(&vec![8, 1, 2, 4]));
assert_eq!(result.get("y_h"), Some(&vec![1, 2, 4]));
}