use super::super::types::OptLevel;
use super::super::Session;
use crate::graph::{Attributes, Graph, Node, OpKind};
use crate::tensor::Tensor;
use std::collections::HashMap;
fn build_dynamic_relu_graph() -> (Graph, HashMap<String, Tensor>) {
use oxionnx_core::{DType, TensorInfo};
let node = Node {
op: OpKind::Relu,
name: "relu".to_string(),
inputs: vec!["input".to_string()],
outputs: vec!["output".to_string()],
attrs: Attributes::default(),
};
let graph = Graph {
nodes: vec![node],
input_names: vec!["input".to_string()],
output_names: vec!["output".to_string()],
input_infos: vec![TensorInfo {
name: "input".to_string(),
dtype: DType::F32,
shape: vec![None, Some(3)],
dim_params: vec![Some("batch_size".to_string()), None],
}],
output_infos: vec![TensorInfo {
name: "output".to_string(),
dtype: DType::F32,
shape: vec![None, Some(3)],
dim_params: vec![Some("batch_size".to_string()), None],
}],
..Default::default()
};
(graph, HashMap::new())
}
fn build_dynamic_batch_seq_graph() -> (Graph, HashMap<String, Tensor>) {
use oxionnx_core::{DType, TensorInfo};
let node = Node {
op: OpKind::Relu,
name: "relu".to_string(),
inputs: vec!["input".to_string()],
outputs: vec!["output".to_string()],
attrs: Attributes::default(),
};
let graph = Graph {
nodes: vec![node],
input_names: vec!["input".to_string()],
output_names: vec!["output".to_string()],
input_infos: vec![TensorInfo {
name: "input".to_string(),
dtype: DType::F32,
shape: vec![None, None],
dim_params: vec![Some("batch_size".to_string()), Some("seq_len".to_string())],
}],
output_infos: vec![],
..Default::default()
};
(graph, HashMap::new())
}
fn build_dual_input_dynamic_graph() -> (Graph, HashMap<String, Tensor>) {
use oxionnx_core::{DType, TensorInfo};
let node_a = Node {
op: OpKind::Relu,
name: "relu_a".to_string(),
inputs: vec!["input_a".to_string()],
outputs: vec!["out_a".to_string()],
attrs: Attributes::default(),
};
let node_b = Node {
op: OpKind::Relu,
name: "relu_b".to_string(),
inputs: vec!["input_b".to_string()],
outputs: vec!["out_b".to_string()],
attrs: Attributes::default(),
};
let graph = Graph {
nodes: vec![node_a, node_b],
input_names: vec!["input_a".to_string(), "input_b".to_string()],
output_names: vec!["out_a".to_string(), "out_b".to_string()],
input_infos: vec![
TensorInfo {
name: "input_a".to_string(),
dtype: DType::F32,
shape: vec![None, Some(3)],
dim_params: vec![Some("batch_size".to_string()), None],
},
TensorInfo {
name: "input_b".to_string(),
dtype: DType::F32,
shape: vec![None, Some(5)],
dim_params: vec![Some("batch_size".to_string()), None],
},
],
output_infos: vec![],
..Default::default()
};
(graph, HashMap::new())
}
#[test]
fn test_dynamic_batch_size() {
let (graph, weights) = build_dynamic_relu_graph();
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, weights)
.expect("build");
let input1 = Tensor::new(vec![-1.0, 2.0, 3.0], vec![1, 3]);
let out1 = session.run_one("input", input1).expect("batch=1 run");
let y1 = out1.get("output").expect("output");
assert_eq!(y1.shape, vec![1, 3]);
assert_eq!(y1.data, vec![0.0, 2.0, 3.0]);
let input4 = Tensor::new(
vec![
-1.0, 2.0, 3.0, 4.0, -5.0, 6.0, -7.0, 8.0, 9.0, 10.0, -11.0, 12.0,
],
vec![4, 3],
);
let out4 = session.run_one("input", input4).expect("batch=4 run");
let y4 = out4.get("output").expect("output");
assert_eq!(y4.shape, vec![4, 3]);
assert_eq!(
y4.data,
vec![0.0, 2.0, 3.0, 4.0, 0.0, 6.0, 0.0, 8.0, 9.0, 10.0, 0.0, 12.0]
);
let dims = session.dynamic_dims();
assert_eq!(dims.get("batch_size"), Some(&4));
}
#[test]
fn test_dynamic_seq_len() {
let (graph, weights) = build_dynamic_batch_seq_graph();
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, weights)
.expect("build");
let input_s2 = Tensor::new(vec![-1.0, 2.0, 3.0, -4.0], vec![2, 2]);
let out_s2 = session.run_one("input", input_s2).expect("seq=2 run");
let y_s2 = out_s2.get("output").expect("output");
assert_eq!(y_s2.shape, vec![2, 2]);
let input_s5 = Tensor::new(
vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0],
vec![2, 5],
);
let out_s5 = session.run_one("input", input_s5).expect("seq=5 run");
let y_s5 = out_s5.get("output").expect("output");
assert_eq!(y_s5.shape, vec![2, 5]);
assert_eq!(
y_s5.data,
vec![1.0, 0.0, 3.0, 0.0, 5.0, 0.0, 7.0, 0.0, 9.0, 0.0]
);
let dims = session.dynamic_dims();
assert_eq!(dims.get("batch_size"), Some(&2));
assert_eq!(dims.get("seq_len"), Some(&5));
}
#[test]
fn test_shape_validation_rank_error() {
let (graph, weights) = build_dynamic_relu_graph();
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, weights)
.expect("build");
let bad_input = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![1, 2, 3]);
let result = session.run_one("input", bad_input);
assert!(result.is_err());
let err_msg = format!("{}", result.expect_err("should error"));
assert!(
err_msg.contains("rank"),
"Error should mention rank: {err_msg}"
);
}
#[test]
fn test_shape_validation_dim_error() {
let (graph, weights) = build_dynamic_relu_graph();
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, weights)
.expect("build");
let bad_input = Tensor::new(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
vec![2, 5],
);
let result = session.run_one("input", bad_input);
assert!(result.is_err());
let err_msg = format!("{}", result.expect_err("should error"));
assert!(
err_msg.contains("static dim"),
"Error should mention static dim: {err_msg}"
);
}
#[test]
fn test_shape_validation_symbolic_consistency() {
let (graph, weights) = build_dual_input_dynamic_graph();
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, weights)
.expect("build");
let input_a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let input_b = Tensor::new(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
],
vec![3, 5],
);
let mut inputs = HashMap::new();
inputs.insert("input_a", input_a);
inputs.insert("input_b", input_b);
let result = session.run(&inputs);
assert!(result.is_err());
let err_msg = format!("{}", result.expect_err("should error"));
assert!(
err_msg.contains("inconsistent") || err_msg.contains("conflicting"),
"Error should mention inconsistency: {err_msg}"
);
}
#[test]
fn test_resolve_dynamic_shapes_basic() {
use oxionnx_core::{DType, TensorInfo};
let infos = vec![
TensorInfo {
name: "x".to_string(),
dtype: DType::F32,
shape: vec![None, Some(768)],
dim_params: vec![Some("batch_size".to_string()), None],
},
TensorInfo {
name: "y".to_string(),
dtype: DType::F32,
shape: vec![None, None],
dim_params: vec![Some("batch_size".to_string()), Some("seq_len".to_string())],
},
];
let tensor_x = Tensor::new(vec![0.0; 4 * 768], vec![4, 768]);
let tensor_y = Tensor::new(vec![0.0; 4 * 128], vec![4, 128]);
let mut inputs: HashMap<&str, &Tensor> = HashMap::new();
inputs.insert("x", &tensor_x);
inputs.insert("y", &tensor_y);
let dim_map = Session::resolve_dynamic_shapes(&infos, &inputs).expect("resolve should succeed");
assert_eq!(dim_map.get("batch_size"), Some(&4));
assert_eq!(dim_map.get("seq_len"), Some(&128));
}
#[test]
fn test_run_after_shape_change() {
let (graph, weights) = build_dynamic_relu_graph();
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.with_memory_pool(true)
.build_from_graph(graph, weights)
.expect("build");
let input2 = Tensor::new(vec![-1.0, 2.0, 3.0, -4.0, 5.0, -6.0], vec![2, 3]);
let out2 = session.run_one("input", input2).expect("batch=2");
let y2 = out2.get("output").expect("output");
assert_eq!(y2.data, vec![0.0, 2.0, 3.0, 0.0, 5.0, 0.0]);
let input1 = Tensor::new(vec![10.0, -20.0, 30.0], vec![1, 3]);
let out1 = session.run_one("input", input1).expect("batch=1");
let y1 = out1.get("output").expect("output");
assert_eq!(y1.data, vec![10.0, 0.0, 30.0]);
assert_eq!(y1.shape, vec![1, 3]);
let input3 = Tensor::new(
vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0],
vec![3, 3],
);
let out3 = session.run_one("input", input3).expect("batch=3");
let y3 = out3.get("output").expect("output");
assert_eq!(y3.shape, vec![3, 3]);
assert_eq!(y3.data, vec![1.0, 0.0, 3.0, 0.0, 5.0, 0.0, 7.0, 0.0, 9.0]);
}
#[test]
fn test_dynamic_shape_memory_reallocation() {
let (graph, weights) = build_dynamic_relu_graph();
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.with_memory_pool(true)
.build_from_graph(graph, weights)
.expect("build");
let small = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]);
let out_small = session.run_one("input", small).expect("small batch");
assert_eq!(out_small.get("output").expect("output").shape, vec![1, 3]);
let large_data: Vec<f32> = (0..300).map(|i| i as f32 - 150.0).collect();
let large = Tensor::new(large_data, vec![100, 3]);
let out_large = session.run_one("input", large).expect("large batch");
let y_large = out_large.get("output").expect("output");
assert_eq!(y_large.shape, vec![100, 3]);
for (i, &val) in y_large.data.iter().enumerate() {
let original = i as f32 - 150.0;
let expected = if original < 0.0 { 0.0 } else { original };
assert!(
(val - expected).abs() < 1e-6,
"Mismatch at index {i}: got {val}, expected {expected}"
);
}
let small2 = Tensor::new(vec![-5.0, 10.0, -15.0], vec![1, 3]);
let out_small2 = session.run_one("input", small2).expect("small batch 2");
let y2 = out_small2.get("output").expect("output");
assert_eq!(y2.data, vec![0.0, 10.0, 0.0]);
}