Skip to main content

nova_snark/provider/
traits.rs

1//! Traits for provider implementations.
2use crate::traits::{commitment::ScalarMul, Group, TranscriptReprTrait};
3use core::{
4  fmt::Debug,
5  ops::{Add, AddAssign, Sub, SubAssign},
6};
7use halo2curves::{serde::SerdeObject, CurveAffine};
8use num_integer::Integer;
9use num_traits::ToPrimitive;
10use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
11use serde::{Deserialize, Serialize};
12
13/// A helper trait for types with a group operation.
14pub trait GroupOps<Rhs = Self, Output = Self>:
15  Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + AddAssign<Rhs> + SubAssign<Rhs>
16{
17}
18
19impl<T, Rhs, Output> GroupOps<Rhs, Output> for T where
20  T: Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + AddAssign<Rhs> + SubAssign<Rhs>
21{
22}
23
24/// A helper trait for references with a group operation.
25pub trait GroupOpsOwned<Rhs = Self, Output = Self>: for<'r> GroupOps<&'r Rhs, Output> {}
26impl<T, Rhs, Output> GroupOpsOwned<Rhs, Output> for T where T: for<'r> GroupOps<&'r Rhs, Output> {}
27
28/// A helper trait for references implementing group scalar multiplication.
29pub trait ScalarMulOwned<Rhs, Output = Self>: for<'r> ScalarMul<&'r Rhs, Output> {}
30impl<T, Rhs, Output> ScalarMulOwned<Rhs, Output> for T where T: for<'r> ScalarMul<&'r Rhs, Output> {}
31
32/// A trait that defines the core discrete logarithm group functionality
33pub trait DlogGroup:
34  Group
35  + Serialize
36  + for<'de> Deserialize<'de>
37  + GroupOps
38  + GroupOpsOwned
39  + ScalarMul<<Self as Group>::Scalar>
40  + ScalarMulOwned<<Self as Group>::Scalar>
41{
42  /// A type representing preprocessed group element
43  type AffineGroupElement: Clone
44    + Debug
45    + PartialEq
46    + Eq
47    + Send
48    + Sync
49    + Serialize
50    + for<'de> Deserialize<'de>
51    + TranscriptReprTrait<Self>
52    + CurveAffine
53    + SerdeObject
54    + crate::traits::evm_serde::CustomSerdeTrait;
55
56  /// Produce a vector of group elements using a static label
57  fn from_label(label: &'static [u8], n: usize) -> Vec<Self::AffineGroupElement>;
58
59  /// Produces a preprocessed element
60  fn affine(&self) -> Self::AffineGroupElement;
61
62  /// Returns a group element from a preprocessed group element
63  fn group(p: &Self::AffineGroupElement) -> Self;
64
65  /// Returns an element that is the additive identity of the group
66  fn zero() -> Self;
67
68  /// Returns the generator of the group
69  fn gen() -> Self;
70
71  /// Returns the affine coordinates (x, y, infinity) for the point
72  fn to_coordinates(&self) -> (<Self as Group>::Base, <Self as Group>::Base, bool);
73}
74
75/// Extension trait for DlogGroup that provides multi-scalar multiplication operations
76pub trait DlogGroupExt: DlogGroup {
77  /// A method to compute a multiexponentation
78  fn vartime_multiscalar_mul(scalars: &[Self::Scalar], bases: &[Self::AffineGroupElement]) -> Self;
79
80  /// A method to compute a batch of multiexponentations
81  fn batch_vartime_multiscalar_mul(
82    scalars: &[Vec<Self::Scalar>],
83    bases: &[Self::AffineGroupElement],
84  ) -> Vec<Self> {
85    scalars
86      .par_iter()
87      .map(|scalar| Self::vartime_multiscalar_mul(scalar, &bases[..scalar.len()]))
88      .collect::<Vec<_>>()
89  }
90
91  /// A method to compute a multiexponentation with small scalars
92  fn vartime_multiscalar_mul_small<T: Integer + Into<u64> + Copy + Sync + ToPrimitive>(
93    scalars: &[T],
94    bases: &[Self::AffineGroupElement],
95  ) -> Self;
96
97  /// A method to compute a multiexponentation with small scalars
98  fn vartime_multiscalar_mul_small_with_max_num_bits<
99    T: Integer + Into<u64> + Copy + Sync + ToPrimitive,
100  >(
101    scalars: &[T],
102    bases: &[Self::AffineGroupElement],
103    max_num_bits: usize,
104  ) -> Self;
105
106  /// A method to compute a batch of multiexponentations with small scalars
107  fn batch_vartime_multiscalar_mul_small<T: Integer + Into<u64> + Copy + Sync + ToPrimitive>(
108    scalars: &[Vec<T>],
109    bases: &[Self::AffineGroupElement],
110  ) -> Vec<Self> {
111    scalars
112      .par_iter()
113      .map(|scalar| Self::vartime_multiscalar_mul_small(scalar, &bases[..scalar.len()]))
114      .collect::<Vec<_>>()
115  }
116}
117
118/// A trait that defines extensions to the DlogGroup trait, to be implemented for
119/// elliptic curve groups that are pairing friendly
120pub trait PairingGroup: DlogGroupExt {
121  /// A type representing the second group
122  type G2: DlogGroup<Scalar = Self::Scalar, Base = Self::Base>;
123
124  /// A type representing the target group
125  type GT: PartialEq + Eq;
126
127  /// A method to compute a pairing
128  fn pairing(p: &Self, q: &Self::G2) -> Self::GT;
129}
130
131/// Implements Nova's traits except DlogGroupExt so that the MSM can be implemented differently
132#[macro_export]
133macro_rules! impl_traits_no_dlog_ext {
134  (
135    $name:ident,
136    $name_curve:ident,
137    $name_curve_affine:ident,
138    $order_str:literal,
139    $base_str:literal
140  ) => {
141    impl Group for $name::Point {
142      type Base = $name::Base;
143      type Scalar = $name::Scalar;
144
145      fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) {
146        let A = $name::Point::a();
147        let B = $name::Point::b();
148        let order = BigInt::from_str_radix($order_str, 16).unwrap();
149        let base = BigInt::from_str_radix($base_str, 16).unwrap();
150
151        (A, B, order, base)
152      }
153    }
154
155    impl $crate::traits::evm_serde::CustomSerdeTrait for $name::Scalar {
156      #[cfg(feature = "evm")]
157      fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
158        use ff::PrimeField;
159        use serde::Serialize;
160        let mut bytes = self.to_repr();
161        bytes.as_mut().reverse(); // big-endian
162        bytes.serialize(serializer)
163      }
164
165      #[cfg(feature = "evm")]
166      fn deserialize<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
167        use ff::PrimeField;
168        use serde::de::Error;
169        use serde::Deserialize;
170        let mut bytes = <[u8; 32]>::deserialize(deserializer)?;
171        bytes.reverse(); // big-endian
172        Option::from(Self::from_repr(bytes.into()))
173          .ok_or_else(|| D::Error::custom("deserialized bytes don't encode a valid field element"))
174      }
175    }
176
177    impl $crate::traits::evm_serde::CustomSerdeTrait for $name::Affine {
178      #[cfg(feature = "evm")]
179      fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
180        use serde::{Deserialize, Serialize};
181        use serde_with::serde_as;
182        use $crate::traits::evm_serde::EvmCompatSerde;
183
184        #[serde_as]
185        #[derive(Deserialize, Serialize)]
186        struct HelperAffine(
187          #[serde_as(as = "EvmCompatSerde")] $name::Base,
188          #[serde_as(as = "EvmCompatSerde")] $name::Base,
189        );
190
191        let affine = HelperAffine(self.x, self.y);
192        affine.serialize(serializer)
193      }
194
195      #[cfg(feature = "evm")]
196      fn deserialize<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
197        use serde::{Deserialize, Serialize};
198        use serde_with::serde_as;
199        use $crate::traits::evm_serde::EvmCompatSerde;
200
201        #[serde_as]
202        #[derive(Deserialize, Serialize)]
203        struct HelperAffine(
204          #[serde_as(as = "EvmCompatSerde")] $name::Base,
205          #[serde_as(as = "EvmCompatSerde")] $name::Base,
206        );
207
208        let affine = HelperAffine::deserialize(deserializer)?;
209        Ok($name::Affine {
210          x: affine.0,
211          y: affine.1,
212        })
213      }
214    }
215
216    impl $crate::traits::evm_serde::CustomSerdeTrait for $name::Point {
217      #[cfg(feature = "evm")]
218      fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
219        use $crate::traits::evm_serde::CustomSerdeTrait;
220        <$name::Affine as CustomSerdeTrait>::serialize(&self.to_affine(), serializer)
221      }
222
223      #[cfg(feature = "evm")]
224      fn deserialize<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
225        use $crate::traits::evm_serde::CustomSerdeTrait;
226        Ok(Self::from(
227          <$name::Affine as CustomSerdeTrait>::deserialize(deserializer)?,
228        ))
229      }
230    }
231
232    impl DlogGroup for $name::Point {
233      type AffineGroupElement = $name::Affine;
234
235      fn affine(&self) -> Self::AffineGroupElement {
236        self.to_affine()
237      }
238
239      fn group(p: &Self::AffineGroupElement) -> Self {
240        $name::Point::from(*p)
241      }
242
243      fn from_label(label: &'static [u8], n: usize) -> Vec<Self::AffineGroupElement> {
244        let mut shake = Shake256::default();
245        shake.update(label);
246        let mut reader = shake.finalize_xof();
247        let mut uniform_bytes_vec = Vec::new();
248        for _ in 0..n {
249          let mut uniform_bytes = [0u8; 32];
250          digest::XofReader::read(&mut reader, &mut uniform_bytes);
251          uniform_bytes_vec.push(uniform_bytes);
252        }
253        let gens_proj: Vec<$name_curve> = (0..n)
254          .into_par_iter()
255          .map(|i| {
256            let hash = $name_curve::hash_to_curve("from_uniform_bytes");
257            hash(&uniform_bytes_vec[i])
258          })
259          .collect();
260
261        let num_threads = rayon::current_num_threads();
262        if gens_proj.len() > num_threads {
263          let chunk = (gens_proj.len() as f64 / num_threads as f64).ceil() as usize;
264          (0..num_threads)
265            .into_par_iter()
266            .flat_map(|i| {
267              let start = i * chunk;
268              let end = if i == num_threads - 1 {
269                gens_proj.len()
270              } else {
271                core::cmp::min((i + 1) * chunk, gens_proj.len())
272              };
273              if end > start {
274                let mut gens = vec![$name_curve_affine::identity(); end - start];
275                <Self as Curve>::batch_normalize(&gens_proj[start..end], &mut gens);
276                gens
277              } else {
278                vec![]
279              }
280            })
281            .collect()
282        } else {
283          let mut gens = vec![$name_curve_affine::identity(); n];
284          <Self as Curve>::batch_normalize(&gens_proj, &mut gens);
285          gens
286        }
287      }
288
289      fn zero() -> Self {
290        $name::Point::identity()
291      }
292
293      fn gen() -> Self {
294        $name::Point::generator()
295      }
296
297      fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) {
298        let coordinates = self.affine().coordinates();
299        if coordinates.is_some().unwrap_u8() == 1
300          && ($name_curve_affine::identity() != self.affine())
301        {
302          (*coordinates.unwrap().x(), *coordinates.unwrap().y(), false)
303        } else {
304          (Self::Base::zero(), Self::Base::zero(), true)
305        }
306      }
307    }
308
309    impl PrimeFieldExt for $name::Scalar {
310      fn from_uniform(bytes: &[u8]) -> Self {
311        let bytes_arr: [u8; 64] = bytes.try_into().unwrap();
312        $name::Scalar::from_uniform_bytes(&bytes_arr)
313      }
314    }
315
316    impl<G: Group> TranscriptReprTrait<G> for $name::Scalar {
317      fn to_transcript_bytes(&self) -> Vec<u8> {
318        #[cfg(not(feature = "evm"))]
319        {
320          self.to_bytes().into_iter().collect()
321        }
322        #[cfg(feature = "evm")]
323        {
324          self.to_bytes().into_iter().rev().collect()
325        }
326      }
327    }
328
329    impl<G: DlogGroup> TranscriptReprTrait<G> for $name::Affine {
330      fn to_transcript_bytes(&self) -> Vec<u8> {
331        let coords = self.coordinates().unwrap();
332        let x_bytes = coords.x().to_bytes().into_iter();
333        let y_bytes = coords.y().to_bytes().into_iter();
334        #[cfg(not(feature = "evm"))]
335        {
336          x_bytes.chain(y_bytes).collect()
337        }
338        #[cfg(feature = "evm")]
339        {
340          x_bytes.rev().chain(y_bytes.rev()).collect()
341        }
342      }
343    }
344  };
345}
346
347/// Implements Nova's traits
348#[macro_export]
349macro_rules! impl_traits {
350  (
351    $name:ident,
352    $name_curve:ident,
353    $name_curve_affine:ident,
354    $order_str:literal,
355    $base_str:literal
356  ) => {
357    $crate::impl_traits_no_dlog_ext!(
358      $name,
359      $name_curve,
360      $name_curve_affine,
361      $order_str,
362      $base_str
363    );
364
365    impl DlogGroupExt for $name::Point {
366      fn vartime_multiscalar_mul(
367        scalars: &[Self::Scalar],
368        bases: &[Self::AffineGroupElement],
369      ) -> Self {
370        msm(scalars, bases)
371      }
372
373      fn vartime_multiscalar_mul_small<T: Integer + Into<u64> + Copy + Sync + ToPrimitive>(
374        scalars: &[T],
375        bases: &[Self::AffineGroupElement],
376      ) -> Self {
377        msm_small(scalars, bases)
378      }
379
380      /// A method to compute a multiexponentation with small scalars
381      fn vartime_multiscalar_mul_small_with_max_num_bits<
382        T: Integer + Into<u64> + Copy + Sync + ToPrimitive,
383      >(
384        scalars: &[T],
385        bases: &[Self::AffineGroupElement],
386        max_num_bits: usize,
387      ) -> Self {
388        msm_small_with_max_num_bits(scalars, bases, max_num_bits)
389      }
390    }
391  };
392}