webnn-graph 0.3.0

Simple DSL for WebNN graphs
Documentation
// Shape operation evaluator
// Extracts the shape of a tensor as a 1D int64 tensor

use crate::onnx::constant_folding::{
    ConstantEvaluator as EvaluatorTrait, ConstantFoldingContext, ConstantTensor, TensorData,
};
use crate::onnx::convert::OnnxError;
use crate::protos::onnx::{NodeProto, TensorProto_DataType};

pub struct ShapeEvaluator;

impl EvaluatorTrait for ShapeEvaluator {
    fn op_type(&self) -> &str {
        "Shape"
    }

    fn can_evaluate(&self, node: &NodeProto, ctx: &ConstantFoldingContext) -> bool {
        if node.op_type.as_str() != "Shape" {
            return false;
        }

        // Shape operation requires that we know the input's shape
        // The input doesn't need to be a constant, but we need its shape metadata
        if let Some(input_name) = node.input.as_slice().first() {
            // Check if we have this as a constant (which includes shape info)
            if ctx.is_constant(input_name.as_str()) {
                return true;
            }
        }

        false
    }

    fn evaluate(
        &self,
        node: &NodeProto,
        ctx: &ConstantFoldingContext,
    ) -> Result<Vec<ConstantTensor>, OnnxError> {
        let input_name =
            node.input
                .as_slice()
                .first()
                .ok_or_else(|| OnnxError::MissingAttribute {
                    attr: "input".to_string(),
                    op: "Shape".to_string(),
                })?;

        let input_tensor = ctx.get_constant(input_name.as_str()).ok_or_else(|| {
            OnnxError::ShapeInference(format!(
                "Input tensor '{}' not found in constants",
                input_name
            ))
        })?;

        // Extract shape as int64 vector
        let shape_values: Vec<i64> = input_tensor.shape.clone();

        // Create output tensor (1D int64 array)
        let output = ConstantTensor {
            data: TensorData::Int64(shape_values.clone()),
            shape: vec![shape_values.len() as i64],
            data_type: TensorProto_DataType::Int64.into(),
        };

        Ok(vec![output])
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::protos::onnx::TensorProto;
    use std::collections::HashMap;

    #[test]
    fn test_shape_evaluator() {
        // Create a test tensor with shape [2, 3, 4]
        let tensor = TensorProto {
            name: "test_input".to_string(),
            data_type: TensorProto_DataType::Float.into(),
            dims: vec![2, 3, 4],
            raw_data: vec![0u8; 4 * 2 * 3 * 4], // 24 floats
            ..Default::default()
        };

        // We need to leak the tensor to get a 'static reference for the test
        // In production code, the model owns the tensors
        let leaked_tensor: &'static TensorProto = Box::leak(Box::new(tensor));

        let mut init_map = HashMap::new();
        init_map.insert("test_input".to_string(), leaked_tensor);

        let ctx = ConstantFoldingContext::new(&init_map).unwrap();
        let evaluator = ShapeEvaluator;

        // Create a Shape node
        let node = NodeProto {
            op_type: "Shape".to_string(),
            input: vec!["test_input".to_string()],
            output: vec!["test_output".to_string()],
            ..Default::default()
        };

        // Check can_evaluate
        assert!(evaluator.can_evaluate(&node, &ctx));

        // Evaluate
        let result = evaluator.evaluate(&node, &ctx).unwrap();
        assert_eq!(result.len(), 1);

        let output = &result[0];
        assert_eq!(output.shape, vec![3]); // Output is 1D with 3 elements
        assert_eq!(output.data_type, TensorProto_DataType::Int64 as i32);

        if let TensorData::Int64(ref values) = output.data {
            assert_eq!(values, &vec![2, 3, 4]);
        } else {
            panic!("Expected Int64 data");
        }
    }
}