sigma_compiler/
rangeutils.rs

1//! A module containing some utility functions useful for the runtime
2//! processing of range statements.
3
4use group::ff::PrimeField;
5use sigma_proofs::errors::Error;
6use subtle::Choice;
7
8/// Convert a [`Scalar`] to an [`u128`], assuming it fits in an [`i128`]
9/// and is nonnegative.  Also output the number of bits of the
10/// [`Scalar`].  This version assumes that `s` is public, and so does
11/// not need to run in constant time.
12///
13/// [`Scalar`]: https://docs.rs/group/0.13.0/group/trait.Group.html#associatedtype.Scalar
14pub fn bit_decomp_vartime<S: PrimeField>(mut s: S) -> Option<(u128, u32)> {
15    let mut val = 0u128;
16    let mut bitnum = 0u32;
17    let mut bitval = 1u128; // Invariant: bitval = 2^bitnum
18    while bitnum < 127 && !s.is_zero_vartime() {
19        if s.is_odd().into() {
20            val += bitval;
21            s -= S::ONE;
22        }
23        bitnum += 1;
24        bitval <<= 1;
25        s *= S::TWO_INV;
26    }
27    if s.is_zero_vartime() {
28        Some((val, bitnum))
29    } else {
30        None
31    }
32}
33
34/// Convert the low `nbits` bits of the given [`Scalar`] to a vector of
35/// [`Choice`].  The first element of the vector is the low bit.  This
36/// version runs in constant time.
37///
38/// [`Scalar`]: https://docs.rs/group/0.13.0/group/trait.Group.html#associatedtype.Scalar
39pub fn bit_decomp<S: PrimeField>(mut s: S, nbits: u32) -> Vec<Choice> {
40    let mut bits = Vec::with_capacity(nbits as usize);
41    let mut bitnum = 0u32;
42    while bitnum < nbits && bitnum < 127 {
43        let lowbit = s.is_odd();
44        s -= S::conditional_select(&S::ZERO, &S::ONE, lowbit);
45        s *= S::TWO_INV;
46        bits.push(lowbit);
47        bitnum += 1;
48    }
49    bits
50}
51
52/// Given a [`Scalar`] `upper` (strictly greater than 1), make a vector
53/// of [`Scalar`]s with the property that a [`Scalar`] `x` can be
54/// written as a sum of zero or more (distinct) elements of this vector
55/// if and only if `0 <= x < upper`.
56///
57/// The strategy is to write x as a sequence of `nbits` bits, with one
58/// twist: the low bits represent 2^0, 2^1, 2^2, etc., as usual.  But
59/// the highest bit represents `upper-2^{nbits-1}` instead of the usual
60/// `2^{nbits-1}`.  `nbits` will be the largest value for which
61/// `2^{nbits-1}` is strictly less than `upper`.  For example, if
62/// `upper` is 100, the bits represent 1, 2, 4, 8, 16, 32, 36.  A number
63/// x can be represented as a sum of 0 or more elements of this sequence
64/// if and only if `0 <= x < upper`.
65///
66/// It is assumed that `upper` is public, and so this function is not
67/// constant time.
68///
69/// [`Scalar`]: https://docs.rs/group/0.13.0/group/trait.Group.html#associatedtype.Scalar
70pub fn bitrep_scalars_vartime<S: PrimeField>(upper: S) -> Result<Vec<S>, Error> {
71    // Get the `u128` value of `upper`, and its number of bits `nbits`
72    let (upper_val, mut nbits) = bit_decomp_vartime(upper).ok_or(Error::VerificationFailure)?;
73
74    // Ensure `nbits` is at least 2.
75    if nbits < 2 {
76        return Err(Error::VerificationFailure);
77    }
78
79    // If upper is exactly a power of 2, use one fewer bit
80    if upper_val == 1u128 << (nbits - 1) {
81        nbits -= 1;
82    }
83
84    // Make the vector of Scalars containing the represented value of
85    // the bits
86    Ok((0..nbits)
87        .map(|i| {
88            if i < nbits - 1 {
89                S::from_u128(1u128 << i)
90            } else {
91                // Compute the represented value of the highest bit
92                S::from_u128(upper_val - (1u128 << (nbits - 1)))
93            }
94        })
95        .collect())
96}
97
98/// Given a vector of [`Scalar`]s as output by
99/// [`bitrep_scalars_vartime`] and a private [`Scalar`] `x`, output a
100/// vector of [`Choice`] (of the same length as the given
101/// `bitrep_scalars` vector) such that `x` is the sum of the chosen
102/// elements of `bitrep_scalars`.  This function should be constant time
103/// in the value of `x`.  If `x` is not less than the `upper` used by
104/// [`bitrep_scalars_vartime`] to generate `bitrep_scalars`, then `x`
105/// will not (and indeed cannot) equal the sum of the chosen elements of
106/// `bitrep_scalars`.
107///
108/// [`Scalar`]: https://docs.rs/group/0.13.0/group/trait.Group.html#associatedtype.Scalar
109pub fn compute_bitrep<S: PrimeField>(mut x: S, bitrep_scalars: &[S]) -> Vec<Choice> {
110    // We know the length of bitrep_scalars is at most 127.
111    let nbits: u32 = bitrep_scalars.len().try_into().unwrap();
112
113    // Decompose `x` as a normal `nbit`-bit vector.  This only looks at
114    // the low `nbits` bits of `x`, so the resulting bit vector forces
115    // `x < 2^{nbits}`.
116    let x_raw_bits = bit_decomp(x, nbits);
117    let high_bit = x_raw_bits[(nbits as usize) - 1];
118
119    // Conditionally subtract the last represented value in the
120    // vector, depending on whether the high bit of x is set.  That is,
121    // if `x < 2^{nbits-1}`, then we don't subtract from x.  If `x >=
122    // 2^{nbits-1}`, then we will subtract `upper - 2^{nbits-1}` from
123    // `x`.  In either case, the remaining value is non-negative, and
124    // strictly less than 2^{nbits-1}.
125    x -= S::conditional_select(&S::ZERO, &bitrep_scalars[(nbits as usize) - 1], high_bit);
126
127    // Now get the `nbits-1` bits of the result in the usual way
128    let mut x_bits = bit_decomp(x, nbits - 1);
129
130    // and tack on the high bit
131    x_bits.push(high_bit);
132
133    x_bits
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use curve25519_dalek::scalar::Scalar;
140    use std::ops::Neg;
141    use subtle::ConditionallySelectable;
142
143    fn bit_decomp_tester(s: Scalar, nbits: u32, expect_bitstr: &str) {
144        // Convert the expected string of '0' and '1' into a vector of
145        // Choice
146        assert_eq!(
147            bit_decomp(s, nbits)
148                .into_iter()
149                .map(|c| char::from(u8::conditional_select(&b'0', &b'1', c)))
150                .collect::<String>(),
151            expect_bitstr
152        );
153    }
154
155    #[test]
156    fn bit_decomp_test() {
157        assert_eq!(bit_decomp_vartime(Scalar::from(0u32)), Some((0, 0)));
158        assert_eq!(bit_decomp_vartime(Scalar::from(1u32)), Some((1, 1)));
159        assert_eq!(bit_decomp_vartime(Scalar::from(2u32)), Some((2, 2)));
160        assert_eq!(bit_decomp_vartime(Scalar::from(3u32)), Some((3, 2)));
161        assert_eq!(bit_decomp_vartime(Scalar::from(4u32)), Some((4, 3)));
162        assert_eq!(bit_decomp_vartime(Scalar::from(5u32)), Some((5, 3)));
163        assert_eq!(bit_decomp_vartime(Scalar::from(6u32)), Some((6, 3)));
164        assert_eq!(bit_decomp_vartime(Scalar::from(7u32)), Some((7, 3)));
165        assert_eq!(bit_decomp_vartime(Scalar::from(8u32)), Some((8, 4)));
166        assert_eq!(bit_decomp_vartime(Scalar::from(1u32).neg()), None);
167        assert_eq!(
168            bit_decomp_vartime(Scalar::from((1u128 << 127) - 2)),
169            Some(((i128::MAX - 1) as u128, 127))
170        );
171        assert_eq!(
172            bit_decomp_vartime(Scalar::from((1u128 << 127) - 1)),
173            Some((i128::MAX as u128, 127))
174        );
175        assert_eq!(bit_decomp_vartime(Scalar::from(1u128 << 127)), None);
176
177        bit_decomp_tester(Scalar::from(0u32), 0, "");
178        bit_decomp_tester(Scalar::from(0u32), 5, "00000");
179        bit_decomp_tester(Scalar::from(1u32), 0, "");
180        bit_decomp_tester(Scalar::from(1u32), 1, "1");
181        bit_decomp_tester(Scalar::from(2u32), 1, "0");
182        bit_decomp_tester(Scalar::from(2u32), 2, "01");
183        bit_decomp_tester(Scalar::from(3u32), 1, "1");
184        bit_decomp_tester(Scalar::from(3u32), 2, "11");
185        bit_decomp_tester(Scalar::from(5u32), 8, "10100000");
186        // The order of this Scalar group is
187        // 0x1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed
188        bit_decomp_tester(
189            Scalar::from(1u32).neg(),
190            32,
191            "00110111110010111010111100111010",
192        );
193        bit_decomp_tester(Scalar::from((1u128 << 127) - 2), 127,
194        "0111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"
195        );
196        bit_decomp_tester(Scalar::from((1u128 << 127) - 1), 127,
197        "1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"
198        );
199        bit_decomp_tester(Scalar::from(1u128 << 127), 127,
200        "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
201        );
202        bit_decomp_tester(Scalar::from(1u128 << 127), 128,
203        "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
204        );
205    }
206
207    // Obliviously test whether x is in 0..upper (that is, 0 <= x <
208    // upper) using bit decomposition.  `upper` is considered public,
209    // but `x` is private.  `upper` must be at least 2.
210    fn bitrep_tester(upper: Scalar, x: Scalar, expected: bool) -> Result<(), Error> {
211        let rep_scalars = bitrep_scalars_vartime(upper)?;
212        let bitrep = compute_bitrep(x, &rep_scalars);
213
214        let nbits = bitrep.len();
215        assert!(nbits == rep_scalars.len());
216        let mut x_out = Scalar::ZERO;
217        for i in 0..nbits {
218            x_out += Scalar::conditional_select(&Scalar::ZERO, &rep_scalars[i], bitrep[i]);
219        }
220
221        if (x == x_out) != expected {
222            return Err(Error::VerificationFailure);
223        }
224
225        Ok(())
226    }
227
228    #[test]
229    fn bitrep_test() {
230        bitrep_tester(Scalar::from(0u32), Scalar::from(0u32), false).unwrap_err();
231        bitrep_tester(Scalar::from(1u32), Scalar::from(0u32), true).unwrap_err();
232        bitrep_tester(Scalar::from(2u32), Scalar::from(1u32), true).unwrap();
233        bitrep_tester(Scalar::from(3u32), Scalar::from(1u32), true).unwrap();
234        bitrep_tester(Scalar::from(100u32), Scalar::from(99u32), true).unwrap();
235        bitrep_tester(Scalar::from(127u32), Scalar::from(126u32), true).unwrap();
236        bitrep_tester(Scalar::from(128u32), Scalar::from(127u32), true).unwrap();
237        bitrep_tester(Scalar::from(128u32), Scalar::from(128u32), false).unwrap();
238        bitrep_tester(Scalar::from(129u32), Scalar::from(128u32), true).unwrap();
239        bitrep_tester(Scalar::from(129u32), Scalar::from(0u32), true).unwrap();
240        bitrep_tester(Scalar::from(129u32), Scalar::from(129u32), false).unwrap();
241    }
242}