mx-proto 0.1.1

Protobuf and gRPC bindings for MultiversX network protocols.
Documentation
use regex::Regex;
use std::{
    collections::HashMap,
    env, fs,
    path::{Path, PathBuf},
};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
    let tonic_enabled = env::var_os("CARGO_FEATURE_TONIC").is_some();
    let proto_root = manifest_dir.join("../../proto");
    let raw_root = proto_root.join("raw");
    let generated_root = manifest_dir.join("generated");
    let out_dir = PathBuf::from(env::var("OUT_DIR")?);
    fs::create_dir_all(&out_dir)?;

    if env::var_os("DOCS_RS").is_some() || !raw_root.exists() {
        copy_checked_in_bindings(&generated_root, &out_dir)?;
        if !tonic_enabled {
            gate_generated_tonic_modules(&out_dir)?;
        }

        if raw_root.exists() {
            println!("cargo:warning=using checked-in protobuf bindings for docs.rs build");
        } else {
            println!(
                "cargo:warning=proto raw directory missing, using checked-in protobuf bindings"
            );
        }

        return Ok(());
    }

    let mapping_path = raw_root.join("paths.json");
    let import_map: HashMap<String, String> = if mapping_path.exists() {
        let mapping_contents = fs::read_to_string(&mapping_path)?;
        serde_json::from_str(&mapping_contents)?
    } else {
        HashMap::new()
    };

    let mut reverse_map: HashMap<String, String> = HashMap::new();
    for (original, flattened) in &import_map {
        reverse_map.insert(flattened.clone(), original.clone());
    }

    let mut raw_files = Vec::new();
    collect_proto_files(&raw_root, &mut raw_files)?;
    raw_files.sort();
    raw_files.dedup();

    if raw_files.is_empty() {
        fs::write(
            out_dir.join("mod.rs"),
            "// generated protobuf modules will appear here when proto files are added.\n",
        )?;
        println!("cargo:warning=no proto files discovered, skipping prost build");
        return Ok(());
    }

    let processed_root = out_dir.join("processed");
    if processed_root.exists() {
        fs::remove_dir_all(&processed_root)?;
    }
    fs::create_dir_all(&processed_root)?;

    let mut sanitized_files = Vec::with_capacity(raw_files.len());
    for raw in raw_files {
        if raw.file_name().and_then(|n| n.to_str()) == Some("paths.json") {
            continue;
        }

        println!("cargo:rerun-if-changed={}", raw.display());

        let file_name = match raw.file_name().and_then(|name| name.to_str()) {
            Some(name) => name.to_owned(),
            None => continue,
        };

        let dest = processed_root.join(&file_name);

        let contents = fs::read_to_string(&raw)?;
        let sanitized = sanitize_proto(&contents, &import_map);
        fs::write(&dest, sanitized)?;

        sanitized_files.push(dest);
    }

    if mapping_path.exists() {
        println!("cargo:rerun-if-changed={}", mapping_path.display());
    }

    let mut config = prost_build::Config::new();
    config.protoc_executable(protoc_bin_vendored::protoc_bin_path()?);
    config.out_dir(&out_dir);
    config.bytes(["."]); // use bytes::Bytes for all bytes fields to minimize copies
    config.message_attribute(
        ".*",
        "#[cfg_attr(feature = \"serde\", derive(serde::Serialize, serde::Deserialize))]",
    );

    let include_path = processed_root;
    let proto_paths: Vec<&Path> = sanitized_files.iter().map(PathBuf::as_path).collect();
    compile_bindings(config, &proto_paths, include_path.as_path(), tonic_enabled)?;

    if !tonic_enabled {
        gate_generated_tonic_modules(&out_dir)?;
    }

    write_mod_file(&out_dir)?;

    Ok(())
}

#[cfg(feature = "tonic")]
fn compile_bindings(
    mut config: prost_build::Config,
    proto_paths: &[&Path],
    include_path: &Path,
    tonic_enabled: bool,
) -> Result<(), Box<dyn std::error::Error>> {
    if tonic_enabled {
        tonic_prost_build::configure()
            .build_client(true)
            .build_server(true)
            .compile_with_config(config, proto_paths, &[include_path])?;
    } else {
        config.compile_protos(proto_paths, &[include_path])?;
    }

    Ok(())
}

#[cfg(not(feature = "tonic"))]
fn compile_bindings(
    mut config: prost_build::Config,
    proto_paths: &[&Path],
    include_path: &Path,
    tonic_enabled: bool,
) -> Result<(), Box<dyn std::error::Error>> {
    debug_assert!(!tonic_enabled);
    let _ = tonic_enabled;
    config.compile_protos(proto_paths, &[include_path])?;
    Ok(())
}

