use clap::{Parser, Subcommand};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use zyga::{
code_generation::generate_proving_key_file,
compile_constraints, create_proving_key, generate_proof, generate_trusted_setup,
Expression,
ProvingKey,
};
#[derive(Parser, Debug)]
#[command(author, version, about = "ZYGA Zero-Knowledge Proof System CLI", long_about = None)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand, Debug)]
enum Commands {
Setup {
#[arg(short, long)]
constraint_file: PathBuf,
#[arg(short, long)]
output: PathBuf,
#[arg(short = 'p', long)]
prefix: Option<String>,
#[arg(short, long, default_value = "12345")]
seed: u64,
#[arg(short, long)]
verbose: bool,
},
Prove {
#[arg(short, long)]
setup: PathBuf,
#[arg(short, long)]
witness_file: PathBuf,
#[arg(short, long, default_value = "proof.json")]
output: PathBuf,
#[arg(long)]
debug_matrices: bool,
#[arg(short = 'f', long)]
force: bool,
#[arg(short, long)]
verbose: bool,
},
}
#[derive(serde::Serialize)]
struct DebugInfo {
a_matrix: Vec<Vec<f64>>,
b_matrix: Vec<Vec<f64>>,
c_matrix: Vec<Vec<f64>>,
witnesses: Vec<String>,
witness_values: Vec<serde_json::Value>,
env_dict: HashMap<String, serde_json::Value>,
hardcoded_claims: Option<serde_json::Value>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
match cli.command {
Commands::Setup {
constraint_file,
output,
prefix,
seed,
verbose,
} => {
setup_command(constraint_file, output, prefix, seed, verbose)
}
Commands::Prove {
setup,
witness_file,
output,
debug_matrices,
force,
verbose,
} => {
prove_command(setup, witness_file, output, debug_matrices, force, verbose)
}
}
}
fn setup_command(
constraint_file: PathBuf,
output: PathBuf,
prefix: Option<String>,
seed: u64,
verbose: bool,
) -> Result<(), Box<dyn std::error::Error>> {
if let Some(ref prefix) = prefix {
if prefix.is_empty() {
return Err("Prefix cannot be empty".into());
}
if prefix.chars().next().unwrap().is_ascii_digit() {
return Err(format!(
"Prefix '{}' cannot start with a digit (would create invalid Rust identifiers)",
prefix
)
.into());
}
if !prefix
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_')
{
return Err(format!(
"Prefix '{}' can only contain alphanumeric characters and underscores",
prefix
)
.into());
}
}
println!("ZYGA Setup Phase");
println!("Constraint file: {}", constraint_file.display());
println!("Output base: {}", output.display());
if let Some(ref p) = prefix {
println!("Prefix: {}", p);
}
println!("Seed: {}", seed);
let constraint_content = fs::read_to_string(&constraint_file)?;
println!(
"\nConstraint file loaded: {} bytes",
constraint_content.len()
);
println!("\n=== Compiling Constraints ===");
if verbose {
println!("Starting constraint compilation...");
println!(
"Constraint content length: {} bytes",
constraint_content.len()
);
}
let compilation_result =
compile_constraints(&constraint_content, verbose)?;
if verbose {
println!("Compilation complete!");
println!(
"Number of constraints: {}",
compilation_result.a_matrix.len()
);
println!(
"Number of variables: {}",
compilation_result.witnesses.len()
);
}
println!("\n=== Generating Trusted Setup ===");
if verbose {
println!("Generating trusted setup with seed {}...", seed);
}
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let trusted_setup = generate_trusted_setup(&mut rng)?;
println!("Trusted setup generated");
println!("\n=== Creating Proving Key and Verification Key ===");
if verbose {
println!("Computing Lagrange basis evaluations...");
}
let (proving_key, verification_key) = create_proving_key(
&compilation_result,
&trusted_setup,
Some(&constraint_file),
prefix.as_deref(),
)?;
println!("Created proving key with {} constraints", proving_key.num_constraints);
println!("Created verification key");
let zyga_path = output.with_extension("zyga");
println!("\n=== Saving Proving Key ===");
proving_key.save_to_file(&zyga_path)?;
println!("Saved proving key to: {}", zyga_path.display());
let output_dir = constraint_file.parent().unwrap_or(std::path::Path::new("."));
generate_proving_key_file(&verification_key, output_dir, prefix.as_deref())?;
println!("\n=== Setup Complete ===");
println!("Generated files:");
println!(" - {}", zyga_path.display());
let coeff_file = if let Some(ref p) = prefix {
format!("{}_public_coefficients.rs", p.to_lowercase())
} else {
"public_coefficients.rs".to_string()
};
println!(" - {}/{}", output_dir.display(), coeff_file);
let key_file = if let Some(ref p) = prefix {
format!("{}_proving_key.rs", p.to_lowercase())
} else {
"proving_key.rs".to_string()
};
println!(" - {}/{}", output_dir.display(), key_file);
println!("\nUse 'zyga prove' with the .zyga file to generate proofs");
Ok(())
}
fn prove_command(
setup: PathBuf,
witness_file: PathBuf,
output: PathBuf,
debug_matrices: bool,
force: bool,
verbose: bool,
) -> Result<(), Box<dyn std::error::Error>> {
println!("ZYGA Prove Phase");
println!("Setup file: {}", setup.display());
println!("Witness file: {}", witness_file.display());
println!("Output: {}", output.display());
println!("\n=== Loading Proving Key ===");
let proving_key = ProvingKey::load_from_file(&setup)?;
println!("Loaded proving key with {} constraints", proving_key.num_constraints);
println!("\n=== Loading Witness ===");
let witness_content = fs::read_to_string(&witness_file)?;
let witness_json: serde_json::Value = serde_json::from_str(&witness_content)?;
let mut witness_values: HashMap<String, Vec<f64>> = HashMap::new();
let mut expanded_witness: HashMap<String, f64> = HashMap::new();
if let Some(obj) = witness_json.as_object() {
for (key, value) in obj {
if let Some(arr) = value.as_array() {
let values: Vec<f64> = arr.iter().filter_map(|v| {
if let Some(n) = v.as_f64() {
Some(n)
} else if let Some(i) = v.as_i64() {
Some(i as f64)
} else {
None
}
}).collect();
witness_values.insert(key.clone(), values.clone());
for (i, &val) in values.iter().enumerate() {
let element_key = format!("{}[{}]", key, i);
expanded_witness.insert(element_key, val);
}
} else if let Some(num) = value.as_f64() {
witness_values.insert(key.clone(), vec![num]);
expanded_witness.insert(key.clone(), num);
} else if let Some(num) = value.as_i64() {
let val = num as f64;
witness_values.insert(key.clone(), vec![val]);
expanded_witness.insert(key.clone(), val);
}
}
}
if verbose {
println!("Loaded witness values:");
for (key, values) in &witness_values {
println!(" {}: {:?}", key, values);
}
}
println!("\n=== Computing Full Witness ===");
let extended_witness = proving_key.witness_dag.extend_witness(
&proving_key.witness_ids,
&proving_key.witness_names,
&expanded_witness
).map_err(|e| format!("Failed to compute witness: {}", e))?;
println!("Computed {} witness variables from {} input values",
extended_witness.len(),
expanded_witness.len());
println!("\n=== Generating Proof ===");
if force {
println!("WARNING: Force flag enabled - proof will be generated even if constraints are not satisfied!");
}
let pairing_proof = generate_proof(&proving_key, &extended_witness, force)?;
println!("Generated single proof element");
let mut public_inputs = HashMap::new();
let mut public_arrays = std::collections::HashSet::new();
for expr_id in proving_key.witness_ids.iter() {
let expr = proving_key.witness_dag.get(*expr_id);
match expr {
Expression::Public(name) | Expression::Deferred(name) => {
if let Some(bracket_pos) = name.find('[') {
let array_name = &name[..bracket_pos];
public_arrays.insert(array_name.to_string());
} else {
if name == "1" {
public_inputs.insert(name.clone(), 1.0);
} else {
public_inputs.insert(name.clone(), *expanded_witness.get(name).unwrap_or(&0.0));
}
}
}
_ => {}
}
}
for (key, value) in &extended_witness {
if key.contains('[') && key.contains(']') {
if let Some(bracket_pos) = key.find('[') {
let array_name = &key[..bracket_pos];
if public_arrays.contains(array_name) {
public_inputs.insert(key.clone(), *value);
}
}
}
}
let output_json = if debug_matrices {
let debug_info = DebugInfo {
a_matrix: proving_key.a_matrix.clone(),
b_matrix: proving_key.b_matrix.clone(),
c_matrix: proving_key.c_matrix.clone(),
witnesses: proving_key.witness_names.clone(),
witness_values: witness_values.into_iter()
.map(|(_, v)| serde_json::json!(v))
.collect(),
env_dict: proving_key.env_dict.iter()
.map(|(k, v)| (k.clone(), serde_json::json!(*v)))
.collect(),
hardcoded_claims: None,
};
serde_json::json!({
"debug_info": debug_info,
"pairing_proof": pairing_proof,
"public_inputs": public_inputs,
})
} else {
serde_json::json!({
"pairing_proof": pairing_proof,
"public_inputs": public_inputs,
})
};
let output_str = serde_json::to_string_pretty(&output_json)?;
fs::write(&output, output_str)?;
println!("\n=== Proof Generation Complete ===");
println!("Proof written to: {}", output.display());
println!("Public inputs: {} values", public_inputs.len());
Ok(())
}