p2o 0.1.1

A PaddlePaddle New IR (PIR) to ONNX model converter.
Documentation
use crate::converter::Converter;
use crate::helper;
use crate::proto::onnx;
use anyhow::bail;
use serde_json::Value;

impl Converter {
    pub(crate) fn op_batch_norm(&mut self, op: &Value) -> anyhow::Result<()> {
        if matches!(
            helper::attr(op, "is_test").and_then(|d| d.as_bool()),
            Some(false)
        ) {
            bail!("batch_norm_ only supports inference-style lowering with is_test=true");
        }
        let data_format = helper::attr(op, "data_format")
            .and_then(|d| d.as_str())
            .unwrap_or("NCHW");
        if data_format != "NCHW" {
            bail!("batch_norm_ only supports NCHW");
        }

        let inputs = helper::op_input_ids(op);
        if inputs.len() < 5 {
            bail!("batch_norm_ missing inputs");
        }
        let mut onnx_node = onnx::NodeProto {
            op_type: "BatchNormalization".to_string(),
            input: vec![
                self.get_tensor_name(inputs[0])?,
                self.get_tensor_name(inputs[3])?,
                self.get_tensor_name(inputs[4])?,
                self.get_tensor_name(inputs[1])?,
                self.get_tensor_name(inputs[2])?,
            ],
            ..Default::default()
        };
        if let Some(attrs) = op.get("A") {
            onnx_node.attribute = self.extract_attributes("batch_norm_", attrs);
        }
        let out_id = helper::op_out_id(op)?;
        onnx_node.output.push(self.get_tensor_name(out_id)?);
        self.onnx_graph.node.push(onnx_node);
        Ok(())
    }

    pub fn op_group_norm(&mut self, op: &Value) -> anyhow::Result<()> {
        self.require_opset(6, "group_norm")?;

        let inputs = helper::op_input_ids(op);
        if inputs.len() < 3 {
            bail!("group_norm missing inputs");
        }
        let outputs = op
            .get("O")
            .and_then(|o| o.as_array())
            .ok_or_else(|| anyhow::anyhow!("group_norm missing outputs"))?;
        let out_id = outputs[0]
            .get("%")
            .and_then(|id| id.as_i64())
            .ok_or_else(|| anyhow::anyhow!("group_norm missing output id"))?;
        let groups = helper::attr(op, "groups")
            .and_then(|d| d.as_i64())
            .ok_or_else(|| anyhow::anyhow!("group_norm missing groups"))?;
        let epsilon = helper::attr(op, "epsilon")
            .and_then(|d| d.as_f64())
            .unwrap_or(1e-5) as f32;
        let data_format = helper::attr(op, "data_format")
            .and_then(|d| d.as_str())
            .unwrap_or("NCHW");
        if data_format != "NCHW" {
            bail!("group_norm only supports NCHW");
        }

        let input_id = inputs[0];
        let input_rank = self
            .state
            .tensor_shapes
            .get(&input_id)
            .map(|shape| shape.len())
            .ok_or_else(|| anyhow::anyhow!("group_norm missing input rank metadata"))?;
        if input_rank < 3 {
            bail!("group_norm expects rank >= 3, got {}", input_rank);
        }

        let input_name = self.get_tensor_name(input_id)?;
        let input_shape_name = format!("group_norm_input_shape_{}", out_id);
        self.onnx_graph.node.push(onnx::NodeProto {
            op_type: "Shape".to_string(),
            input: vec![input_name.clone()],
            output: vec![input_shape_name.clone()],
            ..Default::default()
        });

        let batch_dim_name = format!("group_norm_batch_dim_{}", out_id);
        self.add_slice_node(
            input_shape_name.clone(),
            batch_dim_name.clone(),
            &[0],
            &[1],
            Some(&[0]),
            None,
            &format!("group_norm_batch_dim_{}", out_id),
        )?;

        let groups_name = format!("group_norm_groups_{}", out_id);
        self.push_i64_initializer(groups_name.clone(), vec![1], &[groups]);
        let minus_one_name = format!("group_norm_minus_one_{}", out_id);
        self.push_i64_initializer(minus_one_name.clone(), vec![1], &[-1]);

        let reshape_shape_name = format!("group_norm_reshape_shape_{}", out_id);
        let mut reshape_shape = onnx::NodeProto {
            op_type: "Concat".to_string(),
            input: vec![batch_dim_name, groups_name, minus_one_name],
            output: vec![reshape_shape_name.clone()],
            ..Default::default()
        };
        reshape_shape.attribute.push(helper::attr_int("axis", 0));
        self.onnx_graph.node.push(reshape_shape);

        let reshaped_name = format!("group_norm_reshaped_{}", out_id);
        self.onnx_graph.node.push(onnx::NodeProto {
            op_type: "Reshape".to_string(),
            input: vec![input_name, reshape_shape_name],
            output: vec![reshaped_name.clone()],
            ..Default::default()
        });

        let scale_name = format!("group_norm_scale_{}", out_id);
        let bias_name = format!("group_norm_bias_{}", out_id);
        self.push_f32_initializer(
            scale_name.clone(),
            vec![groups],
            &vec![1.0; groups as usize],
        );
        self.push_f32_initializer(bias_name.clone(), vec![groups], &vec![0.0; groups as usize]);

        let normalized_name = format!("group_norm_normalized_{}", out_id);
        let mut instance_norm = onnx::NodeProto {
            op_type: "InstanceNormalization".to_string(),
            input: vec![reshaped_name, scale_name, bias_name],
            output: vec![normalized_name.clone()],
            ..Default::default()
        };
        instance_norm
            .attribute
            .push(helper::attr_float("epsilon", epsilon));
        self.onnx_graph.node.push(instance_norm);

        let restored_name = format!("group_norm_restored_{}", out_id);
        self.onnx_graph.node.push(onnx::NodeProto {
            op_type: "Reshape".to_string(),
            input: vec![normalized_name, input_shape_name],
            output: vec![restored_name.clone()],
            ..Default::default()
        });

        let affine_axes = std::iter::once(0)
            .chain(2..input_rank as i64)
            .collect::<Vec<_>>();
        let scale_broadcast_name = format!("group_norm_affine_scale_{}", out_id);
        let bias_broadcast_name = format!("group_norm_affine_bias_{}", out_id);
        self.add_unsqueeze_node(
            self.get_tensor_name(inputs[1])?,
            scale_broadcast_name.clone(),
            &affine_axes,
            format!("group_norm_affine_scale_axes_{}", out_id),
        );
        self.add_unsqueeze_node(
            self.get_tensor_name(inputs[2])?,
            bias_broadcast_name.clone(),
            &affine_axes,
            format!("group_norm_affine_bias_axes_{}", out_id),
        );

        let scaled_name = format!("group_norm_scaled_{}", out_id);
        self.add_binary_node(
            "Mul",
            restored_name,
            scale_broadcast_name,
            scaled_name.clone(),
        );
        self.add_binary_node(
            "Add",
            scaled_name,
            bias_broadcast_name,
            self.get_tensor_name(out_id)?,
        );
        Ok(())
    }
}