use crate::onnx::constant_folding::{
ConstantEvaluator as EvaluatorTrait, ConstantFoldingContext, ConstantTensor, TensorData,
};
use crate::onnx::convert::OnnxError;
use crate::protos::onnx::{NodeProto, TensorProto_DataType};
pub struct ShapeEvaluator;
impl EvaluatorTrait for ShapeEvaluator {
fn op_type(&self) -> &str {
"Shape"
}
fn can_evaluate(&self, node: &NodeProto, ctx: &ConstantFoldingContext) -> bool {
if node.op_type.as_str() != "Shape" {
return false;
}
if let Some(input_name) = node.input.as_slice().first() {
if ctx.is_constant(input_name.as_str()) {
return true;
}
}
false
}
fn evaluate(
&self,
node: &NodeProto,
ctx: &ConstantFoldingContext,
) -> Result<Vec<ConstantTensor>, OnnxError> {
let input_name =
node.input
.as_slice()
.first()
.ok_or_else(|| OnnxError::MissingAttribute {
attr: "input".to_string(),
op: "Shape".to_string(),
})?;
let input_tensor = ctx.get_constant(input_name.as_str()).ok_or_else(|| {
OnnxError::ShapeInference(format!(
"Input tensor '{}' not found in constants",
input_name
))
})?;
let shape_values: Vec<i64> = input_tensor.shape.clone();
let output = ConstantTensor {
data: TensorData::Int64(shape_values.clone()),
shape: vec![shape_values.len() as i64],
data_type: TensorProto_DataType::Int64.into(),
};
Ok(vec![output])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protos::onnx::TensorProto;
use std::collections::HashMap;
#[test]
fn test_shape_evaluator() {
let tensor = TensorProto {
name: "test_input".to_string(),
data_type: TensorProto_DataType::Float.into(),
dims: vec![2, 3, 4],
raw_data: vec![0u8; 4 * 2 * 3 * 4], ..Default::default()
};
let leaked_tensor: &'static TensorProto = Box::leak(Box::new(tensor));
let mut init_map = HashMap::new();
init_map.insert("test_input".to_string(), leaked_tensor);
let ctx = ConstantFoldingContext::new(&init_map).unwrap();
let evaluator = ShapeEvaluator;
let node = NodeProto {
op_type: "Shape".to_string(),
input: vec!["test_input".to_string()],
output: vec!["test_output".to_string()],
..Default::default()
};
assert!(evaluator.can_evaluate(&node, &ctx));
let result = evaluator.evaluate(&node, &ctx).unwrap();
assert_eq!(result.len(), 1);
let output = &result[0];
assert_eq!(output.shape, vec![3]); assert_eq!(output.data_type, TensorProto_DataType::Int64 as i32);
if let TensorData::Int64(ref values) = output.data {
assert_eq!(values, &vec![2, 3, 4]);
} else {
panic!("Expected Int64 data");
}
}
}