use crate::{
core::{
bounds::{Bounds, IsBounds},
compile::Compiler,
expressions::{
bit_expr::BitExpr,
conversion_expr::ConversionExpr,
curve_expr::CurveExpr,
domain::DomainElement,
expr::{EvalFailure, EvalValue, Expr, UndefinedBehavior},
field_expr::FieldExpr,
macro_uses::{BoundUnFold, EvalValueUnwrap},
},
instruction::ArcisInstruction,
tracking::Tracking,
},
utils::{
field::{BaseField, ScalarField},
number::Number,
used_field::UsedField,
},
};
use ff::Field;
use indexmap::IndexSet;
use rand::Rng;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use std::{env, fmt, fmt::Formatter, io::Write};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateRepresentation {
exprs: Vec<Expr<usize>>, outputs: Vec<usize>, bounds: Vec<Bounds>, is_plaintext: Vec<bool>, tracking: Tracking,
}
impl IntermediateRepresentation {
pub fn new(
exprs: Vec<Expr<usize>>,
outputs: Vec<usize>,
bounds: Vec<Bounds>,
is_plaintext: Vec<bool>,
) -> IntermediateRepresentation {
let tracking = Default::default();
IntermediateRepresentation {
exprs,
outputs,
bounds,
is_plaintext,
tracking,
}
}
pub fn new_with_tracking(
exprs: Vec<Expr<usize>>,
outputs: Vec<usize>,
bounds: Vec<Bounds>,
is_plaintext: Vec<bool>,
tracking: Tracking,
) -> IntermediateRepresentation {
IntermediateRepresentation {
exprs,
outputs,
bounds,
is_plaintext,
tracking,
}
}
#[allow(clippy::type_complexity)]
pub fn destructure(
self,
) -> (
Vec<Expr<usize>>, // the usize are indices in exprs
Vec<usize>, // the items are indices in exprs
Vec<Bounds>, // same size as exprs
Vec<bool>, // same size as exprs
Tracking,
) {
(
self.exprs,
self.outputs,
self.bounds,
self.is_plaintext,
self.tracking,
)
}
pub fn get_exprs(&self) -> &[Expr<usize>] {
&self.exprs
}
pub fn get_outputs(&self) -> &[usize] {
&self.outputs
}
pub fn get_bounds(&self) -> &[Bounds] {
&self.bounds
}
pub fn get_is_plaintext(&self) -> &[bool] {
&self.is_plaintext
}
pub fn get_output_domains(&self) -> Vec<DomainElement<(), (), (), ()>> {
self.outputs
.iter()
.map(|x| self.bounds[*x].to_domain())
.collect()
}
pub fn get_expr(&self, id: usize) -> &Expr<usize> {
&self.exprs[id]
}
pub fn get_tracking(&self) -> &Tracking {
&self.tracking
}
pub fn check_for_integrity(&self) -> Result<(), ProgramError> {
let mut input_ids: Vec<usize> = self.exprs.iter().filter_map(Expr::get_input).collect();
input_ids.sort();
let has_duplicates = input_ids.windows(2).any(|x| x[0] == x[1]);
if has_duplicates {
return Err(ProgramError::DuplicateInputId);
}
Ok(())
}
pub fn optimize_into_circuitable(self) -> IntermediateRepresentation {
Compiler::optimize_into_circuitable(self)
}
pub fn to_async_mpc_circuit(&self) -> ArcisInstruction {
let (unimproved_circuit, metadata, _) = Compiler::ir_to_async_mpc_circuit(self);
let circuit = Compiler::improve_async_mpc_circuit(unimproved_circuit);
ArcisInstruction { circuit, metadata }
}
pub fn generated_bools_filter(
&self,
opt_ir: &IntermediateRepresentation,
) -> Vec<Option<usize>> {
let mut old_ids = IndexSet::new();
for expr in self.get_exprs() {
if let Expr::Bit(BitExpr::Random(id)) = expr {
old_ids.insert(id.clone());
}
}
let mut result = Vec::new();
for expr in opt_ir.get_exprs() {
if let Expr::Bit(BitExpr::Random(id)) = expr {
let n = old_ids.get_index_of(id);
result.push(n);
}
}
result
}
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap()
}
pub fn from_bytes(bytes: &[u8]) -> Result<IntermediateRepresentation, std::io::Error> {
bincode::deserialize(bytes).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Error while reading IntermediateRepresentation: {}.\nThis is probably caused by incompatible arcis versions used by the interpreter (in encrypted-ixs) and the CLI (the `arcium` binary in `arcium build`).", e),
)
})
}
const DEFAULT_CIRCUIT_OUT_DIR: &'static str = "build";
pub fn write_bytes(
&self,
circuit_name: &str,
out_dir: Option<String>,
) -> Result<String, std::io::Error> {
let current_dir = env::current_dir()?;
let circuits_dir =
current_dir.join(out_dir.unwrap_or(Self::DEFAULT_CIRCUIT_OUT_DIR.to_string()));
let file_path = circuits_dir.join(format!("{circuit_name}.arcis.ir"));
let res = String::from(file_path.to_str().unwrap());
std::fs::create_dir_all(&circuits_dir)?;
let mut file = std::fs::File::create(file_path)?;
file.write_all(&self.to_bytes())?;
Ok(res)
}
pub fn eval_with_log<R: Rng + ?Sized>(
&self,
rng: &mut R,
input_vals: &mut FxHashMap<usize, EvalValue>,
always_recover_from_ub: bool,
do_log: bool,
skip_bound_check: bool,
mut generated_bits: impl Iterator<Item = bool>,
) -> Result<Vec<EvalValue>, UndefinedBehavior> {
let mut vals = Vec::<EvalValue>::with_capacity(self.get_exprs().len());
for (i, expr) in self.get_exprs().iter().enumerate() {
let val: EvalValue = match expr {
Expr::Scalar(FieldExpr::Input(input_id, info)) => match input_vals.get(input_id) {
None => {
let v = ScalarField::gen_inclusive_range(rng, info.min, info.max);
let res = EvalValue::Scalar(v);
input_vals.insert(*input_id, res);
res
}
Some(v) => *v,
},
Expr::Base(FieldExpr::Input(input_id, info)) => match input_vals.get(input_id) {
None => {
let v = BaseField::gen_inclusive_range(rng, info.min, info.max);
let res = EvalValue::Base(v);
input_vals.insert(*input_id, res);
res
}
Some(v) => *v,
},
Expr::Bit(BitExpr::Input(input_id, _)) => match input_vals.get(input_id) {
None => {
let v = R::gen(rng);
let res = EvalValue::Bit(v);
input_vals.insert(*input_id, res);
res
}
Some(v) => *v,
},
Expr::Curve(CurveExpr::Input(input_id, _)) => match input_vals.get(input_id) {
None => {
let v = R::gen(rng);
let res = EvalValue::Curve(v);
input_vals.insert(*input_id, res);
res
}
Some(v) => *v,
},
Expr::Scalar(FieldExpr::RandomVal(_)) => {
EvalValue::Scalar(ScalarField::random(&mut *rng))
}
Expr::Base(FieldExpr::RandomVal(_)) => {
EvalValue::Base(BaseField::random(&mut *rng))
}
Expr::Bit(BitExpr::Random(_)) => {
EvalValue::Bit(generated_bits.next().unwrap_or(rng.gen()))
}
Expr::ScalarConversion(ConversionExpr::EdaBit(_, k, _)) => {
let v = Number::gen_range(rng, &0.into(), &Number::power_of_two(*k));
EvalValue::Scalar(v.into())
}
Expr::BaseConversion(ConversionExpr::EdaBit(_, k, _)) => {
let v = Number::gen_range(rng, &0.into(), &Number::power_of_two(*k));
EvalValue::Base(v.into())
}
_ => {
let res = expr
.clone()
.apply(|x| vals[x])
.apply_2(&mut EvalValueUnwrap)
.eval();
match res {
Ok(n) => n,
Err(e) => match e {
EvalFailure::UndefinedBehavior(ub) => {
if always_recover_from_ub {
self.get_bounds()[i].get_sample_val()
} else {
if do_log {
println!("vals: ");
for (i, val) in vals.iter().enumerate() {
println!("{i}: {val:?}");
}
}
return Err(ub);
}
}
_ => {
panic!("Error at expr {expr:?}: {e:?}")
}
},
}
}
};
if !skip_bound_check && !self.get_bounds()[i].contains(val) {
let expr_bound = expr
.clone()
.apply(|x| self.get_bounds()[x])
.apply_2(&mut BoundUnFold)
.bounds();
panic!(
"{val:?} is not in {:?} for {expr:?}, or {expr_bound:?}: whose bounds are
{:?}",
self.get_bounds()[i],
expr_bound
);
}
vals.push(val);
}
if do_log {
println!("vals: ");
for (i, val) in vals.iter().enumerate() {
println!("{i}: {val:?}");
}
}
Ok(self.get_outputs().iter().map(|i| vals[*i]).collect())
}
pub fn eval<R: Rng + ?Sized>(
&self,
rng: &mut R,
input_vals: &mut FxHashMap<usize, EvalValue>,
) -> Result<Vec<EvalValue>, UndefinedBehavior> {
self.eval_with_log(rng, input_vals, false, false, false, std::iter::empty())
}
pub fn eval_vec<R: Rng + ?Sized>(
&self,
inputs: Vec<EvalValue>,
rng: &mut R,
generated_bits: &[bool],
) -> Result<Vec<EvalValue>, UndefinedBehavior> {
let mut inputs = inputs.into_iter().enumerate().collect();
self.eval_with_log(
rng,
&mut inputs,
true,
false,
false,
generated_bits.iter().cloned(),
)
}
}
impl fmt::Display for IntermediateRepresentation {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut i = 0;
for gate in &self.exprs {
writeln!(f, "{i}: {gate:?};")?;
i += 1;
}
for output in &self.outputs {
writeln!(f, "{i}: Output({output});")?;
i += 1;
}
write!(f, "")
}
}
#[derive(Debug)]
pub enum ProgramError {
DuplicateInputId,
}
#[cfg(test)]
mod tests {
use crate::core::{expressions::expr::EvalValue, ir::IntermediateRepresentation};
use rand::Rng;
use rustc_hash::FxHashMap;
impl IntermediateRepresentation {
pub fn test_eq_with_vals<R: Rng + ?Sized>(
rng: &mut R,
ctrl_ir: &IntermediateRepresentation,
test_ir: &IntermediateRepresentation,
input_vals: &mut FxHashMap<usize, EvalValue>,
) {
let ctrl_res = ctrl_ir.eval(rng, input_vals);
if ctrl_res.is_err() {
return;
}
let test_res = test_ir.eval(rng, input_vals);
if ctrl_res != test_res {
if test_ir.get_exprs().len() < 65536 {
println!("ctrl: {}", ctrl_ir);
println!("test: {}", test_ir);
println!("input_vals: {input_vals:?}");
println!("ctrl_res: {ctrl_res:?}");
println!("test_res: {test_res:?}");
let _ = test_ir.eval_with_log(
rng,
input_vals,
false,
true,
false,
std::iter::empty(),
);
}
assert_eq!(ctrl_res, test_res);
}
}
pub fn test_eq<R: Rng + ?Sized>(
rng: &mut R,
ctrl_ir: &IntermediateRepresentation,
test_ir: &IntermediateRepresentation,
n_tests: usize,
) {
for _ in 0..n_tests {
let mut input_vals = FxHashMap::<usize, _>::default();
Self::test_eq_with_vals(rng, ctrl_ir, test_ir, &mut input_vals)
}
}
}
}