p2o 0.1.1

A PaddlePaddle New IR (PIR) to ONNX model converter.
Documentation
use crate::converter::Converter;
use crate::helper::{self, dt};
use crate::proto::onnx;
use anyhow::bail;
use serde_json::Value;

impl Converter {
    pub fn op_flip(&mut self, op: &Value) -> anyhow::Result<()> {
        let out_id = helper::op_out_id(op)?;
        let axes = helper::attr(op, "axis")
            .and_then(|d| d.as_array())
            .map(|arr| {
                arr.iter()
                    .filter_map(|value| {
                        value
                            .get("D")
                            .and_then(|v| v.as_i64())
                            .or_else(|| value.as_i64())
                    })
                    .collect::<Vec<_>>()
            })
            .unwrap_or_default();
        let input_id = *helper::op_input_ids(op)
            .first()
            .ok_or_else(|| anyhow::anyhow!("flip missing input"))?;
        let input_shape = self
            .state
            .tensor_shapes
            .get(&input_id)
            .cloned()
            .or_else(|| {
                op.get("O")
                    .and_then(|o| o.as_array())
                    .and_then(|o| o.first())
                    .and_then(|o| o.get("TT"))
                    .and_then(|tt| tt.get("D"))
                    .and_then(|d| d.as_array())
                    .and_then(|tt| tt.get(1))
                    .and_then(|shape| shape.as_array())
                    .map(|dims| {
                        dims.iter()
                            .map(|dim| dim.as_i64().unwrap_or(-1))
                            .collect::<Vec<_>>()
                    })
            })
            .unwrap_or_default();
        let rank = input_shape.len() as i64;
        let normalized_axes = if axes.is_empty() {
            vec![0]
        } else {
            axes.into_iter()
                .map(|axis| if axis < 0 { axis + rank } else { axis })
                .collect::<Vec<_>>()
        };
        let axis_count = normalized_axes.len();
        let mut current_input = self.get_tensor_name(input_id)?;
        for (axis_index, axis) in normalized_axes.into_iter().enumerate() {
            let axis_len = input_shape.get(axis as usize).copied().unwrap_or(-1);
            if axis_len <= 0 {
                bail!("flip requires a static positive axis length");
            }

            let indices_name = format!("flip_indices_{}_{}", out_id, axis_index);
            let mut indices_tensor = onnx::TensorProto {
                name: indices_name.clone(),
                dims: vec![axis_len],
                data_type: dt::INT64,
                ..Default::default()
            };
            for idx in (0..axis_len).rev() {
                indices_tensor
                    .raw_data
                    .extend_from_slice(&idx.to_le_bytes());
            }
            self.onnx_graph.initializer.push(indices_tensor);

            let output_name = if axis_index + 1 == axis_count {
                self.get_tensor_name(out_id)?
            } else {
                format!("flip_axis_{}_{}", out_id, axis_index)
            };
            let mut node = onnx::NodeProto {
                op_type: "Gather".to_string(),
                input: vec![current_input, indices_name],
                output: vec![output_name.clone()],
                ..Default::default()
            };
            node.attribute.push(helper::attr_int("axis", axis));
            self.onnx_graph.node.push(node);
            current_input = output_name;
        }
        Ok(())
    }

    pub fn op_cumsum(&mut self, op: &Value) -> anyhow::Result<()> {
        self.require_opset(11, "cumsum")?;
        if helper::attr(op, "flatten")
            .and_then(|d| d.as_bool())
            .unwrap_or(false)
        {
            bail!("cumsum with flatten=true is not supported");
        }

        let out_id = helper::op_out_id(op)?;
        let inputs = helper::op_input_ids(op);
        if inputs.len() < 2 {
            bail!("cumsum missing inputs");
        }
        let mut node = onnx::NodeProto {
            op_type: "CumSum".to_string(),
            input: vec![
                self.get_tensor_name(inputs[0])?,
                self.get_tensor_name(inputs[1])?,
            ],
            output: vec![self.get_tensor_name(out_id)?],
            ..Default::default()
        };
        for (name, enabled) in [
            (
                "exclusive",
                helper::attr(op, "exclusive")
                    .and_then(|d| d.as_bool())
                    .unwrap_or(false),
            ),
            (
                "reverse",
                helper::attr(op, "reverse")
                    .and_then(|d| d.as_bool())
                    .unwrap_or(false),
            ),
        ] {
            node.attribute
                .push(helper::attr_int(name, i64::from(enabled)));
        }
        self.onnx_graph.node.push(node);
        Ok(())
    }

