use bellperson::{
gadgets::{
boolean::{AllocatedBit, Boolean},
num::AllocatedNum,
},
ConstraintSystem, SynthesisError,
};
use ff::PrimeField;
pub fn insert<Scalar: PrimeField, CS: ConstraintSystem<Scalar>>(
cs: &mut CS,
element: &AllocatedNum<Scalar>,
bits: &[Boolean],
elements: &[AllocatedNum<Scalar>],
) -> Result<Vec<AllocatedNum<Scalar>>, SynthesisError> {
let size = elements.len() + 1;
assert_eq!(1 << bits.len(), size);
if size == 2 {
return insert_2(cs, element, bits, elements);
} else if size == 4 {
return insert_4(cs, element, bits, elements);
} else if size == 8 {
return insert_8(cs, element, bits, elements);
};
let mut potential_results = Vec::new();
for index in 0..size {
let mut result = Vec::new();
(0..index).for_each(|i| result.push(elements[i].clone()));
result.push(element.clone());
(index..elements.len()).for_each(|i| result.push(elements[i].clone()));
potential_results.push(result);
}
let mut result = Vec::new();
for pos in 0..size {
let choices = (0..size)
.map(|index| potential_results[index][pos].clone())
.collect::<Vec<_>>();
result.push(select(
cs.namespace(|| format!("choice at {}", pos)),
&choices,
bits,
)?);
}
Ok(result)
}
pub fn insert_2<Scalar: PrimeField, CS: ConstraintSystem<Scalar>>(
cs: &mut CS,
element: &AllocatedNum<Scalar>,
bits: &[Boolean],
elements: &[AllocatedNum<Scalar>],
) -> Result<Vec<AllocatedNum<Scalar>>, SynthesisError> {
assert_eq!(elements.len() + 1, 2);
assert_eq!(bits.len(), 1);
Ok(vec![
pick(
cs.namespace(|| "binary insert 0"),
&bits[0],
&elements[0],
element,
)?,
pick(
cs.namespace(|| "binary insert 1"),
&bits[0],
element,
&elements[0],
)?,
])
}
pub fn insert_4<Scalar: PrimeField, CS: ConstraintSystem<Scalar>>(
cs: &mut CS,
element: &AllocatedNum<Scalar>,
bits: &[Boolean],
elements: &[AllocatedNum<Scalar>],
) -> Result<Vec<AllocatedNum<Scalar>>, SynthesisError> {
assert_eq!(elements.len() + 1, 4);
assert_eq!(bits.len(), 2);
let (b0, b1) = (&bits[0], &bits[1]);
let (a, b, c, d) = (&element, &elements[0], &elements[1], &elements[2]);
macro_rules! witness {
( $var:ident <== if $cond:ident { $a:expr } else { $b:expr }) => {
let $var = pick(cs.namespace(|| stringify!($var)), $cond, $a, $b)?;
};
}
witness!(p0_x0 <== if b0 { b } else { a });
witness!(p0 <== if b1 { b } else { &p0_x0 });
witness!(p1_x0 <== if b0 { a } else { b });
witness!(p1 <== if b1 { c } else { &p1_x0 });
witness!(p2_x1 <== if b0 { d } else { a });
witness!(p2 <== if b1 { &p2_x1 } else {c });
witness!(p3_x1 <== if b0 { a } else { d });
witness!(p3 <== if b1 { &p3_x1 } else { d });
Ok(vec![p0, p1, p2, p3])
}
#[allow(clippy::many_single_char_names)]
pub fn insert_8<Scalar: PrimeField, CS: ConstraintSystem<Scalar>>(
cs: &mut CS,
element: &AllocatedNum<Scalar>,
bits: &[Boolean],
elements: &[AllocatedNum<Scalar>],
) -> Result<Vec<AllocatedNum<Scalar>>, SynthesisError> {
assert_eq!(elements.len() + 1, 8);
assert_eq!(bits.len(), 3);
let (b0, b1, b2) = (&bits[0], &bits[1], &bits[2]);
let (a, b, c, d, e, f, g, h) = (
&element,
&elements[0],
&elements[1],
&elements[2],
&elements[3],
&elements[4],
&elements[5],
&elements[6],
);
let b0_nor_b1 = match (b0, b1) {
(Boolean::Is(ref b0), Boolean::Is(ref b1)) => {
Boolean::Is(AllocatedBit::nor(cs.namespace(|| "b0 nor b1"), b0, b1)?)
}
_ => panic!("bits must be allocated and unnegated"),
};
let b0_and_b1 = match (&bits[0], &bits[1]) {
(Boolean::Is(ref b0), Boolean::Is(ref b1)) => {
Boolean::Is(AllocatedBit::and(cs.namespace(|| "b0 and b1"), b0, b1)?)
}
_ => panic!("bits must be allocated and unnegated"),
};
macro_rules! witness {
( $var:ident <== if $cond:ident { $a:expr } else { $b:expr }) => {
let $var = pick(cs.namespace(|| stringify!($var)), $cond, $a, $b)?;
};
( $var:ident <== if &$cond:ident { $a:expr } else { $b:expr }) => {
let $var = pick(cs.namespace(|| stringify!($var)), &$cond, $a, $b)?;
};
}
witness!(p0_xx0 <== if &b0_nor_b1 { a } else { b });
witness!(p0 <== if b2 { b } else { &p0_xx0 });
witness!(p1_x00 <== if b0 { a } else { b });
witness!(p1_xx0 <== if b1 { c } else { &p1_x00 });
witness!(p1 <== if b2 { c } else { &p1_xx0 });
witness!(p2_x10 <== if b0 { d } else { a });
witness!(p2_xx0 <== if b1 { &p2_x10 } else { c });
witness!(p2 <== if b2 { d } else { &p2_xx0 });
witness!(p3_xx0 <== if &b0_and_b1 { a } else { d });
witness!(p3 <== if b2 { e } else { &p3_xx0 });
witness!(p4_xx1 <== if &b0_nor_b1 { a } else { f });
witness!(p4 <== if b2 { &p4_xx1 } else { e });
witness!(p5_x01 <== if b0 { a } else { f });
witness!(p5_xx1 <== if b1 { g } else { &p5_x01 });
witness!(p5 <== if b2 { &p5_xx1 } else { f });
witness!(p6_x11 <== if b0 { h } else { a });
witness!(p6_xx1 <== if b1 { &p6_x11 } else { g });
witness!(p6 <== if b2 { &p6_xx1 } else { g });
witness!(p7_xx1 <== if &b0_and_b1 { a } else { h });
witness!(p7 <== if b2 { &p7_xx1 } else { h });
Ok(vec![p0, p1, p2, p3, p4, p5, p6, p7])
}
pub fn select<Scalar: PrimeField, CS: ConstraintSystem<Scalar>>(
mut cs: CS,
from: &[AllocatedNum<Scalar>],
path_bits: &[Boolean],
) -> Result<AllocatedNum<Scalar>, SynthesisError> {
let pathlen = path_bits.len();
assert_eq!(1 << pathlen, from.len());
let mut state = Vec::new();
for elt in from {
state.push(elt.clone())
}
let mut half_size = from.len() / 2;
for (i, bit) in path_bits.iter().rev().enumerate() {
let mut new_state = Vec::new();
for j in 0..half_size {
new_state.push(pick(
cs.namespace(|| format!("pick {}, {}", i, j)),
bit,
&state[half_size + j],
&state[j],
)?);
}
state = new_state;
half_size /= 2;
}
Ok(state.remove(0))
}
pub fn pick<Scalar: PrimeField, CS: ConstraintSystem<Scalar>>(
mut cs: CS,
condition: &Boolean,
a: &AllocatedNum<Scalar>,
b: &AllocatedNum<Scalar>,
) -> Result<AllocatedNum<Scalar>, SynthesisError>
where
CS: ConstraintSystem<Scalar>,
{
let c = AllocatedNum::alloc(cs.namespace(|| "pick result"), || {
if condition
.get_value()
.ok_or(SynthesisError::AssignmentMissing)?
{
Ok(a.get_value().ok_or(SynthesisError::AssignmentMissing)?)
} else {
Ok(b.get_value().ok_or(SynthesisError::AssignmentMissing)?)
}
})?;
cs.enforce(
|| "pick",
|lc| lc + b.get_variable() - a.get_variable(),
|_| condition.lc(CS::one(), Scalar::one()),
|lc| lc + b.get_variable() - c.get_variable(),
);
Ok(c)
}
#[cfg(test)]
mod tests {
use super::*;
use bellperson::util_cs::test_cs::TestConstraintSystem;
use blstrs::Scalar as Fr;
use ff::Field;
use rand::SeedableRng;
use rand_xorshift::XorShiftRng;
use crate::TEST_SEED;
#[test]
fn test_select() {
for log_size in 1..5 {
let size = 1 << log_size;
for index in 0..size {
let mut rng = XorShiftRng::from_seed(TEST_SEED);
let mut cs = TestConstraintSystem::new();
let elements: Vec<_> = (0..size)
.map(|i| {
AllocatedNum::<Fr>::alloc(
&mut cs.namespace(|| format!("element {}", i)),
|| {
let elt = Fr::random(&mut rng);
Ok(elt)
},
)
.expect("alloc failed")
})
.collect();
let path_bits = (0..log_size)
.map(|i| {
<Boolean as From<AllocatedBit>>::from(
AllocatedBit::alloc(cs.namespace(|| format!("index bit {}", i)), {
let bit = ((index >> i) & 1) == 1;
Some(bit)
})
.expect("alloc failed"),
)
})
.collect::<Vec<_>>();
let test_constraints = cs.num_constraints();
assert_eq!(log_size, test_constraints);
let selected = select(cs.namespace(|| "select"), &elements, &path_bits)
.expect("select failed");
assert!(cs.is_satisfied());
assert_eq!(elements[index].get_value(), selected.get_value());
let expected_constraints = size - 1;
let actual_constraints = cs.num_constraints() - test_constraints;
assert_eq!(expected_constraints, actual_constraints);
}
}
}
#[test]
fn test_insert() {
for log_size in 1..=4 {
let size = 1 << log_size;
for index in 0..size {
let mut rng = XorShiftRng::from_seed(TEST_SEED);
let mut cs = TestConstraintSystem::new();
let elements: Vec<_> = (0..size - 1)
.map(|i| {
AllocatedNum::<Fr>::alloc(
&mut cs.namespace(|| format!("element {}", i)),
|| {
let elt = Fr::random(&mut rng);
Ok(elt)
},
)
.expect("alloc failed")
})
.collect();
let to_insert = AllocatedNum::<Fr>::alloc(&mut cs.namespace(|| "insert"), || {
let elt_to_insert = Fr::random(&mut rng);
Ok(elt_to_insert)
})
.expect("alloc failed");
let index_bits = (0..log_size)
.map(|i| {
<Boolean as From<AllocatedBit>>::from(
AllocatedBit::alloc(cs.namespace(|| format!("index bit {}", i)), {
let bit = ((index >> i) & 1) == 1;
Some(bit)
})
.expect("alloc failed"),
)
})
.collect::<Vec<_>>();
let test_constraints = cs.num_constraints();
assert_eq!(log_size, test_constraints);
let mut inserted = insert(
&mut cs,
&to_insert.clone(),
index_bits.as_slice(),
elements.as_slice(),
)
.expect("insert failed");
assert!(cs.is_satisfied());
let extracted = inserted.remove(index);
assert_eq!(to_insert.get_value(), extracted.get_value(),);
for i in 0..size - 1 {
let a = elements[i].get_value();
let b = inserted[i].get_value();
assert_eq!(a, b)
}
let expected_constraints = match size {
8 => 22, 4 => 8, _ => size * (size - 1),
};
let actual_constraints = cs.num_constraints() - test_constraints;
assert_eq!(expected_constraints, actual_constraints);
}
}
}
}