use crate::{
core::expressions::{
bit_expr::BitExpr,
curve_expr::CurveExpr,
domain::DomainElement,
expr::{EvalValue, Expr},
field_expr::FieldExpr,
},
utils::{
curve_point::CurvePoint,
field::{BaseField, ScalarField},
number::Number,
used_field::UsedField,
},
};
use serde::{Deserialize, Serialize};
use std::{
fmt::Debug,
hash::Hash,
ops::{Add, BitAnd, BitXor, Mul, Neg, Not},
};
pub trait IsBounds<T>: Clone + Copy + Debug + PartialEq + Eq + Hash + From<T> {
fn contains(self, value: T) -> bool;
fn inter(self, other: Self) -> Self;
fn as_constant(self) -> Option<T>;
fn is_empty(self) -> bool;
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum FieldBounds<T: UsedField> {
#[default]
All,
Interval(T, T), Empty,
}
impl<T: UsedField> From<T> for FieldBounds<T> {
fn from(value: T) -> Self {
FieldBounds::new(value, value)
}
}
impl<T: UsedField> FieldBounds<T> {
pub fn new(min: T, max: T) -> Self {
if (max - min).is_ge_zero() {
FieldBounds::Interval(min, max)
} else {
FieldBounds::All
}
}
pub fn union(self, other: Self) -> Self {
match (self, other) {
(FieldBounds::All, _) => FieldBounds::All,
(_, FieldBounds::All) => FieldBounds::All,
(FieldBounds::Empty, y) => y,
(y, FieldBounds::Empty) => y,
(FieldBounds::Interval(a1, a2), FieldBounds::Interval(b1, b2)) => {
let (c1, d1) = a1.sort_pair(b1);
let (c2, d2) = a2.sort_pair(b2);
if (c2 - c1).is_ge_zero() && (d2 - d1).is_ge_zero() {
FieldBounds::new(c1, d2)
} else {
FieldBounds::All
}
}
}
}
fn binary_op_bounds(
self,
other: Self,
both_all: Self,
left_all: fn(T, T) -> Self,
right_all: fn(T, T) -> Self,
both_intervals: fn((T, T), (T, T)) -> Self,
) -> Self {
match (self, other) {
(_, FieldBounds::Empty) => FieldBounds::Empty,
(FieldBounds::Empty, _) => FieldBounds::Empty,
(FieldBounds::All, FieldBounds::All) => both_all,
(FieldBounds::All, FieldBounds::Interval(min, max)) => left_all(min, max),
(FieldBounds::Interval(min, max), FieldBounds::All) => right_all(min, max),
(FieldBounds::Interval(min1, max1), FieldBounds::Interval(min2, max2)) => {
both_intervals((min1, max1), (min2, max2))
}
}
}
fn binary_sym_op_bounds(
self,
other: Self,
both_all: Self,
one_all: fn(T, T) -> Self,
both_intervals: fn((T, T), (T, T)) -> Self,
) -> Self {
self.binary_op_bounds(other, both_all, one_all, one_all, both_intervals)
}
pub fn unsigned_min(&self) -> T {
match self {
FieldBounds::All => T::ZERO,
FieldBounds::Interval(min, max) => {
if *min > *max {
T::ZERO
} else {
*min
}
}
FieldBounds::Empty => {
panic!("Empty Bounds do not have a minimum.")
}
}
}
pub fn unsigned_max(&self) -> T {
match self {
FieldBounds::All => T::ZERO - T::ONE,
FieldBounds::Interval(min, max) => {
if *min > *max {
T::ZERO - T::ONE
} else {
*max
}
}
FieldBounds::Empty => {
panic!("Empty Bounds do not have a maximum.")
}
}
}
pub fn min_and_max(&self, signed: bool) -> (T, T) {
if signed {
(self.signed_min(), self.signed_max())
} else {
(self.unsigned_min(), self.unsigned_max())
}
}
pub fn unsigned_bin_size(&self) -> usize {
self.unsigned_max().unsigned_bits()
}
pub fn signed_min(&self) -> T {
match self {
FieldBounds::All => T::TWO_INV, FieldBounds::Interval(a, _) => {
if self.contains(T::TWO_INV) {
T::TWO_INV
} else {
*a
}
}
FieldBounds::Empty => {
panic!("Empty Bounds do not have a signed minimum.")
}
}
}
pub fn signed_max(&self) -> T {
match self {
FieldBounds::All => T::TWO_INV - T::ONE,
FieldBounds::Interval(_, b) => {
if self.contains(T::TWO_INV - T::ONE) {
T::TWO_INV - T::ONE
} else {
*b
}
}
FieldBounds::Empty => {
panic!("Empty Bounds do not have a signed maximum.")
}
}
}
pub fn min_abs(&self) -> T {
match *self {
FieldBounds::All => T::ZERO,
FieldBounds::Interval(min, max) => {
if min > max {
T::ZERO
} else {
min.min(max.abs(), false)
}
}
FieldBounds::Empty => {
panic!("Empty Bounds do not have a minimum absolute value.")
}
}
}
pub fn max_abs(&self) -> T {
match *self {
FieldBounds::All => T::TWO_INV - T::ONE,
FieldBounds::Interval(min, max) => {
if self.contains(T::TWO_INV) {
T::TWO_INV - T::ONE
} else {
min.abs().max(max.abs(), false)
}
}
FieldBounds::Empty => {
panic!("Empty Bounds do not have a maximum absolute value.")
}
}
}
pub fn has_positives(&self) -> bool {
!self.signed_max().is_le_zero()
}
pub fn has_negatives(&self) -> bool {
!self.signed_min().is_ge_zero()
}
pub fn signed_bin_size(&self) -> usize {
self.max_abs().unsigned_bits() + 1
}
pub fn bin_size(&self, signed: bool) -> usize {
if signed {
self.signed_bin_size()
} else {
self.unsigned_bin_size()
}
}
pub fn bits_in_pos(&self, pos: usize, signed: bool) -> (bool, bool) {
let (min, max) = self.min_and_max(signed);
if signed {
(min.signed_bit(pos), max.signed_bit(pos))
} else {
(min.unsigned_bit(pos), max.unsigned_bit(pos))
}
}
pub fn lowest_included_unique_power_of_2(self, signed: bool) -> usize {
let (min, max) = self.min_and_max(signed);
let gap = max - min;
if gap == 0.into() {
return 0;
}
let gap_bits = gap.unsigned_bits();
let temp = self.bits_in_pos(gap_bits - 1, signed);
if temp.0 == temp.1 {
let a = self.bits_in_pos(gap_bits, signed);
assert_ne!(a.0, a.1);
gap_bits
} else {
gap_bits - 1
}
}
pub fn to_signed_number_pair(self) -> (Number, Number) {
let (min, max) = self.min_and_max(true);
(min.to_signed_number(), max.to_signed_number())
}
pub fn to_unsigned_number_pair(self) -> (Number, Number) {
let (min, max) = self.min_and_max(false);
(min.to_unsigned_number(), max.to_unsigned_number())
}
pub fn to_number_pair(self, signed: bool) -> (Number, Number) {
if signed {
self.to_signed_number_pair()
} else {
self.to_unsigned_number_pair()
}
}
fn get_sample_val(self) -> T {
match self {
FieldBounds::Interval(a, _) => a,
_ => T::ZERO,
}
}
}
impl<T: UsedField> IsBounds<T> for FieldBounds<T> {
fn contains(self, value: T) -> bool {
match self {
FieldBounds::All => true,
FieldBounds::Interval(min, max) => {
(value - min).is_ge_zero() && (max - value).is_ge_zero()
}
FieldBounds::Empty => false,
}
}
fn inter(self, other: Self) -> Self {
self.binary_sym_op_bounds(
other,
FieldBounds::All,
|min, max| FieldBounds::Interval(min, max),
|(min1, max1), (min2, max2)| {
let (min, min_is_2) = min1.max_cyclic(min2);
let (max, max_is_2) = max1.min_cyclic(max2);
let other_min = [min2, min1][max_is_2 as usize];
let other_max = [max2, max1][min_is_2 as usize];
if (other_max - min).is_ge_zero() && (max - other_min).is_ge_zero() {
FieldBounds::Interval(min, max)
} else {
FieldBounds::Empty
}
},
)
}
fn as_constant(self) -> Option<T> {
match self {
FieldBounds::All => None,
FieldBounds::Interval(a, b) => {
if (a - b).is_zero_vartime() {
Some(a)
} else {
None
}
}
FieldBounds::Empty => None,
}
}
fn is_empty(self) -> bool {
self == FieldBounds::Empty
}
}
impl<T: UsedField> Add for FieldBounds<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
self.binary_sym_op_bounds(
rhs,
FieldBounds::All,
|_, _| FieldBounds::All,
|(min1, max1), (min2, max2)| FieldBounds::new(min1 + min2, max1 + max2),
)
}
}
impl<T: UsedField> Add<T> for FieldBounds<T> {
type Output = Self;
fn add(self, rhs: T) -> Self::Output {
match self {
FieldBounds::All => FieldBounds::All,
FieldBounds::Interval(a, b) => FieldBounds::Interval(a + rhs, b + rhs),
FieldBounds::Empty => FieldBounds::Empty,
}
}
}
impl<T: UsedField> Neg for FieldBounds<T> {
type Output = Self;
fn neg(self) -> Self::Output {
match self {
FieldBounds::All => FieldBounds::All,
FieldBounds::Interval(min, max) => FieldBounds::Interval(-max, -min),
FieldBounds::Empty => FieldBounds::Empty,
}
}
}
impl<T: UsedField> Mul for FieldBounds<T> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
fn mul_simple_interval_bounds<T: UsedField>(a: T, b: T) -> FieldBounds<T> {
if a.does_mul_overflow(b) {
FieldBounds::All
} else {
FieldBounds::new(T::ZERO, a * b)
}
}
fn mul_ge_zero_interval_bounds<T: UsedField>(
(min1, max1): (T, T),
(min2, max2): (T, T),
) -> FieldBounds<T> {
let (diff1, diff2) = (max1 - min1, max2 - min2);
let close_to_res = mul_simple_interval_bounds(diff1, diff2)
+ mul_simple_interval_bounds(diff1, min2)
+ mul_simple_interval_bounds(diff2, min1);
let min_prod = min1 * min2;
close_to_res + min_prod
}
fn split_at_0<T: UsedField>(min: T, max: T) -> [Option<(T, T)>; 2] {
if !min.is_ge_zero() {
if max.is_le_zero() {
[Some((T::ZERO - max, T::ZERO - min)), None]
} else {
[Some((T::ZERO, T::ZERO - min)), Some((T::ZERO, max))]
}
} else if !max.is_ge_zero() {
[
Some((T::ZERO - max, T::ZERO - T::TWO_INV)),
Some((min, T::ZERO - T::TWO_INV)),
]
} else {
[None, Some((min, max))]
}
}
fn handle_intervals<T: UsedField>(
i1: Option<(T, T)>,
i2: Option<(T, T)>,
neg: bool,
) -> FieldBounds<T> {
let i_res = if let (Some(i1), Some(i2)) = (i1, i2) {
mul_ge_zero_interval_bounds(i1, i2)
} else {
FieldBounds::Empty
};
if neg {
-i_res
} else {
i_res
}
}
self.binary_sym_op_bounds(
rhs,
FieldBounds::All,
|min, max| {
if min.is_zero_vartime() && max.is_zero_vartime() {
FieldBounds::Interval(T::ZERO, T::ZERO)
} else {
FieldBounds::All
}
},
|(min1, max1), (min2, max2)| {
let [n1, p1] = split_at_0(min1, max1);
let [n2, p2] = split_at_0(min2, max2);
handle_intervals(n1, n2, false)
.union(handle_intervals(n1, p2, true))
.union(handle_intervals(p1, n2, true))
.union(handle_intervals(p1, p2, false))
},
)
}
}
impl<T: UsedField> Mul<T> for FieldBounds<T> {
type Output = Self;
fn mul(self, rhs: T) -> Self::Output {
if rhs.is_zero_vartime() {
FieldBounds::Interval(T::ZERO, T::ZERO)
} else {
match self {
FieldBounds::All => FieldBounds::All,
FieldBounds::Interval(min, max) => {
let is_positive = rhs.is_ge_zero();
let overflow =
(max - min).does_mul_overflow(if is_positive { rhs } else { -rhs });
if overflow {
return FieldBounds::All;
}
let a = min * rhs;
let b = max * rhs;
if is_positive {
FieldBounds::new(a, b)
} else {
FieldBounds::new(b, a)
}
}
FieldBounds::Empty => FieldBounds::Empty,
}
}
}
}
impl<T: UsedField> From<BoolBounds> for FieldBounds<T> {
fn from(value: BoolBounds) -> Self {
match (value.can_be_false, value.can_be_true) {
(false, false) => FieldBounds::Empty,
(true, false) => FieldBounds::Interval(T::ZERO, T::ZERO),
(false, true) => FieldBounds::Interval(T::ONE, T::ONE),
(true, true) => FieldBounds::Interval(T::ZERO, T::ONE),
}
}
}
impl<T: UsedField> From<(T, T)> for FieldBounds<T> {
fn from(value: (T, T)) -> Self {
FieldBounds::new(value.0, value.1)
}
}
impl<F: UsedField> From<i32> for FieldBounds<F> {
fn from(value: i32) -> Self {
let value = value.into();
FieldBounds::Interval(value, value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct BoolBounds {
can_be_false: bool, can_be_true: bool, }
impl BoolBounds {
pub fn new(can_be_false: bool, can_be_true: bool) -> Self {
BoolBounds {
can_be_false,
can_be_true,
}
}
pub fn multiply_by_power_of_two<F: UsedField>(
self,
exponent: usize,
is_negative: bool,
) -> FieldBounds<F> {
let factor = if is_negative {
F::negative_power_of_two(exponent)
} else {
F::power_of_two(exponent)
};
if self.can_be_true {
if self.can_be_false {
if is_negative {
FieldBounds::new(factor, F::ZERO)
} else {
FieldBounds::new(F::ZERO, factor)
}
} else {
FieldBounds::new(factor, factor)
}
} else {
FieldBounds::new(F::ZERO, F::ZERO)
}
}
}
impl From<bool> for BoolBounds {
fn from(value: bool) -> Self {
if value {
BoolBounds {
can_be_false: false,
can_be_true: true,
}
} else {
BoolBounds {
can_be_false: true,
can_be_true: false,
}
}
}
}
impl IsBounds<bool> for BoolBounds {
fn contains(self, value: bool) -> bool {
if value {
self.can_be_true
} else {
self.can_be_false
}
}
fn inter(self, other: Self) -> Self {
BoolBounds {
can_be_false: self.can_be_false && other.can_be_false,
can_be_true: self.can_be_true && other.can_be_true,
}
}
fn as_constant(self) -> Option<bool> {
if self.can_be_true && !self.can_be_false {
Some(true)
} else if self.can_be_false && !self.can_be_true {
Some(false)
} else {
None
}
}
fn is_empty(self) -> bool {
!self.can_be_true && !self.can_be_false
}
}
pub struct BoolBoundsIter {
bool_bounds: BoolBounds,
curr: usize,
}
impl BoolBoundsIter {
fn new(bool_bounds: BoolBounds) -> Self {
BoolBoundsIter {
bool_bounds,
curr: 0,
}
}
}
impl Iterator for BoolBoundsIter {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
if self.curr == 0 && self.bool_bounds.can_be_false {
self.curr = 1;
Some(false)
} else if self.curr <= 1 && self.bool_bounds.can_be_true {
self.curr = 2;
Some(true)
} else {
None
}
}
}
impl IntoIterator for BoolBounds {
type Item = bool;
type IntoIter = BoolBoundsIter;
fn into_iter(self) -> Self::IntoIter {
BoolBoundsIter::new(self)
}
}
impl BitAnd for BoolBounds {
type Output = Self;
fn bitand(self, other: Self) -> Self {
BoolBounds::new(
self.can_be_false || other.can_be_false,
self.can_be_true && other.can_be_true,
)
}
}
impl BitXor for BoolBounds {
type Output = Self;
fn bitxor(self, other: Self) -> Self {
BoolBounds::new(
(self.can_be_false && other.can_be_false) || (self.can_be_true && other.can_be_true),
(self.can_be_false && other.can_be_true) || (self.can_be_true && other.can_be_false),
)
}
}
impl Not for BoolBounds {
type Output = Self;
fn not(self) -> Self {
BoolBounds::new(self.can_be_true, self.can_be_false)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CurveBounds {
All,
Constant(CurvePoint),
Empty,
}
impl CurveBounds {
pub fn get_sample_val(self) -> CurvePoint {
match self {
CurveBounds::Constant(a) => a,
_ => CurvePoint::identity(),
}
}
}
impl From<CurvePoint> for CurveBounds {
fn from(value: CurvePoint) -> Self {
CurveBounds::Constant(value)
}
}
impl IsBounds<CurvePoint> for CurveBounds {
fn contains(self, value: CurvePoint) -> bool {
match self {
CurveBounds::All => true,
CurveBounds::Constant(v) => v == value,
CurveBounds::Empty => false,
}
}
fn inter(self, other: Self) -> Self {
match (self, other) {
(CurveBounds::All, _) => other,
(_, CurveBounds::All) => self,
(CurveBounds::Empty, _) => CurveBounds::Empty,
(_, CurveBounds::Empty) => CurveBounds::Empty,
(CurveBounds::Constant(v1), CurveBounds::Constant(v2)) => {
if v1 == v2 {
CurveBounds::Constant(v1)
} else {
CurveBounds::Empty
}
}
}
}
fn as_constant(self) -> Option<CurvePoint> {
if let CurveBounds::Constant(v) = self {
Some(v)
} else {
None
}
}
fn is_empty(self) -> bool {
self == CurveBounds::Empty
}
}
pub type Bounds =
DomainElement<BoolBounds, FieldBounds<ScalarField>, FieldBounds<BaseField>, CurveBounds>;
impl Bounds {
pub fn to_scalar(self) -> FieldBounds<ScalarField> {
let Bounds::Scalar(b) = self else {
panic!("Can't get scalar bounds");
};
b
}
pub fn to_bit(self) -> BoolBounds {
let Bounds::Bit(b) = self else {
panic!("Can't get bool bounds");
};
b
}
pub fn to_curve(self) -> CurveBounds {
let Bounds::Curve(c) = self else {
panic!("Can't get curve bounds");
};
c
}
pub fn as_constant_expr(self) -> Option<Expr<usize>> {
self.as_constant().map(|t| match t {
EvalValue::Scalar(s) => Expr::Scalar(FieldExpr::Val(s)),
EvalValue::Bit(b) => Expr::Bit(BitExpr::Val(b)),
EvalValue::Base(s) => Expr::Base(FieldExpr::Val(s)),
EvalValue::Curve(c) => Expr::Curve(CurveExpr::Val(c)),
})
}
pub fn get_sample_val(self) -> EvalValue {
match self {
Bounds::Bit(b) => {
if b.can_be_false {
EvalValue::Bit(false)
} else if b.can_be_true {
EvalValue::Bit(true)
} else {
EvalValue::Bit(false)
}
}
Bounds::Scalar(b) => EvalValue::Scalar(b.get_sample_val()),
Bounds::Base(b) => EvalValue::Base(b.get_sample_val()),
Bounds::Curve(b) => EvalValue::Curve(b.get_sample_val()),
}
}
}
impl From<EvalValue> for Bounds {
fn from(value: EvalValue) -> Self {
match value {
EvalValue::Scalar(v) => Bounds::Scalar(v.into()),
EvalValue::Bit(v) => Bounds::Bit(v.into()),
EvalValue::Base(v) => Bounds::Base(v.into()),
EvalValue::Curve(v) => Bounds::Curve(v.into()),
}
}
}
impl IsBounds<EvalValue> for Bounds {
fn contains(self, value: EvalValue) -> bool {
match (self, value) {
(Bounds::Scalar(b), EvalValue::Scalar(v)) => b.contains(v),
(Bounds::Bit(b), EvalValue::Bit(v)) => b.contains(v),
(Bounds::Base(b), EvalValue::Base(v)) => b.contains(v),
(Bounds::Curve(b), EvalValue::Curve(v)) => b.contains(v),
(_, _) => false,
}
}
fn inter(self, other: Self) -> Self {
match (self, other) {
(Bounds::Scalar(b1), Bounds::Scalar(b2)) => Bounds::Scalar(b1.inter(b2)),
(Bounds::Bit(b1), Bounds::Bit(b2)) => Bounds::Bit(b1.inter(b2)),
(Bounds::Base(b1), Bounds::Base(b2)) => Bounds::Base(b1.inter(b2)),
(Bounds::Curve(b1), Bounds::Curve(b2)) => Bounds::Curve(b1.inter(b2)),
(_, _) => panic!("Intersecting bounds in different domains."),
}
}
fn as_constant(self) -> Option<EvalValue> {
match self {
Bounds::Bit(b) => b.as_constant().map(EvalValue::Bit),
Bounds::Scalar(b) => b.as_constant().map(EvalValue::Scalar),
Bounds::Base(b) => b.as_constant().map(EvalValue::Base),
Bounds::Curve(b) => b.as_constant().map(EvalValue::Curve),
}
}
fn is_empty(self) -> bool {
match self {
Bounds::Bit(b) => b.is_empty(),
Bounds::Scalar(b) => b.is_empty(),
Bounds::Base(b) => b.is_empty(),
Bounds::Curve(b) => b.is_empty(),
}
}
}
pub fn below_power_of_two<F: UsedField>(u: usize) -> FieldBounds<F> {
if u <= F::CAPACITY as usize {
FieldBounds::new(F::ZERO, F::power_of_two(u) - F::ONE)
} else {
FieldBounds::All
}
}
impl<F: UsedField> FieldBounds<F> {
pub fn is_arithmetic_boolean(self) -> bool {
let bools = FieldBounds::new(F::ZERO, F::ONE);
match self {
FieldBounds::All => false,
FieldBounds::Interval(a, b) => bools.contains(a) && bools.contains(b),
FieldBounds::Empty => true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::expressions::{
field_expr::InputInfo,
macro_uses::{BoundSampler, BoundWrap},
random_expr::{ExprGenHelper, ExprGenerator},
InputKind,
};
use ff::Field;
use rand::Rng;
use std::rc::Rc;
impl Bounds {
pub fn is_arithmetic_boolean(self) -> bool {
match self {
Bounds::Bit(_) => false,
Bounds::Curve(_) => false,
Bounds::Scalar(b) => b.is_arithmetic_boolean(),
Bounds::Base(b) => b.is_arithmetic_boolean(),
}
}
pub fn contains_field_zero(self) -> bool {
match self {
Bounds::Bit(_) => false,
Bounds::Curve(_) => false,
Bounds::Scalar(b) => b.contains(0.into()),
Bounds::Base(b) => b.contains(0.into()),
}
}
}
impl<F: UsedField> FieldBounds<F> {
fn gen_interval<R: Rng + ?Sized>(rng: &mut R, bounds: FieldBounds<F>) -> FieldBounds<F> {
let a = bounds.sample(rng);
let b = bounds.sample(rng);
if (b - a).is_ge_zero() {
FieldBounds::new(a, b)
} else {
FieldBounds::new(b, a)
}
}
pub fn gen_bounds<R: Rng + ?Sized>(rng: &mut R, bounds: FieldBounds<F>) -> FieldBounds<F> {
if bounds == FieldBounds::All && rng.gen_bool(0.5) {
FieldBounds::All
} else {
Self::gen_interval(rng, bounds)
}
}
fn symmetric_around_zero(v: F) -> Self {
FieldBounds::new(-v, v)
}
pub fn sample<R: Rng + ?Sized>(self, rng: &mut R) -> F {
match self {
FieldBounds::All => F::random(rng),
FieldBounds::Interval(min, max) => F::gen_inclusive_range(rng, min, max),
FieldBounds::Empty => {
panic!("Cannot generate from empty field bounds")
}
}
}
pub fn as_input_info(self, kind: InputKind) -> Rc<InputInfo<F>> {
let info = match self {
FieldBounds::All => InputInfo {
kind,
min: F::ZERO,
max: -F::ONE,
..InputInfo::default()
},
FieldBounds::Interval(min, max) => InputInfo {
kind,
min,
max,
..InputInfo::default()
},
FieldBounds::Empty => {
panic!("Cannot generate input_info from empty field bounds")
}
};
Rc::new(info)
}
}
impl BoolBounds {
pub fn gen_bounds<R: Rng + ?Sized>(rng: &mut R, bounds: BoolBounds) -> BoolBounds {
if bounds.can_be_false && bounds.can_be_true {
if rng.gen_bool(0.5) {
BoolBounds::new(true, true)
} else if rng.gen_bool(0.5) {
BoolBounds::new(true, false)
} else {
BoolBounds::new(false, true)
}
} else {
bounds
}
}
pub fn sample<R: Rng + ?Sized>(self, rng: &mut R) -> bool {
if self.can_be_false && self.can_be_true {
rng.gen_bool(0.5)
} else if self.can_be_false {
false
} else if self.can_be_true {
true
} else {
panic!("Cannot generate in empty bool bounds")
}
}
}
impl CurveBounds {
pub fn gen_bounds<R: Rng + ?Sized>(rng: &mut R, bounds: CurveBounds) -> CurveBounds {
if bounds == CurveBounds::All {
let is_constant = rng.gen_bool(0.5);
if is_constant {
CurveBounds::Constant(R::gen(rng))
} else {
CurveBounds::All
}
} else {
bounds
}
}
pub fn sample<R: Rng + ?Sized>(self, rng: &mut R) -> CurvePoint {
match self {
CurveBounds::All => R::gen(rng),
CurveBounds::Constant(a) => a,
CurveBounds::Empty => {
panic!("Cannot generate in empty curve bounds")
}
}
}
}
struct BoundExprGenHelper {
bool_bounds: BoolBounds,
scalar_bounds: FieldBounds<ScalarField>,
scalar_cond_bounds: FieldBounds<ScalarField>,
scalar_pos_bounds: FieldBounds<ScalarField>,
base_field_bounds: FieldBounds<BaseField>,
base_field_cond_bounds: FieldBounds<BaseField>,
base_field_pos_bounds: FieldBounds<BaseField>,
max_eda_size: usize,
}
impl BoundExprGenHelper {
fn new(number: &Number) -> BoundExprGenHelper {
let bool_bounds = BoolBounds::new(true, true);
let scalar_bound = number.clone().min(ScalarField::modulus() / 2).into();
let scalar_bounds = FieldBounds::symmetric_around_zero(scalar_bound);
let scalar_cond_bounds = FieldBounds::new(ScalarField::ZERO, ScalarField::ONE);
let scalar_pos_bounds = FieldBounds::new(ScalarField::ONE, scalar_bound);
assert_ne!(scalar_pos_bounds, FieldBounds::All);
let base_field_bound = number.clone().min(BaseField::modulus() / 2).into();
let base_field_bounds = FieldBounds::symmetric_around_zero(base_field_bound);
let base_field_cond_bounds = FieldBounds::new(BaseField::ZERO, BaseField::ONE);
let base_field_pos_bounds = FieldBounds::new(BaseField::ONE, base_field_bound);
assert_ne!(base_field_pos_bounds, FieldBounds::All);
let max_eda_size = scalar_bound.unsigned_bits();
BoundExprGenHelper {
bool_bounds,
scalar_bounds,
scalar_cond_bounds,
scalar_pos_bounds,
base_field_bounds,
base_field_cond_bounds,
base_field_pos_bounds,
max_eda_size,
}
}
}
impl ExprGenHelper for BoundExprGenHelper {
type ScalarType = FieldBounds<ScalarField>;
type BitType = BoolBounds;
type BaseType = FieldBounds<BaseField>;
type CurveType = CurveBounds;
fn scalar<R: Rng + ?Sized>(&self, rng: &mut R) -> FieldBounds<ScalarField> {
FieldBounds::gen_bounds(rng, self.scalar_bounds)
}
fn bit<R: Rng + ?Sized>(&self, rng: &mut R) -> BoolBounds {
BoolBounds::gen_bounds(rng, self.bool_bounds)
}
fn base<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::BaseType {
FieldBounds::gen_bounds(rng, self.base_field_bounds)
}
fn curve_point<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::CurveType {
CurveBounds::gen_bounds(rng, CurveBounds::All)
}
fn curve_val<R: Rng + ?Sized>(&self, rng: &mut R) -> CurvePoint {
R::gen(rng)
}
fn scalar_cond<R: Rng + ?Sized>(&self, rng: &mut R) -> FieldBounds<ScalarField> {
FieldBounds::gen_bounds(rng, self.scalar_cond_bounds)
}
fn scalar_pos<R: Rng + ?Sized>(&self, rng: &mut R) -> FieldBounds<ScalarField> {
FieldBounds::gen_bounds(rng, self.scalar_pos_bounds)
}
fn scalar_eda<R: Rng + ?Sized>(&self, rng: &mut R) -> FieldBounds<ScalarField> {
let eda_size = rng.gen_range(1..=self.max_eda_size);
below_power_of_two(eda_size)
}
fn scalar_int<R: Rng + ?Sized>(&self, rng: &mut R) -> ScalarField {
ScalarField::random(rng)
}
fn base_field_cond<R: Rng + ?Sized>(&self, rng: &mut R) -> FieldBounds<BaseField> {
FieldBounds::gen_bounds(rng, self.base_field_cond_bounds)
}
fn base_field_pos<R: Rng + ?Sized>(&self, rng: &mut R) -> FieldBounds<BaseField> {
FieldBounds::gen_bounds(rng, self.base_field_pos_bounds)
}
fn base_field_eda<R: Rng + ?Sized>(&self, rng: &mut R) -> FieldBounds<BaseField> {
let eda_size = rng.gen_range(1..=self.max_eda_size);
below_power_of_two(eda_size)
}
fn base_field_int<R: Rng + ?Sized>(&self, rng: &mut R) -> BaseField {
BaseField::random(rng)
}
}
#[test]
fn bounds_test() {
let rng = &mut crate::utils::test_rng::get();
for bound in [Number::from(1), 4.into(), Number::power_of_two(255)] {
let mut gen_helper = BoundExprGenHelper::new(&bound);
for _ in 0..4096 {
let expr = gen_helper.expr(rng);
let deps_are_all_constant = expr
.clone()
.apply_2(&mut BoundWrap)
.get_deps()
.iter()
.all(|x| x.as_constant().is_some());
let val_expr = expr.clone().apply_2(&mut BoundSampler(rng));
let val = val_expr.clone().eval();
if val.is_err() {
continue;
}
let bounds = expr.clone().bounds();
let val = val.unwrap();
if !bounds.contains(val) {
panic!("{bounds:?} do not contain {val:?}, from {val_expr:?}, from {expr:?}");
}
if deps_are_all_constant
&& expr.is_eval_deterministic_fn_from_deps()
&& bounds.as_constant().is_none()
{
panic!("{bounds:?} are not constant {val:?}, from {val_expr:?}, from {expr:?}");
}
}
}
}
}