zyga 0.5.1

ZYGA zero-knowledge proof system - CLI and library for generating ZK proofs
Documentation
use clap::{Parser, Subcommand};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use zyga::{
    code_generation::generate_proving_key_file,
    compile_constraints, create_proving_key, generate_proof, generate_trusted_setup,
    Expression,
    ProvingKey,
};

#[derive(Parser, Debug)]
#[command(author, version, about = "ZYGA Zero-Knowledge Proof System CLI", long_about = None)]
struct Cli {
    #[command(subcommand)]
    command: Commands,
}

#[derive(Subcommand, Debug)]
enum Commands {
    /// Generate proving key and verification files from constraints
    Setup {
        /// Path to the AOA constraint file
        #[arg(short, long)]
        constraint_file: PathBuf,

        /// Output file base name (without extension)
        #[arg(short, long)]
        output: PathBuf,

        /// Prefix for generated variables and functions
        #[arg(short = 'p', long)]
        prefix: Option<String>,

        /// Random seed for trusted setup (default: 12345)
        #[arg(short, long, default_value = "12345")]
        seed: u64,

        /// Enable verbose output
        #[arg(short, long)]
        verbose: bool,
    },

    /// Generate a proof using proving key and witness
    Prove {
        /// Path to the proving key file (.zyga)
        #[arg(short, long)]
        setup: PathBuf,

        /// Path to the witness JSON file
        #[arg(short, long)]
        witness_file: PathBuf,

        /// Path to output proof JSON file
        #[arg(short, long, default_value = "proof.json")]
        output: PathBuf,

        /// Include debug information in output
        #[arg(long)]
        debug_matrices: bool,

        /// Force proof generation even when constraints are not satisfied
        #[arg(short = 'f', long)]
        force: bool,

        /// Enable verbose output
        #[arg(short, long)]
        verbose: bool,
    },
}


#[derive(serde::Serialize)]
struct DebugInfo {
    a_matrix: Vec<Vec<f64>>,
    b_matrix: Vec<Vec<f64>>,
    c_matrix: Vec<Vec<f64>>,
    witnesses: Vec<String>,
    witness_values: Vec<serde_json::Value>,
    env_dict: HashMap<String, serde_json::Value>,
    hardcoded_claims: Option<serde_json::Value>,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let cli = Cli::parse();

    match cli.command {
        Commands::Setup {
            constraint_file,
            output,
            prefix,
            seed,
            verbose,
        } => {
            setup_command(constraint_file, output, prefix, seed, verbose)
        }
        Commands::Prove {
            setup,
            witness_file,
            output,
            debug_matrices,
            force,
            verbose,
        } => {
            prove_command(setup, witness_file, output, debug_matrices, force, verbose)
        }
    }
}

