export-aptos-verifier 0.1.1

CLI for exporting Groth16 artifacts to Aptos Move verifier packages.
use clap::{Parser, Subcommand, ValueEnum};
use regex::Regex;
use std::fmt;
use std::path::PathBuf;
use std::process::Command as ProcessCommand;

use export_aptos_verifier_core::curves::{create_adapter, PointFormat};
use export_aptos_verifier_core::error::{Error, Result};
use export_aptos_verifier_core::formats::{
    load_compact_bundle, load_snarkjs_json_inputs_with_curve_hint,
};
use export_aptos_verifier_core::local_verify;
use export_aptos_verifier_core::movegen::{
    generate_move_package, GenerateMovePackageOptions, MovegenMode,
};
use export_aptos_verifier_core::CurveKind;

#[derive(Parser)]
#[command(
    name = "export-aptos-verifier",
    version,
    about = "Export Groth16 artifacts to an Aptos Move verifier package"
)]
struct Cli {
    #[command(subcommand)]
    command: CliCommand,
}

#[derive(Subcommand)]
enum CliCommand {
    Generate(GenerateArgs),
}

#[derive(clap::Args)]
struct GenerateArgs {
    #[arg(long)]
    vk: Option<PathBuf>,
    #[arg(long)]
    proof: Option<PathBuf>,
    #[arg(long)]
    public: Option<PathBuf>,
    #[arg(long)]
    out: PathBuf,
    #[arg(long)]
    package_name: String,
    #[arg(long)]
    module_name: String,
    #[arg(long)]
    account_address: String,

    #[arg(long, default_value_t = CurveArg::Auto)]
    curve: CurveArg,
    #[arg(long, default_value_t = InputFormatArg::Auto)]
    input_format: InputFormatArg,
    #[arg(long, default_value_t = PointFormatArg::Uncompressed)]
    bn254_format: PointFormatArg,
    #[arg(long, default_value_t = PointFormatArg::Compressed)]
    bls_format: PointFormatArg,
    #[arg(long, default_value_t = ModeArg::Entry)]
    mode: ModeArg,
    #[arg(long, default_value_t = false)]
    run_aptos_test: bool,
    #[arg(long, default_value_t = false)]
    force: bool,
    #[arg(long, default_value_t = false)]
    skip_local_verify: bool,
    #[arg(long, default_value_t = false)]
    prepared: bool,
    #[arg(long)]
    bundle: Option<PathBuf>,
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
enum CurveArg {
    Auto,
    Bn254,
    Bls12381,
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
enum InputFormatArg {
    Auto,
    SnarkjsJson,
    ArkworksCompact,
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
enum PointFormatArg {
    Compressed,
    Uncompressed,
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
enum ModeArg {
    Library,
    Entry,
    Test,
}

impl ModeArg {
    fn into_move_mode(self) -> MovegenMode {
        match self {
            Self::Library => MovegenMode::Library,
            Self::Entry => MovegenMode::Entry,
            Self::Test => MovegenMode::Test,
        }
    }
}

fn main() {
    let cli = Cli::parse();
    let result = match cli.command {
        CliCommand::Generate(args) => run_generate(args),
    };
    if let Err(error) = result {
        eprintln!("{error}");
        std::process::exit(1);
    }
}

fn run_generate(args: GenerateArgs) -> Result<()> {
    validate_names(&args.package_name, "package_name")?;
    validate_names(&args.module_name, "module_name")?;
    validate_account_address(&args.account_address)?;

    let curve_hint = if matches!(args.curve, CurveArg::Auto) {
        None
    } else {
        Some(args.curve.to_string())
    };

    let inputs = match (args.bundle.as_ref(), args.input_format) {
        (Some(bundle), InputFormatArg::Auto | InputFormatArg::ArkworksCompact) => {
            load_compact_bundle(bundle, curve_hint.as_deref())?
        }
        (None, InputFormatArg::Auto | InputFormatArg::SnarkjsJson) => {
            let vk = args.vk.as_ref().ok_or_else(|| {
                Error::MissingInput("--vk is required unless --bundle is used".to_string())
            })?;
            let proof = args.proof.as_ref().ok_or_else(|| {
                Error::MissingInput("--proof is required unless --bundle is used".to_string())
            })?;
            load_snarkjs_json_inputs_with_curve_hint(
                vk,
                proof,
                args.public.as_deref(),
                curve_hint.as_deref(),
            )?
        }
        (Some(_), InputFormatArg::SnarkjsJson) => {
            return Err(Error::MissingInput(
                "snarkjs-json mode requires --vk and --proof inputs".to_string(),
            ));
        }
        (None, InputFormatArg::ArkworksCompact) => {
            return Err(Error::MissingInput(
                "arkworks-compact mode requires --bundle".to_string(),
            ));
        }
    };

    let requested_curve = match args.curve {
        CurveArg::Auto => inputs.curve.canonical_name().to_string(),
        CurveArg::Bn254 => {
            if inputs.curve != CurveKind::Bn254 {
                return Err(Error::CurveMismatch(
                    "requested curve bn254 does not match input curve metadata".to_string(),
                ));
            }
            "bn254".to_string()
        }
        CurveArg::Bls12381 => {
            if inputs.curve != CurveKind::Bls12_381 {
                return Err(Error::CurveMismatch(
                    "requested curve bls12381 does not match input curve metadata".to_string(),
                ));
            }
            "bls12381".to_string()
        }
    };

    if args.prepared {
        return Err(Error::PreparedNotImplemented);
    }

    let adapter = create_adapter(&requested_curve)?;
    validate_point_format(adapter.as_ref(), &requested_curve, &args)?;

    if !args.skip_local_verify {
        let ok = local_verify(adapter.as_ref(), &inputs)?;
        if !ok {
            return Err(Error::LocalProofVerificationFailed(
                "local arkworks verification returned false".to_string(),
            ));
        }
    }

    generate_move_package(
        &args.out,
        adapter.as_ref(),
        &inputs,
        &GenerateMovePackageOptions {
            package_name: &args.package_name,
            module_name: &args.module_name,
            account_address: &args.account_address,
            mode: args.mode.into_move_mode(),
            force: args.force,
        },
    )?;

    if args.run_aptos_test {
        run_aptos_test(&args.out)?;
    }

    Ok(())
}

fn validate_point_format(
    adapter: &dyn export_aptos_verifier_core::curves::CurveAdapter,
    requested_curve: &str,
    args: &GenerateArgs,
) -> Result<()> {
    let normalized = canonicalize_curve(requested_curve);
    if normalized == "bn254" {
        let expected = adapter.default_point_format();
        if expected != map_point_format(&args.bn254_format) {
            return Err(Error::UnsupportedCurve(format!(
                "unsupported BN254 format, expected {:?}",
                expected
            )));
        }
    }
    if normalized == "bls12381" {
        let expected = adapter.default_point_format();
        if expected != map_point_format(&args.bls_format) {
            return Err(Error::UnsupportedCurve(format!(
                "unsupported BLS12-381 format, expected {:?}",
                expected
            )));
        }
    }
    Ok(())
}

impl fmt::Display for CurveArg {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let s = match self {
            Self::Auto => "auto",
            Self::Bn254 => "bn254",
            Self::Bls12381 => "bls12381",
        };
        write!(f, "{s}")
    }
}

impl fmt::Display for InputFormatArg {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let s = match self {
            Self::Auto => "auto",
            Self::SnarkjsJson => "snarkjs-json",
            Self::ArkworksCompact => "arkworks-compact",
        };
        write!(f, "{s}")
    }
}

impl fmt::Display for PointFormatArg {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let s = match self {
            Self::Compressed => "compressed",
            Self::Uncompressed => "uncompressed",
        };
        write!(f, "{s}")
    }
}

impl fmt::Display for ModeArg {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let s = match self {
            Self::Library => "library",
            Self::Entry => "entry",
            Self::Test => "test",
        };
        write!(f, "{s}")
    }
}

