use crate::{ZkError, VerificationKey};
use std::collections::HashMap;
use std::fmt::Write;
use std::fs;
pub fn generate_public_coefficients_file(
public_coeffs_a: &HashMap<String, f64>,
public_coeffs_b: &HashMap<String, f64>,
public_coeffs_c: &HashMap<String, f64>,
constraint_file_path: Option<&std::path::Path>,
prefix: Option<&str>,
) -> Result<(), ZkError> {
let mut public_input_names: Vec<String> = public_coeffs_a
.keys()
.chain(public_coeffs_b.keys())
.chain(public_coeffs_c.keys())
.cloned()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
public_input_names.sort();
let num_inputs = public_input_names.len();
let mut content = String::new();
let upper_prefix = prefix.map(|p| p.to_uppercase());
let lower_prefix = prefix.map(|p| p.to_lowercase());
let prefix_with_underscore = upper_prefix
.as_ref()
.map(|p| format!("{}_", p))
.unwrap_or_default();
let fn_prefix = lower_prefix
.as_ref()
.map(|p| format!("{}_", p))
.unwrap_or_default();
writeln!(
content,
"// Auto-generated public coefficients for on-chain verification"
)
.unwrap();
if let Some(p) = prefix {
writeln!(content, "// Prefix: {}", p).unwrap();
}
writeln!(content, "// Generated from constraint compilation").unwrap();
writeln!(content, "// DO NOT EDIT MANUALLY").unwrap();
writeln!(
content,
"// This file is no-std compatible for Solana BPF\n"
)
.unwrap();
writeln!(content, "/// Number of public inputs").unwrap();
writeln!(
content,
"pub const {}NUM_PUBLIC_INPUTS: usize = {};",
prefix_with_underscore, num_inputs
)
.unwrap();
writeln!(content).unwrap();
writeln!(content, "/// Public input names in order").unwrap();
writeln!(
content,
"pub const {}PUBLIC_INPUT_NAMES: [&str; {}] = [",
prefix_with_underscore, num_inputs
)
.unwrap();
for name in &public_input_names {
writeln!(content, " \"{}\",", name).unwrap();
}
writeln!(content, "];\n").unwrap();
writeln!(
content,
"/// Coefficient structure for a single public input"
)
.unwrap();
writeln!(content, "#[derive(Clone, Copy, Debug)]").unwrap();
writeln!(content, "pub struct InputCoefficients {{").unwrap();
writeln!(content, " pub coeff_a: i64,").unwrap();
writeln!(content, " pub coeff_b: i64,").unwrap();
writeln!(content, " pub coeff_c: i64,").unwrap();
writeln!(content, "}}\n").unwrap();
writeln!(
content,
"/// All public coefficients indexed by input position"
)
.unwrap();
writeln!(
content,
"pub const {}PUBLIC_COEFFICIENTS: [InputCoefficients; {}] = [",
prefix_with_underscore, num_inputs
)
.unwrap();
for name in &public_input_names {
let coeff_a = public_coeffs_a.get(name).copied().unwrap_or(0.0) as i64;
let coeff_b = public_coeffs_b.get(name).copied().unwrap_or(0.0) as i64;
let coeff_c = public_coeffs_c.get(name).copied().unwrap_or(0.0) as i64;
writeln!(content, " InputCoefficients {{ // {}", name).unwrap();
writeln!(content, " coeff_a: {},", coeff_a).unwrap();
writeln!(content, " coeff_b: {},", coeff_b).unwrap();
writeln!(content, " coeff_c: {},", coeff_c).unwrap();
writeln!(content, " }},").unwrap();
}
writeln!(content, "];\n").unwrap();
writeln!(content, "/// Compute a2, b2, c2 from public inputs").unwrap();
writeln!(
content,
"/// inputs should be provided in the same order as {}PUBLIC_INPUT_NAMES",
prefix_with_underscore
)
.unwrap();
writeln!(content, "pub fn compute_{}public_coefficients(public_inputs: &[i64]) -> Result<(i64, i64, i64), &'static str> {{", fn_prefix).unwrap();
writeln!(
content,
" if public_inputs.len() != {}NUM_PUBLIC_INPUTS {{",
prefix_with_underscore
)
.unwrap();
writeln!(
content,
" return Err(\"Invalid number of public inputs\");"
)
.unwrap();
writeln!(content, " }}").unwrap();
writeln!(content, " ").unwrap();
writeln!(content, " let mut a2 = 0i64;").unwrap();
writeln!(content, " let mut b2 = 0i64;").unwrap();
writeln!(content, " let mut c2 = 0i64;").unwrap();
writeln!(content, " ").unwrap();
writeln!(
content,
" for (i, &input_value) in public_inputs.iter().enumerate() {{"
)
.unwrap();
writeln!(
content,
" let coeffs = &{}PUBLIC_COEFFICIENTS[i];",
prefix_with_underscore
)
.unwrap();
writeln!(content, " a2 += coeffs.coeff_a * input_value;").unwrap();
writeln!(content, " b2 += coeffs.coeff_b * input_value;").unwrap();
writeln!(content, " c2 += coeffs.coeff_c * input_value;").unwrap();
writeln!(content, " }}").unwrap();
writeln!(content, " ").unwrap();
writeln!(content, " Ok((a2, b2, c2))").unwrap();
writeln!(content, "}}\n").unwrap();
writeln!(content, "/// Get the index of a public input by name").unwrap();
writeln!(
content,
"pub fn get_{}input_index(name: &str) -> Option<usize> {{",
fn_prefix
)
.unwrap();
writeln!(
content,
" {}PUBLIC_INPUT_NAMES.iter().position(|&n| n == name)",
prefix_with_underscore
)
.unwrap();
writeln!(content, "}}\n").unwrap();
let display_filename = if let Some(prefix) = lower_prefix.as_ref() {
format!("{}_public_coefficients.rs", prefix)
} else {
"public_coefficients.rs".to_string()
};
println!("\n=== Writing {} ===", display_filename);
println!("Found {} public inputs:", num_inputs);
for (i, name) in public_input_names.iter().enumerate() {
let coeff_a = public_coeffs_a.get(name).copied().unwrap_or(0.0) as i64;
let coeff_b = public_coeffs_b.get(name).copied().unwrap_or(0.0) as i64;
let coeff_c = public_coeffs_c.get(name).copied().unwrap_or(0.0) as i64;
println!(
" [{}] {}: A={}, B={}, C={}",
i, name, coeff_a, coeff_b, coeff_c
);
}
let filename = if let Some(prefix) = lower_prefix.as_ref() {
format!("{}_public_coefficients.rs", prefix)
} else {
"public_coefficients.rs".to_string()
};
let output_path = constraint_file_path
.expect("Constraint file path should always be provided")
.parent()
.unwrap_or(std::path::Path::new("."))
.join(filename);
fs::write(&output_path, content).map_err(|e| {
ZkError::ComputationError(format!("Failed to write public_coefficients.rs: {}", e))
})?;
println!("Generated {}", output_path.display());
Ok(())
}
pub fn generate_proving_key_file(
verification_key: &VerificationKey,
output_dir: &std::path::Path,
prefix: Option<&str>,
) -> Result<(), ZkError> {
let rust_code = verification_key.to_rust_code(prefix);
let filename = if let Some(prefix) = prefix {
format!("{}_proving_key.rs", prefix.to_lowercase())
} else {
"proving_key.rs".to_string()
};
let output_path = output_dir.join(&filename);
fs::write(&output_path, rust_code).map_err(|e| {
ZkError::IoError(format!("Failed to write {}: {}", filename, e))
})?;
println!("Generated {}", output_path.display());
Ok(())
}