use anyhow::Result;
use ronn_core::{DataType, Tensor, TensorLayout};
pub fn create_sequential_tensor(shape: Vec<usize>, dtype: DataType) -> Result<Tensor> {
let numel: usize = shape.iter().product();
let data: Vec<f32> = (0..numel).map(|i| i as f32).collect();
Tensor::from_data(data, shape, dtype, TensorLayout::RowMajor)
}
pub fn create_ones_tensor(shape: Vec<usize>, dtype: DataType) -> Result<Tensor> {
Tensor::ones(shape, dtype, TensorLayout::RowMajor)
}
pub fn create_zeros_tensor(shape: Vec<usize>, dtype: DataType) -> Result<Tensor> {
Tensor::zeros(shape, dtype, TensorLayout::RowMajor)
}
pub fn create_random_tensor(shape: Vec<usize>, dtype: DataType) -> Result<Tensor> {
Tensor::rand(shape, dtype, TensorLayout::RowMajor)
}
pub fn assert_tensor_approx_eq(a: &Tensor, b: &Tensor, epsilon: f32) -> Result<()> {
assert_eq!(a.shape(), b.shape(), "Tensor shapes don't match");
let a_data = a.to_vec()?;
let b_data = b.to_vec()?;
for (i, (&a_val, &b_val)) in a_data.iter().zip(b_data.iter()).enumerate() {
assert!(
(a_val - b_val).abs() < epsilon,
"Tensors differ at index {}: {} vs {} (diff: {})",
i,
a_val,
b_val,
(a_val - b_val).abs()
);
}
Ok(())
}
pub fn assert_tensor_all_zeros(tensor: &Tensor) -> Result<()> {
let data = tensor.to_vec()?;
for (i, &val) in data.iter().enumerate() {
assert_eq!(val, 0.0, "Non-zero value at index {}: {}", i, val);
}
Ok(())
}
pub fn assert_tensor_all_ones(tensor: &Tensor) -> Result<()> {
let data = tensor.to_vec()?;
for (i, &val) in data.iter().enumerate() {
assert!(
(val - 1.0).abs() < 1e-6,
"Non-one value at index {}: {}",
i,
val
);
}
Ok(())
}
pub fn assert_tensor_eq(tensor: &Tensor, expected: &[f32]) -> Result<()> {
let data = tensor.to_vec()?;
assert_eq!(
data.len(),
expected.len(),
"Tensor size mismatch: {} vs {}",
data.len(),
expected.len()
);
for (i, (&actual, &expected)) in data.iter().zip(expected.iter()).enumerate() {
assert!(
(actual - expected).abs() < 1e-5,
"Value mismatch at index {}: {} vs {}",
i,
actual,
expected
);
}
Ok(())
}
pub fn create_test_graph() -> Result<ronn_core::ModelGraph> {
use ronn_core::{AttributeValue, GraphBuilder};
let mut builder = GraphBuilder::new();
let input_id = builder.add_op("Input", Some("input_layer".to_string()));
builder.add_output(input_id, "input_tensor");
let conv_id = builder.add_op("Conv", Some("conv_layer".to_string()));
builder
.add_input(conv_id, "input_tensor")
.add_output(conv_id, "conv_output")
.add_attribute(conv_id, "kernel_size", AttributeValue::IntArray(vec![3, 3]));
let relu_id = builder.add_op("ReLU", Some("relu_layer".to_string()));
builder
.add_input(relu_id, "conv_output")
.add_output(relu_id, "relu_output");
builder.connect(input_id, conv_id, "input_tensor")?;
builder.connect(conv_id, relu_id, "conv_output")?;
builder
.set_inputs(vec!["input_tensor".to_string()])
.set_outputs(vec!["relu_output".to_string()]);
builder.build()
}
pub fn create_complex_test_graph() -> Result<ronn_core::ModelGraph> {
use ronn_core::{AttributeValue, GraphBuilder};
let mut builder = GraphBuilder::new();
let input_id = builder.add_op("Input", Some("input".to_string()));
builder.add_output(input_id, "input_tensor");
let conv1_id = builder.add_op("Conv", Some("conv1".to_string()));
builder
.add_input(conv1_id, "input_tensor")
.add_output(conv1_id, "conv1_out")
.add_attribute(
conv1_id,
"kernel_size",
AttributeValue::IntArray(vec![3, 3]),
);
let relu1_id = builder.add_op("ReLU", Some("relu1".to_string()));
builder
.add_input(relu1_id, "conv1_out")
.add_output(relu1_id, "relu1_out");
let conv2_id = builder.add_op("Conv", Some("conv2".to_string()));
builder
.add_input(conv2_id, "input_tensor")
.add_output(conv2_id, "conv2_out")
.add_attribute(
conv2_id,
"kernel_size",
AttributeValue::IntArray(vec![5, 5]),
);
let relu2_id = builder.add_op("ReLU", Some("relu2".to_string()));
builder
.add_input(relu2_id, "conv2_out")
.add_output(relu2_id, "relu2_out");
let add_id = builder.add_op("Add", Some("add".to_string()));
builder
.add_input(add_id, "relu1_out")
.add_input(add_id, "relu2_out")
.add_output(add_id, "output_tensor");
builder.connect(input_id, conv1_id, "input_tensor")?;
builder.connect(conv1_id, relu1_id, "conv1_out")?;
builder.connect(input_id, conv2_id, "input_tensor")?;
builder.connect(conv2_id, relu2_id, "conv2_out")?;
builder.connect(relu1_id, add_id, "relu1_out")?;
builder.connect(relu2_id, add_id, "relu2_out")?;
builder
.set_inputs(vec!["input_tensor".to_string()])
.set_outputs(vec!["output_tensor".to_string()]);
builder.build()
}
pub fn measure_time<F, R>(operation: F) -> (R, std::time::Duration)
where
F: FnOnce() -> R,
{
let start = std::time::Instant::now();
let result = operation();
let duration = start.elapsed();
(result, duration)
}
pub fn test_shapes_1d() -> Vec<Vec<usize>> {
vec![vec![1], vec![10], vec![100], vec![1000]]
}
pub fn test_shapes_2d() -> Vec<Vec<usize>> {
vec![
vec![1, 1],
vec![2, 3],
vec![4, 4],
vec![10, 20],
vec![100, 50],
]
}
pub fn test_shapes_3d() -> Vec<Vec<usize>> {
vec![vec![1, 1, 1], vec![2, 3, 4], vec![10, 20, 30]]
}
pub fn test_shapes_4d() -> Vec<Vec<usize>> {
vec![
vec![1, 1, 1, 1],
vec![2, 3, 4, 5],
vec![1, 3, 224, 224], ]
}
pub fn test_data_types() -> Vec<DataType> {
vec![
DataType::F32,
DataType::F16,
DataType::BF16,
DataType::I8,
DataType::I32,
DataType::I64,
DataType::U8,
DataType::U32,
DataType::Bool,
DataType::F64,
]
}
pub fn common_data_types() -> Vec<DataType> {
vec![DataType::F32, DataType::F16, DataType::I32]
}