1use crate::integer_arith::{ArithOperators, ArithUtils, SuperTrait};
6use modinverse::modinverse;
7use rand::rngs::{StdRng,ThreadRng};
8use rand::{FromEntropy};
9use super::Rng;
10use ::std::ops;
11pub use std::sync::Arc;
12
13impl Rng for StdRng {}
14impl Rng for ThreadRng {}
15
16#[derive(Debug, PartialEq, Eq, Clone)]
18struct ScalarContext {
19 barrett_ratio: (u64, u64),
20}
21
22impl ScalarContext {
23 fn new(q: u64) -> Self {
24 let ratio = Self::compute_barrett_ratio(q);
25 ScalarContext {
26 barrett_ratio: ratio,
27 }
28 }
29
30 fn compute_barrett_ratio(q: u64) -> (u64, u64) {
32 let a = 1u128 << 127;
34 let mut t = a % (q as u128);
35 let mut s = (a - t) / (q as u128);
36
37 s <<= 1;
38 t <<= 1;
39 if t >= (q as u128) {
40 s += 1;
41 }
42 (s as u64, (s >> 64) as u64)
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct Scalar {
49 context: Option<ScalarContext>,
50 rep: u64,
51 bit_count: usize,
52}
53
54impl Scalar {
55 pub fn new(a: u64) -> Self {
57 Scalar {
58 rep: a,
59 context: None,
60 bit_count: 0,
61 }
62 }
63
64 pub fn rep(&self) -> u64{
65 self.rep
66 }
67}
68
69impl SuperTrait<Scalar> for Scalar {}
71
72impl PartialEq for Scalar {
73 fn eq(&self, other: &Self) -> bool {
74 self.rep == other.rep
75 }
76}
77
78impl From<u32> for Scalar {
80 fn from(item: u32) -> Self {
81 Scalar { context: None, rep: item as u64, bit_count: 0 }
82 }
83}
84
85impl From<u64> for Scalar {
86 fn from(item: u64) -> Self {
87 Scalar { context: None, rep: item, bit_count: 0 }
88 }
89}
90
91impl From<Scalar> for u64{
92 fn from(item: Scalar) -> u64 {
93 item.rep
94 }
95}
96
97impl ops::Add<&Scalar> for Scalar {
99 type Output = Scalar;
100 fn add(self, v: &Scalar) -> Scalar {
101 Scalar::new(self.rep + v.rep)
102 }
103}
104
105impl ops::Add<Scalar> for Scalar {
106 type Output = Scalar;
107 fn add(self, v: Scalar) -> Scalar {
108 self + &v
109 }
110}
111
112impl ops::Sub<&Scalar> for Scalar {
113 type Output = Scalar;
114 fn sub(self, v: &Scalar) -> Scalar {
115 Scalar::new(self.rep - v.rep)
116 }
117}
118
119impl ops::Sub<Scalar> for Scalar {
120 type Output = Scalar;
121 fn sub(self, v: Scalar) -> Scalar {
122 self - &v
123 }
124}
125
126impl ops::Mul<u64> for Scalar {
127 type Output = Scalar;
128 fn mul(self, v: u64) -> Scalar {
129 Scalar::new(self.rep * v)
130 }
131}
132
133impl ArithOperators for Scalar{
134 fn add_u64(&mut self, a: u64){
135 self.rep += a;
136 }
137
138 fn sub_u64(&mut self, a: u64){
139 self.rep -= a;
140 }
141
142 fn rep(&self) -> u64{
143 self.rep
144 }
145}
146
147impl ArithUtils<Scalar> for Scalar {
149 fn new_modulus(q: u64) -> Scalar {
150 Scalar {
151 rep: q,
152 context: Some(ScalarContext::new(q)),
153 bit_count: 64 - q.leading_zeros() as usize,
154 }
155 }
156
157 fn sub(a: &Scalar, b: &Scalar) -> Scalar {
158 Scalar::new(a.rep - b.rep)
159 }
160
161 fn div(a: &Scalar, b: &Scalar) -> Scalar {
162 Scalar::new(a.rep / b.rep)
163 }
164
165 fn add_mod(a: &Scalar, b: &Scalar, q: &Scalar) -> Scalar {
166 let mut sum = a.rep + b.rep;
167 if sum >= q.rep {
168 sum -= q.rep;
169 }
170 Scalar::new(sum)
171 }
172
173 fn sub_mod(a: &Scalar, b: &Scalar, q: &Scalar) -> Scalar {
174 Scalar::_sub_mod(a, b, q.rep)
175 }
176
177 fn mul_mod(a: &Scalar, b: &Scalar, q: &Scalar) -> Scalar {
178 let res = Scalar::_barret_multiply(a, b, q.context.as_ref().unwrap().barrett_ratio, q.rep);
179 Scalar::new(res)
180 }
181
182 fn inv_mod(a: &Scalar, q: &Scalar) -> Scalar {
183 Scalar::_inv_mod(a, q.rep)
184 }
185
186 fn from_u32(a: u32, q: &Scalar) -> Scalar {
187 Scalar::new((a as u64) % q.rep)
188 }
189
190 fn from_u32_raw(a: u32) -> Scalar {
191 Scalar::new(a as u64)
192 }
193
194 fn from_u64_raw(a: u64) -> Scalar {
195 Scalar::new(a)
196 }
197
198 fn pow_mod(base: &Scalar, b: &Scalar, q: &Scalar) -> Scalar {
199 let bits: Vec<bool> = b.get_bits();
200 let mut res = Self::one();
201 res = Self::modulus(&res, q);
202 let mut pow = Scalar::new(base.rep);
203 for bit in bits.iter() {
204 if *bit {
205 res = Self::mul_mod(&res, &pow, q);
206 }
207 pow = Self::mul_mod(&pow, &pow, q);
208 }
209 res
210 }
211
212 fn double(a: &Scalar) -> Scalar {
213 Scalar::new(a.rep << 1)
214 }
215
216 fn sample_blw(upper_bound: &Scalar) -> Scalar {
217 loop {
218 let n = Self::_sample(upper_bound.bit_count);
219 if n < upper_bound.rep {
220 return Scalar::new(n);
221 }
222 }
223 }
224
225 fn sample_below_from_rng(upper_bound: &Scalar, rng: &mut dyn Rng) -> Self {
227 upper_bound.sample(rng)
228 }
229
230 fn modulus(a: &Scalar, q: &Scalar) -> Scalar {
231 match &q.context{
232 Some(context) => {Scalar::from(Scalar::_barret_reduce((a.rep(), 0), context.barrett_ratio, q.rep()))}
233 None => Scalar::new(a.rep % q.rep)
234 }
235 }
236
237 fn mul(a: &Scalar, b: &Scalar) -> Scalar {
238 Scalar::new(a.rep * b.rep)
239 }
240
241 fn to_u64(a: &Scalar) -> u64 {
242 a.rep
243 }
244
245 fn add(a: &Scalar, b: &Scalar) -> Scalar {
246 Scalar::new(a.rep + b.rep)
247 }
248}
249
250impl Scalar {
251 fn bit_length(&self) -> usize {
253 64 - self.rep.leading_zeros() as usize
254 }
255
256 fn get_bits(&self) -> Vec<bool> {
258 let len = self.bit_length();
259 let mut res = vec![];
260 let mut mask = 1u64;
261 for _ in 0..len {
262 res.push((self.rep & mask) != 0);
263 mask <<= 1;
264 }
265 res
266 }
267
268 fn sample(&self, rng: &mut dyn Rng) -> Scalar {
269 let max_multiple = self.rep() * (u64::MAX / self.rep() );
270 loop{
271 let a = rng.next_u64();
272 if a < max_multiple {
273 return Scalar::modulus(&Scalar::from(a), self);
274 }
275 }
276 }
277
278 fn _sample_from_rng(bit_size: usize, rng: &mut dyn Rng) -> u64 {
279 let bytes = (bit_size - 1) / 8 + 1;
280 let mut buf: Vec<u8> = vec![0; bytes];
281 rng.fill_bytes(&mut buf);
282
283 let mut a = 0u64;
285 for x in buf.iter() {
286 a <<= 8;
287 a += *x as u64;
288 }
289 a >>= bytes * 8 - bit_size;
290 a
291 }
292
293 fn _sample(bit_size: usize) -> u64 {
294 let mut rng = StdRng::from_entropy();
295 Self::_sample_from_rng(bit_size, &mut rng)
296 }
297
298 fn _sub_mod(a: &Scalar, b: &Scalar, q: u64) -> Self {
299 let diff;
300 if a.rep >= b.rep {
301 diff = a.rep - b.rep;
302 } else {
303 diff = a.rep + q - b.rep;
304 }
305 Scalar::new(diff)
306 }
307
308 fn _slowmul_mod(a: &Scalar, b: &Scalar, q: u64) -> Self {
309 let res = (a.rep as u128) * (b.rep as u128);
310 Scalar::new((res % (q as u128)) as u64)
311 }
312
313 fn _multiply_u64(a: u64, b: u64) -> (u64, u64) {
314 let res = (a as u128) * (b as u128);
315 (res as u64, (res >> 64) as u64)
316 }
317
318 fn _add_u64(a: u64, b: u64) -> (u64, bool) {
319 let res = (a as u128 + b as u128) as u64;
320 (res, res < a)
321 }
322
323 fn _barret_reduce(a: (u64, u64), ratio: (u64, u64), q: u64) -> u64 {
324 let mut w = 0;
328 if a.1 != 0{
329 w = a.1.wrapping_mul(ratio.1);
330 }
331 let a0r0 = Scalar::_multiply_u64(a.0, ratio.0);
332
333 let a0r1 = Scalar::_multiply_u64(a.0, ratio.1);
334
335 w += a0r1.1;
337
338 let (tmp, carry) = Scalar::_add_u64(a0r0.1, a0r1.0);
340 w += carry as u64;
341
342 if a.1 != 0{
344 let a1r0 = Scalar::_multiply_u64(a.1, ratio.0);
345 w += a1r0.1;
346 let (_, carry2) = Scalar::_add_u64(a1r0.0, tmp);
348 w += carry2 as u64;
349 }
350
351 let low = w.wrapping_mul(q);
354
355 let mut res;
356 if a.0 >= low {
357 res = a.0 - low;
358 } else {
359 res = a.0 + (!low) + 1;
361 }
362
363 if res >= q {
364 res -= q;
365 }
366 res
367 }
368
369 fn _inv_mod(a: &Scalar, q: u64) -> Self {
370 Scalar::new(modinverse(a.rep as i128, q as i128).unwrap() as u64)
371 }
372
373 fn _barret_multiply(a: &Scalar, b: &Scalar, ratio: (u64, u64), q: u64) -> u64 {
374 let prod = Scalar::_multiply_u64(a.rep, b.rep);
375 Scalar::_barret_reduce(prod, ratio, q)
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 #[test]
383 fn test_bitlength() {
384 assert_eq!(Scalar::from(2u32).bit_length(), 2);
385 assert_eq!(Scalar::from(16u32).bit_length(), 5);
386 assert_eq!(Scalar::from_u64_raw(18014398492704769u64).bit_length(), 54);
387 }
388
389 #[test]
390 fn test_getbits() {
391 assert_eq!(Scalar::from(1u32).get_bits(), vec![true]);
392 assert_eq!(Scalar::from(2u32).get_bits(), vec![false, true]);
393 assert_eq!(Scalar::from(5u32).get_bits(), vec![true, false, true]);
394 assert_eq!(
395 Scalar::from_u64_raw(127).get_bits(),
396 vec![true, true, true, true, true, true, true]
397 );
398 }
399
400 #[test]
401 fn test_sample_bitsize() {
402 let bit_size = 54;
403 let bound = 1u64 << bit_size;
404 for _ in 0..10 {
405 let a = Scalar::_sample(bit_size);
406 assert!(a < bound);
407 }
408 }
409
410 #[test]
411 fn test_sample_below() {
412 let q: u64 = 18014398492704769;
413 let q_scalar = Scalar::new_modulus(q);
414 for _ in 0..10 {
415 assert!(Scalar::sample_blw(&q_scalar).rep < q);
416 }
417 }
418
419 #[test]
420 fn test_sample_below_prng() {
421 use rand::{thread_rng};
422 let q: u64 = 18014398492704769;
423 let q_scalar = Scalar::new_modulus(q);
424 let mut rng = thread_rng();
425 for _ in 0..10 {
426 assert!(Scalar::sample_below_from_rng(&q_scalar, &mut rng).rep < q);
427 }
428 }
429 #[test]
430 fn test_equality() {
431 assert_eq!(Scalar::zero(), Scalar::zero());
432 }
433
434 #[test]
435 fn test_subtraction() {
436 let a = Scalar::zero();
437 let b = Scalar::one();
438 let c = Scalar::_sub_mod(&a, &b, 12289);
439 assert_eq!(c.rep, 12288);
440 }
441
442 #[test]
443 fn test_inverse() {
444 let q = Scalar::new(11);
445 let c = Scalar::new(2);
446 let a = Scalar::inv_mod(&c, &q);
447 assert_eq!(a.rep, 6);
448 }
449
450 #[test]
451 fn test_mul_mod() {
452 let q = 11u64;
453 let c = Scalar::new(4);
454 let a = Scalar::_slowmul_mod(&c, &c, q);
455 assert_eq!(a.rep, 5);
456 }
457
458 #[test]
459 fn test_pow_mod() {
460 let q = Scalar::new_modulus(11);
461 let c = Scalar::new(4);
462 let a = Scalar::pow_mod(&c, &c, &q);
463 assert_eq!(a.rep, 3);
464 }
465
466 #[test]
467 fn test_pow_mod_large() {
468 let q = Scalar::new_modulus(12289);
469 let two = Scalar::new(2);
470 let mut a: Scalar = Scalar::from_u64_raw(3);
471 a = Scalar::modulus(&a, &q);
472
473 for _ in 0..10 {
474 a = Scalar::pow_mod(&a, &two, &q);
475 assert!(a.rep < q.rep);
476 }
477 }
478
479 #[test]
480 fn test_barret_ratio() {
481 let q = 18014398492704769u64;
482 assert_eq!(
483 ScalarContext::compute_barrett_ratio(q),
484 (17592185012223u64, 1024u64)
485 );
486 }
487
488 #[test]
489 fn test_barret_reduction() {
490 let q = 18014398492704769;
491 let ratio = (17592185012223u64, 1024u64);
492
493 let a: (u64, u64) = (1, 0);
494 let b = Scalar::_barret_reduce(a, ratio, q);
495 assert_eq!(b, 1);
496
497 let a: (u64, u64) = (q, 0);
498 let b = Scalar::_barret_reduce(a, ratio, q);
499 assert_eq!(b, 0);
500
501 let a: (u64, u64) = (0, 1);
502 let b = Scalar::_barret_reduce(a, ratio, q);
503 assert_eq!(b, 17179868160);
504 }
505
506 #[test]
507 fn test_barret_multiply() {
508 let q: u64 = 18014398492704769;
509 let ratio = (17592185012223u64, 1024u64);
510
511 let a = Scalar::new(q - 2);
512 let b = Scalar::new(q - 3);
513 let c = Scalar::_barret_multiply(&a, &b, ratio, q);
514
515 assert_eq!(c, 6);
516 }
517
518 #[test]
519 fn test_operator_add(){
520 let a = Scalar::new(123);
521 let b = Scalar::new(123);
522 let c = a + &b;
523 assert_eq!(u64::from(c), 246u64);
524 }
525}