use anyhow::{Result, anyhow};
use ark_bn254::Fr;
use ark_ff::{BigInteger, PrimeField};
use std::path::Path;
use crate::groth16::witness::{WitnessCalculator, MembershipInputs};
use crate::groth16::rapidsnark_ffi::prove_with_library;
#[derive(Debug, Clone)]
pub struct ProofResult {
pub proof_bytes: Vec<u8>,
pub merkle_root: String,
pub npub_x: String,
pub npub_y: String,
pub proof_time_ms: u64,
}
pub struct ProverConfig {
pub witness_lib_path: String,
pub dat_path: String,
pub rapidsnark_lib_path: String,
pub zkey_path: String,
}
impl ProverConfig {
pub fn from_circuits_dir<P: AsRef<Path>>(dir: P) -> Self {
let dir = dir.as_ref();
Self {
witness_lib_path: dir.join("libmembership_witness.so").to_string_lossy().to_string(),
dat_path: dir.join("membership.dat").to_string_lossy().to_string(),
rapidsnark_lib_path: dir.join("librapidsnark.so").to_string_lossy().to_string(),
zkey_path: dir.join("membership_final.zkey").to_string_lossy().to_string(),
}
}
pub fn new(
witness_lib_path: String,
dat_path: String,
rapidsnark_lib_path: String,
zkey_path: String,
) -> Self {
Self { witness_lib_path, dat_path, rapidsnark_lib_path, zkey_path }
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MerkleWitness {
pub siblings: Vec<String>,
pub path_bits: Vec<u8>,
}
pub struct MembershipProver {
config: ProverConfig,
}
impl MembershipProver {
pub fn new(config: ProverConfig) -> Self {
Self { config }
}
pub fn generate_proof(
&self,
leaf_secret: &[Fr; 5],
witness: &MerkleWitness,
) -> Result<ProofResult> {
let start = std::time::Instant::now();
if witness.siblings.len() != 20 {
return Err(anyhow!("Expected 20 siblings, got {}", witness.siblings.len()));
}
if witness.path_bits.len() != 20 {
return Err(anyhow!("Expected 20 path_bits, got {}", witness.path_bits.len()));
}
let leaf_secret_strs: [String; 5] = [
fr_to_decimal(&leaf_secret[0]),
fr_to_decimal(&leaf_secret[1]),
fr_to_decimal(&leaf_secret[2]),
fr_to_decimal(&leaf_secret[3]),
fr_to_decimal(&leaf_secret[4]),
];
let inputs = MembershipInputs::from_field_elements(
leaf_secret_strs,
&witness.siblings,
&witness.path_bits,
).map_err(|e| anyhow!("Failed to create inputs: {}", e))?;
eprintln!("[prover] Step 1: Calculating witness...");
let witness_calc = WitnessCalculator::new(&self.config.witness_lib_path, &self.config.dat_path);
if !witness_calc.is_available() {
return Err(anyhow!(
"Witness calculator not available. Check paths:\n lib: {}\n dat: {}",
self.config.witness_lib_path,
self.config.dat_path
));
}
let witness_bytes = witness_calc.calculate_to_buffer(&inputs)
.map_err(|e| anyhow!("Witness calculation failed: {}", e))?;
eprintln!("[prover] Witness generated: {} bytes", witness_bytes.len());
eprintln!("[prover] Step 2: Generating Groth16 proof...");
let (proof_json, public_json) = prove_with_library(
&self.config.rapidsnark_lib_path,
&self.config.zkey_path,
&witness_bytes,
).map_err(|e| anyhow!("Proof generation failed: {}", e))?;
eprintln!("[prover] Proof generated, parsing public signals...");
let public_signals: Vec<String> = serde_json::from_str(&public_json)
.map_err(|e| anyhow!("Failed to parse public signals: {}", e))?;
if public_signals.len() != 9 {
return Err(anyhow!(
"Expected 9 public signals (1 root + 4 npub_x + 4 npub_y), got {}",
public_signals.len()
));
}
let merkle_root = decimal_to_hex_32(&public_signals[0])?;
let npub_x = limbs_to_hex_256(&public_signals[1..5])?;
let npub_y = limbs_to_hex_256(&public_signals[5..9])?;
let elapsed = start.elapsed();
eprintln!("[prover] Total proof time: {:?}", elapsed);
Ok(ProofResult {
proof_bytes: proof_json.into_bytes(),
merkle_root,
npub_x,
npub_y,
proof_time_ms: elapsed.as_millis() as u64,
})
}
}
fn fr_to_decimal(fr: &Fr) -> String {
let bigint = fr.into_bigint();
let bytes = bigint.to_bytes_be();
let val = num_bigint::BigUint::from_bytes_be(&bytes);
val.to_string()
}
fn decimal_to_hex_32(decimal: &str) -> Result<String> {
let val = num_bigint::BigUint::parse_bytes(decimal.as_bytes(), 10)
.ok_or_else(|| anyhow!("Invalid decimal: {}", decimal))?;
let bytes = val.to_bytes_be();
let mut padded = vec![0u8; 32];
let start = 32 - bytes.len().min(32);
padded[start..].copy_from_slice(&bytes[bytes.len().saturating_sub(32)..]);
Ok(hex::encode(padded))
}
fn limbs_to_hex_256(limbs: &[String]) -> Result<String> {
if limbs.len() != 4 {
return Err(anyhow!("Expected 4 limbs, got {}", limbs.len()));
}
let mut result = [0u8; 32];
for (i, limb_str) in limbs.iter().enumerate() {
let limb_val: u64 = limb_str.parse()
.map_err(|e| anyhow!("Invalid limb '{}': {}", limb_str, e))?;
let limb_bytes = limb_val.to_le_bytes();
result[i * 8..(i + 1) * 8].copy_from_slice(&limb_bytes);
}
result.reverse();
Ok(hex::encode(result))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decimal_to_hex_32() {
let hex = decimal_to_hex_32("255").unwrap();
assert_eq!(hex.len(), 64); assert!(hex.ends_with("ff"));
}
#[test]
fn test_limbs_to_hex_256() {
let limbs = vec!["0".to_string(), "0".to_string(), "0".to_string(), "0".to_string()];
let hex = limbs_to_hex_256(&limbs).unwrap();
assert_eq!(hex, "0".repeat(64));
}
#[test]
fn test_prover_config_from_dir() {
let config = ProverConfig::from_circuits_dir("/tmp/circuits");
assert!(config.witness_lib_path.contains("libmembership_witness.so"));
assert!(config.rapidsnark_lib_path.contains("librapidsnark.so"));
assert!(config.zkey_path.contains("membership_final.zkey"));
}
}