#[cfg(test)]
mod test {
use std::collections::HashMap;
use std::iter::FromIterator;
use rand::distributions::{Distribution, Standard};
use rand::thread_rng;
use crate::eval::{evaluate_composite_program, largest_wires, smallest_wires};
use crate::has_io::HasIO;
use crate::translatable::Translatable;
use crate::{CombineOperation, OpType, Operation, WireValue};
#[test]
fn test_io_operations() {
fn check_combine<T: WireValue>(
gate: Operation<T>,
collected_inputs: Vec<usize>,
collected_outputs: Vec<usize>,
) where
CombineOperation: From<Operation<T>>,
{
let as_combine: CombineOperation = gate.into();
let collected_inputs_combine: Vec<usize> = as_combine.inputs().collect();
let collected_outputs_combine: Vec<usize> = as_combine.outputs().collect();
assert_eq!(collected_inputs, collected_inputs_combine);
assert_eq!(collected_outputs, collected_outputs_combine);
}
fn do_gate_test<T: WireValue>()
where
Standard:
Distribution<usize> + Distribution<(usize, T)> + Distribution<(usize, usize, T)>,
CombineOperation: From<Operation<T>>,
{
match Operation::<T>::random_variant(&mut thread_rng()) {
OpType::Binary(ty) => {
let (out, in1, in2): (usize, usize, usize) = rand::random();
let gate = ty(out, in1, in2);
let collected_inputs: Vec<usize> = gate.inputs().collect();
let collected_outputs: Vec<usize> = gate.outputs().collect();
assert_eq!(collected_inputs, vec![in1, in2]);
assert_eq!(collected_outputs, vec![out]);
assert_eq!(gate.dst().unwrap(), out);
check_combine::<T>(gate, collected_inputs, collected_outputs);
}
OpType::BinaryConst(ty) => {
let (out, in1, in2): (usize, usize, T) = rand::random();
let gate = ty(out, in1, in2);
let collected_inputs: Vec<usize> = gate.inputs().collect();
let collected_outputs: Vec<usize> = gate.outputs().collect();
assert_eq!(collected_inputs, vec![in1]);
assert_eq!(collected_outputs, vec![out]);
assert_eq!(gate.dst().unwrap(), out);
check_combine::<T>(gate, collected_inputs, collected_outputs);
}
OpType::Input(ty) => {
let out: usize = rand::random();
let gate = ty(out);
let collected_inputs: Vec<usize> = gate.inputs().collect();
let collected_outputs: Vec<usize> = gate.outputs().collect();
assert!(collected_inputs.is_empty());
assert_eq!(collected_outputs, vec![out]);
assert_eq!(gate.dst().unwrap(), out);
check_combine::<T>(gate, collected_inputs, collected_outputs);
}
OpType::InputConst(ty) => {
let (out, in1): (usize, T) = rand::random();
let gate = ty(out, in1);
let collected_inputs: Vec<usize> = gate.inputs().collect();
let collected_outputs: Vec<usize> = gate.outputs().collect();
assert!(collected_inputs.is_empty());
assert_eq!(collected_outputs, vec![out]);
assert_eq!(gate.dst().unwrap(), out);
check_combine::<T>(gate, collected_inputs, collected_outputs);
}
OpType::Output(ty) => {
let in1: usize = rand::random();
let gate = ty(in1);
let collected_inputs: Vec<usize> = gate.inputs().collect();
let collected_outputs: Vec<usize> = gate.outputs().collect();
assert_eq!(collected_inputs, vec![in1]);
assert!(collected_outputs.is_empty());
assert!(gate.dst().is_none());
check_combine::<T>(gate, collected_inputs, collected_outputs);
}
}
}
for _ in 0..1000 {
do_gate_test::<bool>();
do_gate_test::<u64>();
}
}
#[test]
fn test_io_combine_operations() {
for _ in 0..10 {
let (out, low): (usize, usize) = rand::random();
let gate = CombineOperation::B2A(out, low);
let mut expected_inputs: Vec<usize> = vec![];
for i in low..(low + 64) {
expected_inputs.push(i);
}
let collected_inputs: Vec<usize> = gate.inputs().collect();
let collected_outputs: Vec<usize> = gate.outputs().collect();
assert_eq!(collected_inputs, expected_inputs);
assert_eq!(collected_outputs, vec![out]);
let gate = CombineOperation::SizeHint(out, low);
assert!(gate.inputs().next().is_none());
assert!(gate.outputs().next().is_none());
}
}
#[test]
fn test_translation_operations() {
fn do_gate_test<T: WireValue>()
where
CombineOperation: From<Operation<T>>,
Standard: Distribution<(usize, usize, usize, T)>,
{
let variant = Operation::<T>::random_variant(&mut thread_rng());
let (original_out, original_in1, original_in2, original_c): (usize, usize, usize, T) =
rand::random();
let (translated_out, translated_in1, translated_in2, _new_c): (usize, usize, usize, T) =
rand::random();
let gate = Operation::<T>::construct(
variant,
[original_in1, original_in2].iter().copied(),
[original_out].iter().copied(),
Some(original_c),
);
let translation_target = Operation::<T>::construct(
variant,
[translated_in1, translated_in2].iter().copied(),
[translated_out].iter().copied(),
Some(original_c),
);
let identity = gate
.translate(gate.inputs(), gate.outputs())
.expect("Failed to perform identity translation");
let translated = gate
.translate(translation_target.inputs(), translation_target.outputs())
.expect("Failed to perform actual translation");
assert_eq!(gate, identity);
assert_eq!(translation_target, translated);
assert_eq!(
Operation::<bool>::Add(4, 3, 2)
.translate([7, 8].iter().copied(), [0].iter().copied())
.unwrap(),
Operation::<bool>::Add(0, 7, 8)
);
assert_eq!(
CombineOperation::GF2(Operation::<bool>::Add(4, 3, 2))
.translate([7, 8].iter().copied(), [0].iter().copied())
.unwrap(),
CombineOperation::GF2(Operation::<bool>::Add(0, 7, 8))
);
let translated_via_hashmap = gate
.translate_from_hashmap(HashMap::<usize, usize>::from_iter(
IntoIterator::into_iter([
(original_out, translated_out),
(original_in1, translated_in1),
(original_in2, translated_in2),
]),
))
.expect("Hashmap Translation Failed");
assert_eq!(translation_target, translated_via_hashmap);
let incremented = Operation::<T>::construct(
variant,
[original_in1 + 1, original_in2 + 1].iter().copied(),
[original_out + 2].iter().copied(),
Some(original_c),
);
let translated_via_fn = gate
.translate_from_fn(|x| x + 1, |x| x + 2)
.expect("Function translation failed");
assert_eq!(incremented, translated_via_fn);
let as_combine: CombineOperation = gate.into();
let target_as_combine: CombineOperation = translation_target.into();
let identity_combine = as_combine
.translate(as_combine.inputs(), as_combine.outputs())
.unwrap();
let translated_combine = as_combine
.translate(target_as_combine.inputs(), target_as_combine.outputs())
.unwrap();
assert_eq!(as_combine, identity_combine);
assert_eq!(target_as_combine, translated_combine);
}
for _ in 0..1000 {
do_gate_test::<bool>();
do_gate_test::<u64>();
}
}
#[test]
fn test_translation_combine_operations() {
for _ in 0..10 {
let (out, low): (usize, usize) = rand::random();
let (target_out, target_low): (usize, usize) = rand::random();
let gate = CombineOperation::B2A(out, low);
let translation_target = CombineOperation::B2A(target_out, target_low);
let identity = gate.translate(gate.inputs(), gate.outputs()).unwrap();
let translated = gate
.translate(translation_target.inputs(), translation_target.outputs())
.unwrap();
assert_eq!(gate, identity);
assert_eq!(translated, translation_target);
let gate = CombineOperation::SizeHint(out, low);
let translation_target = CombineOperation::SizeHint(target_out, target_low);
let identity = gate.translate(gate.inputs(), gate.outputs());
let translated =
gate.translate(translation_target.inputs(), translation_target.outputs());
assert_eq!(None, identity);
assert_eq!(None, translated);
}
}
#[test]
fn test_simple_eval() {
let circuit = vec![
CombineOperation::GF2(Operation::Const(0, true)),
CombineOperation::GF2(Operation::AddConst(1, 0, false)),
CombineOperation::GF2(Operation::SubConst(2, 1, true)),
CombineOperation::GF2(Operation::AssertZero(2)),
CombineOperation::Z64(Operation::Const(0, 15)),
CombineOperation::Z64(Operation::AddConst(1, 0, 14)),
CombineOperation::Z64(Operation::SubConst(2, 1, 14 + 15)),
CombineOperation::Z64(Operation::AssertZero(2)),
];
evaluate_composite_program(&circuit, &[], &[]);
}
#[test]
fn test_with_inputs() {
let circuit = vec![
CombineOperation::GF2(Operation::Input(0)),
CombineOperation::GF2(Operation::Input(1)),
CombineOperation::GF2(Operation::Mul(2, 1, 0)),
CombineOperation::GF2(Operation::SubConst(3, 0, true)),
CombineOperation::GF2(Operation::AssertZero(3)),
CombineOperation::GF2(Operation::SubConst(4, 1, true)),
CombineOperation::GF2(Operation::AssertZero(4)),
CombineOperation::GF2(Operation::SubConst(5, 2, true)),
CombineOperation::GF2(Operation::AssertZero(5)),
CombineOperation::Z64(Operation::Input(0)),
CombineOperation::Z64(Operation::Input(1)),
CombineOperation::Z64(Operation::Mul(2, 1, 0)),
CombineOperation::Z64(Operation::SubConst(3, 0, 14)),
CombineOperation::Z64(Operation::AssertZero(3)),
CombineOperation::Z64(Operation::SubConst(4, 1, 15)),
CombineOperation::Z64(Operation::AssertZero(4)),
CombineOperation::Z64(Operation::SubConst(5, 2, 14 * 15)),
CombineOperation::Z64(Operation::AssertZero(5)),
];
evaluate_composite_program(&circuit, &[true, true], &[14, 15]);
}
#[test]
fn test_b_to_a() {
let expected: u64 = 0b11011101;
let circuit = vec![
CombineOperation::SizeHint(1, 64),
CombineOperation::GF2(Operation::Input(0)),
CombineOperation::GF2(Operation::Input(1)),
CombineOperation::GF2(Operation::Input(2)),
CombineOperation::GF2(Operation::Input(3)),
CombineOperation::GF2(Operation::Const(4, (expected & (1 << 4)) != 0)),
CombineOperation::GF2(Operation::Const(5, (expected & (1 << 5)) != 0)),
CombineOperation::GF2(Operation::Const(6, (expected & (1 << 6)) != 0)),
CombineOperation::GF2(Operation::Const(7, (expected & (1 << 7)) != 0)),
CombineOperation::B2A(1, 0),
CombineOperation::Z64(Operation::Input(2)),
CombineOperation::Z64(Operation::Sub(3, 1, 2)),
CombineOperation::Z64(Operation::AssertZero(3)),
CombineOperation::Z64(Operation::SubConst(4, 1, expected)),
CombineOperation::Z64(Operation::AssertZero(4)),
CombineOperation::Z64(Operation::SubConst(5, 2, expected)),
CombineOperation::Z64(Operation::AssertZero(5)),
];
evaluate_composite_program(
&circuit,
&[
(expected & (1 << 0)) != 0,
(expected & (1 << 1)) != 0,
(expected & (1 << 2)) != 0,
(expected & (1 << 3)) != 0,
],
&[expected],
);
}
#[test]
fn test_size_hinting() {
let mut circuit = vec![
CombineOperation::GF2(Operation::Input(99)),
CombineOperation::Z64(Operation::Input(199)),
];
assert_eq!((199, 99), smallest_wires(&circuit));
assert_eq!((200, 100), largest_wires(&circuit));
circuit.insert(0, CombineOperation::SizeHint(400, 300));
assert_eq!((400, 300), largest_wires(&circuit));
}
}