fn copy_checked_in_bindings(
    generated_root: &Path,
    out_dir: &Path,
) -> Result<(), Box<dyn std::error::Error>> {
    if !generated_root.exists() {
        return Err(format!(
            "checked-in bindings directory missing: {}",
            generated_root.display()
        )
        .into());
    }

    for entry in fs::read_dir(generated_root)? {
        let entry = entry?;
        let path = entry.path();

        if !path.extension().map(|ext| ext == "rs").unwrap_or(false) {
            continue;
        }

        println!("cargo:rerun-if-changed={}", path.display());

        let file_name = match path.file_name() {
            Some(file_name) => file_name,
            None => continue,
        };

        fs::copy(&path, out_dir.join(file_name))?;
    }

    Ok(())
}

fn collect_proto_files(
    dir: &Path,
    acc: &mut Vec<PathBuf>,
) -> Result<(), Box<dyn std::error::Error>> {
    for entry in fs::read_dir(dir)? {
        let entry = entry?;
        let path = entry.path();
        if path.is_dir() {
            collect_proto_files(&path, acc)?;
        } else if has_proto_extension(&path) {
            acc.push(path);
        }
    }
    Ok(())
}

fn has_proto_extension(path: &Path) -> bool {
    path.extension()
        .and_then(|ext| ext.to_str())
        .map(|ext| ext.eq_ignore_ascii_case("proto"))
        .unwrap_or(false)
}

fn sanitize_proto(source: &str, import_map: &HashMap<String, String>) -> String {
    let mut buffer = String::with_capacity(source.len());
    for line in source.lines() {
        let trimmed = line.trim_start();
        if trimmed.starts_with("import")
            && trimmed.contains("github.com/gogo/protobuf/gogoproto/gogo.proto")
        {
            continue;
        }
        if trimmed.starts_with("option (gogoproto.") {
            continue;
        }
        buffer.push_str(line);
        buffer.push('\n');
    }

    let square_re = Regex::new(r"\[[^\]]*gogoproto[^\]]*\]").unwrap();
    let paren_re = Regex::new(r"\(gogoproto\.[^)]*\)").unwrap();
    let semicolon_space = Regex::new(r"\s+;").unwrap();

    let mut buffer = square_re.replace_all(&buffer, "").into_owned();
    buffer = paren_re.replace_all(&buffer, "").into_owned();
    buffer = semicolon_space.replace_all(&buffer, ";").into_owned();

    for (original, flattened) in import_map {
        let needle = format!("\"{original}\"");
        let replacement = format!("\"{flattened}\"");
        buffer = buffer.replace(&needle, &replacement);
    }

    let mut cleaned = String::with_capacity(buffer.len());
    for line in buffer.lines() {
        cleaned.push_str(line.trim_end());
        cleaned.push('\n');
    }

    cleaned
}

fn gate_generated_tonic_modules(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
    let module_re = Regex::new(r"(?m)^pub mod [A-Za-z0-9_]+_(?:client|server) \{$")?;

    for entry in fs::read_dir(out_dir)? {
        let entry = entry?;
        let path = entry.path();
        if !path.extension().map(|ext| ext == "rs").unwrap_or(false) {
            continue;
        }

        let contents = fs::read_to_string(&path)?;
        if !contents.contains("tonic::") {
            continue;
        }

        let updated = module_re
            .replace_all(&contents, "#[cfg(feature = \"tonic\")]\n$0")
            .into_owned();

        if updated != contents {
            fs::write(path, updated)?;
        }
    }

    Ok(())
}

fn write_mod_file(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
    let mut mod_rs = String::new();
    for entry in fs::read_dir(out_dir)? {
        let entry = entry?;
        let path = entry.path();
        if path.extension().map(|ext| ext == "rs").unwrap_or(false)
            && path.file_name().map(|n| n != "mod.rs").unwrap_or(false)
            && let Some(file_name) = path.file_name().and_then(|n| n.to_str())
        {
            let module_name = file_name.trim_end_matches(".rs");
            mod_rs.push_str(&format!(
                    "pub mod {module_name} {{ include!(concat!(env!(\"OUT_DIR\"), \"/{module_name}.rs\")); }}\n#[allow(ambiguous_glob_reexports)]\n#[allow(clippy::all)]\n#[allow(non_camel_case_types)]\n#[allow(non_snake_case)]\n#[allow(non_upper_case_globals)]\n#[allow(dead_code)]\n#[allow(unused_imports)]\n#[allow(unused_variables)]\npub use {module_name}::*;\n"
                ));
        }
    }

    if !mod_rs.is_empty() {
        fs::write(out_dir.join("mod.rs"), mod_rs)?;
    }

    Ok(())
}