ark_r1cs_std/uint/
shr.rs

1use ark_ff::PrimeField;
2use ark_relations::r1cs::SynthesisError;
3use ark_std::{ops::Shr, ops::ShrAssign};
4
5use crate::boolean::Boolean;
6
7use super::{PrimUInt, UInt};
8
9impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
10    fn _shr_u128(&self, other: u128) -> Result<Self, SynthesisError> {
11        if other < N as u128 {
12            let mut bits = [Boolean::FALSE; N];
13            for (a, b) in bits.iter_mut().zip(&self.bits[other as usize..]) {
14                *a = b.clone();
15            }
16
17            let value = self.value.and_then(|a| Some(a >> other));
18            Ok(Self { bits, value })
19        } else {
20            panic!("attempt to shift right with overflow")
21        }
22    }
23}
24
25impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for UInt<N, T, F> {
26    type Output = Self;
27
28    /// Output `self >> other`.
29    ///
30    /// If at least one of `self` and `other` are constants, then this method
31    /// *does not* create any constraints or variables.
32    ///
33    /// ```
34    /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
35    /// // We'll use the BLS12-381 scalar field for our constraints.
36    /// use ark_test_curves::bls12_381::Fr;
37    /// use ark_relations::r1cs::*;
38    /// use ark_r1cs_std::prelude::*;
39    ///
40    /// let cs = ConstraintSystem::<Fr>::new_ref();
41    /// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
42    /// let b = 1u8;
43    /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?;
44    ///
45    /// (a >> b).enforce_equal(&c)?;
46    /// assert!(cs.is_satisfied().unwrap());
47    /// # Ok(())
48    /// # }
49    /// ```
50    #[tracing::instrument(target = "r1cs", skip(self, other))]
51    fn shr(self, other: T2) -> Self::Output {
52        self._shr_u128(other.into()).unwrap()
53    }
54}
55
56impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for &'a UInt<N, T, F> {
57    type Output = UInt<N, T, F>;
58
59    #[tracing::instrument(target = "r1cs", skip(self, other))]
60    fn shr(self, other: T2) -> Self::Output {
61        self._shr_u128(other.into()).unwrap()
62    }
63}
64
65impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> ShrAssign<T2> for UInt<N, T, F> {
66    /// Sets `self = self >> other`.
67    ///
68    /// If at least one of `self` and `other` are constants, then this method
69    /// *does not* create any constraints or variables.
70    ///
71    /// ```
72    /// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
73    /// // We'll use the BLS12-381 scalar field for our constraints.
74    /// use ark_test_curves::bls12_381::Fr;
75    /// use ark_relations::r1cs::*;
76    /// use ark_r1cs_std::prelude::*;
77    ///
78    /// let cs = ConstraintSystem::<Fr>::new_ref();
79    /// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
80    /// let b = 1u8;
81    /// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?;
82    ///
83    /// a >>= b;
84    /// a.enforce_equal(&c)?;
85    /// assert!(cs.is_satisfied().unwrap());
86    /// # Ok(())
87    /// # }
88    /// ```
89    #[tracing::instrument(target = "r1cs", skip(self, other))]
90    fn shr_assign(&mut self, other: T2) {
91        let result = self._shr_u128(other.into()).unwrap();
92        *self = result;
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::{
100        alloc::{AllocVar, AllocationMode},
101        prelude::EqGadget,
102        uint::test_utils::{run_binary_exhaustive_native_only, run_binary_random_native_only},
103        R1CSVar,
104    };
105    use ark_ff::PrimeField;
106    use ark_test_curves::bls12_381::Fr;
107
108    fn uint_shr<T: PrimUInt, const N: usize, F: PrimeField>(
109        a: UInt<N, T, F>,
110        b: T,
111    ) -> Result<(), SynthesisError> {
112        let cs = a.cs();
113        let b = b.into() % (N as u128);
114        let computed = &a >> b;
115        let expected_mode = if a.is_constant() {
116            AllocationMode::Constant
117        } else {
118            AllocationMode::Witness
119        };
120        let expected =
121            UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? >> b), expected_mode)?;
122        assert_eq!(expected.value(), computed.value());
123        expected.enforce_equal(&computed)?;
124        if !a.is_constant() {
125            assert!(cs.is_satisfied().unwrap());
126        }
127        Ok(())
128    }
129
130    #[test]
131    fn u8_shr() {
132        run_binary_exhaustive_native_only(uint_shr::<u8, 8, Fr>).unwrap()
133    }
134
135    #[test]
136    fn u16_shr() {
137        run_binary_random_native_only::<1000, 16, _, _>(uint_shr::<u16, 16, Fr>).unwrap()
138    }
139
140    #[test]
141    fn u32_shr() {
142        run_binary_random_native_only::<1000, 32, _, _>(uint_shr::<u32, 32, Fr>).unwrap()
143    }
144
145    #[test]
146    fn u64_shr() {
147        run_binary_random_native_only::<1000, 64, _, _>(uint_shr::<u64, 64, Fr>).unwrap()
148    }
149
150    #[test]
151    fn u128_shr() {
152        run_binary_random_native_only::<1000, 128, _, _>(uint_shr::<u128, 128, Fr>).unwrap()
153    }
154}