fn canonicalize_curve(name: &str) -> String {
    name.to_lowercase().replace(['-', '_'], "")
}

fn map_point_format(value: &PointFormatArg) -> PointFormat {
    match value {
        PointFormatArg::Compressed => PointFormat::Compressed,
        PointFormatArg::Uncompressed => PointFormat::Uncompressed,
    }
}

fn validate_names(value: &str, field: &str) -> Result<()> {
    let re = Regex::new(r"^[A-Za-z_][A-Za-z0-9_]*$").unwrap();
    if !re.is_match(value) {
        if field == "module_name" {
            return Err(Error::InvalidModuleName(format!(
                "{field} must match [A-Za-z_][A-Za-z0-9_]*"
            )));
        }
        return Err(Error::InvalidPackageName(format!(
            "{field} must match [A-Za-z_][A-Za-z0-9_]*"
        )));
    }
    Ok(())
}

fn validate_account_address(value: &str) -> Result<()> {
    let re = Regex::new(r"^0[xX][0-9a-fA-F]{1,64}$").unwrap();
    if !re.is_match(value) {
        return Err(Error::InvalidAccountAddress(
            "account_address must match 0x[0-9a-fA-F]{1,64}".to_string(),
        ));
    }
    Ok(())
}

fn run_aptos_test(out_dir: &std::path::Path) -> Result<()> {
    let aptos = ProcessCommand::new("aptos")
        .arg("move")
        .arg("test")
        .arg("--package-dir")
        .arg(out_dir)
        .output();

    match aptos {
        Ok(out) => {
            if !out.status.success() {
                let stdout = String::from_utf8_lossy(&out.stdout);
                let stderr = String::from_utf8_lossy(&out.stderr);
                return Err(Error::AptosTestFailed(format!(
                    "ERR_APTOS_TEST_FAILED: {}\nstdout:\n{}\nstderr:\n{}",
                    out.status, stdout, stderr
                )));
            }
        }
        Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
            return Err(Error::AptosTestFailed(
                "ERR_APTOS_CLI_NOT_FOUND: install Aptos CLI or run without --run-aptos-test"
                    .to_string(),
            ));
        }
        Err(err) => {
            return Err(Error::AptosTestFailed(err.to_string()));
        }
    }

    Ok(())
}