use crate::core::{
circuits::{
boolean::{boolean_value::BooleanValue, byte::Byte},
f64::utils::F64,
},
expressions::expr::EvalFailure,
global_value::global_expr_store::with_local_expr_store_as_global,
ir_builder::IRBuilder,
};
use std::fmt::Debug;
pub trait F64Circuit: Debug {
#[allow(dead_code)]
fn eval(&self, x: Vec<f64>) -> Result<Vec<f64>, EvalFailure>;
#[allow(dead_code)]
fn rtol(&self) -> f64 {
0f64
}
#[allow(dead_code)]
fn run(&self, vals: Vec<F64>) -> Vec<F64>;
#[allow(dead_code)]
fn run_usize(&self, vals: &[usize], expr_store: &mut IRBuilder) -> Vec<usize> {
with_local_expr_store_as_global(
|| {
self.run(
vals.iter()
.map(|val| BooleanValue::new(*val))
.collect::<Vec<BooleanValue>>()
.chunks(64)
.map(|chunk| {
F64::from_le_bytes(
chunk
.to_vec()
.chunks(8)
.map(|bits| {
Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<BooleanValue>| {
panic!(
"Expected a Vec of length 8 (found {})",
v.len()
)
},
))
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<F64>>(),
)
.into_iter()
.flat_map(|x| {
x.to_le_bytes()
.into_iter()
.flat_map(|byte| {
byte.get_bits()
.into_iter()
.map(|bit| bit.get_id())
.collect::<Vec<usize>>()
})
.collect::<Vec<usize>>()
})
.collect::<Vec<usize>>()
},
expr_store,
)
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::{
core::{
circuits::traits::SAVE_CIRC_TEST_FOLDER_ENV_VAR,
expressions::{
bit_expr::{BitExpr, BitInputInfo},
domain::Domain,
expr::EvalValue,
InputKind,
},
ir_builder::ExprStore,
},
utils::field::BaseField,
};
use rand::Rng;
use std::rc::Rc;
fn bits_to_f64s(bits: Vec<bool>) -> Vec<f64> {
bits.chunks(64)
.map(|chunk| {
f64::from_le_bytes(
chunk
.to_vec()
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
},
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<f64>>()
}
fn desc_file_path() -> String {
let folder = std::env::var(SAVE_CIRC_TEST_FOLDER_ENV_VAR).unwrap();
let binding = std::thread::current();
let test_name = binding.name().unwrap();
format!("{folder}/{}.desc", test_name)
}
fn run_file_path() -> String {
let folder = std::env::var(SAVE_CIRC_TEST_FOLDER_ENV_VAR).unwrap();
let binding = std::thread::current();
let test_name = binding.name().unwrap();
format!("{folder}/{}.run", test_name)
}
fn test<R: Rng + ?Sized, C: TestedF64Circuit>(rng: &mut R, desc: &C) {
let n_inputs = 64 * desc.gen_n_inputs(rng);
let input_vals_bool = (0..n_inputs)
.map(|_| rng.gen_bool(0.5))
.collect::<Vec<bool>>();
let input_vals_f64 = bits_to_f64s(input_vals_bool.clone());
let eval_result = desc.eval(input_vals_f64.clone());
let mut expr_store = IRBuilder::new(false);
let input_ids = (0..n_inputs)
.map(|i| {
<IRBuilder as ExprStore<BaseField>>::push_bit(
&mut expr_store,
BitExpr::Input(
i,
Rc::new(BitInputInfo {
kind: InputKind::Secret,
..BitInputInfo::default()
}),
),
)
})
.collect::<Vec<usize>>();
let outputs = desc.run_usize(&input_ids, &mut expr_store);
let test_ir = expr_store.into_ir(outputs);
let mut input_vals_map = input_vals_bool
.into_iter()
.map(EvalValue::Bit)
.enumerate()
.collect();
let test_result = bits_to_f64s(
test_ir
.eval(rng, &mut input_vals_map)
.map(|x| x.into_iter().map(bool::unwrap).collect::<Vec<bool>>())
.unwrap(),
);
if eval_result.is_err() {
return;
}
let eval_result = eval_result.unwrap();
eval_result
.iter()
.zip(test_result)
.for_each(|(eval_res, test_res)| {
assert!((*eval_res - test_res).abs() <= desc.rtol() * (*eval_res).abs(), "\nRelative difference between eval_res: {:?} and test_res: {:?} exceeds rtol: {:?}. Inputs were: {:?}.\n", *eval_res, test_res, desc.rtol(), input_vals_f64)
});
desc.extra_checks(input_vals_f64, eval_result)
}
pub trait TestedF64Circuit: F64Circuit + Clone + 'static {
fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self;
fn gen_n_inputs<R: Rng + ?Sized>(&self, rng: &mut R) -> usize;
#[allow(unused_variables)]
fn extra_checks(&self, inputs: Vec<f64>, outputs: Vec<f64>) {}
fn test(n_desc: usize, n_runs_per_desc: usize) {
let rng = &mut crate::utils::test_rng::get();
let (save_desc, save_run) = if std::env::var(SAVE_CIRC_TEST_FOLDER_ENV_VAR).is_ok() {
let desc_path = desc_file_path();
println!("saving the circuit description at {}", desc_path);
let run_path = run_file_path();
println!("saving the circuit run at {}", run_path);
(Some(desc_path), Some(run_path))
} else {
(None, None)
};
for _ in 0..n_desc {
if let Some(file_path) = &save_desc {
crate::utils::test_rng::save_to_file(rng, file_path);
}
let desc = Self::gen_desc(rng);
for _ in 0..n_runs_per_desc {
if let Some(file_path) = &save_run {
crate::utils::test_rng::save_to_file(rng, file_path);
}
test(rng, &desc);
}
}
}
}
}