ark_r1cs_std/boolean/select.rs
1use super::*;
2
3impl<F: PrimeField> Boolean<F> {
4 /// Conditionally selects one of `first` and `second` based on the value of
5 /// `self`:
6 ///
7 /// If `self.is_eq(&Boolean::TRUE)`, this outputs `first`; else, it outputs
8 /// `second`.
9 /// ```
10 /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
11 /// // We'll use the BLS12-381 scalar field for our constraints.
12 /// use ark_test_curves::bls12_381::Fr;
13 /// use ark_relations::r1cs::*;
14 /// use ark_r1cs_std::prelude::*;
15 ///
16 /// let cs = ConstraintSystem::<Fr>::new_ref();
17 ///
18 /// let a = Boolean::new_witness(cs.clone(), || Ok(true))?;
19 /// let b = Boolean::new_witness(cs.clone(), || Ok(false))?;
20 ///
21 /// let cond = Boolean::new_witness(cs.clone(), || Ok(true))?;
22 ///
23 /// cond.select(&a, &b)?.enforce_equal(&Boolean::TRUE)?;
24 /// cond.select(&b, &a)?.enforce_equal(&Boolean::FALSE)?;
25 ///
26 /// assert!(cs.is_satisfied().unwrap());
27 /// # Ok(())
28 /// # }
29 /// ```
30 #[tracing::instrument(target = "r1cs", skip(first, second))]
31 pub fn select<T: CondSelectGadget<F>>(
32 &self,
33 first: &T,
34 second: &T,
35 ) -> Result<T, SynthesisError> {
36 T::conditionally_select(&self, first, second)
37 }
38}
39impl<F: PrimeField> CondSelectGadget<F> for Boolean<F> {
40 #[tracing::instrument(target = "r1cs")]
41 fn conditionally_select(
42 cond: &Boolean<F>,
43 true_val: &Self,
44 false_val: &Self,
45 ) -> Result<Self, SynthesisError> {
46 use Boolean::*;
47 match cond {
48 Constant(true) => Ok(true_val.clone()),
49 Constant(false) => Ok(false_val.clone()),
50 cond @ Var(_) => match (true_val, false_val) {
51 (x, &Constant(false)) => Ok(cond & x),
52 (&Constant(false), x) => Ok((!cond) & x),
53 (&Constant(true), x) => Ok(cond | x),
54 (x, &Constant(true)) => Ok((!cond) | x),
55 (a, b) => {
56 let cs = cond.cs();
57 let result: Boolean<F> =
58 AllocatedBool::new_witness_without_booleanity_check(cs.clone(), || {
59 let cond = cond.value()?;
60 Ok(if cond { a.value()? } else { b.value()? })
61 })?
62 .into();
63 // a = self; b = other; c = cond;
64 //
65 // r = c * a + (1 - c) * b
66 // r = b + c * (a - b)
67 // c * (a - b) = r - b
68 //
69 // If a, b, cond are all boolean, so is r.
70 //
71 // self | other | cond | result
72 // -----|-------|----------------
73 // 0 | 0 | 1 | 0
74 // 0 | 1 | 1 | 0
75 // 1 | 0 | 1 | 1
76 // 1 | 1 | 1 | 1
77 // 0 | 0 | 0 | 0
78 // 0 | 1 | 0 | 1
79 // 1 | 0 | 0 | 0
80 // 1 | 1 | 0 | 1
81 cs.enforce_constraint(
82 cond.lc(),
83 lc!() + a.lc() - b.lc(),
84 lc!() + result.lc() - b.lc(),
85 )?;
86
87 Ok(result)
88 },
89 },
90 }
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use crate::{
98 alloc::{AllocVar, AllocationMode},
99 boolean::test_utils::run_binary_exhaustive,
100 prelude::EqGadget,
101 R1CSVar,
102 };
103 use ark_test_curves::bls12_381::Fr;
104
105 #[test]
106 fn or() {
107 run_binary_exhaustive::<Fr>(|a, b| {
108 let cs = a.cs().or(b.cs());
109 let both_constant = a.is_constant() && b.is_constant();
110 let expected_mode = if both_constant {
111 AllocationMode::Constant
112 } else {
113 AllocationMode::Witness
114 };
115 for cond in [true, false] {
116 let expected = Boolean::new_variable(
117 cs.clone(),
118 || Ok(if cond { a.value()? } else { b.value()? }),
119 expected_mode,
120 )?;
121 let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?;
122 let computed = cond.select(&a, &b)?;
123
124 assert_eq!(expected.value(), computed.value());
125 expected.enforce_equal(&computed)?;
126 if !both_constant {
127 assert!(cs.is_satisfied().unwrap());
128 }
129 }
130 Ok(())
131 })
132 .unwrap()
133 }
134}