fn setup_command(
    constraint_file: PathBuf,
    output: PathBuf,
    prefix: Option<String>,
    seed: u64,
    verbose: bool,
) -> Result<(), Box<dyn std::error::Error>> {
    // Validate prefix if provided
    if let Some(ref prefix) = prefix {
        if prefix.is_empty() {
            return Err("Prefix cannot be empty".into());
        }
        if prefix.chars().next().unwrap().is_ascii_digit() {
            return Err(format!(
                "Prefix '{}' cannot start with a digit (would create invalid Rust identifiers)",
                prefix
            )
            .into());
        }
        if !prefix
            .chars()
            .all(|c| c.is_ascii_alphanumeric() || c == '_')
        {
            return Err(format!(
                "Prefix '{}' can only contain alphanumeric characters and underscores",
                prefix
            )
            .into());
        }
    }

    println!("ZYGA Setup Phase");
    println!("Constraint file: {}", constraint_file.display());
    println!("Output base: {}", output.display());
    if let Some(ref p) = prefix {
        println!("Prefix: {}", p);
    }
    println!("Seed: {}", seed);

    // Read constraint file
    let constraint_content = fs::read_to_string(&constraint_file)?;
    println!(
        "\nConstraint file loaded: {} bytes",
        constraint_content.len()
    );

    // No dummy witness needed - compilation is purely structural now

    // Compile constraints
    println!("\n=== Compiling Constraints ===");
    if verbose {
        println!("Starting constraint compilation...");
        println!(
            "Constraint content length: {} bytes",
            constraint_content.len()
        );
    }
    let compilation_result =
        compile_constraints(&constraint_content, verbose)?;
    if verbose {
        println!("Compilation complete!");
        println!(
            "Number of constraints: {}",
            compilation_result.a_matrix.len()
        );
        println!(
            "Number of variables: {}",
            compilation_result.witnesses.len()
        );
    }

    // Generate trusted setup
    println!("\n=== Generating Trusted Setup ===");
    if verbose {
        println!("Generating trusted setup with seed {}...", seed);
    }
    use rand::SeedableRng;
    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
    let trusted_setup = generate_trusted_setup(&mut rng)?;
    println!("Trusted setup generated");

    // Create proving key and verification key
    println!("\n=== Creating Proving Key and Verification Key ===");
    if verbose {
        println!("Computing Lagrange basis evaluations...");
    }
    let (proving_key, verification_key) = create_proving_key(
        &compilation_result,
        &trusted_setup,
        Some(&constraint_file),
        prefix.as_deref(),
    )?;

    println!("Created proving key with {} constraints", proving_key.num_constraints);
    println!("Created verification key");

    // Save proving key to .zyga file
    let zyga_path = output.with_extension("zyga");
    println!("\n=== Saving Proving Key ===");
    proving_key.save_to_file(&zyga_path)?;
    println!("Saved proving key to: {}", zyga_path.display());

    // Generate proving_key.rs file
    let output_dir = constraint_file.parent().unwrap_or(std::path::Path::new("."));
    generate_proving_key_file(&verification_key, output_dir, prefix.as_deref())?;

    println!("\n=== Setup Complete ===");
    println!("Generated files:");
    println!("  - {}", zyga_path.display());

    let coeff_file = if let Some(ref p) = prefix {
        format!("{}_public_coefficients.rs", p.to_lowercase())
    } else {
        "public_coefficients.rs".to_string()
    };
    println!("  - {}/{}", output_dir.display(), coeff_file);

    let key_file = if let Some(ref p) = prefix {
        format!("{}_proving_key.rs", p.to_lowercase())
    } else {
        "proving_key.rs".to_string()
    };
    println!("  - {}/{}", output_dir.display(), key_file);

    println!("\nUse 'zyga prove' with the .zyga file to generate proofs");

    Ok(())
}

