zyga 0.5.1

ZYGA zero-knowledge proof system - CLI and library for generating ZK proofs
Documentation
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::debug_println;

pub type ExprId = usize;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpressionDAG {
    pub nodes: Vec<Expression>,
    dedup_map: HashMap<Expression, ExprId>,
}

#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum Expression {
    Private(String),  // Private witness, known only to prover
    Public(String),   // Public variable, known to all but still symbolic during compilation
    Deferred(String), // Public input/output that varies per proof
    Constant(OrderedFloat<f64>), // Literal constant from constraints (0, 1, 2, etc.)
    Add(ExprId, ExprId),
    Sub(ExprId, ExprId),
    Mul(ExprId, ExprId),
}

impl ExpressionDAG {
    pub fn new() -> Self {
        ExpressionDAG {
            nodes: Vec::new(),
            dedup_map: HashMap::new(),
        }
    }

    pub fn add(&mut self, expr: Expression) -> ExprId {
        // Check for existing identical expression
        if let Some(&id) = self.dedup_map.get(&expr) {
            return id; // Reuse existing
        }

        // Add new
        let id = self.nodes.len();
        self.nodes.push(expr.clone());
        self.dedup_map.insert(expr, id);
        id
    }

    pub fn get(&self, id: ExprId) -> &Expression {
        &self.nodes[id]
    }

    pub fn can_evaluate(&self, id: ExprId) -> bool {
        match &self.nodes[id] {
            Expression::Constant(_) => true,
            Expression::Private(_) | Expression::Public(_) | Expression::Deferred(_) => false,
            Expression::Add(l, r) | Expression::Sub(l, r) | Expression::Mul(l, r) => {
                self.can_evaluate(*l) && self.can_evaluate(*r)
            }
        }
    }

    pub fn evaluate(&self, id: ExprId) -> f64 {
        match &self.nodes[id] {
            Expression::Constant(v) => v.0,
            Expression::Private(_) => panic!("Cannot evaluate private expression without witness"),
            Expression::Public(_) => panic!("Cannot evaluate public expression without witness"),
            Expression::Deferred(_) => panic!("Cannot evaluate deferred expression without witness"),
            Expression::Add(l, r) => self.evaluate(*l) + self.evaluate(*r),
            Expression::Sub(l, r) => self.evaluate(*l) - self.evaluate(*r),
            Expression::Mul(l, r) => self.evaluate(*l) * self.evaluate(*r),
        }
    }

    pub fn evaluate_with_env(&self, id: ExprId, env: &HashMap<String, f64>) -> Result<f64, String> {
        match &self.nodes[id] {
            Expression::Constant(v) => Ok(v.0),
            Expression::Private(name) | Expression::Public(name) | Expression::Deferred(name) => {
                env
                    .get(name)
                    .copied()
                    .ok_or_else(|| format!("Unknown variable: {}", name))
            }
            Expression::Add(l, r) => {
                let l_val = self.evaluate_with_env(*l, env)?;
                let r_val = self.evaluate_with_env(*r, env)?;
                Ok(l_val + r_val)
            }
            Expression::Sub(l, r) => {
                let l_val = self.evaluate_with_env(*l, env)?;
                let r_val = self.evaluate_with_env(*r, env)?;
                Ok(l_val - r_val)
            }
            Expression::Mul(l, r) => {
                let l_val = self.evaluate_with_env(*l, env)?;
                let r_val = self.evaluate_with_env(*r, env)?;
                Ok(l_val * r_val)
            }
        }
    }

    pub fn contains_deferred(&self, id: ExprId) -> bool {
        match &self.nodes[id] {
            Expression::Private(_) | Expression::Public(_) | Expression::Deferred(_) => true,
            Expression::Constant(_) => false,
            Expression::Add(l, r) | Expression::Sub(l, r) | Expression::Mul(l, r) => {
                self.contains_deferred(*l) || self.contains_deferred(*r)
            }
        }
    }

    pub fn is_zero(&self, id: ExprId) -> bool {
        match &self.nodes[id] {
            Expression::Constant(v) => v.0 == 0.0,
            Expression::Private(s) | Expression::Public(s) | Expression::Deferred(s) => s == "0",
            _ => false,
        }
    }