    pub fn op_einsum(&mut self, op: &Value) -> anyhow::Result<()> {
        self.require_opset(12, "einsum")?;
        let out_id = helper::op_out_id(op)?;
        let inputs = helper::op_input_ids(op);
        if inputs.is_empty() {
            bail!("einsum missing inputs");
        }

        let mut node_inputs = Vec::new();
        if let Some(expanded) = self.state.combines.get(&inputs[0]) {
            for &input_id in expanded {
                node_inputs.push(self.get_tensor_name(input_id)?);
            }
        } else {
            node_inputs.push(self.get_tensor_name(inputs[0])?);
        }

        let equation = helper::attr(op, "equation")
            .and_then(|d| d.as_str())
            .ok_or_else(|| anyhow::anyhow!("einsum missing equation"))?;

        let mut node = onnx::NodeProto {
            op_type: "Einsum".to_string(),
            input: node_inputs,
            output: vec![self.get_tensor_name(out_id)?],
            ..Default::default()
        };
        node.attribute.push(helper::attr_str("equation", equation));
        self.onnx_graph.node.push(node);
        Ok(())
    }

    pub fn op_meshgrid(&mut self, op: &Value) -> anyhow::Result<()> {
        let out_id = helper::op_out_id(op)?;
        let inputs = helper::op_input_ids(op);
        if inputs.is_empty() {
            bail!("meshgrid missing inputs");
        }
        let input_ids = self
            .state
            .combines
            .get(&inputs[0])
            .cloned()
            .ok_or_else(|| {
                anyhow::anyhow!("meshgrid missing 0.combine metadata for {}", inputs[0])
            })?;
        let output_ids =
            self.state.splits.get(&out_id).cloned().ok_or_else(|| {
                anyhow::anyhow!("meshgrid missing 0.split metadata for {}", out_id)
            })?;
        if input_ids.len() != output_ids.len() {
            bail!(
                "meshgrid expects matching input/output counts, got {} inputs and {} outputs",
                input_ids.len(),
                output_ids.len()
            );
        }
        if input_ids.is_empty() {
            bail!("meshgrid requires at least one input");
        }

        let grid_shape = format!("meshgrid_shape_{}", out_id);
        let mut shape_inputs = Vec::with_capacity(input_ids.len());
        for (index, input_id) in input_ids.iter().enumerate() {
            let shape_name = format!("meshgrid_shape_part_{}_{}", out_id, index);
            self.onnx_graph.node.push(onnx::NodeProto {
                op_type: "Shape".to_string(),
                input: vec![self.get_tensor_name(*input_id)?],
                output: vec![shape_name.clone()],
                ..Default::default()
            });
            shape_inputs.push(shape_name);
        }
        let mut concat = onnx::NodeProto {
            op_type: "Concat".to_string(),
            input: shape_inputs,
            output: vec![grid_shape.clone()],
            ..Default::default()
        };
        concat.attribute.push(helper::attr_int("axis", 0));
        self.onnx_graph.node.push(concat);

        let rank = input_ids.len() as i64;
        for (index, (&input_id, &output_id)) in input_ids.iter().zip(output_ids.iter()).enumerate()
        {
            let unsqueezed = format!("meshgrid_unsqueezed_{}_{}", out_id, index);
            let axes = (0..rank)
                .filter(|&axis| axis != index as i64)
                .collect::<Vec<_>>();
            self.add_unsqueeze_node(
                self.get_tensor_name(input_id)?,
                unsqueezed.clone(),
                &axes,
                format!("meshgrid_axes_{}_{}", out_id, index),
            );
            self.onnx_graph.node.push(onnx::NodeProto {
                op_type: "Expand".to_string(),
                input: vec![unsqueezed, grid_shape.clone()],
                output: vec![self.get_tensor_name(output_id)?],
                ..Default::default()
            });
        }
        Ok(())
    }
}