fn prove_command(
    setup: PathBuf,
    witness_file: PathBuf,
    output: PathBuf,
    debug_matrices: bool,
    force: bool,
    verbose: bool,
) -> Result<(), Box<dyn std::error::Error>> {
    println!("ZYGA Prove Phase");
    println!("Setup file: {}", setup.display());
    println!("Witness file: {}", witness_file.display());
    println!("Output: {}", output.display());

    // Load proving key
    println!("\n=== Loading Proving Key ===");
    let proving_key = ProvingKey::load_from_file(&setup)?;
    println!("Loaded proving key with {} constraints", proving_key.num_constraints);

    // Read and parse witness file
    println!("\n=== Loading Witness ===");
    let witness_content = fs::read_to_string(&witness_file)?;
    let witness_json: serde_json::Value = serde_json::from_str(&witness_content)?;
    let mut witness_values: HashMap<String, Vec<f64>> = HashMap::new();
    let mut expanded_witness: HashMap<String, f64> = HashMap::new();

    // Convert JSON witness to HashMap and expand arrays
    if let Some(obj) = witness_json.as_object() {
        for (key, value) in obj {
            if let Some(arr) = value.as_array() {
                let values: Vec<f64> = arr.iter().filter_map(|v| {
                    if let Some(n) = v.as_f64() {
                        Some(n)
                    } else if let Some(i) = v.as_i64() {
                        Some(i as f64)
                    } else {
                        None
                    }
                }).collect();
                witness_values.insert(key.clone(), values.clone());

                // Expand array elements for env_dict
                for (i, &val) in values.iter().enumerate() {
                    let element_key = format!("{}[{}]", key, i);
                    expanded_witness.insert(element_key, val);
                }
            } else if let Some(num) = value.as_f64() {
                // Support scalar entries
                witness_values.insert(key.clone(), vec![num]);
                expanded_witness.insert(key.clone(), num);
            } else if let Some(num) = value.as_i64() {
                // Support integer entries
                let val = num as f64;
                witness_values.insert(key.clone(), vec![val]);
                expanded_witness.insert(key.clone(), val);
            }
        }
    }

    if verbose {
        println!("Loaded witness values:");
        for (key, values) in &witness_values {
            println!("  {}: {:?}", key, values);
        }
    }

    // Extend witness to compute intermediate variables
    println!("\n=== Computing Full Witness ===");
    let extended_witness = proving_key.witness_dag.extend_witness(
        &proving_key.witness_ids,
        &proving_key.witness_names,
        &expanded_witness
    ).map_err(|e| format!("Failed to compute witness: {}", e))?;
    println!("Computed {} witness variables from {} input values",
             extended_witness.len(),
             expanded_witness.len());

    // Generate proof
    println!("\n=== Generating Proof ===");
    if force {
        println!("WARNING: Force flag enabled - proof will be generated even if constraints are not satisfied!");
    }
    let pairing_proof = generate_proof(&proving_key, &extended_witness, force)?;
    println!("Generated single proof element");

    // Extract public inputs from the expanded witness
    // Include both public and deferred values (both are needed for verification)
    let mut public_inputs = HashMap::new();

    // Check which arrays are public or deferred based on the witness DAG
    let mut public_arrays = std::collections::HashSet::new();
    for expr_id in proving_key.witness_ids.iter() {
        let expr = proving_key.witness_dag.get(*expr_id);
        match expr {
            Expression::Public(name) | Expression::Deferred(name) => {
                // Extract array name from element reference (e.g., "b[0]" -> "b")
                if let Some(bracket_pos) = name.find('[') {
                    let array_name = &name[..bracket_pos];
                    public_arrays.insert(array_name.to_string());
                } else {
                    // Single public/deferred variable (not an array)
                    if name == "1" {
                        // The constant "1" is always 1
                        public_inputs.insert(name.clone(), 1.0);
                    } else {
                        public_inputs.insert(name.clone(), *expanded_witness.get(name).unwrap_or(&0.0));
                    }
                }
            }
            _ => {}
        }
    }

    // Now include elements from public/deferred arrays
    for (key, value) in &extended_witness {
        if key.contains('[') && key.contains(']') {
            // Check if this is from a public/deferred array
            if let Some(bracket_pos) = key.find('[') {
                let array_name = &key[..bracket_pos];
                if public_arrays.contains(array_name) {
                    public_inputs.insert(key.clone(), *value);
                }
            }
        }
    }

    // Create output structure
    let output_json = if debug_matrices {
        // Include debug information
        let debug_info = DebugInfo {
            a_matrix: proving_key.a_matrix.clone(),
            b_matrix: proving_key.b_matrix.clone(),
            c_matrix: proving_key.c_matrix.clone(),
            witnesses: proving_key.witness_names.clone(),
            witness_values: witness_values.into_iter()
                .map(|(_, v)| serde_json::json!(v))
                .collect(),
            env_dict: proving_key.env_dict.iter()
                .map(|(k, v)| (k.clone(), serde_json::json!(*v)))
                .collect(),
            hardcoded_claims: None,
        };

        serde_json::json!({
            "debug_info": debug_info,
            "pairing_proof": pairing_proof,
            "public_inputs": public_inputs,
        })
    } else {
        serde_json::json!({
            "pairing_proof": pairing_proof,
            "public_inputs": public_inputs,
        })
    };

    // Write output
    let output_str = serde_json::to_string_pretty(&output_json)?;
    fs::write(&output, output_str)?;

    println!("\n=== Proof Generation Complete ===");
    println!("Proof written to: {}", output.display());
    println!("Public inputs: {} values", public_inputs.len());

    Ok(())
}