use crate::dag::{ExprId, Expression, ExpressionDAG};
use crate::errors::ZkError;
use ordered_float::OrderedFloat;
use std::collections::{HashMap, HashSet};
type ConstraintRow = (Vec<f64>, Vec<f64>, Vec<f64>);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum VariableVisibility {
Private, Public, Deferred, }
#[derive(Debug)]
pub struct Constraint {
pub op: String,
}
impl Constraint {
pub fn new(op: String) -> Self {
Constraint { op }
}
}
#[derive(Debug)]
pub struct CompilationResult {
pub a_matrix: Vec<Vec<f64>>,
pub b_matrix: Vec<Vec<f64>>,
pub c_matrix: Vec<Vec<f64>>,
pub dag: ExpressionDAG,
pub witness_ids: Vec<ExprId>, pub env_dict: HashMap<String, serde_json::Value>,
pub witnesses: Vec<String>,
pub public_variables: HashSet<String>,
}
pub fn compile_constraints(
constr_code: &str,
verbose: bool,
) -> Result<CompilationResult, ZkError> {
let mut dag = ExpressionDAG::new();
let mut var_map: HashMap<String, ExprId> = HashMap::new();
let mut dws: HashMap<String, usize> = HashMap::new();
let mut env_dict = HashMap::new();
let mut public_variables = HashSet::new();
let mut witness_definitions: HashMap<String, ExprId> = HashMap::new();
let mut a_matrix = Vec::new();
let mut b_matrix = Vec::new();
let mut c_matrix = Vec::new();
let mut witnesses = vec!["1".to_string()];
let one_id = dag.add(Expression::Deferred("1".to_string()));
var_map.insert("1".to_string(), one_id);
dws.insert("1".to_string(), 0);
public_variables.insert("1".to_string());
let n_vars = quick_lex(constr_code)?;
if verbose {
println!("Total variables detected: {}", n_vars);
}
let mut pos = 1;
let start = std::time::Instant::now();
let mut last_time = start;
let mut line_count = 0;
for line in constr_code.lines() {
let line = line.split('#').next().unwrap_or("").trim();
if line.is_empty() {
continue;
}
line_count += 1;
if line_count % 50 == 0 && verbose {
let now = std::time::Instant::now();
let total_elapsed = now.duration_since(start).as_secs_f64();
let last_elapsed = now.duration_since(last_time).as_millis();
eprintln!("Line {}/{}. Time: {:.2}s total, {}ms for last 50. Matrices: A={}, witnesses={}, vars={}",
line_count, constr_code.lines().count(), total_elapsed, last_elapsed,
a_matrix.len(), witnesses.len(), var_map.len());
last_time = now;
}
if line.contains("decl") {
let (array_name, witness_names, visibility) =
add_decl_to_env(&mut env_dict, line)?;
if matches!(
visibility,
VariableVisibility::Public | VariableVisibility::Deferred
) {
for name in &witness_names {
public_variables.insert(name.clone());
}
}
if let Some(arr_name) = array_name {
for (i, name) in witness_names.iter().enumerate() {
witnesses.push(name.clone());
dws.insert(name.clone(), pos);
pos += 1;
match visibility {
VariableVisibility::Private => {
let expr_id = dag.add(Expression::Private(name.clone()));
var_map.insert(name.clone(), expr_id);
}
VariableVisibility::Public => {
let expr_id = dag.add(Expression::Public(name.clone()));
var_map.insert(name.clone(), expr_id);
}
VariableVisibility::Deferred => {
let expr_id = dag.add(Expression::Deferred(name.clone()));
var_map.insert(name.clone(), expr_id);
}
}
}
} else {
let witness_name = witness_names[0].clone();
witnesses.push(witness_name.clone());
dws.insert(witness_name.clone(), pos);
pos += 1;
match visibility {
VariableVisibility::Private => {
let expr_id = dag.add(Expression::Private(witness_name.clone()));
var_map.insert(witness_name.clone(), expr_id);
}
VariableVisibility::Public => {
let expr_id = dag.add(Expression::Public(witness_name.clone()));
var_map.insert(witness_name.clone(), expr_id);
}
VariableVisibility::Deferred => {
let expr_id = dag.add(Expression::Deferred(witness_name.clone()));
var_map.insert(witness_name.clone(), expr_id);
}
}
}
} else if line.contains("==") {
let parts: Vec<&str> = line.split("==").collect();
if parts.len() != 2 {
eprintln!("Warning: malformed equality constraint: {}", line);
continue;
}
let left_expr = parts[0].trim();
let right_expr = parts[1].trim();
let c = left_expr.to_string();
let equality_witness = format!("__eq_{}__", c);
if !dws.contains_key(&equality_witness) {
witnesses.push(equality_witness.clone());
dws.insert(equality_witness.clone(), pos);
pos += 1;
}
let left_id = build_symbolic_expression(left_expr, &mut dag, &var_map);
let right_id = build_symbolic_expression(right_expr, &mut dag, &var_map);
let diff_id = dag.add(Expression::Sub(left_id, right_id));
if verbose {
eprintln!("DEBUG: Equality constraint {} == {}", left_expr, right_expr);
eprintln!(" left_id: {:?}", dag.get(left_id));
eprintln!(" right_id: {:?}", dag.get(right_id));
eprintln!(" diff: {:?}", dag.get(diff_id));
}
let one_const_id = dag.add(Expression::Constant(OrderedFloat(1.0)));
let zero_const_id = dag.add(Expression::Constant(OrderedFloat(0.0)));
let (a_row, b_row, c_row) = process_constraint_row(
&dag,
diff_id,
one_const_id,
zero_const_id,
n_vars,
&dws,
&public_variables,
)?;
a_matrix.push(a_row);
b_matrix.push(b_row);
c_matrix.push(c_row);
let zero_id = dag.add(Expression::Constant(OrderedFloat(0.0)));
var_map.insert(c, zero_id);
} else {
let (c, expr) = parse_constraint(line)?;
if expr.contains('*') {
if !dws.contains_key(&c) {
witnesses.push(c.clone());
dws.insert(c.clone(), pos);
pos += 1;
}
let parts: Vec<&str> = expr.split('*').collect();
if parts.len() == 2 {
let operand1 = parts[0].trim();
let operand2 = parts[1].trim();
let expr1_id = if dws.contains_key(operand1) {
dag.add(Expression::Deferred(operand1.to_string()))
} else {
build_symbolic_expression(operand1, &mut dag, &var_map)
};
let expr2_id = if dws.contains_key(operand2) {
dag.add(Expression::Deferred(operand2.to_string()))
} else {
build_symbolic_expression(operand2, &mut dag, &var_map)
};
if verbose {
eprintln!(
"DEBUG: Multiplication constraint {} = {} * {}",
c,
parts[0].trim(),
parts[1].trim()
);
eprintln!(" expr1: {:?}", dag.get(expr1_id));
eprintln!(" expr2: {:?}", dag.get(expr2_id));
}
let c_witness = dag.add(Expression::Private(c.clone()));
var_map.insert(c.clone(), c_witness);
let result_id = dag.add(Expression::Mul(expr1_id, expr2_id));
witness_definitions.insert(c.clone(), result_id);
let c_witness_ref = dag.add(Expression::Deferred(c.clone()));
let (a_row, b_row, c_row) = process_constraint_row(
&dag,
expr1_id,
expr2_id,
c_witness_ref,
n_vars,
&dws,
&public_variables,
)?;
a_matrix.push(a_row);
b_matrix.push(b_row);
c_matrix.push(c_row);
} else {
let c_witness = dag.add(Expression::Private(c.clone()));
var_map.insert(c.clone(), c_witness);
let expr_id = build_symbolic_expression(&expr, &mut dag, &var_map);
witness_definitions.insert(c, expr_id);
}
} else {
if !dws.contains_key(&c) {
witnesses.push(c.clone());
dws.insert(c.clone(), pos);
pos += 1;
}
let witness_id = dag.add(Expression::Private(c.clone()));
let expr_id = build_symbolic_expression(&expr, &mut dag, &var_map);
var_map.insert(c.clone(), witness_id);
witness_definitions.insert(c.clone(), expr_id);
}
}
}
let mut witness_ids = vec![one_id; n_vars];
for (name, &def_id) in &witness_definitions {
if let Some(&pos) = dws.get(name) {
witness_ids[pos] = def_id;
}
}
for (name, &id) in &var_map {
if let Some(&pos) = dws.get(name) {
if witness_ids[pos] == one_id && !witness_definitions.contains_key(name) {
witness_ids[pos] = id;
}
}
}
let mut has_public_b = false;
for (i, &witness_id) in witness_ids.iter().enumerate() {
if i >= b_matrix[0].len() {
break;
}
let witness_expr = dag.get(witness_id);
if matches!(witness_expr, Expression::Deferred(_) | Expression::Public(_)) {
for b_row in &b_matrix {
if i < b_row.len() && b_row[i] != 0.0 {
has_public_b = true;
break;
}
}
}
if has_public_b {
break;
}
}
if !has_public_b && !public_variables.is_empty() {
if verbose {
eprintln!("WARNING: No public/deferred variables in B position detected.");
eprintln!("Adding dummy constraint to ensure non-trivial pairing.");
}
let first_public = public_variables.iter().next().cloned();
if !public_variables.is_empty() {
let mut dummy_a_row = vec![0.0; n_vars];
let mut dummy_b_row = vec![0.0; n_vars];
let mut dummy_c_row = vec![0.0; n_vars];
let first_public = public_variables.iter()
.find(|&name| name != "1")
.or_else(|| public_variables.iter().next())
.cloned();
if let Some(first_public_name) = first_public {
if let Some(&pos) = dws.get(&first_public_name) {
dummy_a_row[pos] = 1.0;
dummy_c_row[pos] = 1.0; }
}
dummy_b_row[0] = 1.0;
a_matrix.push(dummy_a_row);
b_matrix.push(dummy_b_row);
c_matrix.push(dummy_c_row);
if verbose {
eprintln!("Added dummy constraint with public '1' in B position");
eprintln!("This ensures b2 will always be at least 1 (non-zero)");
}
}
}
Ok(CompilationResult {
a_matrix,
b_matrix,
c_matrix,
dag,
witness_ids,
env_dict,
witnesses,
public_variables,
})
}
fn quick_lex(constr_code: &str) -> Result<usize, ZkError> {
let mut n_vars = 1;
for line in constr_code.lines() {
let line = line.split('#').next().unwrap_or("").trim();
if line.is_empty() {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.is_empty() {
continue;
}
if parts[0] == "decl" && parts.len() == 4 {
let value_type = parts[2];
let value_name = parts[3];
if value_type == "array" {
if let Some(start) = value_name.find('[') {
if let Some(end) = value_name.find(']') {
if let Ok(size) = value_name[start + 1..end].parse::<usize>() {
n_vars += size;
}
}
}
} else {
n_vars += 1;
}
} else if line.contains("==") {
n_vars += 1;
} else if line.contains('=') {
n_vars += 1;
}
}
Ok(n_vars)
}
fn parse_term(term: &str, dag: &mut ExpressionDAG, var_map: &HashMap<String, ExprId>) -> ExprId {
let term = term.trim(); var_map.get(term).copied().unwrap_or_else(|| {
if let Ok(val) = term.parse::<f64>() {
dag.add(Expression::Constant(OrderedFloat(val)))
} else {
dag.add(Expression::Deferred(term.to_string()))
}
})
}
fn build_symbolic_expression(
expr: &str,
dag: &mut ExpressionDAG,
var_map: &HashMap<String, ExprId>,
) -> ExprId {
if expr.contains('*') {
let parts: Vec<&str> = expr.split('*').collect();
if parts.len() == 2 {
let left_id = parse_term(parts[0], dag, var_map);
let right_id = parse_term(parts[1], dag, var_map);
if dag.can_evaluate(left_id) && dag.can_evaluate(right_id) {
let result = dag.evaluate(left_id) * dag.evaluate(right_id);
return dag.add(Expression::Constant(OrderedFloat(result)));
}
return dag.add(Expression::Mul(left_id, right_id));
}
}
if expr.contains('+') {
let parts: Vec<&str> = expr.split('+').collect();
let mut result_id = parse_term(parts[0], dag, var_map);
for part in parts.iter().skip(1) {
let val_id = parse_term(part, dag, var_map);
result_id = dag.add(Expression::Add(result_id, val_id));
}
return result_id;
}
if expr.contains('-') {
let parts: Vec<&str> = expr.splitn(2, '-').collect();
let left_id = if parts[0].is_empty() {
dag.add(Expression::Constant(OrderedFloat(0.0)))
} else {
parse_term(parts[0], dag, var_map)
};
if parts.len() > 1 {
let right_id = parse_term(parts[1], dag, var_map);
let result_id = dag.add(Expression::Sub(left_id, right_id));
return result_id;
}
return left_id;
}
parse_term(expr, dag, var_map)
}
fn add_decl_to_env(
_env: &mut HashMap<String, serde_json::Value>,
decl: &str,
) -> Result<(Option<String>, Vec<String>, VariableVisibility), ZkError> {
println!("decl {}", decl);
let parts: Vec<&str> = decl.split_whitespace().collect();
if parts.len() != 4 {
return Err(ZkError::InvalidParameters);
}
let visibility_str = parts[1];
let visibility = match visibility_str {
"private" => VariableVisibility::Private,
"public" => VariableVisibility::Public,
"deferred" => VariableVisibility::Deferred,
_ => return Err(ZkError::InvalidParameters),
};
let value_type = parts[2];
let value_name = parts[3];
if value_type == "array" {
if let Some(bracket_pos) = value_name.find('[') {
let array_name = value_name[..bracket_pos].to_string();
let size_str = &value_name[bracket_pos + 1..value_name.len() - 1];
let size = size_str
.parse::<usize>()
.map_err(|_| ZkError::InvalidParameters)?;
let witness_names: Vec<String> = (0..size)
.map(|i| format!("{}[{}]", array_name, i))
.collect();
Ok((Some(array_name), witness_names, visibility))
} else {
Err(ZkError::InvalidParameters)
}
} else {
Ok((None, vec![value_name.to_string()], visibility))
}
}
fn parse_constraint(line: &str) -> Result<(String, String), ZkError> {
if line.contains("==") {
if let Some(eq_pos) = line.find("==") {
let left_part = line[..eq_pos].trim();
let right_part = line[eq_pos..].trim(); return Ok((left_part.to_string(), right_part.to_string()));
}
}
if let Some(eq_pos) = line.find('=') {
let left_part = line[..eq_pos].trim();
let right_part = line[eq_pos + 1..].trim();
Ok((left_part.to_string(), right_part.to_string()))
} else {
Err(ZkError::InvalidParameters)
}
}
fn process_constraint_row(
dag: &ExpressionDAG,
expr1_id: ExprId,
expr2_id: ExprId,
expr3_id: ExprId,
n_vars: usize,
dws: &HashMap<String, usize>,
public_variables: &HashSet<String>,
) -> Result<ConstraintRow, ZkError> {
let mut a_row = vec![0.0; n_vars];
let mut b_row = vec![0.0; n_vars];
let mut c_row = vec![0.0; n_vars];
fill_row_from_dag(dag, expr1_id, &mut a_row, dws, public_variables)?;
fill_row_from_dag(dag, expr2_id, &mut b_row, dws, public_variables)?;
fill_row_from_dag(dag, expr3_id, &mut c_row, dws, public_variables)?;
Ok((a_row, b_row, c_row))
}
fn fill_row_from_dag(
dag: &ExpressionDAG,
expr_id: ExprId,
row: &mut [f64],
dws: &HashMap<String, usize>,
_public_variables: &HashSet<String>,
) -> Result<(), ZkError> {
let expr = dag.get(expr_id);
match expr {
Expression::Constant(val) => {
row[0] = val.0;
}
Expression::Private(name) | Expression::Public(name) | Expression::Deferred(name) => {
if let Some(&pos) = dws.get(name) {
row[pos] = 1.0;
}
}
Expression::Add(left_id, right_id) => {
fill_row_from_dag(dag, *left_id, row, dws, _public_variables)?;
let mut right_row = vec![0.0; row.len()];
fill_row_from_dag(dag, *right_id, &mut right_row, dws, _public_variables)?;
for i in 0..row.len() {
row[i] += right_row[i];
}
}
Expression::Sub(left_id, right_id) => {
fill_row_from_dag(dag, *left_id, row, dws, _public_variables)?;
let mut right_row = vec![0.0; row.len()];
fill_row_from_dag(dag, *right_id, &mut right_row, dws, _public_variables)?;
for i in 0..row.len() {
row[i] -= right_row[i];
}
}
Expression::Mul(_, _) => {
return Err(ZkError::InvalidParameters);
}
}
Ok(())
}