p2o 0.1.1

A PaddlePaddle New IR (PIR) to ONNX model converter.
Documentation
use serde_json::Value;

use crate::helper::{self, paddle_tt};
use crate::proto::onnx;

impl super::Converter {
    pub fn onnx_elem_type_from_paddle(&self, elem_type_str: &str) -> anyhow::Result<i32> {
        helper::paddle_elem_type_to_onnx(elem_type_str)
            .ok_or_else(|| anyhow::anyhow!("Unsupported Paddle element type: {}", elem_type_str))
    }

    pub fn build_value_info_from_meta(
        &self,
        name: String,
        elem_type_str: &str,
        dims: &[i64],
    ) -> anyhow::Result<onnx::ValueInfoProto> {
        let mut vi = onnx::ValueInfoProto {
            name,
            ..Default::default()
        };

        let mut shape_proto = onnx::TensorShapeProto { dim: vec![] };
        for &dim in dims {
            let mut dim_proto = onnx::tensor_shape_proto::Dimension {
                denotation: "".to_string(),
                ..Default::default()
            };
            if dim >= 0 {
                dim_proto.value = Some(onnx::tensor_shape_proto::dimension::Value::DimValue(dim));
            }
            shape_proto.dim.push(dim_proto);
        }

        let tensor_type = onnx::type_proto::Tensor {
            elem_type: self.onnx_elem_type_from_paddle(elem_type_str)?,
            shape: Some(shape_proto),
        };

        vi.r#type = Some(onnx::TypeProto {
            denotation: "".to_string(),
            value: Some(onnx::type_proto::Value::TensorType(tensor_type)),
        });
        Ok(vi)
    }

    pub fn build_value_info_for_id(
        &self,
        id: i64,
        name: String,
    ) -> anyhow::Result<onnx::ValueInfoProto> {
        let elem_type = self
            .state
            .tensor_types
            .get(&id)
            .map(String::as_str)
            .unwrap_or(paddle_tt::F32);
        let dims = self
            .state
            .tensor_shapes
            .get(&id)
            .cloned()
            .unwrap_or_default();
        self.build_value_info_from_meta(name, elem_type, &dims)
    }

    pub fn build_value_info(
        &mut self,
        op: &Value,
        is_input: bool,
    ) -> anyhow::Result<onnx::ValueInfoProto> {
        let default_name = if is_input {
            helper::op_out_id(op)
                .map(|id| format!("input_{}", id))
                .unwrap_or_else(|_| "input".to_string())
        } else {
            helper::op_input_ids(op)
                .first()
                .map(|id| format!("output_{}", id))
                .unwrap_or_else(|| "output".to_string())
        };
        let name = helper::attr(op, "name")
            .and_then(|d| d.as_str())
            .filter(|name| !name.is_empty())
            .map(str::to_owned)
            .unwrap_or(default_name);

        let mut vi = onnx::ValueInfoProto {
            name,
            ..Default::default()
        };

        let tt_d = op
            .get("O")
            .and_then(|o| o.as_array())
            .and_then(|o| o.first())
            .and_then(|o| o.get("TT"))
            .and_then(|tt| tt.get("D"))
            .and_then(|d| d.as_array());

        if let Some(tt) = tt_d
            && tt.len() >= 2
            && let Some(elem_type_str) = tt[0].get("#").and_then(|t| t.as_str())
        {
            let dims = tt[1]
                .as_array()
                .map(|dims| {
                    dims.iter()
                        .filter_map(|dim| dim.as_i64())
                        .collect::<Vec<_>>()
                })
                .unwrap_or_default();
            vi = self.build_value_info_from_meta(vi.name.clone(), elem_type_str, &dims)?;
        }

        Ok(vi)
    }
}