    pub fn to_string(&self, id: ExprId) -> String {
        match &self.nodes[id] {
            Expression::Constant(v) => {
                if v.0.fract() == 0.0 {
                    format!("{:.0}", v.0)
                } else {
                    format!("{}", v.0)
                }
            }
            Expression::Private(s) | Expression::Public(s) | Expression::Deferred(s) => s.clone(),
            Expression::Add(l, r) => format!("({} + {})", self.to_string(*l), self.to_string(*r)),
            Expression::Sub(l, r) => format!("({} - {})", self.to_string(*l), self.to_string(*r)),
            Expression::Mul(l, r) => format!("({} * {})", self.to_string(*l), self.to_string(*r)),
        }
    }

    pub fn collect_public_inputs(&self, id: ExprId, public_inputs: &mut HashMap<String, f64>) {
        match &self.nodes[id] {
            Expression::Deferred(name) => {
                // Track deferred variables as they're public at verification
                public_inputs.entry(name.clone()).or_insert(0.0);
            }
            Expression::Add(l, r) | Expression::Sub(l, r) | Expression::Mul(l, r) => {
                self.collect_public_inputs(*l, public_inputs);
                self.collect_public_inputs(*r, public_inputs);
            }
            _ => {}
        }
    }

    /// Extend witness by computing intermediate variables from input witness
    pub fn extend_witness(&self, witness_ids: &[ExprId], witness_names: &[String], input_witness: &HashMap<String, f64>) -> Result<HashMap<String, f64>, String> {
        let mut extended_witness = input_witness.clone();

        // Add constant "1" to the witness (used for padding)
        extended_witness.insert("1".to_string(), 1.0);

        let mut changed = true;
        let max_iterations = 100; // Prevent infinite loops
        let mut iteration = 0;

        while changed && iteration < max_iterations {
            changed = false;
            iteration += 1;

            // Try to compute each witness variable
            // witness_ids and witness_names should be the same length
            for (i, &expr_id) in witness_ids.iter().enumerate() {
                // Check bounds
                if i >= witness_names.len() {
                    break;
                }

                let var_name = &witness_names[i];

                // Skip if we already have this variable
                if extended_witness.contains_key(var_name) {
                    continue;
                }

                // Try to evaluate this expression with current witness
                match self.evaluate_with_env(expr_id, &extended_witness) {
                    Ok(value) => {
                        extended_witness.insert(var_name.clone(), value);
                        changed = true;

                        // Debug specific variables
                        if var_name == "no_borrow" {
                            debug_println!("Computing no_borrow: expr_id={}, expr={:?}", expr_id, &self.nodes[expr_id]);
                            // Check if this is a Sub expression
                            if let Expression::Sub(left, right) = &self.nodes[expr_id] {
                                debug_println!("  Left expr ({}): {:?}", left, &self.nodes[*left]);
                                debug_println!("  Right expr ({}): {:?}", right, &self.nodes[*right]);
                                let left_val = self.evaluate_with_env(*left, &extended_witness).unwrap_or(-999.0);
                                let right_val = self.evaluate_with_env(*right, &extended_witness).unwrap_or(-999.0);
                                debug_println!("  Left value: {}, Right value: {}", left_val, right_val);
                                debug_println!("  Expected: {} - {} = {}", left_val, right_val, left_val - right_val);
                            }
                        }

                        debug_println!("Computed intermediate variable {} = {}", var_name, value);
                    }
                    Err(_) => {
                        // Can't compute yet - might depend on other intermediates
                    }
                }
            }
        }

        // Check that all witness variables were computed
        for var_name in witness_names {
            if !extended_witness.contains_key(var_name) {
                // Skip the special "1" constant name
                if var_name == "1" {
                    extended_witness.insert("1".to_string(), 1.0);
                    continue;
                }
                return Err(format!("Could not compute witness variable: {}", var_name));
            }
        }

        Ok(extended_witness)
    }
}

impl Default for ExpressionDAG {
    fn default() -> Self {
        Self::new()
    }
}