use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::debug_println;
pub type ExprId = usize;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpressionDAG {
pub nodes: Vec<Expression>,
dedup_map: HashMap<Expression, ExprId>,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum Expression {
Private(String), Public(String), Deferred(String), Constant(OrderedFloat<f64>), Add(ExprId, ExprId),
Sub(ExprId, ExprId),
Mul(ExprId, ExprId),
}
impl ExpressionDAG {
pub fn new() -> Self {
ExpressionDAG {
nodes: Vec::new(),
dedup_map: HashMap::new(),
}
}
pub fn add(&mut self, expr: Expression) -> ExprId {
if let Some(&id) = self.dedup_map.get(&expr) {
return id; }
let id = self.nodes.len();
self.nodes.push(expr.clone());
self.dedup_map.insert(expr, id);
id
}
pub fn get(&self, id: ExprId) -> &Expression {
&self.nodes[id]
}
pub fn can_evaluate(&self, id: ExprId) -> bool {
match &self.nodes[id] {
Expression::Constant(_) => true,
Expression::Private(_) | Expression::Public(_) | Expression::Deferred(_) => false,
Expression::Add(l, r) | Expression::Sub(l, r) | Expression::Mul(l, r) => {
self.can_evaluate(*l) && self.can_evaluate(*r)
}
}
}
pub fn evaluate(&self, id: ExprId) -> f64 {
match &self.nodes[id] {
Expression::Constant(v) => v.0,
Expression::Private(_) => panic!("Cannot evaluate private expression without witness"),
Expression::Public(_) => panic!("Cannot evaluate public expression without witness"),
Expression::Deferred(_) => panic!("Cannot evaluate deferred expression without witness"),
Expression::Add(l, r) => self.evaluate(*l) + self.evaluate(*r),
Expression::Sub(l, r) => self.evaluate(*l) - self.evaluate(*r),
Expression::Mul(l, r) => self.evaluate(*l) * self.evaluate(*r),
}
}
pub fn evaluate_with_env(&self, id: ExprId, env: &HashMap<String, f64>) -> Result<f64, String> {
match &self.nodes[id] {
Expression::Constant(v) => Ok(v.0),
Expression::Private(name) | Expression::Public(name) | Expression::Deferred(name) => {
env
.get(name)
.copied()
.ok_or_else(|| format!("Unknown variable: {}", name))
}
Expression::Add(l, r) => {
let l_val = self.evaluate_with_env(*l, env)?;
let r_val = self.evaluate_with_env(*r, env)?;
Ok(l_val + r_val)
}
Expression::Sub(l, r) => {
let l_val = self.evaluate_with_env(*l, env)?;
let r_val = self.evaluate_with_env(*r, env)?;
Ok(l_val - r_val)
}
Expression::Mul(l, r) => {
let l_val = self.evaluate_with_env(*l, env)?;
let r_val = self.evaluate_with_env(*r, env)?;
Ok(l_val * r_val)
}
}
}
pub fn contains_deferred(&self, id: ExprId) -> bool {
match &self.nodes[id] {
Expression::Private(_) | Expression::Public(_) | Expression::Deferred(_) => true,
Expression::Constant(_) => false,
Expression::Add(l, r) | Expression::Sub(l, r) | Expression::Mul(l, r) => {
self.contains_deferred(*l) || self.contains_deferred(*r)
}
}
}
pub fn is_zero(&self, id: ExprId) -> bool {
match &self.nodes[id] {
Expression::Constant(v) => v.0 == 0.0,
Expression::Private(s) | Expression::Public(s) | Expression::Deferred(s) => s == "0",
_ => false,
}
}
pub fn to_string(&self, id: ExprId) -> String {
match &self.nodes[id] {
Expression::Constant(v) => {
if v.0.fract() == 0.0 {
format!("{:.0}", v.0)
} else {
format!("{}", v.0)
}
}
Expression::Private(s) | Expression::Public(s) | Expression::Deferred(s) => s.clone(),
Expression::Add(l, r) => format!("({} + {})", self.to_string(*l), self.to_string(*r)),
Expression::Sub(l, r) => format!("({} - {})", self.to_string(*l), self.to_string(*r)),
Expression::Mul(l, r) => format!("({} * {})", self.to_string(*l), self.to_string(*r)),
}
}
pub fn collect_public_inputs(&self, id: ExprId, public_inputs: &mut HashMap<String, f64>) {
match &self.nodes[id] {
Expression::Deferred(name) => {
public_inputs.entry(name.clone()).or_insert(0.0);
}
Expression::Add(l, r) | Expression::Sub(l, r) | Expression::Mul(l, r) => {
self.collect_public_inputs(*l, public_inputs);
self.collect_public_inputs(*r, public_inputs);
}
_ => {}
}
}
pub fn extend_witness(&self, witness_ids: &[ExprId], witness_names: &[String], input_witness: &HashMap<String, f64>) -> Result<HashMap<String, f64>, String> {
let mut extended_witness = input_witness.clone();
extended_witness.insert("1".to_string(), 1.0);
let mut changed = true;
let max_iterations = 100; let mut iteration = 0;
while changed && iteration < max_iterations {
changed = false;
iteration += 1;
for (i, &expr_id) in witness_ids.iter().enumerate() {
if i >= witness_names.len() {
break;
}
let var_name = &witness_names[i];
if extended_witness.contains_key(var_name) {
continue;
}
match self.evaluate_with_env(expr_id, &extended_witness) {
Ok(value) => {
extended_witness.insert(var_name.clone(), value);
changed = true;
if var_name == "no_borrow" {
debug_println!("Computing no_borrow: expr_id={}, expr={:?}", expr_id, &self.nodes[expr_id]);
if let Expression::Sub(left, right) = &self.nodes[expr_id] {
debug_println!(" Left expr ({}): {:?}", left, &self.nodes[*left]);
debug_println!(" Right expr ({}): {:?}", right, &self.nodes[*right]);
let left_val = self.evaluate_with_env(*left, &extended_witness).unwrap_or(-999.0);
let right_val = self.evaluate_with_env(*right, &extended_witness).unwrap_or(-999.0);
debug_println!(" Left value: {}, Right value: {}", left_val, right_val);
debug_println!(" Expected: {} - {} = {}", left_val, right_val, left_val - right_val);
}
}
debug_println!("Computed intermediate variable {} = {}", var_name, value);
}
Err(_) => {
}
}
}
}
for var_name in witness_names {
if !extended_witness.contains_key(var_name) {
if var_name == "1" {
extended_witness.insert("1".to_string(), 1.0);
continue;
}
return Err(format!("Could not compute witness variable: {}", var_name));
}
}
Ok(extended_witness)
}
}
impl Default for ExpressionDAG {
fn default() -> Self {
Self::new()
}
}