use itertools::enumerate;
use crate::expression::{BinaryExpression, BooleanExpression, Expression};
use crate::field::{Element, Field};
use crate::gadget_builder::GadgetBuilder;
use crate::wire_values::WireValues;
impl<F: Field> GadgetBuilder<F> {
pub fn assert_lt(&mut self, x: &Expression<F>, y: &Expression<F>) {
let lt = self.lt(x, y);
self.assert_true(<);
}
pub fn assert_le(&mut self, x: &Expression<F>, y: &Expression<F>) {
let le = self.le(x, y);
self.assert_true(&le);
}
pub fn assert_gt(&mut self, x: &Expression<F>, y: &Expression<F>) {
let gt = self.gt(x, y);
self.assert_true(>);
}
pub fn assert_ge(&mut self, x: &Expression<F>, y: &Expression<F>) {
let ge = self.ge(x, y);
self.assert_true(&ge);
}
pub fn assert_lt_binary(&mut self, x: &BinaryExpression<F>, y: &BinaryExpression<F>) {
let lt = self.lt_binary(x, y);
self.assert_true(<);
}
pub fn assert_le_binary(&mut self, x: &BinaryExpression<F>, y: &BinaryExpression<F>) {
let le = self.le_binary(x, y);
self.assert_true(&le);
}
pub fn assert_gt_binary(&mut self, x: &BinaryExpression<F>, y: &BinaryExpression<F>) {
let gt = self.gt_binary(x, y);
self.assert_true(>);
}
pub fn assert_ge_binary(&mut self, x: &BinaryExpression<F>, y: &BinaryExpression<F>)
{
let ge = self.ge_binary(x, y);
self.assert_true(&ge);
}
pub fn lt(&mut self, x: &Expression<F>, y: &Expression<F>) -> BooleanExpression<F> {
self.cmp(x, y, true, true)
}
pub fn le(&mut self, x: &Expression<F>, y: &Expression<F>) -> BooleanExpression<F> {
self.cmp(x, y, true, false)
}
pub fn gt(&mut self, x: &Expression<F>, y: &Expression<F>) -> BooleanExpression<F> {
self.cmp(x, y, false, true)
}
pub fn ge(&mut self, x: &Expression<F>, y: &Expression<F>) -> BooleanExpression<F> {
self.cmp(x, y, false, false)
}
pub fn lt_binary(
&mut self, x: &BinaryExpression<F>, y: &BinaryExpression<F>,
) -> BooleanExpression<F> {
self.cmp_binary(x, y, true, true)
}
pub fn le_binary(
&mut self, x: &BinaryExpression<F>, y: &BinaryExpression<F>,
) -> BooleanExpression<F> {
self.cmp_binary(x, y, true, false)
}
pub fn gt_binary(
&mut self, x: &BinaryExpression<F>, y: &BinaryExpression<F>,
) -> BooleanExpression<F> {
self.cmp_binary(x, y, false, true)
}
pub fn ge_binary(
&mut self, x: &BinaryExpression<F>, y: &BinaryExpression<F>,
) -> BooleanExpression<F> {
self.cmp_binary(x, y, false, false)
}
fn cmp(
&mut self, x: &Expression<F>, y: &Expression<F>, less: bool, strict: bool,
) -> BooleanExpression<F> {
let (x_bin, y_bin) = if less {
(self.split_allowing_ambiguity(x), self.split(y))
} else {
(self.split(x), self.split_allowing_ambiguity(y))
};
self.cmp_binary(&x_bin, &y_bin, less, strict)
}
fn cmp_binary(
&mut self,
x_bits: &BinaryExpression<F>,
y_bits: &BinaryExpression<F>,
less: bool, strict: bool,
) -> BooleanExpression<F> {
assert_eq!(x_bits.len(), y_bits.len());
let operand_bits = x_bits.len();
let chunk_bits = Self::cmp_chunk_bits(operand_bits);
let x_chunks: Vec<Expression<F>> = x_bits.chunks(chunk_bits)
.iter().map(BinaryExpression::join).collect();
let y_chunks: Vec<Expression<F>> = y_bits.chunks(chunk_bits)
.iter().map(BinaryExpression::join).collect();
let chunks = x_chunks.len();
let mask = self.wires(chunks);
for &m in &mask {
self.assert_boolean(&Expression::from(m));
}
let diff_exists = self.assert_boolean(&Expression::sum_of_wires(&mask));
{
let x_chunks = x_chunks.clone();
let y_chunks = y_chunks.clone();
let mask = mask.clone();
self.generator(
[x_bits.dependencies(), y_bits.dependencies()].concat(),
move |values: &mut WireValues<F>| {
let mut seen_diff: bool = false;
for (i, &mask_bit) in enumerate(&mask).rev() {
let x_chunk_value = x_chunks[i].evaluate(values);
let y_chunk_value = y_chunks[i].evaluate(values);
let diff = x_chunk_value != y_chunk_value;
let mask_bit_value = diff && !seen_diff;
seen_diff |= diff;
values.set(mask_bit, mask_bit_value.into());
}
},
);
}
let mut diff_chunk = Expression::zero();
for i in 0..chunks {
let diff = &x_chunks[i] - &y_chunks[i];
diff_chunk += self.product(&Expression::from(mask[i]), &diff);
}
let mut diff_seen: Expression<F> = mask[0].into();
for i in 1..chunks {
self.assert_product(&diff_seen,
&(&x_chunks[i] - &y_chunks[i]),
&Expression::zero());
diff_seen += Expression::from(mask[i]);
}
if !strict {
let nonzero = self.selection(&diff_exists, &diff_chunk, &Expression::from(42u8));
self.assert_nonzero(&nonzero);
}
self.cmp_subtractive(diff_chunk, less, strict, chunk_bits)
}
fn cmp_subtractive(&mut self, diff: Expression<F>,
less: bool, strict: bool, bits: usize) -> BooleanExpression<F> {
let base = Expression::from(
(Element::one() << bits) - Element::from(strict));
let z = base + if less { -diff } else { diff };
self.split_bounded(&z, bits + 1).bits[bits].clone()
}
fn cmp_constraints(operand_bits: usize, chunk_bits: usize) -> usize {
let chunks = (operand_bits + chunk_bits - 1) / chunk_bits;
3 * chunks + 2 + chunk_bits
}
fn cmp_chunk_bits(operand_bits: usize) -> usize {
let mut best_chunk_bits = 1;
let mut best_constraints = Self::cmp_constraints(operand_bits, 1);
for chunk_bits in 2..Element::<F>::max_bits() {
let constraints = Self::cmp_constraints(operand_bits, chunk_bits);
if constraints < best_constraints {
best_chunk_bits = chunk_bits;
best_constraints = constraints;
}
}
best_chunk_bits
}
}
#[cfg(test)]
mod tests {
use crate::expression::Expression;
use crate::field::{Bn128, Element};
use crate::gadget_builder::GadgetBuilder;
use crate::test_util::assert_eq_false;
use crate::test_util::assert_eq_true;
#[test]
fn comparisons() {
let mut builder = GadgetBuilder::<Bn128>::new();
let (x, y) = (builder.wire(), builder.wire());
let x_exp = Expression::from(x);
let y_exp = Expression::from(y);
let lt = builder.lt(&x_exp, &y_exp);
let le = builder.le(&x_exp, &y_exp);
let gt = builder.gt(&x_exp, &y_exp);
let ge = builder.ge(&x_exp, &y_exp);
let gadget = builder.build();
let mut values_42_63 = values!(x => 42u8.into(), y => 63u8.into());
assert!(gadget.execute(&mut values_42_63));
assert_eq_true(<, &values_42_63);
assert_eq_true(&le, &values_42_63);
assert_eq_false(>, &values_42_63);
assert_eq_false(&ge, &values_42_63);
let mut values_42_42 = values!(x => 42u8.into(), y => 42u8.into());
assert!(gadget.execute(&mut values_42_42));
assert_eq_false(<, &values_42_42);
assert_eq_true(&le, &values_42_42);
assert_eq_false(>, &values_42_42);
assert_eq_true(&ge, &values_42_42);
let mut values_42_41 = values!(x => 42u8.into(), y => 41u8.into());
assert!(gadget.execute(&mut values_42_41));
assert_eq_false(<, &values_42_41);
assert_eq_false(&le, &values_42_41);
assert_eq_true(>, &values_42_41);
assert_eq_true(&ge, &values_42_41);
let mut values_large_lt = values!(
x => Element::from(1u128 << 80 | 1u128),
y => Element::from(1u128 << 81));
assert!(gadget.execute(&mut values_large_lt));
assert_eq_true(<, &values_large_lt);
assert_eq_true(&le, &values_large_lt);
assert_eq_false(>, &values_large_lt);
assert_eq_false(&ge, &values_large_lt);
}
}