noah_algebra/
traits.rs

1use crate::{
2    errors::AlgebraError,
3    prelude::*,
4    rand::{CryptoRng, RngCore},
5};
6use ark_ff::FftField;
7use ark_std::fmt::Debug;
8use digest::{generic_array::typenum::U64, Digest};
9use num_bigint::BigUint;
10use serde::{Deserialize, Serialize};
11
12/// The trait for scalars
13pub trait Scalar:
14    Copy
15    + Default
16    + Debug
17    + PartialEq
18    + Eq
19    + Serialize
20    + for<'de> Deserialize<'de>
21    + Into<BigUint>
22    + for<'a> From<&'a BigUint>
23    + Clone
24    + One
25    + Zero
26    + Sized
27    + Add<Self, Output = Self>
28    + Mul<Self, Output = Self>
29    + Sum<Self>
30    + for<'a> Add<&'a Self, Output = Self>
31    + for<'a> AddAssign<&'a Self>
32    + for<'a> Mul<&'a Self, Output = Self>
33    + for<'a> MulAssign<&'a Self>
34    + for<'a> Sub<&'a Self, Output = Self>
35    + for<'a> SubAssign<&'a Self>
36    + for<'a> Sum<&'a Self>
37    + From<u32>
38    + From<u64>
39    + Neg<Output = Self>
40    + Sync
41    + Send
42{
43    /// Return a random scalar
44    fn random<R: CryptoRng + RngCore>(rng: &mut R) -> Self;
45
46    /// Sample a scalar based on a hash value
47    fn from_hash<D>(hash: D) -> Self
48    where
49        D: Digest<OutputSize = U64> + Default;
50
51    /// Return multiplicative generator of order r,
52    /// which is also required to be a quadratic nonresidue
53    fn multiplicative_generator() -> Self;
54
55    /// Return the capacity.
56    fn capacity() -> usize;
57
58    /// Return the little-endian byte representations of the field size
59    fn get_field_size_le_bytes() -> Vec<u8>;
60
61    /// Return the field size as a BigUint
62    fn get_field_size_biguint() -> BigUint;
63
64    /// Return the little-endian byte representation of `(field_size - 1) / 2`,
65    /// assuming that `field_size` is odd
66    fn field_size_minus_one_half() -> Vec<u8> {
67        let mut q_minus_1_half_le = Self::get_field_size_le_bytes();
68        // divide by 2 by shifting, first bit is one since F is odd prime
69        shift_u8_vec(&mut q_minus_1_half_le);
70        q_minus_1_half_le
71    }
72
73    /// Return a representation of the scalar as a vector of u64 in the little-endian order
74    fn get_little_endian_u64(&self) -> Vec<u64>;
75
76    /// Return the len of the byte representation
77    fn bytes_len() -> usize;
78
79    /// Convert to bytes
80    fn to_bytes(&self) -> Vec<u8>;
81
82    /// Convert from bytes
83    fn from_bytes(bytes: &[u8]) -> Result<Self>;
84
85    /// Return the modular inverse of the scalar if it exists
86    fn inv(&self) -> Result<Self>;
87
88    /// Return the square of the field element
89    fn square(&self) -> Self;
90
91    /// exponent form: least significant limb first, with u64 limbs
92    fn pow(&self, exponent: &[u64]) -> Self {
93        let mut base = self.clone();
94        let mut result = Self::one();
95        for exp_u64 in exponent {
96            let mut e = *exp_u64;
97            // we have to square the base for 64 times.
98            for _ in 0..64 {
99                if e % 2 == 1 {
100                    result.mul_assign(&base);
101                }
102                base = base.mul(&base);
103                e >>= 1;
104            }
105        }
106        result
107    }
108}
109
110/// The trait for domain.
111pub trait Domain: Scalar {
112    /// The field that is able to be used in FFTs.
113    type Field: FftField;
114
115    /// Return fft field.
116    fn get_field(&self) -> Self::Field;
117
118    /// Sample a domain based on a fft field.
119    fn from_field(field: Self::Field) -> Self;
120}
121
122/// The trait for group elements
123pub trait Group:
124    Debug
125    + Default
126    + Copy
127    + Sized
128    + PartialEq
129    + Eq
130    + Clone
131    + for<'a> Add<&'a Self, Output = Self>
132    + for<'a> Mul<&'a Self::ScalarType, Output = Self>
133    + for<'a> Sub<&'a Self, Output = Self>
134    + for<'a> AddAssign<&'a Self>
135    + for<'a> SubAssign<&'a Self>
136    + Serialize
137    + Neg
138    + for<'de> Deserialize<'de>
139{
140    /// The scalar type
141    type ScalarType: Scalar;
142
143    // The base type
144    // type BaseType: Scalar;
145
146    /// The number of bytes for a compressed representation of a group element
147    const COMPRESSED_LEN: usize;
148
149    /// Return the doubling of the group element
150    fn double(&self) -> Self;
151
152    /// Return the identity element (i.e., 0 * G)
153    fn get_identity() -> Self;
154
155    /// Return the base element (i.e., 1 * G)
156    fn get_base() -> Self;
157
158    /// Return a random element
159    fn random<R: CryptoRng + RngCore>(rng: &mut R) -> Self;
160
161    /// Convert to bytes in the compressed representation
162    fn to_compressed_bytes(&self) -> Vec<u8>;
163
164    /// Convert from bytes in the compressed representation
165    fn from_compressed_bytes(bytes: &[u8]) -> Result<Self>;
166
167    /// Convert to bytes in the unchecked representation
168    fn to_unchecked_bytes(&self) -> Vec<u8>;
169
170    /// Convert from bytes in the unchecked representation
171    fn from_unchecked_bytes(bytes: &[u8]) -> Result<Self>;
172
173    /// Return the size of unchecked bytes.
174    fn unchecked_size() -> usize;
175
176    /// Sample a group element based on a hash value
177    fn from_hash<D>(hash: D) -> Self
178    where
179        D: Digest<OutputSize = U64> + Default;
180
181    /// Compute the multiscalar multiplication
182    #[inline]
183    fn multi_exp(scalars: &[&Self::ScalarType], points: &[&Self]) -> Self {
184        if scalars.is_empty() {
185            Self::get_identity()
186        } else {
187            pippenger(scalars, points).unwrap()
188        }
189    }
190}
191
192/// Trait for Pedersen commitment.
193pub trait PedersenCommitment<G: Group>: Default {
194    /// Return the generator for the value part.
195    fn generator(&self) -> G;
196    /// Return the generator for the blinding part.
197    fn blinding_generator(&self) -> G;
198    /// Compute the Pedersen commitment over the Ristretto group.
199    fn commit(&self, value: G::ScalarType, blinding: G::ScalarType) -> G;
200}
201
202/// The trait for a pair of groups for pairing
203pub trait Pairing {
204    /// The scalar type
205    type ScalarField: Scalar;
206
207    /// The first group
208    type G1: Group<ScalarType = Self::ScalarField>;
209
210    /// The second group
211    type G2: Group<ScalarType = Self::ScalarField>;
212
213    /// The target group
214    type Gt: Group<ScalarType = Self::ScalarField>;
215
216    /// The pairing operation
217    fn pairing(a: &Self::G1, b: &Self::G2) -> Self::Gt;
218
219    /// The product of pairing operation
220    fn product_of_pairings(a: &[Self::G1], b: &[Self::G2]) -> Self::Gt;
221}
222
223/// Convert the scalar into a vector of small chunks, each of size `w`
224pub fn scalar_to_radix_2_power_w<S: Scalar>(scalar: &S, w: usize) -> Vec<i8> {
225    assert!(w <= 7);
226    if *scalar == S::from(0u32) {
227        return vec![0i8];
228    }
229    let scalar64 = scalar.get_little_endian_u64();
230
231    let radix: u64 = 1 << (w as u64);
232    let window_mask: u64 = radix - 1;
233
234    let mut carry = 0u64;
235    let mut digits = vec![];
236
237    let mut i = 0;
238    loop {
239        // Construct a buffer of bits of the scalar, starting at `bit_offset`.
240        let bit_offset = i * w;
241        let u64_idx = bit_offset / 64;
242        let bit_idx = bit_offset % 64;
243        if u64_idx >= scalar64.len() {
244            digits.push(carry as i8);
245            break;
246        }
247        let is_last = u64_idx == scalar64.len() - 1;
248
249        // Read the bits from the scalar
250        let bit_buf = if bit_idx < 64 - w || is_last {
251            // This window's bits are contained in a single u64,
252            scalar64[u64_idx] >> (bit_idx as u64)
253        } else {
254            // Combine the current u64's bits with the bits from the next u64
255            (scalar64[u64_idx] >> bit_idx) | (scalar64[1 + u64_idx] << (64 - bit_idx))
256        };
257
258        // Read the actual coefficient value from the window
259        let coef = carry + (bit_buf & window_mask); // coef = [0, 2^r)
260
261        // Recenter coefficients from [0,2^w) to [-2^w/2, 2^w/2)
262        carry = (coef + (radix / 2)) >> w;
263        digits.push(((coef as i64) - (carry << w) as i64) as i8);
264        i += 1;
265    }
266
267    while digits.len() > 1 && *digits.last().unwrap() == 0i8 {
268        // safe unwrap
269        digits.pop();
270    }
271    digits
272}
273
274/// Run the pippenger algorithm to compute multiscalar multiplication
275pub fn pippenger<G: Group>(scalars: &[&G::ScalarType], elems: &[&G]) -> Result<G> {
276    let size = scalars.len();
277
278    if size == 0 {
279        return Err(eg!(AlgebraError::ParameterError));
280    }
281
282    let w = if size < 500 {
283        6
284    } else if size < 800 {
285        7
286    } else {
287        8
288    };
289
290    let two_power_w: usize = 1 << w;
291    let digits_vec: Vec<Vec<i8>> = scalars
292        .iter()
293        .map(|s| scalar_to_radix_2_power_w::<G::ScalarType>(s, w))
294        .collect();
295
296    let mut digits_count = 0;
297    for digits in digits_vec.iter() {
298        if digits.len() > digits_count {
299            digits_count = digits.len();
300        }
301    }
302
303    // init all the buckets
304    let mut buckets: Vec<_> = (0..two_power_w / 2).map(|_| G::get_identity()).collect();
305
306    let mut cols = (0..digits_count).rev().map(|index| {
307        // empty each bucket
308        for b in buckets.iter_mut() {
309            *b = G::get_identity();
310        }
311        for (digits, elem) in digits_vec.iter().zip(elems) {
312            if index >= digits.len() {
313                continue;
314            }
315            let digit = digits[index];
316            if digit > 0 {
317                let b_index = (digit - 1) as usize;
318                buckets[b_index].add_assign(elem);
319            }
320            if digit < 0 {
321                let b_index = (-(digit + 1)) as usize;
322                buckets[b_index].sub_assign(elem);
323            }
324        }
325        let mut intermediate_sum = buckets[buckets.len() - 1].clone();
326        let mut sum = buckets[buckets.len() - 1].clone();
327        for i in (0..buckets.len() - 1).rev() {
328            intermediate_sum = intermediate_sum.add(&buckets[i]);
329            sum = sum.add(&intermediate_sum);
330        }
331        sum
332    });
333
334    let two_power_w_int = G::ScalarType::from(two_power_w as u64);
335    // This unwrap is safe as the list of scalars is non empty at this point.
336    let hi_col = cols.next().unwrap();
337    let res = cols.fold(hi_col, |total, p| total.mul(&two_power_w_int).add(&p));
338    Ok(res)
339}
340
341#[cfg(test)]
342pub(crate) mod group_tests {
343    use crate::traits::{scalar_to_radix_2_power_w, Scalar};
344
345    pub(crate) fn test_scalar_operations<S: Scalar>() {
346        let a = S::from(40u32);
347        let b = S::from(60u32);
348        let c = a.add(&b);
349        let d = S::from(100u32);
350        assert_eq!(c, d);
351
352        let mut x = S::from(0u32);
353        x.add_assign(&a);
354        x.add_assign(&b);
355        assert_eq!(x, d);
356
357        let a = S::from(10u32);
358        let b = S::from(40u32);
359        let c = a.mul(&b);
360        let d = S::from(400u32);
361        assert_eq!(c, d);
362
363        let mut x = S::from(1u32);
364        x.mul_assign(&a);
365        x.mul_assign(&b);
366        assert_eq!(x, d);
367
368        let a = S::from(0xFFFFFFFFu32);
369        let b = S::from(1u32);
370        let c = a.add(&b);
371        let d = S::from(0x100000000u64);
372        assert_eq!(c, d);
373
374        let a = S::from(0xFFFFFFFFu32);
375        let b = S::from(1u32);
376        let c = a.mul(&b);
377        let d = S::from(0xFFFFFFFFu32);
378        assert_eq!(c, d);
379
380        let a = S::from(40u32);
381        let b = S::from(60u32);
382        let c = b.sub(&a);
383        let d = S::from(20u32);
384        assert_eq!(c, d);
385
386        let mut x = S::from(120u32);
387        x.sub_assign(&b);
388        x.sub_assign(&a);
389        assert_eq!(x, d);
390
391        let a = S::from(40u32);
392        let b = a.neg();
393        let c = b.add(&a);
394        let d = S::from(0u32);
395        assert_eq!(c, d);
396
397        let a = S::from(40u32);
398        let b = a.inv().unwrap();
399        let c = b.mul(&a);
400        let d = S::from(1u32);
401        assert_eq!(c, d);
402
403        let a = S::from(3u32);
404        let b = vec![20];
405        let c = a.pow(&b[..]);
406        let d = S::from(3486784401u64);
407        assert_eq!(c, d);
408
409        let v = S::get_field_size_biguint().to_bytes_le();
410        assert_eq!(v, S::get_field_size_le_bytes());
411    }
412
413    pub(crate) fn test_scalar_serialization<S: Scalar>() {
414        let a = S::from(100u32);
415        let bytes = a.to_bytes();
416        let b = S::from_bytes(bytes.as_slice()).unwrap();
417        assert_eq!(a, b);
418    }
419
420    pub(crate) fn test_to_radix<S: Scalar>() {
421        let int = S::from(41u32);
422        let w = 2;
423        let r = scalar_to_radix_2_power_w(&int, w);
424        let expected = [1i8, -2, -1, 1]; // 41 = 1 + -2*4 + -1*16 + 64
425        assert_eq!(r.as_slice(), expected.as_ref());
426
427        let int = S::from(0u32);
428        let w = 2;
429        let r = scalar_to_radix_2_power_w(&int, w);
430        let expected = [0i8];
431        assert_eq!(expected.as_ref(), r.as_slice());
432
433        let int = S::from(1000u32);
434        let w = 6;
435        let r = scalar_to_radix_2_power_w(&int, w);
436        let expected = [-24, 16];
437        assert_eq!(expected.as_ref(), r.as_slice());
438    }
439}
440
441#[cfg(test)]
442mod multi_exp_tests {
443    use crate::bls12_381::BLSGt;
444    use crate::bls12_381::BLSG1;
445    use crate::bls12_381::BLSG2;
446    use crate::ristretto::RistrettoPoint;
447    use crate::traits::Group;
448
449    #[test]
450    fn test_multiexp_ristretto() {
451        run_multiexp_test::<RistrettoPoint>();
452    }
453    #[test]
454    fn test_multiexp_blsg1() {
455        run_multiexp_test::<BLSG1>();
456    }
457    #[test]
458    fn test_multiexp_blsg2() {
459        run_multiexp_test::<BLSG2>();
460    }
461    #[test]
462    fn test_multiexp_blsgt() {
463        run_multiexp_test::<BLSGt>();
464    }
465
466    fn run_multiexp_test<G: Group>() {
467        let g = G::multi_exp(&[], &[]);
468        assert_eq!(g, G::get_identity());
469
470        let g1 = G::get_base();
471        let zero = G::ScalarType::from(0u32);
472        let g = G::multi_exp(&[&zero], &[&g1]);
473        assert_eq!(g, G::get_identity());
474
475        let g1 = G::get_base();
476        let one = G::ScalarType::from(1u32);
477        let g = G::multi_exp(&[&one], &[&g1]);
478        assert_eq!(g, G::get_base());
479
480        let g1 = G::get_base();
481        let g1p = G::get_base();
482        let one = G::ScalarType::from(1u32);
483        let zero = G::ScalarType::from(0u32);
484        let g = G::multi_exp(&[&one, &zero], &[&g1, &g1p]);
485        assert_eq!(g, G::get_base());
486
487        let g1 = G::get_base();
488        let g2 = g1.add(&g1);
489        let g3 = g1.mul(&G::ScalarType::from(500u32));
490        let thousand = G::ScalarType::from(1000u32);
491        let two = G::ScalarType::from(2u32);
492        let three = G::ScalarType::from(3u32);
493        let g = G::multi_exp(&[&thousand, &two, &three], &[&g1, &g2, &g3]);
494        let expected = G::get_base().mul(&G::ScalarType::from((1000 + 4 + 1500) as u32));
495        assert_eq!(g, expected);
496    }
497}