use crate::{
error::{OnnxError, Result},
tensor::Tensor,
};
use std::collections::HashMap;
use std::str::FromStr;
#[cfg(all(feature = "blas", not(feature = "naive-conv")))]
extern crate blas_src;
#[derive(Debug, Clone, PartialEq)]
pub enum OperatorType {
Add,
Mul,
MatMul,
Conv,
Relu,
Sigmoid,
Reshape,
Transpose,
Concat,
Slice,
Upsample,
MaxPool,
Softmax,
NonMaxSuppression,
BatchNormalization,
Split,
Gather,
ConstantOfShape,
Cast,
Shape,
Unsqueeze,
Squeeze,
Pad,
Div,
Sub,
Exp,
Sqrt,
Pow,
ReduceMean,
Identity,
Resize,
}
impl FromStr for OperatorType {
type Err = OnnxError;
fn from_str(s: &str) -> Result<Self> {
match s {
"Add" => Ok(OperatorType::Add),
"Mul" => Ok(OperatorType::Mul),
"MatMul" => Ok(OperatorType::MatMul),
"Conv" => Ok(OperatorType::Conv),
"Relu" => Ok(OperatorType::Relu),
"Sigmoid" => Ok(OperatorType::Sigmoid),
"Reshape" => Ok(OperatorType::Reshape),
"Transpose" => Ok(OperatorType::Transpose),
"Concat" => Ok(OperatorType::Concat),
"Slice" => Ok(OperatorType::Slice),
"Upsample" => Ok(OperatorType::Upsample),
"MaxPool" => Ok(OperatorType::MaxPool),
"Softmax" => Ok(OperatorType::Softmax),
"NonMaxSuppression" => Ok(OperatorType::NonMaxSuppression),
"BatchNormalization" => Ok(OperatorType::BatchNormalization),
"Split" => Ok(OperatorType::Split),
"Gather" => Ok(OperatorType::Gather),
"ConstantOfShape" => Ok(OperatorType::ConstantOfShape),
"Cast" => Ok(OperatorType::Cast),
"Shape" => Ok(OperatorType::Shape),
"Unsqueeze" => Ok(OperatorType::Unsqueeze),
"Squeeze" => Ok(OperatorType::Squeeze),
"Pad" => Ok(OperatorType::Pad),
"Div" => Ok(OperatorType::Div),
"Sub" => Ok(OperatorType::Sub),
"Exp" => Ok(OperatorType::Exp),
"Sqrt" => Ok(OperatorType::Sqrt),
"Pow" => Ok(OperatorType::Pow),
"ReduceMean" => Ok(OperatorType::ReduceMean),
"Identity" => Ok(OperatorType::Identity),
"Resize" => Ok(OperatorType::Resize),
_ => Err(OnnxError::unsupported_operation(s)),
}
}
}
fn parse_int_array(attr_value: &str) -> Result<Vec<i64>> {
let cleaned = attr_value.trim().trim_matches(['[', ']']);
if cleaned.is_empty() {
return Ok(Vec::new());
}
cleaned
.split(',')
.map(|s| {
s.trim().parse::<i64>().map_err(|e| {
OnnxError::runtime_error(format!("Failed to parse integer '{s}': {e}"))
})
})
.collect()
}
pub fn execute_operator(
op_type: &OperatorType,
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
match op_type {
OperatorType::Add => add_op(inputs),
OperatorType::Mul => mul_op(inputs),
OperatorType::MatMul => matmul_op(inputs),
OperatorType::Conv => conv_op(inputs, attributes),
OperatorType::Relu => relu_op(inputs),
OperatorType::Sigmoid => sigmoid_op(inputs),
OperatorType::Reshape => reshape_op(inputs),
OperatorType::Transpose => transpose_op(inputs, attributes),
OperatorType::Concat => concat_op(inputs, attributes),
OperatorType::Slice => slice_op(inputs, attributes),
OperatorType::Upsample => upsample_op(inputs, attributes),
OperatorType::MaxPool => maxpool_op(inputs, attributes),
OperatorType::Softmax => softmax_op(inputs, attributes),
OperatorType::NonMaxSuppression => nms_op(inputs, attributes),
OperatorType::BatchNormalization => batch_norm_op(inputs, attributes),
OperatorType::Split => split_op(inputs, attributes),
OperatorType::Gather => gather_op(inputs, attributes),
OperatorType::ConstantOfShape => constant_of_shape_op(inputs, attributes),
OperatorType::Cast => cast_op(inputs, attributes),
OperatorType::Shape => shape_op(inputs),
OperatorType::Unsqueeze => unsqueeze_op(inputs, attributes),
OperatorType::Squeeze => squeeze_op(inputs, attributes),
OperatorType::Pad => pad_op(inputs, attributes),
OperatorType::Div => div_op(inputs),
OperatorType::Sub => sub_op(inputs),
OperatorType::Exp => exp_op(inputs),
OperatorType::Sqrt => sqrt_op(inputs),
OperatorType::Pow => pow_op(inputs),
OperatorType::ReduceMean => reduce_mean_op(inputs, attributes),
OperatorType::Identity => identity_op(inputs),
OperatorType::Resize => resize_op(inputs, attributes),
}
}
#[cfg_attr(
feature = "formal-verification",
doc = "This function is formally verified using Why3 specifications"
)]
fn add_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
#[cfg(feature = "formal-verification")]
{
assert!(inputs.len() == 2, "Precondition: exactly 2 inputs required");
}
if inputs.len() != 2 {
return Err(OnnxError::invalid_dimensions(format!(
"Add operator requires exactly 2 inputs, got {}",
inputs.len()
)));
}
let result = inputs[0].add(&inputs[1])?;
#[cfg(feature = "formal-verification")]
{
debug_assert_eq!(
result.shape(),
inputs[0].shape(),
"Postcondition: result shape matches input shape"
);
}
Ok(vec![result])
}
#[cfg_attr(
feature = "formal-verification",
doc = "This function is formally verified using Why3 specifications"
)]
fn mul_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
#[cfg(feature = "formal-verification")]
{
assert!(inputs.len() == 2, "Precondition: exactly 2 inputs required");
}
if inputs.len() != 2 {
return Err(OnnxError::invalid_dimensions(format!(
"Mul operator requires exactly 2 inputs, got {}",
inputs.len()
)));
}
let result = inputs[0].mul(&inputs[1])?;
#[cfg(feature = "formal-verification")]
{
debug_assert_eq!(
result.shape(),
inputs[0].shape(),
"Postcondition: result shape matches input shape"
);
}
Ok(vec![result])
}
#[cfg_attr(
feature = "formal-verification",
doc = "This function is formally verified using Why3 specifications"
)]
fn matmul_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
#[cfg(feature = "formal-verification")]
{
assert!(inputs.len() == 2, "Precondition: exactly 2 inputs required");
assert!(
inputs[0].ndim() == 2,
"Precondition: first input must be 2D"
);
assert!(
inputs[1].ndim() == 2,
"Precondition: second input must be 2D"
);
assert_eq!(
inputs[0].shape()[1],
inputs[1].shape()[0],
"Precondition: inner dimensions must match"
);
}
if inputs.len() != 2 {
return Err(OnnxError::invalid_dimensions(format!(
"MatMul operator requires exactly 2 inputs, got {}",
inputs.len()
)));
}
let result = inputs[0].matmul(&inputs[1])?;
#[cfg(feature = "formal-verification")]
{
debug_assert_eq!(result.ndim(), 2, "Postcondition: result must be 2D");
debug_assert_eq!(
result.shape()[0],
inputs[0].shape()[0],
"Postcondition: output rows match first input rows"
);
debug_assert_eq!(
result.shape()[1],
inputs[1].shape()[1],
"Postcondition: output cols match second input cols"
);
}
Ok(vec![result])
}
struct ConvParams {
batch_size: usize,
channels_in: usize,
height_in: usize,
width_in: usize,
channels_out: usize,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
pad_top: usize,
pad_left: usize,
height_out: usize,
width_out: usize,
}
fn conv_op(inputs: &[Tensor], attrs: &HashMap<String, String>) -> Result<Vec<Tensor>> {
if inputs.len() < 2 {
return Err(OnnxError::invalid_dimensions(format!(
"Conv operator requires at least 2 inputs (input, kernel), got {}",
inputs.len()
)));
}
let input = &inputs[0];
let kernel = &inputs[1];
let bias = if inputs.len() > 2 {
Some(&inputs[2])
} else {
None
};
if input.ndim() != 4 || kernel.ndim() != 4 {
return Err(OnnxError::invalid_dimensions(
"Conv operator requires 4D tensors (NCHW format)".to_string(),
));
}
let input_shape = input.shape();
let kernel_shape = kernel.shape();
let batch_size = input_shape[0];
let channels_in = input_shape[1];
let height_in = input_shape[2];
let width_in = input_shape[3];
let channels_out = kernel_shape[0];
let channels_in_kernel = kernel_shape[1];
let kernel_h = kernel_shape[2];
let kernel_w = kernel_shape[3];
if channels_in != channels_in_kernel {
return Err(OnnxError::invalid_dimensions(format!(
"Input channels ({channels_in}) must match kernel input channels ({channels_in_kernel})"
)));
}
log::debug!("Conv attributes: {attrs:?}");
let strides = attrs
.get("strides")
.map(|s| {
let clean = s.trim_start_matches('[').trim_end_matches(']');
clean
.split(',')
.map(|p| p.trim().parse::<usize>().unwrap_or(1))
.collect::<Vec<_>>()
})
.unwrap_or_else(|| vec![1, 1]);
let stride_h = strides.first().copied().unwrap_or(1);
let stride_w = strides.get(1).copied().unwrap_or(stride_h);
let pads = attrs
.get("pads")
.map(|s| {
let clean = s.trim_start_matches('[').trim_end_matches(']');
clean
.split(',')
.map(|p| p.trim().parse::<usize>().unwrap_or(0))
.collect::<Vec<_>>()
})
.unwrap_or_else(|| {
let default_pad = (kernel_h - 1) / 2;
vec![default_pad, default_pad, default_pad, default_pad]
});
let pad_top = if pads.len() >= 4 {
pads[0]
} else {
pads.first().copied().unwrap_or(0)
};
let pad_left = if pads.len() >= 4 {
pads[1]
} else {
pads.get(1).copied().unwrap_or(pad_top)
};
let pad_bottom = if pads.len() >= 4 { pads[2] } else { pad_top };
let pad_right = if pads.len() >= 4 { pads[3] } else { pad_left };
let height_out = (height_in + pad_top + pad_bottom - kernel_h) / stride_h + 1;
let width_out = (width_in + pad_left + pad_right - kernel_w) / stride_w + 1;
let output_shape = [batch_size, channels_out, height_out, width_out];
log::debug!("Conv: input {input_shape:?}, kernel {kernel_shape:?} -> output {output_shape:?}");
log::debug!(
"Conv: stride={stride_h}x{stride_w}, pad=[{pad_top},{pad_left},{pad_bottom},{pad_right}]"
);
let input_data = input.data();
let kernel_data = kernel.data();
let input_flat = input_data
.as_slice_memory_order()
.expect("Conv: input tensor must be contiguous");
let kernel_flat = kernel_data
.as_slice_memory_order()
.expect("Conv: kernel tensor must be contiguous");
let p = ConvParams {
batch_size,
channels_in,
height_in,
width_in,
channels_out,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_top,
pad_left,
height_out,
width_out,
};
#[cfg(feature = "naive-conv")]
let mut output_data = conv_naive(&p, input_flat, kernel_flat);
#[cfg(all(feature = "blas", not(feature = "naive-conv")))]
let mut output_data = conv_blas(&p, input_flat, kernel_flat)?;
#[cfg(not(any(feature = "naive-conv", feature = "blas")))]
let mut output_data = conv_im2col(&p, input_flat, kernel_flat)?;
log::debug!("Conv: computed {} output values", output_data.len());
let final_output = if let Some(bias) = bias {
let bias_shape = bias.shape();
log::debug!("Conv: applying bias with shape {bias_shape:?}");
let bias_data = bias.data();
let get_bias_val = |c_out: usize| -> f32 {
if bias_shape == [channels_out] {
bias_data[c_out]
} else if bias_shape.len() == 4
&& bias_shape[0] == 1
&& bias_shape[1] == channels_out
&& bias_shape[2] == 1
&& bias_shape[3] == 1
{
bias_data[[0, c_out, 0, 0]]
} else {
0.0 }
};
let supported_bias = bias_shape == [channels_out]
|| (bias_shape.len() == 4
&& bias_shape[0] == 1
&& bias_shape[1] == channels_out
&& bias_shape[2] == 1
&& bias_shape[3] == 1);
if !supported_bias {
log::warn!("Conv: unsupported bias shape {bias_shape:?}, skipping bias addition");
} else {
let hw = height_out * width_out;
for n in 0..batch_size {
let batch_offset = n * channels_out * hw;
for c_out in 0..channels_out {
let bias_val = get_bias_val(c_out);
let start = batch_offset + c_out * hw;
let end = start + hw;
for val in &mut output_data[start..end] {
*val += bias_val;
}
}
}
}
Tensor::from_shape_vec(&output_shape, output_data)?
} else {
Tensor::from_shape_vec(&output_shape, output_data)?
};
log::debug!("Conv: final output shape {:?}", final_output.shape());
Ok(vec![final_output])
}
#[cfg(feature = "naive-conv")]
fn conv_naive(p: &ConvParams, input: &[f32], kernel: &[f32]) -> Vec<f32> {
let &ConvParams {
batch_size,
channels_in,
height_in,
width_in,
channels_out,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_top,
pad_left,
height_out,
width_out,
} = p;
let in_c_stride = height_in * width_in;
let in_n_stride = channels_in * in_c_stride;
let k_kh_stride = kernel_w;
let k_cin_stride = kernel_h * k_kh_stride;
let k_cout_stride = channels_in * k_cin_stride;
let out_hw = height_out * width_out;
let out_c_stride = out_hw;
let out_n_stride = channels_out * out_hw;
let mut out = vec![0.0f32; batch_size * channels_out * height_out * width_out];
for n in 0..batch_size {
for c_out in 0..channels_out {
for h_out in 0..height_out {
for w_out in 0..width_out {
let mut sum = 0.0f32;
for c_in in 0..channels_in {
for kh in 0..kernel_h {
let h_in_padded = h_out * stride_h + kh;
if h_in_padded < pad_top || h_in_padded >= height_in + pad_top {
continue;
}
let h_in = h_in_padded - pad_top;
for kw in 0..kernel_w {
let w_in_padded = w_out * stride_w + kw;
if w_in_padded < pad_left || w_in_padded >= width_in + pad_left {
continue;
}
let w_in = w_in_padded - pad_left;
let i_idx =
n * in_n_stride + c_in * in_c_stride + h_in * width_in + w_in;
let k_idx = c_out * k_cout_stride
+ c_in * k_cin_stride
+ kh * k_kh_stride
+ kw;
sum += input[i_idx] * kernel[k_idx];
}
}
}
out[n * out_n_stride + c_out * out_c_stride + h_out * width_out + w_out] = sum;
}
}
}
}
out
}
#[cfg(not(any(feature = "naive-conv", feature = "blas")))]
fn conv_im2col(p: &ConvParams, input: &[f32], kernel: &[f32]) -> Result<Vec<f32>> {
use ndarray::Array2;
let &ConvParams {
batch_size,
channels_in,
height_in,
width_in,
channels_out,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_top,
pad_left,
height_out,
width_out,
} = p;
let patch_len = channels_in * kernel_h * kernel_w; let n_patches = batch_size * height_out * width_out;
let mut col = vec![0.0f32; n_patches * patch_len];
let in_c_stride = height_in * width_in;
let in_n_stride = channels_in * in_c_stride;
for n in 0..batch_size {
for h_out in 0..height_out {
for w_out in 0..width_out {
let row = n * height_out * width_out + h_out * width_out + w_out;
let row_base = row * patch_len;
let mut col_idx = 0;
for c_in in 0..channels_in {
for kh in 0..kernel_h {
let h_in_padded = h_out * stride_h + kh;
for kw in 0..kernel_w {
let w_in_padded = w_out * stride_w + kw;
if h_in_padded >= pad_top
&& h_in_padded < height_in + pad_top
&& w_in_padded >= pad_left
&& w_in_padded < width_in + pad_left
{
let h_in = h_in_padded - pad_top;
let w_in = w_in_padded - pad_left;
col[row_base + col_idx] = input
[n * in_n_stride + c_in * in_c_stride + h_in * width_in + w_in];
}
col_idx += 1;
}
}
}
}
}
}
let col_mat = Array2::from_shape_vec((n_patches, patch_len), col)
.map_err(|e| OnnxError::invalid_dimensions(format!("im2col reshape failed: {e}")))?;
let w_mat = Array2::from_shape_vec((channels_out, patch_len), kernel.to_vec())
.map_err(|e| OnnxError::invalid_dimensions(format!("kernel reshape failed: {e}")))?;
let out_mat = col_mat.dot(&w_mat.t());
let out_slice = out_mat
.as_slice()
.expect("matmul output should be contiguous");
let hw = height_out * width_out;
let mut output = vec![0.0f32; batch_size * channels_out * hw];
for n in 0..batch_size {
for h in 0..height_out {
for w in 0..width_out {
let row = n * hw + h * width_out + w;
for c_out in 0..channels_out {
output[n * channels_out * hw + c_out * hw + h * width_out + w] =
out_slice[row * channels_out + c_out];
}
}
}
}
Ok(output)
}
#[cfg(all(feature = "blas", not(feature = "naive-conv")))]
fn conv_blas(p: &ConvParams, input: &[f32], kernel: &[f32]) -> Result<Vec<f32>> {
let &ConvParams {
batch_size,
channels_in,
height_in,
width_in,
channels_out,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_top,
pad_left,
height_out,
width_out,
} = p;
let patch_len = channels_in * kernel_h * kernel_w;
let n_patches = batch_size * height_out * width_out;
let hw = height_out * width_out;
let mut col = vec![0.0f32; n_patches * patch_len];
let in_c_stride = height_in * width_in;
let in_n_stride = channels_in * in_c_stride;
for n in 0..batch_size {
for h_out in 0..height_out {
for w_out in 0..width_out {
let row = n * hw + h_out * width_out + w_out;
let row_base = row * patch_len;
let mut col_idx = 0;
for c_in in 0..channels_in {
for kh in 0..kernel_h {
let h_in_padded = h_out * stride_h + kh;
for kw in 0..kernel_w {
let w_in_padded = w_out * stride_w + kw;
if h_in_padded >= pad_top
&& h_in_padded < height_in + pad_top
&& w_in_padded >= pad_left
&& w_in_padded < width_in + pad_left
{
let h_in = h_in_padded - pad_top;
let w_in = w_in_padded - pad_left;
col[row_base + col_idx] = input
[n * in_n_stride + c_in * in_c_stride + h_in * width_in + w_in];
}
col_idx += 1;
}
}
}
}
}
}
let m = n_patches as i32;
let n = channels_out as i32;
let k = patch_len as i32;
let mut out_flat = vec![0.0f32; n_patches * channels_out];
unsafe {
cblas::sgemm(
cblas::Layout::RowMajor,
cblas::Transpose::None, cblas::Transpose::Ordinary, m,
n,
k,
1.0f32, &col,
k, kernel,
k, 0.0f32, &mut out_flat,
n, );
}
let mut output = vec![0.0f32; batch_size * channels_out * hw];
for n in 0..batch_size {
for h in 0..height_out {
for w in 0..width_out {
let row = n * hw + h * width_out + w;
for c_out in 0..channels_out {
output[n * channels_out * hw + c_out * hw + h * width_out + w] =
out_flat[row * channels_out + c_out];
}
}
}
}
Ok(output)
}
#[cfg_attr(
feature = "formal-verification",
doc = "This function is formally verified using Why3 specifications"
)]
fn relu_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
#[cfg(feature = "formal-verification")]
{
assert!(inputs.len() == 1, "Precondition: exactly 1 input required");
}
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Relu operator requires exactly 1 input, got {}",
inputs.len()
)));
}
let result = inputs[0].relu()?;
#[cfg(feature = "formal-verification")]
{
debug_assert_eq!(
result.shape(),
inputs[0].shape(),
"Postcondition: result shape matches input shape"
);
for &value in result.data() {
debug_assert!(
value >= 0.0,
"Postcondition: all values must be non-negative"
);
}
}
Ok(vec![result])
}
#[cfg_attr(
feature = "formal-verification",
doc = "This function is formally verified using Why3 specifications"
)]
fn sigmoid_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
#[cfg(feature = "formal-verification")]
{
assert!(inputs.len() == 1, "Precondition: exactly 1 input required");
}
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Sigmoid operator requires exactly 1 input, got {}",
inputs.len()
)));
}
let result = inputs[0].sigmoid()?;
#[cfg(feature = "formal-verification")]
{
debug_assert_eq!(
result.shape(),
inputs[0].shape(),
"Postcondition: result shape matches input shape"
);
for &value in result.data() {
debug_assert!(
value > 0.0 && value < 1.0,
"Postcondition: sigmoid output must be in (0,1), got {value}"
);
}
}
Ok(vec![result])
}
fn reshape_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(OnnxError::invalid_dimensions(format!(
"Reshape operator requires exactly 2 inputs (data, shape), got {}",
inputs.len()
)));
}
let data = &inputs[0];
let shape_tensor = &inputs[1];
let raw_shape: Vec<i64> = shape_tensor.data().iter().map(|&x| x as i64).collect();
log::debug!(
"Reshape: input shape {:?}, target shape {:?}",
data.shape(),
raw_shape
);
let input_shape = data.shape();
let total_elements: usize = input_shape.iter().product();
let mut new_shape = Vec::new();
let mut infer_dim_index = None;
let mut inferred_elements = 1;
for (i, &dim) in raw_shape.iter().enumerate() {
if dim == 0 {
if i < input_shape.len() {
new_shape.push(input_shape[i]);
inferred_elements *= input_shape[i];
} else {
return Err(OnnxError::invalid_dimensions(format!(
"Cannot copy dimension {i} from input shape {input_shape:?} (index out of bounds)"
)));
}
} else if dim == -1 {
if infer_dim_index.is_some() {
return Err(OnnxError::invalid_dimensions(
"Only one dimension can be inferred (-1) in reshape".to_string(),
));
}
infer_dim_index = Some(i);
new_shape.push(0); } else if dim > 0 {
new_shape.push(dim as usize);
inferred_elements *= dim as usize;
} else {
return Err(OnnxError::invalid_dimensions(format!(
"Invalid dimension {dim} in reshape"
)));
}
}
if let Some(infer_idx) = infer_dim_index {
if inferred_elements == 0 || total_elements % inferred_elements != 0 {
return Err(OnnxError::invalid_dimensions(format!(
"Cannot infer dimension: total elements {total_elements} is not divisible by product of known dimensions {inferred_elements}"
)));
}
new_shape[infer_idx] = total_elements / inferred_elements;
}
log::debug!("Reshape: computed new shape {new_shape:?}");
let result = data.reshape(&new_shape)?;
Ok(vec![result])
}
fn transpose_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Transpose operator requires exactly 1 input, got {}",
inputs.len()
)));
}
let input = &inputs[0];
let perm = if let Some(perm_str) = attributes.get("perm") {
let cleaned = perm_str.trim_matches(['[', ']']);
let perm_result: std::result::Result<Vec<usize>, std::num::ParseIntError> = cleaned
.split(',')
.map(|s| s.trim().parse::<usize>())
.collect();
match perm_result {
Ok(p) => {
log::debug!("Transpose: using perm attribute {p:?}");
Some(p)
}
Err(_) => {
log::warn!("Transpose: invalid perm attribute '{perm_str}', using default");
None
}
}
} else {
None
};
let result = input.transpose_with_perm(perm.as_deref())?;
log::debug!(
"Transpose: input {:?} -> output {:?} (perm: {:?})",
input.shape(),
result.shape(),
perm
);
Ok(vec![result])
}
fn concat_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(OnnxError::invalid_dimensions(
"Concat operator requires at least 1 input".to_string(),
));
}
let axis = attributes
.get("axis")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(0);
let normalized_axis = if axis < 0 {
(inputs[0].ndim() as i32 + axis) as usize
} else {
axis as usize
};
if inputs.len() == 1 {
return Ok(vec![inputs[0].clone()]);
}
let tensor_refs: Vec<&Tensor> = inputs.iter().collect();
let result = Tensor::concat(&tensor_refs, normalized_axis)?;
log::debug!(
"Concat: axis {} shapes {:?} -> output {:?}",
normalized_axis,
inputs.iter().map(|t| t.shape()).collect::<Vec<_>>(),
result.shape()
);
Ok(vec![result])
}
fn slice_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(OnnxError::invalid_dimensions(
"Slice operator requires at least 1 input".to_string(),
));
}
let data = &inputs[0];
if inputs.len() >= 3 {
let starts_tensor = &inputs[1];
let ends_tensor = &inputs[2];
let starts: Vec<i64> = starts_tensor
.data()
.iter()
.map(|&x| {
x.round() as i64
})
.collect();
let ends: Vec<i64> = ends_tensor
.data()
.iter()
.map(|&x| {
if x < -0.5 && x > -1.5 {
-1
} else if x > (i64::MAX as f32 * 0.9) {
-1
} else {
x.round() as i64
}
})
.collect();
log::debug!("Slice: raw starts tensor data: {:?}", starts_tensor.data());
log::debug!("Slice: raw ends tensor data: {:?}", ends_tensor.data());
log::debug!("Slice: parsed starts: {starts:?}");
log::debug!("Slice: parsed ends: {ends:?}");
let axes = if inputs.len() >= 4 {
Some(
inputs[3]
.data()
.iter()
.map(|&x| x as i64)
.collect::<Vec<_>>(),
)
} else {
None
};
let steps = if inputs.len() >= 5 {
Some(
inputs[4]
.data()
.iter()
.map(|&x| x as i64)
.collect::<Vec<_>>(),
)
} else {
None
};
let result = data.slice(&starts, &ends, axes.as_deref(), steps.as_deref())?;
log::debug!(
"Slice: input {:?} -> output {:?} (starts: {:?}, ends: {:?}, axes: {:?})",
data.shape(),
result.shape(),
starts,
ends,
axes
);
Ok(vec![result])
} else {
let parse_list = |key: &str| -> Option<Vec<i64>> {
attributes.get(key).and_then(|val| {
val.split(',')
.map(|s| s.trim().parse::<i64>().ok())
.collect::<Option<Vec<_>>>()
})
};
match (parse_list("starts"), parse_list("ends")) {
(Some(starts), Some(ends)) => {
let axes = parse_list("axes");
let steps = parse_list("steps");
let result = inputs[0].slice(&starts, &ends, axes.as_deref(), steps.as_deref())?;
Ok(vec![result])
}
_ => {
log::warn!("Slice operator missing required attributes - returning input tensor");
Ok(vec![data.clone()])
}
}
}
}
fn upsample_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(OnnxError::invalid_dimensions(
"Upsample operator requires at least 1 input".to_string(),
));
}
let _mode = attributes
.get("mode")
.map(|s| s.as_str())
.unwrap_or("nearest");
log::warn!("Upsample operator is simplified - returning input tensor");
Ok(vec![inputs[0].clone()])
}
fn maxpool_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"MaxPool operator requires exactly 1 input, got {}",
inputs.len()
)));
}
let input = &inputs[0];
if input.ndim() != 4 {
return Err(OnnxError::invalid_dimensions(
"MaxPool operator requires 4D input tensor (NCHW format)".to_string(),
));
}
let kernel_shape = parse_int_array(
attributes
.get("kernel_shape")
.unwrap_or(&"[2,2]".to_string()),
)?;
let strides = parse_int_array(attributes.get("strides").unwrap_or(&"[1,1]".to_string()))?;
let pads = parse_int_array(attributes.get("pads").unwrap_or(&"[0,0,0,0]".to_string()))?;
let kernel_h = kernel_shape.first().copied().unwrap_or(2) as usize;
let kernel_w = kernel_shape.get(1).copied().unwrap_or(2) as usize;
let stride_h = strides.first().copied().unwrap_or(1) as usize;
let stride_w = strides.get(1).copied().unwrap_or(1) as usize;
let pad_top = pads.first().copied().unwrap_or(0) as usize;
let pad_left = pads.get(1).copied().unwrap_or(0) as usize;
let pad_bottom = pads.get(2).copied().unwrap_or(0) as usize;
let pad_right = pads.get(3).copied().unwrap_or(0) as usize;
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let input_h = input_shape[2];
let input_w = input_shape[3];
let output_h = (input_h + pad_top + pad_bottom - kernel_h) / stride_h + 1;
let output_w = (input_w + pad_left + pad_right - kernel_w) / stride_w + 1;
let output_shape = [batch_size, channels, output_h, output_w];
log::debug!("MaxPool: {input_h}x{input_w} -> {output_h}x{output_w}, kernel={kernel_h}x{kernel_w}, stride={stride_h}x{stride_w}, pad=[{pad_top},{pad_left},{pad_bottom},{pad_right}]");
let mut output_data = Vec::with_capacity(output_shape.iter().product());
let input_data = input.data();
for batch in 0..batch_size {
for channel in 0..channels {
for out_h in 0..output_h {
for out_w in 0..output_w {
let mut max_val = f32::NEG_INFINITY;
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let in_h = out_h * stride_h + kh;
let in_w = out_w * stride_w + kw;
if in_h >= pad_top
&& in_h < input_h + pad_top
&& in_w >= pad_left
&& in_w < input_w + pad_left
{
let actual_h = in_h - pad_top;
let actual_w = in_w - pad_left;
if actual_h < input_h && actual_w < input_w {
let val = input_data[[batch, channel, actual_h, actual_w]];
max_val = max_val.max(val);
}
}
}
}
if max_val == f32::NEG_INFINITY {
max_val = 0.0;
}
output_data.push(max_val);
}
}
}
}
let result = Tensor::from_shape_vec(&output_shape, output_data)?;
log::debug!("MaxPool: output shape {:?}", result.shape());
Ok(vec![result])
}
fn softmax_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Softmax operator requires exactly 1 input, got {}",
inputs.len()
)));
}
let input = &inputs[0];
let axis = attributes
.get("axis")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(-1);
let normalized_axis = if axis < 0 {
(input.ndim() as i32 + axis) as usize
} else {
axis as usize
};
if normalized_axis >= input.ndim() {
return Err(OnnxError::invalid_dimensions(format!(
"Softmax axis {} out of bounds for tensor with {} dimensions",
axis,
input.ndim()
)));
}
log::debug!(
"Softmax: input shape {:?}, axis {}",
input.shape(),
normalized_axis
);
if input.ndim() <= 2 || normalized_axis == input.ndim() - 1 {
let result = input.softmax()?;
log::debug!(
"Softmax: used tensor method, output shape {:?}",
result.shape()
);
Ok(vec![result])
} else {
let input_shape = input.shape();
let total_elements = input_shape.iter().product::<usize>();
let axis_size = input_shape[normalized_axis];
log::debug!("Softmax: complex case - shape {input_shape:?}, axis {normalized_axis}, axis_size {axis_size}");
let mut output = crate::tensor::Tensor::zeros(input_shape);
let num_slices = total_elements / axis_size;
let mut strides = vec![1; input.ndim()];
for i in (0..input.ndim() - 1).rev() {
strides[i] = strides[i + 1] * input_shape[i + 1];
}
let axis_stride = strides[normalized_axis];
for slice_idx in 0..num_slices {
let mut base_idx = 0;
let mut remaining = slice_idx;
for dim in 0..input.ndim() {
if dim != normalized_axis {
let dim_size = input_shape[dim];
let coord = remaining % dim_size;
remaining /= dim_size;
base_idx += coord * strides[dim];
}
}
let mut axis_values = Vec::with_capacity(axis_size);
for i in 0..axis_size {
let idx = base_idx + i * axis_stride;
let mut coords = vec![0; input.ndim()];
let mut temp_idx = idx;
for dim in (0..input.ndim()).rev() {
coords[dim] = temp_idx % input_shape[dim];
temp_idx /= input_shape[dim];
}
let value = input.data()[&*coords];
axis_values.push(value);
}
let max_val = axis_values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_values: Vec<f32> = axis_values.iter().map(|&x| (x - max_val).exp()).collect();
let sum_exp: f32 = exp_values.iter().sum();
if sum_exp == 0.0 || !sum_exp.is_finite() {
let uniform_val = 1.0 / axis_size as f32;
for i in 0..axis_size {
let idx = base_idx + i * axis_stride;
let mut coords = vec![0; input.ndim()];
let mut temp_idx = idx;
for dim in (0..input.ndim()).rev() {
coords[dim] = temp_idx % input_shape[dim];
temp_idx /= input_shape[dim];
}
output.data_mut()[&*coords] = uniform_val;
}
} else {
for (i, &exp_val) in exp_values.iter().enumerate().take(axis_size) {
let idx = base_idx + i * axis_stride;
let mut coords = vec![0; input.ndim()];
let mut temp_idx = idx;
for dim in (0..input.ndim()).rev() {
coords[dim] = temp_idx % input_shape[dim];
temp_idx /= input_shape[dim];
}
let softmax_val = exp_val / sum_exp;
output.data_mut()[&*coords] = softmax_val;
}
}
}
log::debug!("Softmax: output shape {:?}", output.shape());
Ok(vec![output])
}
}
fn nms_op(
inputs: &[Tensor],
_attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() < 2 {
return Err(OnnxError::invalid_dimensions(format!(
"NonMaxSuppression operator requires at least 2 inputs (boxes, scores), got {}",
inputs.len()
)));
}
log::warn!("NonMaxSuppression operator is simplified - returning empty tensor");
let empty_result = Tensor::zeros(&[0, 3]); Ok(vec![empty_result])
}
fn batch_norm_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() < 5 {
return Err(OnnxError::invalid_dimensions(format!(
"BatchNormalization requires 5 inputs (input, scale, bias, mean, var), got {}",
inputs.len()
)));
}
let input = &inputs[0];
let scale = &inputs[1];
let bias = &inputs[2];
let mean = &inputs[3];
let variance = &inputs[4];
let epsilon = attributes
.get("epsilon")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(1e-5);
log::debug!(
"BatchNorm: input shape {:?}, epsilon {}",
input.shape(),
epsilon
);
let input_shape = input.shape();
if input.ndim() != 4 {
return Err(OnnxError::invalid_dimensions(
"BatchNormalization currently supports only 4D input tensors (NCHW)".to_string(),
));
}
let batch_size = input_shape[0];
let channels = input_shape[1];
let height = input_shape[2];
let width = input_shape[3];
let expected_param_shape = [channels];
if scale.shape() != expected_param_shape
|| bias.shape() != expected_param_shape
|| mean.shape() != expected_param_shape
|| variance.shape() != expected_param_shape
{
return Err(OnnxError::invalid_dimensions(format!(
"BatchNorm parameters must have shape [{channels}], got scale: {:?}, bias: {:?}, mean: {:?}, var: {:?}",
scale.shape(), bias.shape(), mean.shape(), variance.shape()
)));
}
let scale_data = scale.data().as_slice().unwrap();
let bias_data = bias.data().as_slice().unwrap();
let mean_data = mean.data().as_slice().unwrap();
let var_data = variance.data().as_slice().unwrap();
let input_data = input.data().as_slice().unwrap();
let mut output_data = vec![0.0; input_data.len()];
let idx = |n: usize, c: usize, h: usize, w: usize| -> usize {
n * (channels * height * width) + c * (height * width) + h * width + w
};
for n in 0..batch_size {
for c in 0..channels {
let scale_val = scale_data[c];
let bias_val = bias_data[c];
let mean_val = mean_data[c];
let var_val = var_data[c];
let inv_std = 1.0 / (var_val + epsilon).sqrt();
for h in 0..height {
for w in 0..width {
let i = idx(n, c, h, w);
let normalized = (input_data[i] - mean_val) * inv_std;
output_data[i] = scale_val * normalized + bias_val;
}
}
}
}
let result = Tensor::from_shape_vec(input_shape, output_data)?;
log::debug!("BatchNorm: output shape {:?}", result.shape());
Ok(vec![result])
}
fn split_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(OnnxError::invalid_dimensions(
"Split requires at least 1 input".to_string(),
));
}
let input = &inputs[0];
let axis = attributes
.get("axis")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(0);
let normalized_axis = if axis < 0 {
(input.ndim() as i32 + axis) as usize
} else {
axis as usize
};
if normalized_axis >= input.ndim() {
return Err(OnnxError::invalid_dimensions(format!(
"Split axis {} out of bounds for tensor with {} dimensions",
axis,
input.ndim()
)));
}
let axis_size = input.shape()[normalized_axis];
let split_sizes = if inputs.len() >= 2 {
let split_tensor = &inputs[1];
split_tensor
.data()
.iter()
.map(|&x| x as usize)
.collect::<Vec<_>>()
} else if let Some(split_attr) = attributes.get("split") {
let cleaned = split_attr.trim_matches(['[', ']']);
cleaned
.split(',')
.map(|s| s.trim().parse::<usize>().unwrap_or(1))
.collect()
} else {
let num_splits = 2;
let chunk_size = axis_size / num_splits;
let remainder = axis_size % num_splits;
let mut sizes = vec![chunk_size; num_splits];
for size in sizes.iter_mut().take(remainder) {
*size += 1;
}
sizes
};
let total_size: usize = split_sizes.iter().sum();
if total_size != axis_size {
return Err(OnnxError::invalid_dimensions(format!(
"Split sizes sum ({total_size}) must equal axis size ({axis_size})"
)));
}
log::debug!(
"Split: input shape {:?}, axis {}, sizes {:?}",
input.shape(),
normalized_axis,
split_sizes
);
let mut results = Vec::new();
let mut current_offset = 0;
for &split_size in &split_sizes {
if split_size == 0 {
continue; }
let mut starts = vec![0i64; input.ndim()];
let mut ends = vec![0i64; input.ndim()];
for i in 0..input.ndim() {
if i == normalized_axis {
starts[i] = current_offset as i64;
ends[i] = (current_offset + split_size) as i64;
} else {
starts[i] = 0;
ends[i] = input.shape()[i] as i64;
}
}
let split_result = input.slice(&starts, &ends, None, None)?;
results.push(split_result);
current_offset += split_size;
}
log::debug!("Split: created {} output tensors", results.len());
Ok(results)
}
fn gather_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(OnnxError::invalid_dimensions(format!(
"Gather requires exactly 2 inputs (data, indices), got {}",
inputs.len()
)));
}
let data = &inputs[0];
let indices = &inputs[1];
let axis = attributes
.get("axis")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(0);
let normalized_axis = if axis < 0 {
(data.ndim() as i32 + axis) as usize
} else {
axis as usize
};
if normalized_axis >= data.ndim() {
return Err(OnnxError::invalid_dimensions(format!(
"Gather axis {} out of bounds for tensor with {} dimensions",
axis,
data.ndim()
)));
}
let data_shape = data.shape();
let indices_shape = indices.shape();
let axis_size = data_shape[normalized_axis];
log::debug!("Gather: data shape {data_shape:?}, indices shape {indices_shape:?}, axis {normalized_axis}");
let mut output_shape = Vec::new();
output_shape.extend_from_slice(&data_shape[..normalized_axis]);
output_shape.extend_from_slice(indices_shape);
output_shape.extend_from_slice(&data_shape[normalized_axis + 1..]);
log::debug!("Gather: output shape {output_shape:?}");
let indices_data = indices.data().as_slice().unwrap();
let indices_int: Vec<usize> = indices_data
.iter()
.map(|&x| {
let idx = x as i64;
let normalized_idx = if idx < 0 {
(axis_size as i64 + idx) as usize
} else {
idx as usize
};
if normalized_idx >= axis_size {
axis_size - 1
} else {
normalized_idx
}
})
.collect();
let mut data_strides = vec![1; data.ndim()];
for i in (0..data.ndim() - 1).rev() {
data_strides[i] = data_strides[i + 1] * data_shape[i + 1];
}
let mut output_strides = vec![1; output_shape.len()];
for i in (0..output_shape.len() - 1).rev() {
output_strides[i] = output_strides[i + 1] * output_shape[i + 1];
}
let data_data = data.data().as_slice().unwrap();
let mut output_data = vec![0.0; output_shape.iter().product()];
for (output_idx, output_val) in output_data.iter_mut().enumerate() {
let mut output_coords = vec![0; output_shape.len()];
let mut remaining = output_idx;
for i in 0..output_shape.len() {
output_coords[i] = remaining / output_strides[i];
remaining %= output_strides[i];
}
let mut data_coords = vec![0; data.ndim()];
data_coords[..normalized_axis].copy_from_slice(&output_coords[..normalized_axis]);
let indices_start = normalized_axis;
let indices_end = indices_start + indices.ndim();
let mut indices_idx = 0;
let mut indices_stride = 1;
for i in (indices_start..indices_end).rev() {
indices_idx += output_coords[i] * indices_stride;
indices_stride *= indices_shape[i - indices_start];
}
data_coords[normalized_axis] = indices_int[indices_idx];
for i in (normalized_axis + 1)..data.ndim() {
data_coords[i] = output_coords[indices_end + i - normalized_axis - 1];
}
let mut data_idx = 0;
for i in 0..data.ndim() {
data_idx += data_coords[i] * data_strides[i];
}
*output_val = data_data[data_idx];
}
let result = Tensor::from_shape_vec(&output_shape, output_data)?;
log::debug!("Gather: result shape {:?}", result.shape());
Ok(vec![result])
}
fn constant_of_shape_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"ConstantOfShape requires exactly 1 input, got {}",
inputs.len()
)));
}
let value = attributes
.get("value")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(0.0);
let shape_data = inputs[0].data();
let shape: Vec<usize> = shape_data.iter().map(|&x| x as usize).collect();
let result = if value == 0.0 {
Tensor::zeros(&shape)
} else if value == 1.0 {
Tensor::ones(&shape)
} else {
let data = vec![value; shape.iter().product()];
Tensor::from_shape_vec(&shape, data)?
};
Ok(vec![result])
}
fn cast_op(
inputs: &[Tensor],
_attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Cast requires exactly 1 input, got {}",
inputs.len()
)));
}
log::warn!("Cast operator is simplified - returning input tensor");
Ok(vec![inputs[0].clone()])
}
fn shape_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Shape requires exactly 1 input, got {}",
inputs.len()
)));
}
let shape_data: Vec<f32> = inputs[0].shape().iter().map(|&dim| dim as f32).collect();
let shape_tensor = Tensor::from_shape_vec(&[shape_data.len()], shape_data)?;
Ok(vec![shape_tensor])
}
fn unsqueeze_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(OnnxError::invalid_dimensions(
"Unsqueeze requires at least 1 input".to_string(),
));
}
let input = &inputs[0];
let axes: Vec<i64> = if inputs.len() >= 2 {
let axes_tensor = &inputs[1];
axes_tensor.data().iter().map(|&x| x as i64).collect()
} else if let Some(axes_attr) = attributes.get("axes") {
let cleaned = axes_attr.trim_matches(['[', ']']);
cleaned
.split(',')
.map(|s| s.trim().parse::<i64>().unwrap_or(0))
.collect()
} else {
return Err(OnnxError::invalid_dimensions(
"Unsqueeze requires axes to be specified".to_string(),
));
};
if axes.is_empty() {
return Ok(vec![input.clone()]);
}
let input_shape = input.shape();
let input_ndim = input.ndim() as i64;
let output_ndim = input_ndim + axes.len() as i64;
let mut normalized_axes: Vec<usize> = axes
.iter()
.map(|&axis| {
if axis < 0 {
(output_ndim + axis) as usize
} else {
axis as usize
}
})
.collect();
for &axis in &normalized_axes {
if axis >= output_ndim as usize {
return Err(OnnxError::invalid_dimensions(format!(
"Unsqueeze axis {axis} out of bounds for output with {output_ndim} dimensions"
)));
}
}
normalized_axes.sort();
for i in 1..normalized_axes.len() {
if normalized_axes[i] == normalized_axes[i - 1] {
return Err(OnnxError::invalid_dimensions(
"Unsqueeze axes must be unique".to_string(),
));
}
}
log::debug!("Unsqueeze: input shape {input_shape:?}, axes {normalized_axes:?}");
let mut output_shape = Vec::with_capacity(output_ndim as usize);
let mut input_dim_idx = 0;
for output_dim_idx in 0..(output_ndim as usize) {
if normalized_axes.contains(&output_dim_idx) {
output_shape.push(1);
} else {
output_shape.push(input_shape[input_dim_idx]);
input_dim_idx += 1;
}
}
log::debug!("Unsqueeze: output shape {output_shape:?}");
let result = input.reshape(&output_shape)?;
Ok(vec![result])
}
fn squeeze_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(OnnxError::invalid_dimensions(
"Squeeze requires at least 1 input".to_string(),
));
}
let input = &inputs[0];
let input_shape = input.shape();
let axes: Option<Vec<i64>> = if inputs.len() >= 2 {
let axes_tensor = &inputs[1];
Some(axes_tensor.data().iter().map(|&x| x as i64).collect())
} else if let Some(axes_attr) = attributes.get("axes") {
let cleaned = axes_attr.trim_matches(['[', ']']);
Some(
cleaned
.split(',')
.map(|s| s.trim().parse::<i64>().unwrap_or(0))
.collect(),
)
} else {
None };
log::debug!("Squeeze: input shape {input_shape:?}, axes {axes:?}");
let input_ndim = input.ndim() as i64;
let axes_to_squeeze: Vec<usize> = if let Some(specified_axes) = axes {
let mut normalized_axes = Vec::new();
for &axis in &specified_axes {
let normalized_axis = if axis < 0 {
(input_ndim + axis) as usize
} else {
axis as usize
};
if normalized_axis >= input.ndim() {
return Err(OnnxError::invalid_dimensions(format!(
"Squeeze axis {} out of bounds for tensor with {} dimensions",
axis,
input.ndim()
)));
}
if input_shape[normalized_axis] != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Cannot squeeze axis {} with size {}",
axis, input_shape[normalized_axis]
)));
}
normalized_axes.push(normalized_axis);
}
normalized_axes
} else {
input_shape
.iter()
.enumerate()
.filter(|(_, &size)| size == 1)
.map(|(idx, _)| idx)
.collect()
};
let output_shape: Vec<usize> = input_shape
.iter()
.enumerate()
.filter(|(idx, _)| !axes_to_squeeze.contains(idx))
.map(|(_, &size)| size)
.collect();
log::debug!("Squeeze: output shape {output_shape:?}");
let result = if output_shape.is_empty() {
if input.len() != 1 {
return Err(OnnxError::invalid_dimensions(
"Cannot squeeze to scalar: input must have exactly 1 element".to_string(),
));
}
let scalar_value = input.data().iter().next().unwrap();
Tensor::from_shape_vec(&[], vec![*scalar_value])?
} else {
input.reshape(&output_shape)?
};
Ok(vec![result])
}
fn pad_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() < 2 {
return Err(OnnxError::invalid_dimensions(
"Pad requires at least 2 inputs (data, pads)".to_string(),
));
}
let input = &inputs[0];
let pads_tensor = &inputs[1];
let constant_value = if inputs.len() >= 3 {
inputs[2].data().iter().next().copied().unwrap_or(0.0)
} else {
0.0
};
let mode = attributes
.get("mode")
.map(|s| s.as_str())
.unwrap_or("constant");
let pads_data = pads_tensor.data().as_slice().unwrap();
let pads: Vec<usize> = pads_data.iter().map(|&x| x as usize).collect();
let input_shape = input.shape();
let ndim = input.ndim();
if pads.len() != 2 * ndim {
return Err(OnnxError::invalid_dimensions(format!(
"Pads length ({}) must be 2 * input dimensions ({})",
pads.len(),
2 * ndim
)));
}
let begin_pads = &pads[..ndim];
let end_pads = &pads[ndim..];
log::debug!("Pad: input shape {input_shape:?}, begin_pads {begin_pads:?}, end_pads {end_pads:?}, mode {mode}, value {constant_value}");
let output_shape: Vec<usize> = input_shape
.iter()
.enumerate()
.map(|(i, &size)| size + begin_pads[i] + end_pads[i])
.collect();
if mode != "constant" {
log::warn!("Pad: only constant mode supported, got {mode}");
return Ok(vec![input.clone()]);
}
log::debug!("Pad: output shape {output_shape:?}");
let mut output_data = vec![constant_value; output_shape.iter().product()];
let input_data = input.data().as_slice().unwrap();
let mut input_strides = vec![1; ndim];
let mut output_strides = vec![1; ndim];
for i in (0..ndim - 1).rev() {
input_strides[i] = input_strides[i + 1] * input_shape[i + 1];
output_strides[i] = output_strides[i + 1] * output_shape[i + 1];
}
let total_input_elements = input_data.len();
for (input_idx, &input_val) in input_data.iter().enumerate().take(total_input_elements) {
let mut input_coords = vec![0; ndim];
let mut remaining = input_idx;
for i in 0..ndim {
input_coords[i] = remaining / input_strides[i];
remaining %= input_strides[i];
}
let output_coords: Vec<usize> = input_coords
.iter()
.enumerate()
.map(|(i, &coord)| coord + begin_pads[i])
.collect();
let mut output_idx = 0;
for i in 0..ndim {
output_idx += output_coords[i] * output_strides[i];
}
output_data[output_idx] = input_val;
}
let result = Tensor::from_shape_vec(&output_shape, output_data)?;
log::debug!("Pad: result shape {:?}", result.shape());
Ok(vec![result])
}
fn div_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(OnnxError::invalid_dimensions(format!(
"Div requires exactly 2 inputs, got {}",
inputs.len()
)));
}
log::debug!(
"Div: input[0] shape {:?}, input[1] shape {:?}",
inputs[0].shape(),
inputs[1].shape()
);
let result = inputs[0].div(&inputs[1])?;
Ok(vec![result])
}
fn sub_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(OnnxError::invalid_dimensions(format!(
"Sub requires exactly 2 inputs, got {}",
inputs.len()
)));
}
let result = inputs[0].sub(&inputs[1])?;
Ok(vec![result])
}
fn exp_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Exp requires exactly 1 input, got {}",
inputs.len()
)));
}
let result = inputs[0].exp()?;
Ok(vec![result])
}
fn sqrt_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Sqrt requires exactly 1 input, got {}",
inputs.len()
)));
}
let result = inputs[0].sqrt()?;
Ok(vec![result])
}
fn pow_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(OnnxError::invalid_dimensions(format!(
"Pow requires exactly 2 inputs, got {}",
inputs.len()
)));
}
let result = inputs[0].pow(&inputs[1])?;
Ok(vec![result])
}
fn reduce_mean_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"ReduceMean requires exactly 1 input, got {}",
inputs.len()
)));
}
let _axes = attributes.get("axes");
let _keepdims = attributes
.get("keepdims")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(1);
let input = &inputs[0];
let mean_value = input.data().iter().sum::<f32>() / input.data().len() as f32;
let result = Tensor::from_shape_vec(&[1], vec![mean_value])?;
Ok(vec![result])
}
fn identity_op(inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(OnnxError::invalid_dimensions(format!(
"Identity requires exactly 1 input, got {}",
inputs.len()
)));
}
Ok(vec![inputs[0].clone()])
}
fn resize_op(
inputs: &[Tensor],
attributes: &std::collections::HashMap<String, String>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(OnnxError::invalid_dimensions(
"Resize requires at least 1 input".to_string(),
));
}
let input = &inputs[0];
log::debug!(
"Resize: input shape {:?}, num_inputs: {}",
input.shape(),
inputs.len()
);
for (key, value) in attributes {
log::debug!("Resize attribute: {key} = {value}");
}
for (i, inp) in inputs.iter().enumerate() {
let first_few: Vec<f32> = inp.data().iter().take(8).copied().collect();
log::debug!(
"Resize input[{}]: shape {:?}, first few values: {:?}",
i,
inp.shape(),
first_few
);
}
if inputs.len() >= 2 {
let scales_tensor = &inputs[1]; let scales: Vec<f32> = scales_tensor.data().iter().copied().collect();
log::debug!("Resize scales from input: {scales:?}");
if scales.len() >= 4 && input.ndim() == 4 {
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let height = input_shape[2];
let width = input_shape[3];
let new_height = (height as f32 * scales[2]) as usize;
let new_width = (width as f32 * scales[3]) as usize;
log::debug!(
"Resize: {}x{} -> {}x{} (scales: {:.2}x{:.2})",
height,
width,
new_height,
new_width,
scales[2],
scales[3]
);
let output_shape = [batch_size, channels, new_height, new_width];
let mut output_data = Vec::with_capacity(output_shape.iter().product());
let input_data = input.data();
for batch in 0..batch_size {
for channel in 0..channels {
for new_h in 0..new_height {
for new_w in 0..new_width {
let orig_h = ((new_h as f32) / scales[2]) as usize;
let orig_w = ((new_w as f32) / scales[3]) as usize;
let orig_h = orig_h.min(height - 1);
let orig_w = orig_w.min(width - 1);
let value = input_data[[batch, channel, orig_h, orig_w]];
output_data.push(value);
}
}
}
}
let result = Tensor::from_shape_vec(&output_shape, output_data)?;
log::debug!(
"Resize: input {:?} -> output {:?} (scales: {:?})",
input.shape(),
result.shape(),
scales
);
return Ok(vec![result]);
}
}
if let Some(scales_str) = attributes.get("scales") {
log::debug!("Resize scales from attributes: {scales_str}");
let cleaned = scales_str.trim_matches(['[', ']']);
let scales: std::result::Result<Vec<f32>, _> = cleaned
.split(',')
.map(|s| s.trim().parse::<f32>())
.collect();
if let Ok(scales) = scales {
if scales.len() >= 4 && input.ndim() == 4 {
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let height = input_shape[2];
let width = input_shape[3];
let new_height = (height as f32 * scales[2]) as usize;
let new_width = (width as f32 * scales[3]) as usize;
log::debug!(
"Resize from attributes: {}x{} -> {}x{} (scales: {:.2}x{:.2})",
height,
width,
new_height,
new_width,
scales[2],
scales[3]
);
let output_shape = [batch_size, channels, new_height, new_width];
let mut output_data = Vec::with_capacity(output_shape.iter().product());
let input_data = input.data();
for batch in 0..batch_size {
for channel in 0..channels {
for new_h in 0..new_height {
for new_w in 0..new_width {
let orig_h = ((new_h as f32) / scales[2]) as usize;
let orig_w = ((new_w as f32) / scales[3]) as usize;
let orig_h = orig_h.min(height - 1);
let orig_w = orig_w.min(width - 1);
let value = input_data[[batch, channel, orig_h, orig_w]];
output_data.push(value);
}
}
}
}
let result = Tensor::from_shape_vec(&output_shape, output_data)?;
log::debug!(
"Resize: input {:?} -> output {:?} (scales: {:?})",
input.shape(),
result.shape(),
scales
);
return Ok(vec![result]);
}
}
}
log::warn!("Resize operator using simplified implementation - returning input tensor");
log::debug!(
"Resize fallback: input shape {:?}, attributes: {:?}",
input.shape(),
attributes
);
Ok(vec![input.clone()])
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Tensor;
use ndarray::{Array1, Array2, Array3, Array4};
use std::collections::HashMap;
#[test]
fn test_operator_type_from_str() {
assert_eq!("Add".parse::<OperatorType>().unwrap(), OperatorType::Add);
assert_eq!("Mul".parse::<OperatorType>().unwrap(), OperatorType::Mul);
assert_eq!(
"MatMul".parse::<OperatorType>().unwrap(),
OperatorType::MatMul
);
assert_eq!("Conv".parse::<OperatorType>().unwrap(), OperatorType::Conv);
assert_eq!("Relu".parse::<OperatorType>().unwrap(), OperatorType::Relu);
assert_eq!(
"Sigmoid".parse::<OperatorType>().unwrap(),
OperatorType::Sigmoid
);
assert_eq!(
"Reshape".parse::<OperatorType>().unwrap(),
OperatorType::Reshape
);
assert_eq!(
"Transpose".parse::<OperatorType>().unwrap(),
OperatorType::Transpose
);
let result = "Unknown".parse::<OperatorType>();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unknown"));
assert!("add".parse::<OperatorType>().is_err());
assert!("ADD".parse::<OperatorType>().is_err());
}
#[test]
fn test_operator_type_debug() {
assert_eq!(format!("{:?}", OperatorType::Add), "Add");
assert_eq!(format!("{:?}", OperatorType::Conv), "Conv");
}
#[test]
fn test_operator_type_clone_eq() {
let op1 = OperatorType::Add;
let op2 = op1.clone();
assert_eq!(op1, op2);
assert_ne!(OperatorType::Add, OperatorType::Mul);
}
#[test]
fn test_add_op() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let b = Tensor::from_array(Array1::from_vec(vec![4.0, 5.0, 6.0]));
let inputs = vec![a, b];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Add, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
let expected = [5.0, 7.0, 9.0];
for (actual, &expected) in result[0].data().iter().zip(expected.iter()) {
assert!((actual - expected).abs() < 1e-6);
}
}
#[test]
fn test_add_op_wrong_inputs() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let inputs = vec![a]; let attrs = HashMap::new();
#[cfg(feature = "formal-verification")]
{
let result =
std::panic::catch_unwind(|| execute_operator(&OperatorType::Add, &inputs, &attrs));
assert!(
result.is_err(),
"Should panic with formal verification enabled"
);
}
#[cfg(not(feature = "formal-verification"))]
{
let result = execute_operator(&OperatorType::Add, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 2 inputs"));
}
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let b = Tensor::from_array(Array1::from_vec(vec![4.0, 5.0, 6.0]));
let c = Tensor::from_array(Array1::from_vec(vec![7.0, 8.0, 9.0]));
let inputs = vec![a, b, c]; let attrs = HashMap::new();
#[cfg(feature = "formal-verification")]
{
let result =
std::panic::catch_unwind(|| execute_operator(&OperatorType::Add, &inputs, &attrs));
assert!(
result.is_err(),
"Should panic with formal verification enabled"
);
}
#[cfg(not(feature = "formal-verification"))]
{
let result = execute_operator(&OperatorType::Add, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 2 inputs"));
}
}
#[test]
fn test_mul_op() {
let a = Tensor::from_array(Array1::from_vec(vec![2.0, 3.0, 4.0]));
let b = Tensor::from_array(Array1::from_vec(vec![5.0, 6.0, 7.0]));
let inputs = vec![a, b];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Mul, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
let expected = [10.0, 18.0, 28.0];
for (actual, &expected) in result[0].data().iter().zip(expected.iter()) {
assert!((actual - expected).abs() < 1e-6);
}
}
#[test]
fn test_mul_op_wrong_inputs() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let inputs = vec![a]; let attrs = HashMap::new();
#[cfg(feature = "formal-verification")]
{
let result =
std::panic::catch_unwind(|| execute_operator(&OperatorType::Mul, &inputs, &attrs));
assert!(
result.is_err(),
"Should panic with formal verification enabled"
);
}
#[cfg(not(feature = "formal-verification"))]
{
let result = execute_operator(&OperatorType::Mul, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 2 inputs"));
}
}
#[test]
fn test_matmul_op() {
let a = Tensor::from_array(
Array2::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]).unwrap(),
);
let b = Tensor::from_array(
Array2::from_shape_vec((3, 2), vec![1., 2., 3., 4., 5., 6.]).unwrap(),
);
let inputs = vec![a, b];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::MatMul, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[2, 2]);
}
#[test]
fn test_matmul_op_wrong_inputs() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let inputs = vec![a]; let attrs = HashMap::new();
#[cfg(feature = "formal-verification")]
{
let result = std::panic::catch_unwind(|| {
execute_operator(&OperatorType::MatMul, &inputs, &attrs)
});
assert!(
result.is_err(),
"Should panic with formal verification enabled"
);
}
#[cfg(not(feature = "formal-verification"))]
{
let result = execute_operator(&OperatorType::MatMul, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 2 inputs"));
}
}
#[test]
fn test_conv_op() {
let input = Tensor::zeros(&[1, 1, 3, 3]); let kernel = Tensor::zeros(&[1, 1, 2, 2]); let inputs = vec![input, kernel];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Conv, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[1, 1, 2, 2]);
}
#[test]
fn test_conv_op_insufficient_inputs() {
let input = Tensor::zeros(&[1, 1, 3, 3]);
let inputs = vec![input]; let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Conv, &inputs, &attrs);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("at least 2 inputs"));
}
#[test]
fn test_conv_op_wrong_dimensions() {
let input = Tensor::zeros(&[1, 3, 3]); let kernel = Tensor::zeros(&[1, 1, 2, 2]); let inputs = vec![input, kernel];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Conv, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("4D tensors"));
let input = Tensor::zeros(&[1, 1, 3, 3]); let kernel = Tensor::zeros(&[1, 2, 2]); let inputs = vec![input, kernel];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Conv, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("4D tensors"));
}
#[test]
fn test_relu_op() {
let a = Tensor::from_array(Array1::from_vec(vec![-1.0, 0.0, 1.0, 2.0]));
let inputs = vec![a];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Relu, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
let expected = [0.0, 0.0, 1.0, 2.0];
for (actual, &expected) in result[0].data().iter().zip(expected.iter()) {
assert!((actual - expected).abs() < 1e-6);
}
}
#[test]
fn test_relu_op_wrong_inputs() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let b = Tensor::from_array(Array1::from_vec(vec![4.0, 5.0, 6.0]));
let inputs = vec![a, b]; let attrs = HashMap::new();
#[cfg(feature = "formal-verification")]
{
let result =
std::panic::catch_unwind(|| execute_operator(&OperatorType::Relu, &inputs, &attrs));
assert!(
result.is_err(),
"Should panic with formal verification enabled"
);
}
#[cfg(not(feature = "formal-verification"))]
{
let result = execute_operator(&OperatorType::Relu, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 1 input"));
}
}
#[test]
fn test_sigmoid_op() {
let a = Tensor::from_array(Array1::from_vec(vec![0.0]));
let inputs = vec![a];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Sigmoid, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert!((result[0].data()[0] - 0.5).abs() < 1e-6);
}
#[test]
fn test_sigmoid_op_wrong_inputs() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let b = Tensor::from_array(Array1::from_vec(vec![4.0, 5.0, 6.0]));
let inputs = vec![a, b]; let attrs = HashMap::new();
#[cfg(feature = "formal-verification")]
{
let result = std::panic::catch_unwind(|| {
execute_operator(&OperatorType::Sigmoid, &inputs, &attrs)
});
assert!(
result.is_err(),
"Should panic with formal verification enabled"
);
}
#[cfg(not(feature = "formal-verification"))]
{
let result = execute_operator(&OperatorType::Sigmoid, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 1 input"));
}
}
#[test]
fn test_reshape_op() {
let data = Tensor::from_shape_vec(&[2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap();
let shape = Tensor::from_array(Array1::from_vec(vec![3.0, 2.0])); let inputs = vec![data, shape];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Reshape, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[3, 2]);
}
#[test]
fn test_reshape_op_wrong_inputs() {
let data = Tensor::from_shape_vec(&[2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap();
let inputs = vec![data]; let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Reshape, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 2 inputs"));
let data = Tensor::from_shape_vec(&[2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap();
let shape = Tensor::from_array(Array1::from_vec(vec![3.0, 2.0]));
let extra = Tensor::from_array(Array1::from_vec(vec![1.0]));
let inputs = vec![data, shape, extra]; let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Reshape, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 2 inputs"));
}
#[test]
fn test_transpose_op() {
let a = Tensor::from_shape_vec(&[2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap();
let inputs = vec![a];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Transpose, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[3, 2]);
}
#[test]
fn test_transpose_op_wrong_inputs() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let b = Tensor::from_array(Array1::from_vec(vec![4.0, 5.0, 6.0]));
let inputs = vec![a, b]; let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Transpose, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 1 input"));
}
#[test]
fn test_execute_operator_with_attributes() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let inputs = vec![a];
let mut attrs = HashMap::new();
attrs.insert("test_attr".to_string(), "test_value".to_string());
let result = execute_operator(&OperatorType::Relu, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_all_operator_types_execute() {
let tensor_1d = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let tensor_2d = Tensor::from_array(
Array2::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]).unwrap(),
);
let tensor_4d = Tensor::zeros(&[1, 1, 2, 2]);
let shape_tensor = Tensor::from_array(Array1::from_vec(vec![3.0, 2.0]));
let attrs = HashMap::new();
assert!(execute_operator(&OperatorType::Relu, &[tensor_1d.clone()], &attrs).is_ok());
assert!(execute_operator(&OperatorType::Sigmoid, &[tensor_1d.clone()], &attrs).is_ok());
assert!(execute_operator(&OperatorType::Transpose, &[tensor_2d.clone()], &attrs).is_ok());
assert!(execute_operator(
&OperatorType::Add,
&[tensor_1d.clone(), tensor_1d.clone()],
&attrs
)
.is_ok());
assert!(execute_operator(
&OperatorType::Mul,
&[tensor_1d.clone(), tensor_1d.clone()],
&attrs
)
.is_ok());
assert!(execute_operator(
&OperatorType::MatMul,
&[tensor_2d.clone(), tensor_2d.transpose().unwrap()],
&attrs
)
.is_ok());
assert!(execute_operator(
&OperatorType::Reshape,
&[tensor_2d.clone(), shape_tensor],
&attrs
)
.is_ok());
assert!(execute_operator(
&OperatorType::Conv,
&[tensor_4d.clone(), tensor_4d.clone()],
&attrs
)
.is_ok());
}
#[test]
fn test_formal_addition_identity() {
let tensor = Tensor::from_shape_vec(&[3], vec![1.0, 2.0, 3.0]).unwrap();
let zero = Tensor::zeros(&[3]);
let result = tensor.add(&zero).unwrap();
assert_eq!(result.data(), tensor.data());
}
#[test]
fn test_formal_addition_commutativity() {
let tensor_a = Tensor::from_shape_vec(&[3], vec![1.0, 2.0, 3.0]).unwrap();
let tensor_b = Tensor::from_shape_vec(&[3], vec![4.0, 5.0, 6.0]).unwrap();
let result1 = tensor_a.add(&tensor_b).unwrap();
let result2 = tensor_b.add(&tensor_a).unwrap();
assert_eq!(result1.data(), result2.data());
}
#[test]
fn test_formal_multiplication_commutativity() {
let tensor_a = Tensor::from_shape_vec(&[3], vec![2.0, 3.0, 4.0]).unwrap();
let tensor_b = Tensor::from_shape_vec(&[3], vec![5.0, 6.0, 7.0]).unwrap();
let result1 = tensor_a.mul(&tensor_b).unwrap();
let result2 = tensor_b.mul(&tensor_a).unwrap();
assert_eq!(result1.data(), result2.data());
}
#[test]
fn test_formal_relu_non_negativity() {
let tensor = Tensor::from_shape_vec(&[5], vec![-2.0, -1.0, 0.0, 1.0, 2.0]).unwrap();
let result = tensor.relu().unwrap();
for &value in result.data() {
assert!(
value >= 0.0,
"ReLU output must be non-negative, got {value}"
);
}
}
#[test]
fn test_formal_relu_idempotency() {
let tensor = Tensor::from_shape_vec(&[5], vec![-2.0, -1.0, 0.0, 1.0, 2.0]).unwrap();
let result1 = tensor.relu().unwrap();
let result2 = result1.relu().unwrap();
assert_eq!(result1.data(), result2.data());
}
#[test]
fn test_formal_sigmoid_bounded() {
let tensor = Tensor::from_shape_vec(&[5], vec![-10.0, -1.0, 0.0, 1.0, 10.0]).unwrap();
let result = tensor.sigmoid().unwrap();
for &value in result.data() {
assert!(
value > 0.0 && value < 1.0,
"Sigmoid output must be in (0,1), got {value}"
);
}
}
#[test]
fn test_formal_matmul_dimensions() {
let matrix_a = Tensor::from_shape_vec(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let matrix_b = Tensor::from_shape_vec(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let result = matrix_a.matmul(&matrix_b).unwrap();
assert_eq!(result.shape(), [2, 2]);
}
#[test]
fn test_formal_matmul_rectangular() {
let matrix_a = Tensor::from_shape_vec(&[1, 3], vec![1.0, 2.0, 3.0]).unwrap();
let matrix_b = Tensor::from_shape_vec(&[3, 2], vec![4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
let result = matrix_a.matmul(&matrix_b).unwrap();
assert_eq!(result.shape(), [1, 2]);
}
#[test]
fn test_concat_op() {
let tensor1 = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0]));
let tensor2 = Tensor::from_array(Array1::from_vec(vec![3.0, 4.0]));
let inputs = vec![tensor1, tensor2];
let mut attrs = HashMap::new();
attrs.insert("axis".to_string(), "0".to_string());
let result = execute_operator(&OperatorType::Concat, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[4]);
let data = result[0].data().as_slice().unwrap();
assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_concat_op_empty_inputs() {
let inputs = vec![];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Concat, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("at least 1 input"));
}
#[test]
fn test_slice_op() {
let tensor = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]));
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("starts".to_string(), "1".to_string());
attrs.insert("ends".to_string(), "3".to_string());
let result = execute_operator(&OperatorType::Slice, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[2]);
let slice_data = result[0].data().as_slice().unwrap();
assert_eq!(slice_data, &[2.0, 3.0]);
}
#[test]
fn test_upsample_op() {
let tensor = Tensor::from_shape_vec(&[1, 1, 2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("mode".to_string(), "nearest".to_string());
let result = execute_operator(&OperatorType::Upsample, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[1, 1, 2, 2]);
}
#[test]
fn test_maxpool_op() {
let tensor = Tensor::zeros(&[1, 1, 4, 4]); let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("kernel_shape".to_string(), "2,2".to_string());
let result = execute_operator(&OperatorType::MaxPool, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[1, 1, 3, 3]);
}
#[test]
fn test_maxpool_op_wrong_dimensions() {
let tensor = Tensor::zeros(&[4, 4]); let inputs = vec![tensor];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::MaxPool, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("4D input tensor"));
}
#[test]
fn test_softmax_op() {
let tensor = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let inputs = vec![tensor];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Softmax, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
let sum: f32 = result[0].data().iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
for &value in result[0].data() {
assert!(value > 0.0);
}
}
#[test]
fn test_softmax_op_wrong_inputs() {
let tensor1 = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0]));
let tensor2 = Tensor::from_array(Array1::from_vec(vec![3.0, 4.0]));
let inputs = vec![tensor1, tensor2]; let attrs = HashMap::new();
let result = execute_operator(&OperatorType::Softmax, &inputs, &attrs);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exactly 1 input"));
}
#[test]
fn test_nms_op() {
let boxes = Tensor::zeros(&[1, 4, 4]); let scores = Tensor::ones(&[1, 1, 4]); let inputs = vec![boxes, scores];
let attrs = HashMap::new();
let result = execute_operator(&OperatorType::NonMaxSuppression, &inputs, &attrs).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].shape(), &[0, 3]); }
#[test]
fn test_nms_op_insufficient_inputs() {
let boxes = Tensor::zeros(&[1, 4, 4]);
let inputs = vec![boxes]; let attrs = HashMap::new();
let result = execute_operator(&OperatorType::NonMaxSuppression, &inputs, &attrs);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("at least 2 inputs"));
}
#[test]
fn test_extended_operator_types_from_str() {
assert_eq!(
"Concat".parse::<OperatorType>().unwrap(),
OperatorType::Concat
);
assert_eq!(
"Slice".parse::<OperatorType>().unwrap(),
OperatorType::Slice
);
assert_eq!(
"Upsample".parse::<OperatorType>().unwrap(),
OperatorType::Upsample
);
assert_eq!(
"MaxPool".parse::<OperatorType>().unwrap(),
OperatorType::MaxPool
);
assert_eq!(
"Softmax".parse::<OperatorType>().unwrap(),
OperatorType::Softmax
);
assert_eq!(
"NonMaxSuppression".parse::<OperatorType>().unwrap(),
OperatorType::NonMaxSuppression
);
}
#[test]
fn test_all_extended_operators_execute() {
let tensor_1d = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let tensor_4d = Tensor::zeros(&[1, 1, 2, 2]);
let attrs = HashMap::new();
assert!(execute_operator(&OperatorType::Concat, &[tensor_1d.clone()], &attrs).is_ok());
let mut slice_attrs = HashMap::new();
slice_attrs.insert("starts".to_string(), "0".to_string());
slice_attrs.insert("ends".to_string(), "1".to_string());
assert!(execute_operator(&OperatorType::Slice, &[tensor_1d.clone()], &slice_attrs).is_ok());
assert!(execute_operator(&OperatorType::Upsample, &[tensor_4d.clone()], &attrs).is_ok());
assert!(execute_operator(&OperatorType::MaxPool, &[tensor_4d.clone()], &attrs).is_ok());
assert!(execute_operator(&OperatorType::Softmax, &[tensor_1d.clone()], &attrs).is_ok());
assert!(execute_operator(
&OperatorType::NonMaxSuppression,
&[tensor_4d.clone(), tensor_1d.clone()],
&attrs
)
.is_ok());
}
#[test]
fn test_sub_op() {
let a = Tensor::from_array(Array1::from_vec(vec![4.0, 6.0, 8.0]));
let b = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let result = execute_operator(&OperatorType::Sub, &[a, b], &HashMap::new()).unwrap();
assert_eq!(result.len(), 1);
let data = result[0].data();
assert!((data[0] - 3.0).abs() < 1e-6);
assert!((data[1] - 4.0).abs() < 1e-6);
assert!((data[2] - 5.0).abs() < 1e-6);
}
#[test]
fn test_sub_op_wrong_inputs() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0]));
let result = execute_operator(&OperatorType::Sub, &[], &HashMap::new());
assert!(result.is_err());
let result = execute_operator(&OperatorType::Sub, &[a.clone()], &HashMap::new());
assert!(result.is_err());
let b = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let result = execute_operator(&OperatorType::Sub, &[a, b], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_div_op() {
let a = Tensor::from_array(Array1::from_vec(vec![8.0, 12.0, 16.0]));
let b = Tensor::from_array(Array1::from_vec(vec![2.0, 3.0, 4.0]));
let result = execute_operator(&OperatorType::Div, &[a, b], &HashMap::new()).unwrap();
assert_eq!(result.len(), 1);
let data = result[0].data();
assert!((data[0] - 4.0).abs() < 1e-6);
assert!((data[1] - 4.0).abs() < 1e-6);
assert!((data[2] - 4.0).abs() < 1e-6);
}
#[test]
fn test_div_op_wrong_inputs() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0]));
let result = execute_operator(&OperatorType::Div, &[], &HashMap::new());
assert!(result.is_err());
let result = execute_operator(&OperatorType::Div, &[a.clone()], &HashMap::new());
assert!(result.is_err());
let b = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let result = execute_operator(&OperatorType::Div, &[a, b], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_pow_op() {
let a = Tensor::from_array(Array1::from_vec(vec![2.0, 3.0, 4.0]));
let b = Tensor::from_array(Array1::from_vec(vec![2.0, 2.0, 2.0]));
let result = execute_operator(&OperatorType::Pow, &[a, b], &HashMap::new()).unwrap();
assert_eq!(result.len(), 1);
let data = result[0].data();
assert!((data[0] - 4.0).abs() < 1e-6);
assert!((data[1] - 9.0).abs() < 1e-6);
assert!((data[2] - 16.0).abs() < 1e-6);
}
#[test]
fn test_pow_op_wrong_inputs() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0]));
let result = execute_operator(&OperatorType::Pow, &[], &HashMap::new());
assert!(result.is_err());
let result = execute_operator(&OperatorType::Pow, &[a.clone()], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_sqrt_op() {
let a = Tensor::from_array(Array1::from_vec(vec![4.0, 9.0, 16.0]));
let result = execute_operator(&OperatorType::Sqrt, &[a], &HashMap::new()).unwrap();
assert_eq!(result.len(), 1);
let data = result[0].data();
assert!((data[0] - 2.0).abs() < 1e-6);
assert!((data[1] - 3.0).abs() < 1e-6);
assert!((data[2] - 4.0).abs() < 1e-6);
}
#[test]
fn test_sqrt_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Sqrt, &[], &HashMap::new());
assert!(result.is_err());
let a = Tensor::from_array(Array1::from_vec(vec![-1.0, 2.0]));
let result = execute_operator(&OperatorType::Sqrt, &[a], &HashMap::new());
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_exp_op() {
let a = Tensor::from_array(Array1::from_vec(vec![0.0, 1.0, 2.0]));
let result = execute_operator(&OperatorType::Exp, &[a], &HashMap::new()).unwrap();
assert_eq!(result.len(), 1);
let data = result[0].data();
assert!((data[0] - 1.0).abs() < 1e-6);
assert!((data[1] - std::f32::consts::E).abs() < 1e-6);
assert!((data[2] - (std::f32::consts::E * std::f32::consts::E)).abs() < 1e-5);
}
#[test]
fn test_exp_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Exp, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_cast_op() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.5, 3.7]));
let mut attrs = HashMap::new();
attrs.insert("to".to_string(), "1".to_string());
let result = execute_operator(&OperatorType::Cast, &[a], &attrs);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_cast_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Cast, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_constant_of_shape_op() {
let shape = Tensor::from_array(Array1::from_vec(vec![2.0, 3.0]));
let mut attrs = HashMap::new();
attrs.insert("value".to_string(), "5.0".to_string());
let result = execute_operator(&OperatorType::ConstantOfShape, &[shape], &attrs);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_constant_of_shape_op_wrong_inputs() {
let result = execute_operator(&OperatorType::ConstantOfShape, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_shape_op() {
let a = Tensor::from_array(
Array3::from_shape_vec((2, 3, 4), (0..24).map(|x| x as f32).collect()).unwrap(),
);
let result = execute_operator(&OperatorType::Shape, &[a], &HashMap::new());
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_shape_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Shape, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_split_op() {
let a = Tensor::from_array(
Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap(),
);
let mut attrs = HashMap::new();
attrs.insert("axis".to_string(), "1".to_string());
attrs.insert("split".to_string(), "2,2".to_string());
let result = execute_operator(&OperatorType::Split, &[a], &attrs);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_split_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Split, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_gather_op() {
let data = Tensor::from_array(
Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
.unwrap(),
);
let indices = Tensor::from_array(Array1::from_vec(vec![0.0, 2.0]));
let mut attrs = HashMap::new();
attrs.insert("axis".to_string(), "0".to_string());
let result = execute_operator(&OperatorType::Gather, &[data, indices], &attrs);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_gather_op_wrong_inputs() {
let data = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0]));
let result = execute_operator(&OperatorType::Gather, &[], &HashMap::new());
assert!(result.is_err());
let result = execute_operator(&OperatorType::Gather, &[data], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_unsqueeze_op() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let mut attrs = HashMap::new();
attrs.insert("axes".to_string(), "0".to_string());
let result = execute_operator(&OperatorType::Unsqueeze, &[a], &attrs);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_unsqueeze_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Unsqueeze, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_squeeze_op() {
let a = Tensor::from_array(Array3::from_shape_vec((1, 3, 1), vec![1.0, 2.0, 3.0]).unwrap());
let result = execute_operator(&OperatorType::Squeeze, &[a], &HashMap::new());
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_squeeze_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Squeeze, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_batch_normalization_op() {
let input = Tensor::from_array(
Array4::from_shape_vec((1, 2, 2, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.unwrap(),
);
let scale = Tensor::from_array(Array1::from_vec(vec![1.0, 1.0]));
let bias = Tensor::from_array(Array1::from_vec(vec![0.0, 0.0]));
let mean = Tensor::from_array(Array1::from_vec(vec![2.5, 6.5]));
let var = Tensor::from_array(Array1::from_vec(vec![1.25, 1.25]));
let result = execute_operator(
&OperatorType::BatchNormalization,
&[input, scale, bias, mean, var],
&HashMap::new(),
);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_batch_normalization_op_wrong_inputs() {
let input = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0]));
let result = execute_operator(&OperatorType::BatchNormalization, &[], &HashMap::new());
assert!(result.is_err());
let result = execute_operator(&OperatorType::BatchNormalization, &[input], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_pad_op() {
let a =
Tensor::from_array(Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap());
let mut attrs = HashMap::new();
attrs.insert("pads".to_string(), "1,1,1,1".to_string());
let result = execute_operator(&OperatorType::Pad, &[a], &attrs);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_pad_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Pad, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_reduce_mean_op() {
let a = Tensor::from_array(
Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(),
);
let mut attrs = HashMap::new();
attrs.insert("axes".to_string(), "1".to_string());
let result = execute_operator(&OperatorType::ReduceMean, &[a], &attrs);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_reduce_mean_op_wrong_inputs() {
let result = execute_operator(&OperatorType::ReduceMean, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_identity_op() {
let a = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]));
let result =
execute_operator(&OperatorType::Identity, &[a.clone()], &HashMap::new()).unwrap();
assert_eq!(result.len(), 1);
let data = result[0].data();
assert!((data[0] - 1.0).abs() < 1e-6);
assert!((data[1] - 2.0).abs() < 1e-6);
assert!((data[2] - 3.0).abs() < 1e-6);
assert!((data[3] - 4.0).abs() < 1e-6);
}
#[test]
fn test_identity_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Identity, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_resize_op() {
let a = Tensor::from_array(
Array4::from_shape_vec((1, 1, 2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
);
let mut attrs = HashMap::new();
attrs.insert("scales".to_string(), "1.0,1.0,2.0,2.0".to_string());
let result = execute_operator(&OperatorType::Resize, &[a], &attrs);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn test_resize_op_wrong_inputs() {
let result = execute_operator(&OperatorType::Resize, &[], &HashMap::new());
assert!(result.is_err());
}
#[test]
fn test_parse_int_array_empty_string() {
let result = parse_int_array("");
assert!(result.is_ok());
assert_eq!(result.unwrap(), Vec::<i64>::new());
}
#[test]
fn test_parse_int_array_whitespace_only() {
let result = parse_int_array(" ");
assert!(result.is_ok());
assert_eq!(result.unwrap(), Vec::<i64>::new());
}
#[test]
fn test_parse_int_array_bracketed() {
let result = parse_int_array("[1,2,3]");
assert!(result.is_ok());
assert_eq!(result.unwrap(), vec![1, 2, 3]);
}
#[test]
fn test_parse_int_array_with_spaces() {
let result = parse_int_array("1, 2, 3");
assert!(result.is_ok());
assert_eq!(result.unwrap(), vec![1, 2, 3]);
}
#[test]
fn test_parse_int_array_invalid() {
let result = parse_int_array("1,invalid,3");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Failed to parse integer"));
}
#[test]
fn test_conv_op_with_custom_strides() {
let input = Tensor::from_shape_vec(&[1, 2, 4, 4], vec![1.0; 32]).unwrap();
let kernel = Tensor::from_shape_vec(&[3, 2, 3, 3], vec![0.1; 54]).unwrap();
let inputs = vec![input, kernel];
let mut attrs = HashMap::new();
attrs.insert("strides".to_string(), "[2,2]".to_string());
attrs.insert("pads".to_string(), "[1,1,1,1]".to_string());
let result = execute_operator(&OperatorType::Conv, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_conv_op_with_bias() {
let input = Tensor::from_shape_vec(&[1, 2, 3, 3], vec![1.0; 18]).unwrap();
let kernel = Tensor::from_shape_vec(&[1, 2, 2, 2], vec![0.5; 8]).unwrap();
let bias = Tensor::from_shape_vec(&[1], vec![0.1]).unwrap();
let inputs = vec![input, kernel, bias];
let result = execute_operator(&OperatorType::Conv, &inputs, &HashMap::new());
assert!(result.is_ok());
}
#[test]
fn test_conv_op_default_strides_and_pads() {
let input = Tensor::from_shape_vec(&[1, 1, 3, 3], vec![1.0; 9]).unwrap();
let kernel = Tensor::from_shape_vec(&[1, 1, 2, 2], vec![0.25; 4]).unwrap();
let inputs = vec![input, kernel];
let result = execute_operator(&OperatorType::Conv, &inputs, &HashMap::new());
assert!(result.is_ok());
}
#[test]
fn test_conv_op_malformed_strides() {
let input = Tensor::from_shape_vec(&[1, 1, 3, 3], vec![1.0; 9]).unwrap();
let kernel = Tensor::from_shape_vec(&[1, 1, 2, 2], vec![0.25; 4]).unwrap();
let inputs = vec![input, kernel];
let mut attrs = HashMap::new();
attrs.insert("strides".to_string(), "invalid_stride".to_string());
let result = execute_operator(&OperatorType::Conv, &inputs, &attrs);
assert!(result.is_ok()); }
#[test]
fn test_conv_op_partial_pads() {
let input = Tensor::from_shape_vec(&[1, 1, 4, 4], vec![1.0; 16]).unwrap();
let kernel = Tensor::from_shape_vec(&[1, 1, 2, 2], vec![0.25; 4]).unwrap();
let inputs = vec![input, kernel];
let mut attrs = HashMap::new();
attrs.insert("pads".to_string(), "[1,2]".to_string());
let result = execute_operator(&OperatorType::Conv, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_conv_op_unsupported_bias_shape() {
let input = Tensor::from_shape_vec(&[1, 2, 3, 3], vec![1.0; 18]).unwrap();
let kernel = Tensor::from_shape_vec(&[2, 2, 2, 2], vec![0.5; 16]).unwrap();
let bias = Tensor::from_shape_vec(&[3, 3], vec![0.1; 9]).unwrap(); let inputs = vec![input, kernel, bias];
let result = execute_operator(&OperatorType::Conv, &inputs, &HashMap::new());
assert!(result.is_ok());
}
#[test]
fn test_conv_op_4d_bias() {
let input = Tensor::from_shape_vec(&[1, 1, 3, 3], vec![1.0; 9]).unwrap();
let kernel = Tensor::from_shape_vec(&[2, 1, 2, 2], vec![0.5; 8]).unwrap();
let bias = Tensor::from_shape_vec(&[1, 2, 1, 1], vec![0.1, 0.2]).unwrap(); let inputs = vec![input, kernel, bias];
let result = execute_operator(&OperatorType::Conv, &inputs, &HashMap::new());
assert!(result.is_ok());
}
#[test]
fn test_transpose_op_with_custom_perm() {
let tensor = Tensor::from_shape_vec(&[2, 3, 4], vec![1.0; 24]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("perm".to_string(), "[2,0,1]".to_string());
let result = execute_operator(&OperatorType::Transpose, &inputs, &attrs);
assert!(result.is_ok());
if let Ok(outputs) = result {
assert_eq!(outputs[0].shape(), &[4, 2, 3]);
}
}
#[test]
fn test_transpose_op_invalid_perm() {
let tensor = Tensor::from_shape_vec(&[2, 3], vec![1.0; 6]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("perm".to_string(), "[0,2]".to_string());
let result = execute_operator(&OperatorType::Transpose, &inputs, &attrs);
assert!(result.is_err());
}
#[test]
fn test_softmax_op_with_axis() {
let tensor = Tensor::from_shape_vec(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("axis".to_string(), "1".to_string());
let result = execute_operator(&OperatorType::Softmax, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_softmax_op_negative_axis() {
let tensor = Tensor::from_shape_vec(&[2, 3], vec![1.0; 6]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("axis".to_string(), "-1".to_string());
let result = execute_operator(&OperatorType::Softmax, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_gather_op_with_indices() {
let data = Tensor::from_shape_vec(&[3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let indices = Tensor::from_shape_vec(&[2], vec![0.0, 2.0]).unwrap();
let inputs = vec![data, indices];
let mut attrs = HashMap::new();
attrs.insert("axis".to_string(), "0".to_string());
let result = execute_operator(&OperatorType::Gather, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_gather_op_out_of_bounds_indices() {
let data = Tensor::from_shape_vec(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let indices = Tensor::from_shape_vec(&[1], vec![1.0]).unwrap(); let inputs = vec![data, indices];
let result = execute_operator(&OperatorType::Gather, &inputs, &HashMap::new());
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_split_op_with_custom_axis() {
let tensor = Tensor::from_shape_vec(&[4, 2], vec![1.0; 8]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("axis".to_string(), "0".to_string());
attrs.insert("split".to_string(), "[2,2]".to_string());
let result = execute_operator(&OperatorType::Split, &inputs, &attrs);
assert!(result.is_ok());
if let Ok(outputs) = result {
assert_eq!(outputs.len(), 2);
}
}
#[test]
fn test_split_op_uneven_split() {
let tensor = Tensor::from_shape_vec(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("split".to_string(), "[2,3]".to_string());
let result = execute_operator(&OperatorType::Split, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_pad_op_constant_mode() {
let tensor = Tensor::from_shape_vec(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let pads = Tensor::from_shape_vec(&[4], vec![1.0, 1.0, 1.0, 1.0]).unwrap(); let inputs = vec![tensor, pads];
let mut attrs = HashMap::new();
attrs.insert("mode".to_string(), "constant".to_string());
let result = execute_operator(&OperatorType::Pad, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_pad_op_reflect_mode() {
let tensor = Tensor::from_shape_vec(&[3, 3], vec![1.0; 9]).unwrap();
let pads = Tensor::from_shape_vec(&[4], vec![1.0, 1.0, 0.0, 0.0]).unwrap(); let inputs = vec![tensor, pads];
let mut attrs = HashMap::new();
attrs.insert("mode".to_string(), "reflect".to_string());
let result = execute_operator(&OperatorType::Pad, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_cast_op_different_types() {
let tensor = Tensor::from_shape_vec(&[3], vec![1.5, 2.7, 3.9]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("to".to_string(), "int32".to_string());
let result = execute_operator(&OperatorType::Cast, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_cast_op_unsupported_type() {
let tensor = Tensor::from_shape_vec(&[2], vec![1.0, 2.0]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("to".to_string(), "complex128".to_string());
let result = execute_operator(&OperatorType::Cast, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_unsqueeze_op_multiple_axes() {
let tensor = Tensor::from_shape_vec(&[2, 3], vec![1.0; 6]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("axes".to_string(), "[0,2]".to_string());
let result = execute_operator(&OperatorType::Unsqueeze, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_squeeze_op_specific_axes() {
let tensor = Tensor::from_shape_vec(&[1, 3, 1], vec![1.0, 2.0, 3.0]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("axes".to_string(), "[0,2]".to_string());
let result = execute_operator(&OperatorType::Squeeze, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_constant_of_shape_with_value() {
let shape_tensor = Tensor::from_shape_vec(&[2], vec![3.0, 4.0]).unwrap();
let inputs = vec![shape_tensor];
let mut attrs = HashMap::new();
attrs.insert("value".to_string(), "5.0".to_string());
let result = execute_operator(&OperatorType::ConstantOfShape, &inputs, &attrs);
assert!(result.is_ok());
if let Ok(outputs) = result {
assert_eq!(outputs[0].shape(), &[3, 4]);
}
}
#[test]
fn test_maxpool_op_with_strides_and_pads() {
let tensor = Tensor::from_shape_vec(&[1, 1, 4, 4], vec![1.0; 16]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("kernel_shape".to_string(), "[2,2]".to_string());
attrs.insert("strides".to_string(), "[2,2]".to_string());
attrs.insert("pads".to_string(), "[1,1,1,1]".to_string());
let result = execute_operator(&OperatorType::MaxPool, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_nms_op_with_all_parameters() {
let boxes = Tensor::from_shape_vec(
&[4, 4],
vec![
0.0, 0.0, 1.0, 1.0, 0.1, 0.1, 1.1, 1.1, 2.0, 2.0, 3.0, 3.0, 2.1, 2.1, 3.1, 3.1,
],
)
.unwrap();
let scores = Tensor::from_shape_vec(&[4], vec![0.9, 0.8, 0.7, 0.6]).unwrap();
let inputs = vec![boxes, scores];
let mut attrs = HashMap::new();
attrs.insert("iou_threshold".to_string(), "0.5".to_string());
attrs.insert("score_threshold".to_string(), "0.1".to_string());
let result = execute_operator(&OperatorType::NonMaxSuppression, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_batch_norm_op_all_parameters() {
let input = Tensor::from_shape_vec(&[1, 2, 2, 2], vec![1.0; 8]).unwrap();
let scale = Tensor::from_shape_vec(&[2], vec![1.0, 1.0]).unwrap();
let bias = Tensor::from_shape_vec(&[2], vec![0.0, 0.0]).unwrap();
let mean = Tensor::from_shape_vec(&[2], vec![0.5, 0.5]).unwrap();
let variance = Tensor::from_shape_vec(&[2], vec![0.25, 0.25]).unwrap();
let inputs = vec![input, scale, bias, mean, variance];
let result = execute_operator(&OperatorType::BatchNormalization, &inputs, &HashMap::new());
assert!(result.is_ok());
}
#[test]
fn test_resize_op_with_scales() {
let tensor = Tensor::from_shape_vec(&[1, 1, 2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("scales".to_string(), "[1.0,1.0,2.0,2.0]".to_string());
attrs.insert("mode".to_string(), "nearest".to_string());
let result = execute_operator(&OperatorType::Resize, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_reduce_mean_op_with_axes() {
let tensor = Tensor::from_shape_vec(&[2, 3, 4], vec![1.0; 24]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("axes".to_string(), "[1,2]".to_string());
attrs.insert("keepdims".to_string(), "true".to_string());
let result = execute_operator(&OperatorType::ReduceMean, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_slice_op_with_steps() {
let tensor =
Tensor::from_shape_vec(&[8], vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).unwrap();
let inputs = vec![tensor];
let mut attrs = HashMap::new();
attrs.insert("starts".to_string(), "0".to_string());
attrs.insert("ends".to_string(), "8".to_string());
attrs.insert("steps".to_string(), "2".to_string());
let result = execute_operator(&OperatorType::Slice, &inputs, &attrs);
assert!(result.is_ok());
}
#[test]
fn test_all_extended_operators_basic_execution() {
let test_operators = vec![
OperatorType::Div,
OperatorType::Sub,
OperatorType::Exp,
OperatorType::Sqrt,
OperatorType::Pow,
];
for op in test_operators {
match op {
OperatorType::Div | OperatorType::Sub | OperatorType::Pow => {
let tensor1 = Tensor::from_shape_vec(&[2], vec![4.0, 9.0]).unwrap();
let tensor2 = Tensor::from_shape_vec(&[2], vec![2.0, 3.0]).unwrap();
let inputs = vec![tensor1, tensor2];
let result = execute_operator(&op, &inputs, &HashMap::new());
assert!(result.is_ok(), "Failed for operator {op:?}");
}
OperatorType::Exp | OperatorType::Sqrt => {
let tensor = Tensor::from_shape_vec(&[2], vec![1.0, 4.0]).unwrap();
let inputs = vec![tensor];
let result = execute_operator(&op, &inputs, &HashMap::new());
assert!(result.is_ok(), "Failed for operator {op:?}");
}
_ => {}
}
}
}
}