Skip to main content

ark_r1cs_std/uint/
select.rs

1use super::*;
2
3impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> CondSelectGadget<ConstraintF>
4    for UInt<N, T, ConstraintF>
5{
6    #[tracing::instrument(target = "gr1cs", skip(cond, true_value, false_value))]
7    fn conditionally_select(
8        cond: &Boolean<ConstraintF>,
9        true_value: &Self,
10        false_value: &Self,
11    ) -> Result<Self, SynthesisError> {
12        let selected_bits = true_value
13            .bits
14            .iter()
15            .zip(&false_value.bits)
16            .map(|(t, f)| cond.select(t, f));
17        let mut bits = [Boolean::FALSE; N];
18        for (result, new) in bits.iter_mut().zip(selected_bits) {
19            *result = new?;
20        }
21
22        let value = cond.value().ok().and_then(|cond| {
23            if cond {
24                true_value.value().ok()
25            } else {
26                false_value.value().ok()
27            }
28        });
29        Ok(Self { bits, value })
30    }
31}
32
33#[cfg(test)]
34mod tests {
35    use super::*;
36    use crate::uint::test_utils::{run_binary_exhaustive, run_binary_random};
37    use ark_test_curves::bls12_381::Fr;
38
39    fn uint_select<T: PrimUInt, const N: usize, F: PrimeField>(
40        a: UInt<N, T, F>,
41        b: UInt<N, T, F>,
42    ) -> Result<(), SynthesisError> {
43        let cs = a.cs().or(b.cs());
44        let both_constant = a.is_constant() && b.is_constant();
45        let expected_mode = if both_constant {
46            AllocationMode::Constant
47        } else {
48            AllocationMode::Witness
49        };
50        for cond in [true, false] {
51            let expected = UInt::new_variable(
52                cs.clone(),
53                || Ok(if cond { a.value()? } else { b.value()? }),
54                expected_mode,
55            )?;
56            let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?;
57            let computed = cond.select(&a, &b)?;
58
59            assert_eq!(expected.value(), computed.value());
60            expected.enforce_equal(&computed)?;
61            if !both_constant {
62                assert!(cs.is_satisfied().unwrap());
63            }
64        }
65        Ok(())
66    }
67
68    #[test]
69    fn u8_select() {
70        run_binary_exhaustive(uint_select::<u8, 8, Fr>).unwrap()
71    }
72
73    #[test]
74    fn u16_select() {
75        run_binary_random::<1000, 16, _, _>(uint_select::<u16, 16, Fr>).unwrap()
76    }
77
78    #[test]
79    fn u32_select() {
80        run_binary_random::<1000, 32, _, _>(uint_select::<u32, 32, Fr>).unwrap()
81    }
82
83    #[test]
84    fn u64_select() {
85        run_binary_random::<1000, 64, _, _>(uint_select::<u64, 64, Fr>).unwrap()
86    }
87
88    #[test]
89    fn u128_select() {
90        run_binary_random::<1000, 128, _, _>(uint_select::<u128, 128, Fr>).unwrap()
91    }
92}