p2o 0.1.1

A PaddlePaddle New IR (PIR) to ONNX model converter.
Documentation
use serde::Deserialize;
use std::env;
use std::error::Error;
use std::fs;
use std::path::PathBuf;

#[derive(Debug, Deserialize)]
struct CodegenConfig {
    pass2_dispatch: Vec<DispatchEntry>,
    op_type_map: Vec<OpTypeEntry>,
}

#[derive(Debug, Deserialize)]
struct DispatchEntry {
    op: String,
    kind: DispatchKind,
    method: String,
    args: Option<Vec<String>>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
enum DispatchKind {
    Noop,
    Simple,
    Typed,
    Arg,
}

#[derive(Debug, Deserialize)]
struct OpTypeEntry {
    paddle: Vec<String>,
    onnx: String,
}

fn main() -> Result<(), Box<dyn Error>> {
    println!("cargo:rerun-if-changed=proto/framework.proto");
    println!("cargo:rerun-if-changed=proto/onnx.proto");
    println!("cargo:rerun-if-changed=codegen/ops.yaml");

    prost_build::compile_protos(&["proto/framework.proto", "proto/onnx.proto"], &["proto/"])?;

    let config: CodegenConfig = serde_yaml::from_str(&fs::read_to_string("codegen/ops.yaml")?)?;
    let out_dir = PathBuf::from(env::var("OUT_DIR")?);

    fs::write(
        out_dir.join("pass2_dispatch.rs"),
        generate_pass2_dispatch(&config.pass2_dispatch),
    )?;
    fs::write(
        out_dir.join("op_type_lookup.rs"),
        generate_op_type_lookup(&config.op_type_map),
    )?;

    Ok(())
}

fn generate_pass2_dispatch(entries: &[DispatchEntry]) -> String {
    let mut output = String::from(
        "fn build_pass2_op_dispatch() -> HashMap<&'static str, Pass2OpHandler> {\n    HashMap::from([\n",
    );
    for entry in entries {
        let line = match entry.kind {
            DispatchKind::Noop => {
                format!(
                    "        (\"{}\", pass2_noop as Pass2OpHandler),\n",
                    entry.op
                )
            }
            DispatchKind::Simple => {
                format!(
                    "        simple_dispatch!(\"{}\" => {}),\n",
                    entry.op, entry.method
                )
            }
            DispatchKind::Typed => {
                format!(
                    "        typed_dispatch!(\"{}\" => {}),\n",
                    entry.op, entry.method
                )
            }
            DispatchKind::Arg => {
                let args = entry.args.as_deref().unwrap_or(&[]).join(", ");
                format!(
                    "        arg_dispatch!(\"{}\" => {}({})),\n",
                    entry.op, entry.method, args
                )
            }
        };
        output.push_str(&line);
    }
    output.push_str("    ])\n}\n");
    output
}

fn generate_op_type_lookup(entries: &[OpTypeEntry]) -> String {
    let mut output = String::from(
        "fn generated_op_type(paddle_op: &str) -> Option<&'static str> {\n    match paddle_op {\n",
    );
    for entry in entries {
        let patterns = entry
            .paddle
            .iter()
            .map(|name| format!("\"{}\"", name))
            .collect::<Vec<_>>()
            .join(" | ");
        output.push_str(&format!(
            "        {} => Some(\"{}\"),\n",
            patterns, entry.onnx
        ));
    }
    output.push_str("        _ => None,\n    }\n}\n");
    output
}