ark_r1cs_std/boolean/
select.rs

1use super::*;
2
3impl<F: PrimeField> Boolean<F> {
4    /// Conditionally selects one of `first` and `second` based on the value of
5    /// `self`:
6    ///
7    /// If `self.is_eq(&Boolean::TRUE)`, this outputs `first`; else, it outputs
8    /// `second`.
9    /// ```
10    /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
11    /// // We'll use the BLS12-381 scalar field for our constraints.
12    /// use ark_test_curves::bls12_381::Fr;
13    /// use ark_relations::r1cs::*;
14    /// use ark_r1cs_std::prelude::*;
15    ///
16    /// let cs = ConstraintSystem::<Fr>::new_ref();
17    ///
18    /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?;
19    /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?;
20    ///
21    /// let cond = Boolean::new_witness(cs.clone(), || Ok(true))?;
22    ///
23    /// cond.select(&a, &b)?.enforce_equal(&Boolean::TRUE)?;
24    /// cond.select(&b, &a)?.enforce_equal(&Boolean::FALSE)?;
25    ///
26    /// assert!(cs.is_satisfied().unwrap());
27    /// # Ok(())
28    /// # }
29    /// ```
30    #[tracing::instrument(target = "r1cs", skip(first, second))]
31    pub fn select<T: CondSelectGadget<F>>(
32        &self,
33        first: &T,
34        second: &T,
35    ) -> Result<T, SynthesisError> {
36        T::conditionally_select(&self, first, second)
37    }
38}
39impl<F: PrimeField> CondSelectGadget<F> for Boolean<F> {
40    #[tracing::instrument(target = "r1cs")]
41    fn conditionally_select(
42        cond: &Boolean<F>,
43        true_val: &Self,
44        false_val: &Self,
45    ) -> Result<Self, SynthesisError> {
46        use Boolean::*;
47        match cond {
48            Constant(true) => Ok(true_val.clone()),
49            Constant(false) => Ok(false_val.clone()),
50            cond @ Var(_) => match (true_val, false_val) {
51                (x, &Constant(false)) => Ok(cond & x),
52                (&Constant(false), x) => Ok((!cond) & x),
53                (&Constant(true), x) => Ok(cond | x),
54                (x, &Constant(true)) => Ok((!cond) | x),
55                (a, b) => {
56                    let cs = cond.cs();
57                    let result: Boolean<F> =
58                        AllocatedBool::new_witness_without_booleanity_check(cs.clone(), || {
59                            let cond = cond.value()?;
60                            Ok(if cond { a.value()? } else { b.value()? })
61                        })?
62                        .into();
63                    // a = self; b = other; c = cond;
64                    //
65                    // r = c * a + (1  - c) * b
66                    // r = b + c * (a - b)
67                    // c * (a - b) = r - b
68                    //
69                    // If a, b, cond are all boolean, so is r.
70                    //
71                    // self | other | cond | result
72                    // -----|-------|----------------
73                    //   0  |   0   |   1  |    0
74                    //   0  |   1   |   1  |    0
75                    //   1  |   0   |   1  |    1
76                    //   1  |   1   |   1  |    1
77                    //   0  |   0   |   0  |    0
78                    //   0  |   1   |   0  |    1
79                    //   1  |   0   |   0  |    0
80                    //   1  |   1   |   0  |    1
81                    cs.enforce_constraint(
82                        cond.lc(),
83                        lc!() + a.lc() - b.lc(),
84                        lc!() + result.lc() - b.lc(),
85                    )?;
86
87                    Ok(result)
88                },
89            },
90        }
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::{
98        alloc::{AllocVar, AllocationMode},
99        boolean::test_utils::run_binary_exhaustive,
100        prelude::EqGadget,
101        R1CSVar,
102    };
103    use ark_test_curves::bls12_381::Fr;
104
105    #[test]
106    fn or() {
107        run_binary_exhaustive::<Fr>(|a, b| {
108            let cs = a.cs().or(b.cs());
109            let both_constant = a.is_constant() && b.is_constant();
110            let expected_mode = if both_constant {
111                AllocationMode::Constant
112            } else {
113                AllocationMode::Witness
114            };
115            for cond in [true, false] {
116                let expected = Boolean::new_variable(
117                    cs.clone(),
118                    || Ok(if cond { a.value()? } else { b.value()? }),
119                    expected_mode,
120                )?;
121                let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?;
122                let computed = cond.select(&a, &b)?;
123
124                assert_eq!(expected.value(), computed.value());
125                expected.enforce_equal(&computed)?;
126                if !both_constant {
127                    assert!(cs.is_satisfied().unwrap());
128                }
129            }
130            Ok(())
131        })
132        .unwrap()
133    }
134}