cupcake/integer_arith/
scalar.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5use crate::integer_arith::{ArithOperators, ArithUtils, SuperTrait};
6use modinverse::modinverse;
7use rand::rngs::{StdRng,ThreadRng};
8use rand::{FromEntropy}; 
9use super::Rng; 
10use ::std::ops;
11pub use std::sync::Arc;
12
13impl Rng for StdRng {}
14impl Rng for ThreadRng {}
15
16/// The ScalarContext class contains useful auxilliary information for fast modular reduction against a Scalar instance.
17#[derive(Debug, PartialEq, Eq, Clone)]
18struct ScalarContext {
19    barrett_ratio: (u64, u64),
20}
21
22impl ScalarContext {
23    fn new(q: u64) -> Self {
24        let ratio = Self::compute_barrett_ratio(q);
25        ScalarContext {
26            barrett_ratio: ratio,
27        }
28    }
29
30    /// Compute floor(2^128/q) and put it in 2 u64s as (low-word, high-word)
31    fn compute_barrett_ratio(q: u64) -> (u64, u64) {
32        // 2^127 = s*q + t.
33        let a = 1u128 << 127;
34        let mut t = a % (q as u128);
35        let mut s = (a - t) / (q as u128);
36
37        s <<= 1;
38        t <<= 1;
39        if t >= (q as u128) {
40            s += 1;
41        }
42        (s as u64, (s >> 64) as u64)
43    }
44}
45
46/// The Scalar struct is a wrapper around u64 which has optional fast modular arithmetic through ScalarContext.
47#[derive(Debug, Clone)]
48pub struct Scalar {
49    context: Option<ScalarContext>,
50    rep: u64,
51    bit_count: usize,
52}
53
54impl Scalar {
55    /// Construct a new scalar from u64.
56    pub fn new(a: u64) -> Self {
57        Scalar {
58            rep: a,
59            context: None,
60            bit_count: 0,
61        }
62    }
63
64    pub fn rep(&self) -> u64{
65        self.rep
66    }
67}
68
69/// Trait implementations
70impl SuperTrait<Scalar> for Scalar {}
71
72impl PartialEq for Scalar {
73    fn eq(&self, other: &Self) -> bool {
74        self.rep == other.rep
75    }
76}
77
78// Conversions
79impl From<u32> for Scalar {
80    fn from(item: u32) -> Self {
81        Scalar {  context: None, rep: item as u64, bit_count: 0 }
82    }
83}
84
85impl From<u64> for Scalar {
86    fn from(item: u64) -> Self {
87        Scalar {  context: None, rep: item, bit_count: 0 }
88    }
89}
90
91impl From<Scalar> for u64{
92    fn from(item: Scalar) -> u64 {
93        item.rep
94    }
95}
96
97// Operators
98impl ops::Add<&Scalar> for Scalar {
99    type Output = Scalar;
100    fn add(self, v: &Scalar) -> Scalar {
101        Scalar::new(self.rep + v.rep)
102    }
103}
104
105impl ops::Add<Scalar> for Scalar {
106    type Output = Scalar;
107    fn add(self, v: Scalar) -> Scalar {
108        self + &v
109    }
110}
111
112impl ops::Sub<&Scalar> for Scalar {
113    type Output = Scalar;
114    fn sub(self, v: &Scalar) -> Scalar {
115         Scalar::new(self.rep - v.rep)
116    }
117}
118
119impl ops::Sub<Scalar> for Scalar {
120    type Output = Scalar;
121    fn sub(self, v: Scalar) -> Scalar {
122        self - &v
123    }
124}
125
126impl ops::Mul<u64> for Scalar {
127    type Output = Scalar;
128    fn mul(self, v: u64) -> Scalar {
129        Scalar::new(self.rep * v)
130    }
131}
132
133impl ArithOperators for Scalar{
134    fn add_u64(&mut self, a: u64){
135        self.rep += a;
136    }
137
138    fn sub_u64(&mut self, a: u64){
139        self.rep -= a;
140    }
141
142    fn rep(&self) -> u64{
143        self.rep
144    }
145}
146
147// Trait implementation
148impl ArithUtils<Scalar> for Scalar {
149    fn new_modulus(q: u64) -> Scalar {
150        Scalar {
151            rep: q,
152            context: Some(ScalarContext::new(q)),
153            bit_count: 64 - q.leading_zeros() as usize,
154        }
155    }
156
157    fn sub(a: &Scalar, b: &Scalar) -> Scalar {
158        Scalar::new(a.rep - b.rep)
159    }
160
161    fn div(a: &Scalar, b: &Scalar) -> Scalar {
162        Scalar::new(a.rep / b.rep)
163    }
164
165    fn add_mod(a: &Scalar, b: &Scalar, q: &Scalar) -> Scalar {
166        let mut sum = a.rep + b.rep;
167        if sum >= q.rep {
168            sum -= q.rep;
169        }
170        Scalar::new(sum)
171    }
172
173    fn sub_mod(a: &Scalar, b: &Scalar, q: &Scalar) -> Scalar {
174        Scalar::_sub_mod(a, b, q.rep)
175    }
176
177    fn mul_mod(a: &Scalar, b: &Scalar, q: &Scalar) -> Scalar {
178        let res = Scalar::_barret_multiply(a, b, q.context.as_ref().unwrap().barrett_ratio, q.rep);
179        Scalar::new(res)
180    }
181
182    fn inv_mod(a: &Scalar, q: &Scalar) -> Scalar {
183        Scalar::_inv_mod(a, q.rep)
184    }
185
186    fn from_u32(a: u32, q: &Scalar) -> Scalar {
187        Scalar::new((a as u64) % q.rep)
188    }
189
190    fn from_u32_raw(a: u32) -> Scalar {
191        Scalar::new(a as u64)
192    }
193
194    fn from_u64_raw(a: u64) -> Scalar {
195        Scalar::new(a)
196    }
197
198    fn pow_mod(base: &Scalar, b: &Scalar, q: &Scalar) -> Scalar {
199        let bits: Vec<bool> = b.get_bits();
200        let mut res = Self::one();
201        res = Self::modulus(&res, q);
202        let mut pow = Scalar::new(base.rep);
203        for bit in bits.iter() {
204            if *bit {
205                res = Self::mul_mod(&res, &pow, q);
206            }
207            pow = Self::mul_mod(&pow, &pow, q);
208        }
209        res
210    }
211
212    fn double(a: &Scalar) -> Scalar {
213        Scalar::new(a.rep << 1)
214    }
215
216    fn sample_blw(upper_bound: &Scalar) -> Scalar {
217        loop {
218            let n = Self::_sample(upper_bound.bit_count);
219            if n < upper_bound.rep {
220                return Scalar::new(n);
221            }
222        }
223    }
224
225    // sample below using a given rng.
226    fn sample_below_from_rng(upper_bound: &Scalar, rng: &mut dyn Rng) -> Self {
227        upper_bound.sample(rng)
228    }
229
230    fn modulus(a: &Scalar, q: &Scalar) -> Scalar {
231        match &q.context{
232            Some(context) => {Scalar::from(Scalar::_barret_reduce((a.rep(), 0), context.barrett_ratio, q.rep()))}
233            None => Scalar::new(a.rep % q.rep)
234        }
235    }
236
237    fn mul(a: &Scalar, b: &Scalar) -> Scalar {
238        Scalar::new(a.rep * b.rep)
239    }
240
241    fn to_u64(a: &Scalar) -> u64 {
242        a.rep
243    }
244
245    fn add(a: &Scalar, b: &Scalar) -> Scalar {
246        Scalar::new(a.rep + b.rep)
247    }
248}
249
250impl Scalar {
251    /// Bit length of this scalar.
252    fn bit_length(&self) -> usize {
253        64 - self.rep.leading_zeros() as usize
254    }
255
256    /// Return a vector of booleans representing the bits of this scalar, starting from the least significant bit.
257    fn get_bits(&self) -> Vec<bool> {
258        let len = self.bit_length();
259        let mut res = vec![];
260        let mut mask = 1u64;
261        for _ in 0..len {
262            res.push((self.rep & mask) != 0);
263            mask <<= 1;
264        }
265        res
266    }
267
268    fn sample(&self, rng: &mut dyn Rng) -> Scalar {
269        let max_multiple = self.rep() * (u64::MAX / self.rep() ); 
270        loop{
271            let a = rng.next_u64(); 
272            if a < max_multiple {
273                return Scalar::modulus(&Scalar::from(a), self);
274            }
275        }
276    }
277
278    fn _sample_from_rng(bit_size: usize, rng: &mut dyn Rng) -> u64 {
279        let bytes = (bit_size - 1) / 8 + 1;
280        let mut buf: Vec<u8> = vec![0; bytes];
281        rng.fill_bytes(&mut buf);
282
283        // from vector to u64.
284        let mut a = 0u64;
285        for x in buf.iter() {
286            a <<= 8;
287            a += *x as u64;
288        }
289        a >>= bytes * 8 - bit_size;
290        a
291    }
292
293    fn _sample(bit_size: usize) -> u64 {
294        let mut rng = StdRng::from_entropy();
295        Self::_sample_from_rng(bit_size, &mut rng)
296    }
297
298    fn _sub_mod(a: &Scalar, b: &Scalar, q: u64) -> Self {
299        let diff;
300        if a.rep >= b.rep {
301            diff = a.rep - b.rep;
302        } else {
303            diff = a.rep + q - b.rep;
304        }
305        Scalar::new(diff)
306    }
307
308    fn _slowmul_mod(a: &Scalar, b: &Scalar, q: u64) -> Self {
309        let res = (a.rep as u128) * (b.rep as u128);
310        Scalar::new((res % (q as u128)) as u64)
311    }
312
313    fn _multiply_u64(a: u64, b: u64) -> (u64, u64) {
314        let res = (a as u128) * (b as u128);
315        (res as u64, (res >> 64) as u64)
316    }
317
318    fn _add_u64(a: u64, b: u64) -> (u64, bool) {
319        let res = (a as u128 + b as u128) as u64;
320        (res, res < a)
321    }
322
323    fn _barret_reduce(a: (u64, u64), ratio: (u64, u64), q: u64) -> u64 {
324        // compute w = a*ratio >> 128.
325
326        // start with lw(a1r1)
327        let mut w = 0; 
328        if a.1 != 0{
329            w = a.1.wrapping_mul(ratio.1);
330        }
331        let a0r0 = Scalar::_multiply_u64(a.0, ratio.0);
332
333        let a0r1 = Scalar::_multiply_u64(a.0, ratio.1);
334
335        // w += hw(a0r1)
336        w += a0r1.1;
337
338        // compute hw(a0r0) + lw(a0r1), add carry into w. put result into tmp.
339        let (tmp, carry) = Scalar::_add_u64(a0r0.1, a0r1.0);
340        w += carry as u64;
341
342        // Round2
343        if a.1 != 0{
344            let a1r0 = Scalar::_multiply_u64(a.1, ratio.0);
345            w += a1r0.1;
346            // final carry
347            let (_, carry2) = Scalar::_add_u64(a1r0.0, tmp);
348            w += carry2 as u64;
349        }
350
351        // low = w*q mod 2^64.
352        // let low = Scalar::multiply_u64(w, q).0;
353        let low = w.wrapping_mul(q);
354
355        let mut res;
356        if a.0 >= low {
357            res = a.0 - low;
358        } else {
359            // res = a.0 + 2^64 - low.
360            res = a.0 + (!low) + 1;
361        }
362
363        if res >= q {
364            res -= q;
365        }
366        res
367    }
368
369    fn _inv_mod(a: &Scalar, q: u64) -> Self {
370        Scalar::new(modinverse(a.rep as i128, q as i128).unwrap() as u64)
371    }
372
373    fn _barret_multiply(a: &Scalar, b: &Scalar, ratio: (u64, u64), q: u64) -> u64 {
374        let prod = Scalar::_multiply_u64(a.rep, b.rep);
375        Scalar::_barret_reduce(prod, ratio, q)
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    #[test]
383    fn test_bitlength() {
384        assert_eq!(Scalar::from(2u32).bit_length(), 2);
385        assert_eq!(Scalar::from(16u32).bit_length(), 5);
386        assert_eq!(Scalar::from_u64_raw(18014398492704769u64).bit_length(), 54);
387    }
388
389    #[test]
390    fn test_getbits() {
391        assert_eq!(Scalar::from(1u32).get_bits(), vec![true]);
392        assert_eq!(Scalar::from(2u32).get_bits(), vec![false, true]);
393        assert_eq!(Scalar::from(5u32).get_bits(), vec![true, false, true]);
394        assert_eq!(
395            Scalar::from_u64_raw(127).get_bits(),
396            vec![true, true, true, true, true, true, true]
397        );
398    }
399
400    #[test]
401    fn test_sample_bitsize() {
402        let bit_size = 54;
403        let bound = 1u64 << bit_size;
404        for _ in 0..10 {
405            let a = Scalar::_sample(bit_size);
406            assert!(a < bound);
407        }
408    }
409
410    #[test]
411    fn test_sample_below() {
412        let q: u64 = 18014398492704769;
413        let q_scalar = Scalar::new_modulus(q);
414        for _ in 0..10 {
415            assert!(Scalar::sample_blw(&q_scalar).rep < q);
416        }
417    }
418
419    #[test]
420    fn test_sample_below_prng() {
421        use rand::{thread_rng};
422        let q: u64 = 18014398492704769;
423        let q_scalar = Scalar::new_modulus(q);
424        let mut rng = thread_rng(); 
425        for _ in 0..10 {
426            assert!(Scalar::sample_below_from_rng(&q_scalar, &mut rng).rep < q);
427        }
428    }
429    #[test]
430    fn test_equality() {
431        assert_eq!(Scalar::zero(), Scalar::zero());
432    }
433
434    #[test]
435    fn test_subtraction() {
436        let a = Scalar::zero();
437        let b = Scalar::one();
438        let c = Scalar::_sub_mod(&a, &b, 12289);
439        assert_eq!(c.rep, 12288);
440    }
441
442    #[test]
443    fn test_inverse() {
444        let q = Scalar::new(11);
445        let c = Scalar::new(2);
446        let a = Scalar::inv_mod(&c, &q);
447        assert_eq!(a.rep, 6);
448    }
449
450    #[test]
451    fn test_mul_mod() {
452        let q = 11u64;
453        let c = Scalar::new(4);
454        let a = Scalar::_slowmul_mod(&c, &c, q);
455        assert_eq!(a.rep, 5);
456    }
457
458    #[test]
459    fn test_pow_mod() {
460        let q = Scalar::new_modulus(11);
461        let c = Scalar::new(4);
462        let a = Scalar::pow_mod(&c, &c, &q);
463        assert_eq!(a.rep, 3);
464    }
465
466    #[test]
467    fn test_pow_mod_large() {
468        let q = Scalar::new_modulus(12289);
469        let two = Scalar::new(2);
470        let mut a: Scalar = Scalar::from_u64_raw(3);
471        a = Scalar::modulus(&a, &q);
472
473        for _ in 0..10 {
474            a = Scalar::pow_mod(&a, &two, &q);
475            assert!(a.rep < q.rep);
476        }
477    }
478
479    #[test]
480    fn test_barret_ratio() {
481        let q = 18014398492704769u64;
482        assert_eq!(
483            ScalarContext::compute_barrett_ratio(q),
484            (17592185012223u64, 1024u64)
485        );
486    }
487
488    #[test]
489    fn test_barret_reduction() {
490        let q = 18014398492704769;
491        let ratio = (17592185012223u64, 1024u64);
492
493        let a: (u64, u64) = (1, 0);
494        let b = Scalar::_barret_reduce(a, ratio, q);
495        assert_eq!(b, 1);
496
497        let a: (u64, u64) = (q, 0);
498        let b = Scalar::_barret_reduce(a, ratio, q);
499        assert_eq!(b, 0);
500
501        let a: (u64, u64) = (0, 1);
502        let b = Scalar::_barret_reduce(a, ratio, q);
503        assert_eq!(b, 17179868160);
504    }
505
506    #[test]
507    fn test_barret_multiply() {
508        let q: u64 = 18014398492704769;
509        let ratio = (17592185012223u64, 1024u64);
510
511        let a = Scalar::new(q - 2);
512        let b = Scalar::new(q - 3);
513        let c = Scalar::_barret_multiply(&a, &b, ratio, q);
514
515        assert_eq!(c, 6);
516    }
517
518    #[test]
519    fn test_operator_add(){
520        let a = Scalar::new(123);
521        let b = Scalar::new(123);
522        let c = a + &b;
523        assert_eq!(u64::from(c), 246u64);
524    }
525}