ark_r1cs_std/uint/
select.rs

1use super::*;
2use crate::select::CondSelectGadget;
3
4impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> CondSelectGadget<ConstraintF>
5    for UInt<N, T, ConstraintF>
6{
7    #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))]
8    fn conditionally_select(
9        cond: &Boolean<ConstraintF>,
10        true_value: &Self,
11        false_value: &Self,
12    ) -> Result<Self, SynthesisError> {
13        let selected_bits = true_value
14            .bits
15            .iter()
16            .zip(&false_value.bits)
17            .map(|(t, f)| cond.select(t, f));
18        let mut bits = [Boolean::FALSE; N];
19        for (result, new) in bits.iter_mut().zip(selected_bits) {
20            *result = new?;
21        }
22
23        let value = cond.value().ok().and_then(|cond| {
24            if cond {
25                true_value.value().ok()
26            } else {
27                false_value.value().ok()
28            }
29        });
30        Ok(Self { bits, value })
31    }
32}
33
34#[cfg(test)]
35mod tests {
36    use super::*;
37    use crate::{
38        alloc::{AllocVar, AllocationMode},
39        prelude::EqGadget,
40        uint::test_utils::{run_binary_exhaustive, run_binary_random},
41    };
42    use ark_ff::PrimeField;
43    use ark_test_curves::bls12_381::Fr;
44
45    fn uint_select<T: PrimUInt, const N: usize, F: PrimeField>(
46        a: UInt<N, T, F>,
47        b: UInt<N, T, F>,
48    ) -> Result<(), SynthesisError> {
49        let cs = a.cs().or(b.cs());
50        let both_constant = a.is_constant() && b.is_constant();
51        let expected_mode = if both_constant {
52            AllocationMode::Constant
53        } else {
54            AllocationMode::Witness
55        };
56        for cond in [true, false] {
57            let expected = UInt::new_variable(
58                cs.clone(),
59                || Ok(if cond { a.value()? } else { b.value()? }),
60                expected_mode,
61            )?;
62            let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?;
63            let computed = cond.select(&a, &b)?;
64
65            assert_eq!(expected.value(), computed.value());
66            expected.enforce_equal(&computed)?;
67            if !both_constant {
68                assert!(cs.is_satisfied().unwrap());
69            }
70        }
71        Ok(())
72    }
73
74    #[test]
75    fn u8_select() {
76        run_binary_exhaustive(uint_select::<u8, 8, Fr>).unwrap()
77    }
78
79    #[test]
80    fn u16_select() {
81        run_binary_random::<1000, 16, _, _>(uint_select::<u16, 16, Fr>).unwrap()
82    }
83
84    #[test]
85    fn u32_select() {
86        run_binary_random::<1000, 32, _, _>(uint_select::<u32, 32, Fr>).unwrap()
87    }
88
89    #[test]
90    fn u64_select() {
91        run_binary_random::<1000, 64, _, _>(uint_select::<u64, 64, Fr>).unwrap()
92    }
93
94    #[test]
95    fn u128_select() {
96        run_binary_random::<1000, 128, _, _>(uint_select::<u128, 128, Fr>).unwrap()
97    }
98}