Skip to main content

ark_r1cs_std/boolean/
mod.rs

1use ark_ff::{BitIteratorBE, Field, PrimeField};
2
3use crate::{fields::fp::FpVar, prelude::*, Vec};
4use ark_relations::gr1cs::{
5    ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable,
6};
7use core::borrow::Borrow;
8
9mod allocated;
10mod and;
11mod cmp;
12mod convert;
13mod eq;
14mod not;
15mod or;
16mod select;
17mod xor;
18
19pub use allocated::AllocatedBool;
20
21#[cfg(test)]
22mod test_utils;
23
24/// Represents a boolean value in the constraint system which is guaranteed
25/// to be either zero or one.
26#[derive(Clone, Debug, Eq, PartialEq)]
27#[must_use]
28pub enum Boolean<F: Field> {
29    /// A variable boolean value.
30    Var(AllocatedBool<F>),
31    /// A constant boolean value.
32    Constant(bool),
33}
34
35impl<F: Field> GR1CSVar<F> for Boolean<F> {
36    type Value = bool;
37
38    fn cs(&self) -> ConstraintSystemRef<F> {
39        match self {
40            Self::Var(a) => a.cs.clone(),
41            _ => ConstraintSystemRef::None,
42        }
43    }
44
45    fn value(&self) -> Result<Self::Value, SynthesisError> {
46        match self {
47            Boolean::Constant(c) => Ok(*c),
48            Boolean::Var(ref v) => v.value(),
49        }
50    }
51}
52
53impl<F: Field> Boolean<F> {
54    /// The constant `true`.
55    pub const TRUE: Self = Boolean::Constant(true);
56
57    /// The constant `false`.
58    pub const FALSE: Self = Boolean::Constant(false);
59
60    /// Constructs a `Boolean` vector from a slice of constant `u8`.
61    /// The `u8`s are decomposed in little-endian manner.
62    ///
63    /// This *does not* create any new variables or constraints.
64    ///
65    /// ```
66    /// # fn main() -> Result<(), ark_relations::gr1cs::SynthesisError> {
67    /// // We'll use the BLS12-381 scalar field for our constraints.
68    /// use ark_test_curves::bls12_381::Fr;
69    /// use ark_relations::gr1cs::*;
70    /// use ark_r1cs_std::prelude::*;
71    ///
72    /// let cs = ConstraintSystem::<Fr>::new_ref();
73    /// let t = Boolean::<Fr>::TRUE;
74    /// let f = Boolean::<Fr>::FALSE;
75    ///
76    /// let bits = vec![f, t];
77    /// let generated_bits = Boolean::constant_vec_from_bytes(&[2]);
78    /// bits[..2].enforce_equal(&generated_bits[..2])?;
79    /// assert!(cs.is_satisfied().unwrap());
80    /// # Ok(())
81    /// # }
82    /// ```
83    pub fn constant_vec_from_bytes(values: &[u8]) -> Vec<Self> {
84        let mut bits = vec![];
85        for byte in values {
86            for i in 0..8 {
87                bits.push(Self::Constant(((byte >> i) & 1u8) == 1u8));
88            }
89        }
90        bits
91    }
92
93    /// Constructs a constant `Boolean` with value `b`.
94    ///
95    /// This *does not* create any new variables or constraints.
96    /// ```
97    /// # fn main() -> Result<(), ark_relations::gr1cs::SynthesisError> {
98    /// // We'll use the BLS12-381 scalar field for our constraints.
99    /// use ark_test_curves::bls12_381::Fr;
100    /// use ark_r1cs_std::prelude::*;
101    ///
102    /// let true_var = Boolean::<Fr>::TRUE;
103    /// let false_var = Boolean::<Fr>::FALSE;
104    ///
105    /// true_var.enforce_equal(&Boolean::TRUE)?;
106    /// false_var.enforce_equal(&Boolean::constant(false))?;
107    /// # Ok(())
108    /// # }
109    /// ```
110    pub fn constant(b: bool) -> Self {
111        Boolean::Constant(b)
112    }
113
114    /// Constructs a `LinearCombination` from `Self`'s variables according
115    /// to the following map.
116    ///
117    /// * `Boolean::TRUE => lc!() + Variable::One`
118    /// * `Boolean::FALSE => lc!()`
119    /// * `Boolean::Var(v) => lc!() + v.variable()`
120    pub fn lc(&self) -> LinearCombination<F> {
121        match self {
122            &Boolean::Constant(false) => lc!(),
123            &Boolean::Constant(true) => Variable::One.into(),
124            Boolean::Var(v) => v.variable().into(),
125        }
126    }
127
128    /// Constructs a `Variable` from `Self`'s variables according
129    /// to the following map.
130    ///
131    /// * `Boolean::TRUE => Variable::One`
132    /// * `Boolean::FALSE => Variable::Zero``
133    /// * `Boolean::Var(v) => v.variable()`
134    pub fn variable(&self) -> Variable {
135        match self {
136            &Boolean::Constant(false) => Variable::Zero,
137            &Boolean::Constant(true) => Variable::One,
138            Boolean::Var(v) => v.variable(),
139        }
140    }
141
142    /// Convert a little-endian bitwise representation of a field element to
143    /// `FpVar<F>`
144    ///
145    /// Wraps around if the bit representation is larger than the field modulus.
146    #[tracing::instrument(target = "gr1cs", skip(bits))]
147    pub fn le_bits_to_fp(bits: &[Self]) -> Result<FpVar<F>, SynthesisError>
148    where
149        F: PrimeField,
150    {
151        // Compute the value of the `FpVar` variable via double-and-add.
152        let mut value = None;
153        let cs = bits.cs();
154        // Assign a value only when `cs` is in setup mode, or if we are constructing
155        // a constant.
156        let should_construct_value = (!cs.is_in_setup_mode()) || bits.is_constant();
157        if should_construct_value {
158            let bits = bits.iter().map(|b| b.value().unwrap()).collect::<Vec<_>>();
159            let bytes = bits
160                .chunks(8)
161                .map(|c| {
162                    let mut value = 0u8;
163                    for (i, &bit) in c.iter().enumerate() {
164                        value += (bit as u8) << i;
165                    }
166                    value
167                })
168                .collect::<Vec<_>>();
169            value = Some(F::from_le_bytes_mod_order(&bytes));
170        }
171
172        if bits.is_constant() {
173            Ok(FpVar::constant(value.unwrap()))
174        } else {
175            let mut power = F::one();
176            // Compute a linear combination for the new field variable, again
177            // via double and add.
178
179            let combined = bits
180                .iter()
181                .map(|b| {
182                    let result = FpVar::from(b.clone()) * power;
183                    power.double_in_place();
184                    result
185                })
186                .sum();
187            // If the number of bits is less than the size of the field,
188            // then we do not need to enforce that the element is less than
189            // the modulus.
190            if bits.len() >= F::MODULUS_BIT_SIZE as usize {
191                Self::enforce_in_field_le(bits)?;
192            }
193            Ok(combined)
194        }
195    }
196}
197
198impl<F: Field> From<AllocatedBool<F>> for Boolean<F> {
199    fn from(b: AllocatedBool<F>) -> Self {
200        Boolean::Var(b)
201    }
202}
203
204impl<F: Field> AllocVar<bool, F> for Boolean<F> {
205    fn new_variable<T: Borrow<bool>>(
206        cs: impl Into<Namespace<F>>,
207        f: impl FnOnce() -> Result<T, SynthesisError>,
208        mode: AllocationMode,
209    ) -> Result<Self, SynthesisError> {
210        if mode == AllocationMode::Constant {
211            Ok(Boolean::Constant(*f()?.borrow()))
212        } else {
213            AllocatedBool::new_variable(cs, f, mode).map(Boolean::Var)
214        }
215    }
216}
217
218#[cfg(test)]
219mod test {
220    use super::Boolean;
221    use crate::{convert::ToBytesGadget, prelude::*};
222    use ark_ff::{
223        AdditiveGroup, BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand,
224    };
225    use ark_relations::gr1cs::{ConstraintSystem, SynthesisError};
226    use ark_test_curves::bls12_381::Fr;
227
228    #[test]
229    fn test_boolean_to_byte() -> Result<(), SynthesisError> {
230        for val in [true, false].iter() {
231            let cs = ConstraintSystem::<Fr>::new_ref();
232            let a = Boolean::new_witness(cs.clone(), || Ok(*val))?;
233            let bytes = a.to_bytes_le()?;
234            assert_eq!(bytes.len(), 1);
235            let byte = &bytes[0];
236            assert_eq!(byte.value()?, *val as u8);
237
238            for (i, bit) in byte.bits.iter().enumerate() {
239                assert_eq!(bit.value()?, (byte.value()? >> i) & 1 == 1);
240            }
241        }
242        Ok(())
243    }
244
245    #[test]
246    fn test_smaller_than_or_equal_to() -> Result<(), SynthesisError> {
247        let mut rng = ark_std::test_rng();
248        for _ in 0..1000 {
249            let mut r = Fr::rand(&mut rng);
250            let mut s = Fr::rand(&mut rng);
251            if r > s {
252                core::mem::swap(&mut r, &mut s)
253            }
254
255            let cs = ConstraintSystem::<Fr>::new_ref();
256
257            let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect();
258            let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?;
259            Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?;
260
261            assert!(cs.is_satisfied().unwrap());
262        }
263
264        for _ in 0..1000 {
265            let r = Fr::rand(&mut rng);
266            if r == -Fr::one() {
267                continue;
268            }
269            let s = r + Fr::one();
270            let s2 = r.double();
271            let cs = ConstraintSystem::<Fr>::new_ref();
272
273            let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect();
274            let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?;
275            Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?;
276            if r < s2 {
277                Boolean::enforce_smaller_or_equal_than_le(&bits, s2.into_bigint())?;
278            }
279
280            assert!(cs.is_satisfied().unwrap());
281        }
282        Ok(())
283    }
284
285    #[test]
286    fn test_enforce_in_field() -> Result<(), SynthesisError> {
287        {
288            let cs = ConstraintSystem::<Fr>::new_ref();
289
290            let mut bits = vec![];
291            for b in BitIteratorBE::new(Fr::characteristic()).skip(1) {
292                bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?);
293            }
294            bits.reverse();
295
296            Boolean::enforce_in_field_le(&bits)?;
297
298            assert!(!cs.is_satisfied().unwrap());
299        }
300
301        let mut rng = ark_std::test_rng();
302
303        for _ in 0..1000 {
304            let r = Fr::rand(&mut rng);
305            let cs = ConstraintSystem::<Fr>::new_ref();
306
307            let mut bits = vec![];
308            for b in BitIteratorBE::new(r.into_bigint()).skip(1) {
309                bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?);
310            }
311            bits.reverse();
312
313            Boolean::enforce_in_field_le(&bits)?;
314
315            assert!(cs.is_satisfied().unwrap());
316        }
317        Ok(())
318    }
319
320    #[test]
321    fn test_bits_to_fp() -> Result<(), SynthesisError> {
322        use AllocationMode::*;
323        let rng = &mut ark_std::test_rng();
324        let cs = ConstraintSystem::<Fr>::new_ref();
325
326        let modes = [Input, Witness, Constant];
327        for &mode in modes.iter() {
328            for _ in 0..1000 {
329                let f = Fr::rand(rng);
330                let bits = BitIteratorLE::new(f.into_bigint()).collect::<Vec<_>>();
331                let bits: Vec<_> =
332                    AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?;
333                let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?;
334                let claimed_f = Boolean::le_bits_to_fp(&bits)?;
335                claimed_f.enforce_equal(&f)?;
336            }
337
338            for _ in 0..1000 {
339                let f = Fr::from(u64::rand(rng));
340                let bits = BitIteratorLE::new(f.into_bigint()).collect::<Vec<_>>();
341                let bits: Vec<_> =
342                    AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?;
343                let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?;
344                let claimed_f = Boolean::le_bits_to_fp(&bits)?;
345                claimed_f.enforce_equal(&f)?;
346            }
347            assert!(cs.is_satisfied().unwrap());
348        }
349
350        Ok(())
351    }
352}