p2o 0.1.1

A PaddlePaddle New IR (PIR) to ONNX model converter.
Documentation
use crate::converter::Converter;
use crate::helper::{self, dt};
use crate::proto::onnx;
use anyhow::bail;
use serde_json::Value;

impl Converter {
    pub fn op_gather(&mut self, op: &Value) -> anyhow::Result<()> {
        let out_id = helper::op_out_id(op)?;
        let inputs = helper::op_input_ids(op);
        if inputs.len() < 2 {
            bail!("gather missing inputs");
        }
        let axis = if let Some(&axis_id) = inputs.get(2) {
            self.state
                .constants
                .get(&axis_id)
                .and_then(|vals| vals.first())
                .copied()
                .ok_or_else(|| anyhow::anyhow!("gather requires constant axis input"))?
                as i64
        } else {
            0
        };
        let mut node = onnx::NodeProto {
            op_type: "Gather".to_string(),
            input: vec![
                self.get_tensor_name(inputs[0])?,
                self.get_tensor_name(inputs[1])?,
            ],
            output: vec![self.get_tensor_name(out_id)?],
            ..Default::default()
        };
        node.attribute.push(helper::attr_int("axis", axis));
        self.onnx_graph.node.push(node);
        Ok(())
    }

    pub fn op_index_select(&mut self, op: &Value) -> anyhow::Result<()> {
        let out_id = helper::op_out_id(op)?;
        let inputs = helper::op_input_ids(op);
        if inputs.len() < 2 {
            bail!("index_select missing inputs");
        }
        let axis = helper::attr(op, "axis")
            .and_then(|d| d.as_i64())
            .unwrap_or(0);
        let mut node = onnx::NodeProto {
            op_type: "Gather".to_string(),
            input: vec![
                self.get_tensor_name(inputs[0])?,
                self.get_tensor_name(inputs[1])?,
            ],
            output: vec![self.get_tensor_name(out_id)?],
            ..Default::default()
        };
        node.attribute.push(helper::attr_int("axis", axis));
        self.onnx_graph.node.push(node);
        Ok(())
    }

    pub fn op_embedding(&mut self, op: &Value) -> anyhow::Result<()> {
        let out_id = helper::op_out_id(op)?;
        let inputs = helper::op_input_ids(op);
        if inputs.len() < 2 {
            bail!("embedding missing inputs");
        }
        let mut node = onnx::NodeProto {
            op_type: "Gather".to_string(),
            input: vec![
                self.get_tensor_name(inputs[1])?,
                self.get_tensor_name(inputs[0])?,
            ],
            output: vec![self.get_tensor_name(out_id)?],
            ..Default::default()
        };
        node.attribute.push(helper::attr_int("axis", 0));
        self.onnx_graph.node.push(node);
        Ok(())
    }

    pub fn op_eye(&mut self, op: &Value) -> anyhow::Result<()> {
        self.require_opset(11, "eye")?;
        let out_id = helper::op_out_id(op)?;
        let inputs = helper::op_input_ids(op);
        if inputs.len() < 2 {
            bail!("eye missing inputs");
        }

        let rows_name = self.ensure_scalar_i64_input(inputs[0], out_id, "eye_rows")?;
        let cols_name = self.ensure_scalar_i64_input(inputs[1], out_id, "eye_cols")?;

        let zero_name = format!("eye_zero_{}", out_id);
        let one_name = format!("eye_one_{}", out_id);
        for (name, value) in [(zero_name.clone(), 0_i64), (one_name.clone(), 1_i64)] {
            let mut tensor = onnx::TensorProto {
                name,
                dims: vec![],
                data_type: dt::INT64,
                ..Default::default()
            };
            tensor.raw_data.extend_from_slice(&value.to_le_bytes());
            self.onnx_graph.initializer.push(tensor);
        }

        let rows_range = format!("eye_rows_range_{}", out_id);
        let cols_range = format!("eye_cols_range_{}", out_id);
        self.onnx_graph.node.push(onnx::NodeProto {
            op_type: "Range".to_string(),
            input: vec![zero_name.clone(), rows_name, one_name.clone()],
            output: vec![rows_range.clone()],
            ..Default::default()
        });
        self.onnx_graph.node.push(onnx::NodeProto {
            op_type: "Range".to_string(),
            input: vec![zero_name, cols_name, one_name],
            output: vec![cols_range.clone()],
            ..Default::default()
        });

        let rows_unsqueezed = format!("eye_rows_unsqueezed_{}", out_id);
        let cols_unsqueezed = format!("eye_cols_unsqueezed_{}", out_id);
        self.add_unsqueeze_node(
            rows_range,
            rows_unsqueezed.clone(),
            &[1],
            format!("eye_rows_axes_{}", out_id),
        );
        self.add_unsqueeze_node(
            cols_range,
            cols_unsqueezed.clone(),
            &[0],
            format!("eye_cols_axes_{}", out_id),
        );

        let equal_name = format!("eye_equal_{}", out_id);
        self.add_binary_node(
            "Equal",
            rows_unsqueezed,
            cols_unsqueezed,
            equal_name.clone(),
        );

        let to = helper::attr(op, "dtype")
            .and_then(|d| d.as_str())
            .and_then(helper::paddle_dtype_to_onnx)
            .unwrap_or(dt::FLOAT);
        self.add_cast_node(equal_name, self.get_tensor_name(out_id)?, to);
        Ok(())
    }

    pub fn op_take_along_axis(&mut self, op: &Value) -> anyhow::Result<()> {
        let out_id = helper::op_out_id(op)?;
        let inputs = helper::op_input_ids(op);
        if inputs.len() < 2 {
            bail!("take_along_axis missing inputs");
        }
        let mut node = onnx::NodeProto {
            op_type: "GatherElements".to_string(),
            input: vec![
                self.get_tensor_name(inputs[0])?,
                self.get_tensor_name(inputs[1])?,
            ],
            output: vec![self.get_tensor_name(out_id)?],
            ..Default::default()
        };
        node.attribute.push(helper::attr_int(
            "axis",
            helper::attr(op, "axis")
                .and_then(|d| d.as_i64())
                .unwrap_or(0),
        ));
        self.onnx_graph.node.push(node);
        Ok(())
    }

    pub fn op_not_equal(&mut self, op: &Value) -> anyhow::Result<()> {
        let out_id = helper::op_out_id(op)?;
        let inputs = helper::op_input_ids(op);
        if inputs.len() < 2 {
            bail!("not_equal missing inputs");
        }
        let equal_out = format!("not_equal_eq_{}", out_id);
        self.add_binary_node(
            "Equal",
            self.get_tensor_name(inputs[0])?,
            self.get_tensor_name(inputs[1])?,
            equal_out.clone(),
        );
        self.onnx_graph.node.push(onnx::NodeProto {
            op_type: "Not".to_string(),
            input: vec![equal_out],
            output: vec![self.get_tensor_name(out_id)?],
            ..Default::default()
        });
        Ok(())
    }
}