1use crate::traits::{commitment::ScalarMul, Group, TranscriptReprTrait};
3use core::{
4 fmt::Debug,
5 ops::{Add, AddAssign, Sub, SubAssign},
6};
7use halo2curves::{serde::SerdeObject, CurveAffine};
8use num_integer::Integer;
9use num_traits::ToPrimitive;
10use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
11use serde::{Deserialize, Serialize};
12
13pub trait GroupOps<Rhs = Self, Output = Self>:
15 Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + AddAssign<Rhs> + SubAssign<Rhs>
16{
17}
18
19impl<T, Rhs, Output> GroupOps<Rhs, Output> for T where
20 T: Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + AddAssign<Rhs> + SubAssign<Rhs>
21{
22}
23
24pub trait GroupOpsOwned<Rhs = Self, Output = Self>: for<'r> GroupOps<&'r Rhs, Output> {}
26impl<T, Rhs, Output> GroupOpsOwned<Rhs, Output> for T where T: for<'r> GroupOps<&'r Rhs, Output> {}
27
28pub trait ScalarMulOwned<Rhs, Output = Self>: for<'r> ScalarMul<&'r Rhs, Output> {}
30impl<T, Rhs, Output> ScalarMulOwned<Rhs, Output> for T where T: for<'r> ScalarMul<&'r Rhs, Output> {}
31
32pub trait DlogGroup:
34 Group
35 + Serialize
36 + for<'de> Deserialize<'de>
37 + GroupOps
38 + GroupOpsOwned
39 + ScalarMul<<Self as Group>::Scalar>
40 + ScalarMulOwned<<Self as Group>::Scalar>
41{
42 type AffineGroupElement: Clone
44 + Debug
45 + PartialEq
46 + Eq
47 + Send
48 + Sync
49 + Serialize
50 + for<'de> Deserialize<'de>
51 + TranscriptReprTrait<Self>
52 + CurveAffine
53 + SerdeObject
54 + crate::traits::evm_serde::CustomSerdeTrait;
55
56 fn from_label(label: &'static [u8], n: usize) -> Vec<Self::AffineGroupElement>;
58
59 fn affine(&self) -> Self::AffineGroupElement;
61
62 fn group(p: &Self::AffineGroupElement) -> Self;
64
65 fn zero() -> Self;
67
68 fn gen() -> Self;
70
71 fn to_coordinates(&self) -> (<Self as Group>::Base, <Self as Group>::Base, bool);
73}
74
75pub trait DlogGroupExt: DlogGroup {
77 fn vartime_multiscalar_mul(scalars: &[Self::Scalar], bases: &[Self::AffineGroupElement]) -> Self;
79
80 fn batch_vartime_multiscalar_mul(
82 scalars: &[Vec<Self::Scalar>],
83 bases: &[Self::AffineGroupElement],
84 ) -> Vec<Self> {
85 scalars
86 .par_iter()
87 .map(|scalar| Self::vartime_multiscalar_mul(scalar, &bases[..scalar.len()]))
88 .collect::<Vec<_>>()
89 }
90
91 fn vartime_multiscalar_mul_small<T: Integer + Into<u64> + Copy + Sync + ToPrimitive>(
93 scalars: &[T],
94 bases: &[Self::AffineGroupElement],
95 ) -> Self;
96
97 fn vartime_multiscalar_mul_small_with_max_num_bits<
99 T: Integer + Into<u64> + Copy + Sync + ToPrimitive,
100 >(
101 scalars: &[T],
102 bases: &[Self::AffineGroupElement],
103 max_num_bits: usize,
104 ) -> Self;
105
106 fn batch_vartime_multiscalar_mul_small<T: Integer + Into<u64> + Copy + Sync + ToPrimitive>(
108 scalars: &[Vec<T>],
109 bases: &[Self::AffineGroupElement],
110 ) -> Vec<Self> {
111 scalars
112 .par_iter()
113 .map(|scalar| Self::vartime_multiscalar_mul_small(scalar, &bases[..scalar.len()]))
114 .collect::<Vec<_>>()
115 }
116}
117
118pub trait PairingGroup: DlogGroupExt {
121 type G2: DlogGroup<Scalar = Self::Scalar, Base = Self::Base>;
123
124 type GT: PartialEq + Eq;
126
127 fn pairing(p: &Self, q: &Self::G2) -> Self::GT;
129}
130
131#[macro_export]
133macro_rules! impl_traits_no_dlog_ext {
134 (
135 $name:ident,
136 $name_curve:ident,
137 $name_curve_affine:ident,
138 $order_str:literal,
139 $base_str:literal
140 ) => {
141 impl Group for $name::Point {
142 type Base = $name::Base;
143 type Scalar = $name::Scalar;
144
145 fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) {
146 let A = $name::Point::a();
147 let B = $name::Point::b();
148 let order = BigInt::from_str_radix($order_str, 16).unwrap();
149 let base = BigInt::from_str_radix($base_str, 16).unwrap();
150
151 (A, B, order, base)
152 }
153 }
154
155 impl $crate::traits::evm_serde::CustomSerdeTrait for $name::Scalar {
156 #[cfg(feature = "evm")]
157 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
158 use ff::PrimeField;
159 use serde::Serialize;
160 let mut bytes = self.to_repr();
161 bytes.as_mut().reverse(); bytes.serialize(serializer)
163 }
164
165 #[cfg(feature = "evm")]
166 fn deserialize<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
167 use ff::PrimeField;
168 use serde::de::Error;
169 use serde::Deserialize;
170 let mut bytes = <[u8; 32]>::deserialize(deserializer)?;
171 bytes.reverse(); Option::from(Self::from_repr(bytes.into()))
173 .ok_or_else(|| D::Error::custom("deserialized bytes don't encode a valid field element"))
174 }
175 }
176
177 impl $crate::traits::evm_serde::CustomSerdeTrait for $name::Affine {
178 #[cfg(feature = "evm")]
179 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
180 use serde::{Deserialize, Serialize};
181 use serde_with::serde_as;
182 use $crate::traits::evm_serde::EvmCompatSerde;
183
184 #[serde_as]
185 #[derive(Deserialize, Serialize)]
186 struct HelperAffine(
187 #[serde_as(as = "EvmCompatSerde")] $name::Base,
188 #[serde_as(as = "EvmCompatSerde")] $name::Base,
189 );
190
191 let affine = HelperAffine(self.x, self.y);
192 affine.serialize(serializer)
193 }
194
195 #[cfg(feature = "evm")]
196 fn deserialize<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
197 use serde::{Deserialize, Serialize};
198 use serde_with::serde_as;
199 use $crate::traits::evm_serde::EvmCompatSerde;
200
201 #[serde_as]
202 #[derive(Deserialize, Serialize)]
203 struct HelperAffine(
204 #[serde_as(as = "EvmCompatSerde")] $name::Base,
205 #[serde_as(as = "EvmCompatSerde")] $name::Base,
206 );
207
208 let affine = HelperAffine::deserialize(deserializer)?;
209 Ok($name::Affine {
210 x: affine.0,
211 y: affine.1,
212 })
213 }
214 }
215
216 impl $crate::traits::evm_serde::CustomSerdeTrait for $name::Point {
217 #[cfg(feature = "evm")]
218 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
219 use $crate::traits::evm_serde::CustomSerdeTrait;
220 <$name::Affine as CustomSerdeTrait>::serialize(&self.to_affine(), serializer)
221 }
222
223 #[cfg(feature = "evm")]
224 fn deserialize<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
225 use $crate::traits::evm_serde::CustomSerdeTrait;
226 Ok(Self::from(
227 <$name::Affine as CustomSerdeTrait>::deserialize(deserializer)?,
228 ))
229 }
230 }
231
232 impl DlogGroup for $name::Point {
233 type AffineGroupElement = $name::Affine;
234
235 fn affine(&self) -> Self::AffineGroupElement {
236 self.to_affine()
237 }
238
239 fn group(p: &Self::AffineGroupElement) -> Self {
240 $name::Point::from(*p)
241 }
242
243 fn from_label(label: &'static [u8], n: usize) -> Vec<Self::AffineGroupElement> {
244 let mut shake = Shake256::default();
245 shake.update(label);
246 let mut reader = shake.finalize_xof();
247 let mut uniform_bytes_vec = Vec::new();
248 for _ in 0..n {
249 let mut uniform_bytes = [0u8; 32];
250 digest::XofReader::read(&mut reader, &mut uniform_bytes);
251 uniform_bytes_vec.push(uniform_bytes);
252 }
253 let gens_proj: Vec<$name_curve> = (0..n)
254 .into_par_iter()
255 .map(|i| {
256 let hash = $name_curve::hash_to_curve("from_uniform_bytes");
257 hash(&uniform_bytes_vec[i])
258 })
259 .collect();
260
261 let num_threads = rayon::current_num_threads();
262 if gens_proj.len() > num_threads {
263 let chunk = (gens_proj.len() as f64 / num_threads as f64).ceil() as usize;
264 (0..num_threads)
265 .into_par_iter()
266 .flat_map(|i| {
267 let start = i * chunk;
268 let end = if i == num_threads - 1 {
269 gens_proj.len()
270 } else {
271 core::cmp::min((i + 1) * chunk, gens_proj.len())
272 };
273 if end > start {
274 let mut gens = vec![$name_curve_affine::identity(); end - start];
275 <Self as Curve>::batch_normalize(&gens_proj[start..end], &mut gens);
276 gens
277 } else {
278 vec![]
279 }
280 })
281 .collect()
282 } else {
283 let mut gens = vec![$name_curve_affine::identity(); n];
284 <Self as Curve>::batch_normalize(&gens_proj, &mut gens);
285 gens
286 }
287 }
288
289 fn zero() -> Self {
290 $name::Point::identity()
291 }
292
293 fn gen() -> Self {
294 $name::Point::generator()
295 }
296
297 fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) {
298 let coordinates = self.affine().coordinates();
299 if coordinates.is_some().unwrap_u8() == 1
300 && ($name_curve_affine::identity() != self.affine())
301 {
302 (*coordinates.unwrap().x(), *coordinates.unwrap().y(), false)
303 } else {
304 (Self::Base::zero(), Self::Base::zero(), true)
305 }
306 }
307 }
308
309 impl PrimeFieldExt for $name::Scalar {
310 fn from_uniform(bytes: &[u8]) -> Self {
311 let bytes_arr: [u8; 64] = bytes.try_into().unwrap();
312 $name::Scalar::from_uniform_bytes(&bytes_arr)
313 }
314 }
315
316 impl<G: Group> TranscriptReprTrait<G> for $name::Scalar {
317 fn to_transcript_bytes(&self) -> Vec<u8> {
318 #[cfg(not(feature = "evm"))]
319 {
320 self.to_bytes().into_iter().collect()
321 }
322 #[cfg(feature = "evm")]
323 {
324 self.to_bytes().into_iter().rev().collect()
325 }
326 }
327 }
328
329 impl<G: DlogGroup> TranscriptReprTrait<G> for $name::Affine {
330 fn to_transcript_bytes(&self) -> Vec<u8> {
331 let coords = self.coordinates().unwrap();
332 let x_bytes = coords.x().to_bytes().into_iter();
333 let y_bytes = coords.y().to_bytes().into_iter();
334 #[cfg(not(feature = "evm"))]
335 {
336 x_bytes.chain(y_bytes).collect()
337 }
338 #[cfg(feature = "evm")]
339 {
340 x_bytes.rev().chain(y_bytes.rev()).collect()
341 }
342 }
343 }
344 };
345}
346
347#[macro_export]
349macro_rules! impl_traits {
350 (
351 $name:ident,
352 $name_curve:ident,
353 $name_curve_affine:ident,
354 $order_str:literal,
355 $base_str:literal
356 ) => {
357 $crate::impl_traits_no_dlog_ext!(
358 $name,
359 $name_curve,
360 $name_curve_affine,
361 $order_str,
362 $base_str
363 );
364
365 impl DlogGroupExt for $name::Point {
366 fn vartime_multiscalar_mul(
367 scalars: &[Self::Scalar],
368 bases: &[Self::AffineGroupElement],
369 ) -> Self {
370 msm(scalars, bases)
371 }
372
373 fn vartime_multiscalar_mul_small<T: Integer + Into<u64> + Copy + Sync + ToPrimitive>(
374 scalars: &[T],
375 bases: &[Self::AffineGroupElement],
376 ) -> Self {
377 msm_small(scalars, bases)
378 }
379
380 fn vartime_multiscalar_mul_small_with_max_num_bits<
382 T: Integer + Into<u64> + Copy + Sync + ToPrimitive,
383 >(
384 scalars: &[T],
385 bases: &[Self::AffineGroupElement],
386 max_num_bits: usize,
387 ) -> Self {
388 msm_small_with_max_num_bits(scalars, bases, max_num_bits)
389 }
390 }
391 };
392}