ark_r1cs_std/boolean/
mod.rs

1use ark_ff::{BitIteratorBE, Field, PrimeField};
2
3use crate::{fields::fp::FpVar, prelude::*, Vec};
4use ark_relations::r1cs::{
5    ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable,
6};
7use core::borrow::Borrow;
8
9mod allocated;
10mod and;
11mod cmp;
12mod convert;
13mod eq;
14mod not;
15mod or;
16mod select;
17mod xor;
18
19pub use allocated::AllocatedBool;
20
21#[cfg(test)]
22mod test_utils;
23
24/// Represents a boolean value in the constraint system which is guaranteed
25/// to be either zero or one.
26#[derive(Clone, Debug, Eq, PartialEq)]
27#[must_use]
28pub enum Boolean<F: Field> {
29    Var(AllocatedBool<F>),
30    Constant(bool),
31}
32
33impl<F: Field> R1CSVar<F> for Boolean<F> {
34    type Value = bool;
35
36    fn cs(&self) -> ConstraintSystemRef<F> {
37        match self {
38            Self::Var(a) => a.cs.clone(),
39            _ => ConstraintSystemRef::None,
40        }
41    }
42
43    fn value(&self) -> Result<Self::Value, SynthesisError> {
44        match self {
45            Boolean::Constant(c) => Ok(*c),
46            Boolean::Var(ref v) => v.value(),
47        }
48    }
49}
50
51impl<F: Field> Boolean<F> {
52    /// The constant `true`.
53    pub const TRUE: Self = Boolean::Constant(true);
54
55    /// The constant `false`.
56    pub const FALSE: Self = Boolean::Constant(false);
57
58    /// Constructs a `Boolean` vector from a slice of constant `u8`.
59    /// The `u8`s are decomposed in little-endian manner.
60    ///
61    /// This *does not* create any new variables or constraints.
62    ///
63    /// ```
64    /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
65    /// // We'll use the BLS12-381 scalar field for our constraints.
66    /// use ark_test_curves::bls12_381::Fr;
67    /// use ark_relations::r1cs::*;
68    /// use ark_r1cs_std::prelude::*;
69    ///
70    /// let cs = ConstraintSystem::<Fr>::new_ref();
71    /// let t = Boolean::<Fr>::TRUE;
72    /// let f = Boolean::<Fr>::FALSE;
73    ///
74    /// let bits = vec![f, t];
75    /// let generated_bits = Boolean::constant_vec_from_bytes(&[2]);
76    /// bits[..2].enforce_equal(&generated_bits[..2])?;
77    /// assert!(cs.is_satisfied().unwrap());
78    /// # Ok(())
79    /// # }
80    /// ```
81    pub fn constant_vec_from_bytes(values: &[u8]) -> Vec<Self> {
82        let mut bits = vec![];
83        for byte in values {
84            for i in 0..8 {
85                bits.push(Self::Constant(((byte >> i) & 1u8) == 1u8));
86            }
87        }
88        bits
89    }
90
91    /// Constructs a constant `Boolean` with value `b`.
92    ///
93    /// This *does not* create any new variables or constraints.
94    /// ```
95    /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
96    /// // We'll use the BLS12-381 scalar field for our constraints.
97    /// use ark_test_curves::bls12_381::Fr;
98    /// use ark_r1cs_std::prelude::*;
99    ///
100    /// let true_var = Boolean::<Fr>::TRUE;
101    /// let false_var = Boolean::<Fr>::FALSE;
102    ///
103    /// true_var.enforce_equal(&Boolean::TRUE)?;
104    /// false_var.enforce_equal(&Boolean::constant(false))?;
105    /// # Ok(())
106    /// # }
107    /// ```
108    pub fn constant(b: bool) -> Self {
109        Boolean::Constant(b)
110    }
111
112    /// Constructs a `LinearCombination` from `Self`'s variables according
113    /// to the following map.
114    ///
115    /// * `Boolean::TRUE => lc!() + Variable::One`
116    /// * `Boolean::FALSE => lc!()`
117    /// * `Boolean::Var(v) => lc!() + v.variable()`
118    pub fn lc(&self) -> LinearCombination<F> {
119        match self {
120            &Boolean::Constant(false) => lc!(),
121            &Boolean::Constant(true) => lc!() + Variable::One,
122            Boolean::Var(v) => v.variable().into(),
123        }
124    }
125
126    /// Convert a little-endian bitwise representation of a field element to
127    /// `FpVar<F>`
128    ///
129    /// Wraps around if the bit representation is larger than the field modulus.
130    #[tracing::instrument(target = "r1cs", skip(bits))]
131    pub fn le_bits_to_fp(bits: &[Self]) -> Result<FpVar<F>, SynthesisError>
132    where
133        F: PrimeField,
134    {
135        // Compute the value of the `FpVar` variable via double-and-add.
136        let mut value = None;
137        let cs = bits.cs();
138        // Assign a value only when `cs` is in setup mode, or if we are constructing
139        // a constant.
140        let should_construct_value = (!cs.is_in_setup_mode()) || bits.is_constant();
141        if should_construct_value {
142            let bits = bits.iter().map(|b| b.value().unwrap()).collect::<Vec<_>>();
143            let bytes = bits
144                .chunks(8)
145                .map(|c| {
146                    let mut value = 0u8;
147                    for (i, &bit) in c.iter().enumerate() {
148                        value += (bit as u8) << i;
149                    }
150                    value
151                })
152                .collect::<Vec<_>>();
153            value = Some(F::from_le_bytes_mod_order(&bytes));
154        }
155
156        if bits.is_constant() {
157            Ok(FpVar::constant(value.unwrap()))
158        } else {
159            let mut power = F::one();
160            // Compute a linear combination for the new field variable, again
161            // via double and add.
162
163            let combined = bits
164                .iter()
165                .map(|b| {
166                    let result = FpVar::from(b.clone()) * power;
167                    power.double_in_place();
168                    result
169                })
170                .sum();
171            // If the number of bits is less than the size of the field,
172            // then we do not need to enforce that the element is less than
173            // the modulus.
174            if bits.len() >= F::MODULUS_BIT_SIZE as usize {
175                Self::enforce_in_field_le(bits)?;
176            }
177            Ok(combined)
178        }
179    }
180}
181
182impl<F: Field> From<AllocatedBool<F>> for Boolean<F> {
183    fn from(b: AllocatedBool<F>) -> Self {
184        Boolean::Var(b)
185    }
186}
187
188impl<F: Field> AllocVar<bool, F> for Boolean<F> {
189    fn new_variable<T: Borrow<bool>>(
190        cs: impl Into<Namespace<F>>,
191        f: impl FnOnce() -> Result<T, SynthesisError>,
192        mode: AllocationMode,
193    ) -> Result<Self, SynthesisError> {
194        if mode == AllocationMode::Constant {
195            Ok(Boolean::Constant(*f()?.borrow()))
196        } else {
197            AllocatedBool::new_variable(cs, f, mode).map(Boolean::Var)
198        }
199    }
200}
201
202#[cfg(test)]
203mod test {
204    use super::Boolean;
205    use crate::convert::ToBytesGadget;
206    use crate::prelude::*;
207    use ark_ff::{
208        AdditiveGroup, BitIteratorBE, BitIteratorLE, Field, One, PrimeField, UniformRand,
209    };
210    use ark_relations::r1cs::{ConstraintSystem, SynthesisError};
211    use ark_test_curves::bls12_381::Fr;
212
213    #[test]
214    fn test_boolean_to_byte() -> Result<(), SynthesisError> {
215        for val in [true, false].iter() {
216            let cs = ConstraintSystem::<Fr>::new_ref();
217            let a = Boolean::new_witness(cs.clone(), || Ok(*val))?;
218            let bytes = a.to_bytes_le()?;
219            assert_eq!(bytes.len(), 1);
220            let byte = &bytes[0];
221            assert_eq!(byte.value()?, *val as u8);
222
223            for (i, bit) in byte.bits.iter().enumerate() {
224                assert_eq!(bit.value()?, (byte.value()? >> i) & 1 == 1);
225            }
226        }
227        Ok(())
228    }
229
230    #[test]
231    fn test_smaller_than_or_equal_to() -> Result<(), SynthesisError> {
232        let mut rng = ark_std::test_rng();
233        for _ in 0..1000 {
234            let mut r = Fr::rand(&mut rng);
235            let mut s = Fr::rand(&mut rng);
236            if r > s {
237                core::mem::swap(&mut r, &mut s)
238            }
239
240            let cs = ConstraintSystem::<Fr>::new_ref();
241
242            let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect();
243            let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?;
244            Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?;
245
246            assert!(cs.is_satisfied().unwrap());
247        }
248
249        for _ in 0..1000 {
250            let r = Fr::rand(&mut rng);
251            if r == -Fr::one() {
252                continue;
253            }
254            let s = r + Fr::one();
255            let s2 = r.double();
256            let cs = ConstraintSystem::<Fr>::new_ref();
257
258            let native_bits: Vec<_> = BitIteratorLE::new(r.into_bigint()).collect();
259            let bits = Vec::new_witness(cs.clone(), || Ok(native_bits))?;
260            Boolean::enforce_smaller_or_equal_than_le(&bits, s.into_bigint())?;
261            if r < s2 {
262                Boolean::enforce_smaller_or_equal_than_le(&bits, s2.into_bigint())?;
263            }
264
265            assert!(cs.is_satisfied().unwrap());
266        }
267        Ok(())
268    }
269
270    #[test]
271    fn test_enforce_in_field() -> Result<(), SynthesisError> {
272        {
273            let cs = ConstraintSystem::<Fr>::new_ref();
274
275            let mut bits = vec![];
276            for b in BitIteratorBE::new(Fr::characteristic()).skip(1) {
277                bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?);
278            }
279            bits.reverse();
280
281            Boolean::enforce_in_field_le(&bits)?;
282
283            assert!(!cs.is_satisfied().unwrap());
284        }
285
286        let mut rng = ark_std::test_rng();
287
288        for _ in 0..1000 {
289            let r = Fr::rand(&mut rng);
290            let cs = ConstraintSystem::<Fr>::new_ref();
291
292            let mut bits = vec![];
293            for b in BitIteratorBE::new(r.into_bigint()).skip(1) {
294                bits.push(Boolean::new_witness(cs.clone(), || Ok(b))?);
295            }
296            bits.reverse();
297
298            Boolean::enforce_in_field_le(&bits)?;
299
300            assert!(cs.is_satisfied().unwrap());
301        }
302        Ok(())
303    }
304
305    #[test]
306    fn test_bits_to_fp() -> Result<(), SynthesisError> {
307        use AllocationMode::*;
308        let rng = &mut ark_std::test_rng();
309        let cs = ConstraintSystem::<Fr>::new_ref();
310
311        let modes = [Input, Witness, Constant];
312        for &mode in modes.iter() {
313            for _ in 0..1000 {
314                let f = Fr::rand(rng);
315                let bits = BitIteratorLE::new(f.into_bigint()).collect::<Vec<_>>();
316                let bits: Vec<_> =
317                    AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?;
318                let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?;
319                let claimed_f = Boolean::le_bits_to_fp(&bits)?;
320                claimed_f.enforce_equal(&f)?;
321            }
322
323            for _ in 0..1000 {
324                let f = Fr::from(u64::rand(rng));
325                let bits = BitIteratorLE::new(f.into_bigint()).collect::<Vec<_>>();
326                let bits: Vec<_> =
327                    AllocVar::new_variable(cs.clone(), || Ok(bits.as_slice()), mode)?;
328                let f = AllocVar::new_variable(cs.clone(), || Ok(f), mode)?;
329                let claimed_f = Boolean::le_bits_to_fp(&bits)?;
330                claimed_f.enforce_equal(&f)?;
331            }
332            assert!(cs.is_satisfied().unwrap());
333        }
334
335        Ok(())
336    }
337}