zyga 0.5.1

ZYGA zero-knowledge proof system - CLI and library for generating ZK proofs
Documentation
use crate::{ZkError, VerificationKey};
use std::collections::HashMap;
use std::fmt::Write;
use std::fs;

/// Generate the public_coefficients.rs file for on-chain verification
/// This file contains coefficients that map public inputs to their contributions to a2, b2, c2
pub fn generate_public_coefficients_file(
    public_coeffs_a: &HashMap<String, f64>,
    public_coeffs_b: &HashMap<String, f64>,
    public_coeffs_c: &HashMap<String, f64>,
    constraint_file_path: Option<&std::path::Path>,
    prefix: Option<&str>,
) -> Result<(), ZkError> {
    // Get all unique public input names and sort them for consistent ordering
    let mut public_input_names: Vec<String> = public_coeffs_a
        .keys()
        .chain(public_coeffs_b.keys())
        .chain(public_coeffs_c.keys())
        .cloned()
        .collect::<std::collections::HashSet<_>>()
        .into_iter()
        .collect();
    public_input_names.sort();

    let num_inputs = public_input_names.len();

    let mut content = String::new();

    // Generate uppercase and lowercase prefix strings
    let upper_prefix = prefix.map(|p| p.to_uppercase());
    let lower_prefix = prefix.map(|p| p.to_lowercase());
    let prefix_with_underscore = upper_prefix
        .as_ref()
        .map(|p| format!("{}_", p))
        .unwrap_or_default();
    let fn_prefix = lower_prefix
        .as_ref()
        .map(|p| format!("{}_", p))
        .unwrap_or_default();

    writeln!(
        content,
        "// Auto-generated public coefficients for on-chain verification"
    )
    .unwrap();
    if let Some(p) = prefix {
        writeln!(content, "// Prefix: {}", p).unwrap();
    }
    writeln!(content, "// Generated from constraint compilation").unwrap();
    writeln!(content, "// DO NOT EDIT MANUALLY").unwrap();
    writeln!(
        content,
        "// This file is no-std compatible for Solana BPF\n"
    )
    .unwrap();

    writeln!(content, "/// Number of public inputs").unwrap();
    writeln!(
        content,
        "pub const {}NUM_PUBLIC_INPUTS: usize = {};",
        prefix_with_underscore, num_inputs
    )
    .unwrap();
    writeln!(content).unwrap();

    writeln!(content, "/// Public input names in order").unwrap();
    writeln!(
        content,
        "pub const {}PUBLIC_INPUT_NAMES: [&str; {}] = [",
        prefix_with_underscore, num_inputs
    )
    .unwrap();
    for name in &public_input_names {
        writeln!(content, "    \"{}\",", name).unwrap();
    }
    writeln!(content, "];\n").unwrap();

    writeln!(
        content,
        "/// Coefficient structure for a single public input"
    )
    .unwrap();
    writeln!(content, "#[derive(Clone, Copy, Debug)]").unwrap();
    writeln!(content, "pub struct InputCoefficients {{").unwrap();
    writeln!(content, "    pub coeff_a: i64,").unwrap();
    writeln!(content, "    pub coeff_b: i64,").unwrap();
    writeln!(content, "    pub coeff_c: i64,").unwrap();
    writeln!(content, "}}\n").unwrap();

    writeln!(
        content,
        "/// All public coefficients indexed by input position"
    )
    .unwrap();
    writeln!(
        content,
        "pub const {}PUBLIC_COEFFICIENTS: [InputCoefficients; {}] = [",
        prefix_with_underscore, num_inputs
    )
    .unwrap();

    for name in &public_input_names {
        let coeff_a = public_coeffs_a.get(name).copied().unwrap_or(0.0) as i64;
        let coeff_b = public_coeffs_b.get(name).copied().unwrap_or(0.0) as i64;
        let coeff_c = public_coeffs_c.get(name).copied().unwrap_or(0.0) as i64;

        writeln!(content, "    InputCoefficients {{ // {}", name).unwrap();
        writeln!(content, "        coeff_a: {},", coeff_a).unwrap();
        writeln!(content, "        coeff_b: {},", coeff_b).unwrap();
        writeln!(content, "        coeff_c: {},", coeff_c).unwrap();
        writeln!(content, "    }},").unwrap();
    }

    writeln!(content, "];\n").unwrap();

    writeln!(content, "/// Compute a2, b2, c2 from public inputs").unwrap();
    writeln!(
        content,
        "/// inputs should be provided in the same order as {}PUBLIC_INPUT_NAMES",
        prefix_with_underscore
    )
    .unwrap();
    writeln!(content, "pub fn compute_{}public_coefficients(public_inputs: &[i64]) -> Result<(i64, i64, i64), &'static str> {{", fn_prefix).unwrap();
    writeln!(
        content,
        "    if public_inputs.len() != {}NUM_PUBLIC_INPUTS {{",
        prefix_with_underscore
    )
    .unwrap();
    writeln!(
        content,
        "        return Err(\"Invalid number of public inputs\");"
    )
    .unwrap();
    writeln!(content, "    }}").unwrap();
    writeln!(content, "    ").unwrap();
    writeln!(content, "    let mut a2 = 0i64;").unwrap();
    writeln!(content, "    let mut b2 = 0i64;").unwrap();
    writeln!(content, "    let mut c2 = 0i64;").unwrap();
    writeln!(content, "    ").unwrap();
    writeln!(
        content,
        "    for (i, &input_value) in public_inputs.iter().enumerate() {{"
    )
    .unwrap();
    writeln!(
        content,
        "        let coeffs = &{}PUBLIC_COEFFICIENTS[i];",
        prefix_with_underscore
    )
    .unwrap();
    writeln!(content, "        a2 += coeffs.coeff_a * input_value;").unwrap();
    writeln!(content, "        b2 += coeffs.coeff_b * input_value;").unwrap();
    writeln!(content, "        c2 += coeffs.coeff_c * input_value;").unwrap();
    writeln!(content, "    }}").unwrap();
    writeln!(content, "    ").unwrap();
    writeln!(content, "    Ok((a2, b2, c2))").unwrap();
    writeln!(content, "}}\n").unwrap();

    // Add a helper function to get input index by name
    writeln!(content, "/// Get the index of a public input by name").unwrap();
    writeln!(
        content,
        "pub fn get_{}input_index(name: &str) -> Option<usize> {{",
        fn_prefix
    )
    .unwrap();
    writeln!(
        content,
        "    {}PUBLIC_INPUT_NAMES.iter().position(|&n| n == name)",
        prefix_with_underscore
    )
    .unwrap();
    writeln!(content, "}}\n").unwrap();

    // Debug: Print what we're about to write
    let display_filename = if let Some(prefix) = lower_prefix.as_ref() {
        format!("{}_public_coefficients.rs", prefix)
    } else {
        "public_coefficients.rs".to_string()
    };
    println!("\n=== Writing {} ===", display_filename);
    println!("Found {} public inputs:", num_inputs);
    for (i, name) in public_input_names.iter().enumerate() {
        let coeff_a = public_coeffs_a.get(name).copied().unwrap_or(0.0) as i64;
        let coeff_b = public_coeffs_b.get(name).copied().unwrap_or(0.0) as i64;
        let coeff_c = public_coeffs_c.get(name).copied().unwrap_or(0.0) as i64;
        println!(
            "  [{}] {}: A={}, B={}, C={}",
            i, name, coeff_a, coeff_b, coeff_c
        );
    }

    // Write to file - in same directory as constraint file
    let filename = if let Some(prefix) = lower_prefix.as_ref() {
        format!("{}_public_coefficients.rs", prefix)
    } else {
        "public_coefficients.rs".to_string()
    };

    let output_path = constraint_file_path
        .expect("Constraint file path should always be provided")
        .parent()
        .unwrap_or(std::path::Path::new("."))
        .join(filename);

    fs::write(&output_path, content).map_err(|e| {
        ZkError::ComputationError(format!("Failed to write public_coefficients.rs: {}", e))
    })?;

    println!("Generated {}", output_path.display());

    Ok(())
}

/// Generate the proving_key.rs file for on-chain verification
pub fn generate_proving_key_file(
    verification_key: &VerificationKey,
    output_dir: &std::path::Path,
    prefix: Option<&str>,
) -> Result<(), ZkError> {
    let rust_code = verification_key.to_rust_code(prefix);

    let filename = if let Some(prefix) = prefix {
        format!("{}_proving_key.rs", prefix.to_lowercase())
    } else {
        "proving_key.rs".to_string()
    };

    let output_path = output_dir.join(&filename);

    fs::write(&output_path, rust_code).map_err(|e| {
        ZkError::IoError(format!("Failed to write {}: {}", filename, e))
    })?;

    println!("Generated {}", output_path.display());

    Ok(())
}