1use crate::base::scalar::{Scalar, ScalarConversionError, ScalarExt};
2use alloc::{
3 format,
4 string::{String, ToString},
5 vec::Vec,
6};
7use ark_ff::{AdditiveGroup, BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use bnum::types::U256;
10use bytemuck::TransparentWrapper;
11use core::{
12 cmp::Ordering,
13 fmt,
14 fmt::{Debug, Display, Formatter},
15 hash::{Hash, Hasher},
16 iter::{Product, Sum},
17 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
18};
19use num_bigint::BigInt;
20use num_traits::Signed;
21use serde::{Deserialize, Deserializer, Serialize, Serializer};
22#[derive(CanonicalSerialize, CanonicalDeserialize, TransparentWrapper)]
23#[repr(transparent)]
27pub struct MontScalar<T: MontConfig<4>>(pub Fp256<MontBackend<T, 4>>);
28
29impl<T: MontConfig<4>> Add for MontScalar<T> {
34 type Output = Self;
35 fn add(self, rhs: Self) -> Self::Output {
36 Self(self.0 + rhs.0)
37 }
38}
39impl<T: MontConfig<4>> Sub for MontScalar<T> {
40 type Output = Self;
41 fn sub(self, rhs: Self) -> Self::Output {
42 Self(self.0 - rhs.0)
43 }
44}
45impl<T: MontConfig<4>> Mul for MontScalar<T> {
46 type Output = Self;
47 fn mul(self, rhs: Self) -> Self::Output {
48 Self(self.0 * rhs.0)
49 }
50}
51impl<T: MontConfig<4>> AddAssign for MontScalar<T> {
52 fn add_assign(&mut self, rhs: Self) {
53 self.0 += rhs.0;
54 }
55}
56impl<T: MontConfig<4>> SubAssign for MontScalar<T> {
57 fn sub_assign(&mut self, rhs: Self) {
58 self.0 -= rhs.0;
59 }
60}
61impl<T: MontConfig<4>> MulAssign for MontScalar<T> {
62 fn mul_assign(&mut self, rhs: Self) {
63 self.0 *= rhs.0;
64 }
65}
66impl<T: MontConfig<4>> Neg for MontScalar<T> {
67 type Output = Self;
68 fn neg(self) -> Self::Output {
69 Self(-self.0)
70 }
71}
72impl<T: MontConfig<4>> Sum for MontScalar<T> {
73 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
74 Self(iter.map(|x| x.0).sum())
75 }
76}
77impl<T: MontConfig<4>> Product for MontScalar<T> {
78 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
79 Self(iter.map(|x| x.0).product())
80 }
81}
82impl<T: MontConfig<4>> Clone for MontScalar<T> {
83 fn clone(&self) -> Self {
84 *self
85 }
86}
87impl<T: MontConfig<4>> Copy for MontScalar<T> {}
88impl<T: MontConfig<4>> PartialOrd for MontScalar<T> {
89 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
90 Some(self.cmp(other))
91 }
92}
93impl<T: MontConfig<4>> PartialEq for MontScalar<T> {
94 fn eq(&self, other: &Self) -> bool {
95 self.0 == other.0
96 }
97}
98impl<T: MontConfig<4>> Default for MontScalar<T> {
99 fn default() -> Self {
100 Self(Fp::default())
101 }
102}
103impl<T: MontConfig<4>> Debug for MontScalar<T> {
104 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
105 f.debug_tuple("MontScalar").field(&self.0).finish()
106 }
107}
108impl<T: MontConfig<4>> Eq for MontScalar<T> {}
109impl<T: MontConfig<4>> Hash for MontScalar<T> {
110 fn hash<H: Hasher>(&self, state: &mut H) {
111 self.0.hash(state);
112 }
113}
114impl<T: MontConfig<4>> Ord for MontScalar<T> {
115 fn cmp(&self, other: &Self) -> Ordering {
116 self.0.cmp(&other.0)
117 }
118}
119macro_rules! impl_from_for_mont_scalar_for_type_supported_by_from {
125 ($tt:ty) => {
126 impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
127 fn from(x: $tt) -> Self {
128 Self(x.into())
129 }
130 }
131 };
132}
133
134impl<T: MontConfig<4>> From<&[u8]> for MontScalar<T> {
136 fn from(x: &[u8]) -> Self {
137 ScalarExt::from_byte_slice_via_hash(x)
138 }
139}
140
141macro_rules! impl_from_for_mont_scalar_for_string {
143 ($tt:ty) => {
144 impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
145 fn from(x: $tt) -> Self {
146 x.as_bytes().into()
147 }
148 }
149 };
150}
151
152impl_from_for_mont_scalar_for_type_supported_by_from!(bool);
153impl_from_for_mont_scalar_for_type_supported_by_from!(u8);
154impl_from_for_mont_scalar_for_type_supported_by_from!(u16);
155impl_from_for_mont_scalar_for_type_supported_by_from!(u32);
156impl_from_for_mont_scalar_for_type_supported_by_from!(u64);
157impl_from_for_mont_scalar_for_type_supported_by_from!(u128);
158impl_from_for_mont_scalar_for_type_supported_by_from!(i8);
159impl_from_for_mont_scalar_for_type_supported_by_from!(i16);
160impl_from_for_mont_scalar_for_type_supported_by_from!(i32);
161impl_from_for_mont_scalar_for_type_supported_by_from!(i64);
162impl_from_for_mont_scalar_for_type_supported_by_from!(i128);
163impl_from_for_mont_scalar_for_string!(&str);
164impl_from_for_mont_scalar_for_string!(String);
165
166impl<F: MontConfig<4>, T> From<&T> for MontScalar<F>
167where
168 T: Into<MontScalar<F>> + Clone,
169{
170 fn from(x: &T) -> Self {
171 x.clone().into()
172 }
173}
174
175impl<T: MontConfig<4>> MontScalar<T> {
176 #[cfg(test)]
178 #[must_use]
179 pub fn new(value: Fp256<MontBackend<T, 4>>) -> Self {
180 Self(value)
181 }
182
183 #[must_use]
189 pub fn from_bigint(vals: [u64; 4]) -> Self {
190 Self(Fp::from_bigint(ark_ff::BigInt(vals)).unwrap())
191 }
192 #[must_use]
194 pub fn from_le_bytes_mod_order(bytes: &[u8]) -> Self {
195 Self(Fp::from_le_bytes_mod_order(bytes))
196 }
197 #[must_use]
199 pub fn to_bytes_le(&self) -> Vec<u8> {
200 self.0.into_bigint().to_bytes_le()
201 }
202 #[cfg(test)]
204 pub fn wrap_slice(slice: &[Fp256<MontBackend<T, 4>>]) -> Vec<Self> {
205 slice.iter().copied().map(Self).collect()
206 }
207 #[cfg(test)]
209 #[must_use]
210 pub fn unwrap_slice(slice: &[Self]) -> Vec<Fp256<MontBackend<T, 4>>> {
211 slice.iter().map(|x| x.0).collect()
212 }
213}
214
215impl<T> TryFrom<BigInt> for MontScalar<T>
216where
217 T: MontConfig<4>,
218 MontScalar<T>: Scalar,
219{
220 type Error = ScalarConversionError;
221
222 fn try_from(value: BigInt) -> Result<Self, Self::Error> {
223 if value.abs() > BigInt::from(<MontScalar<T>>::MAX_SIGNED) {
224 return Err(ScalarConversionError::Overflow {
225 error: "BigInt too large for Scalar".to_string(),
226 });
227 }
228
229 let (sign, digits) = value.to_u64_digits();
230 assert!(digits.len() <= 4); let mut limbs = [0u64; 4];
232 limbs[..digits.len()].copy_from_slice(&digits);
233 let result = Self::from(limbs);
234 Ok(match sign {
235 num_bigint::Sign::Minus => -result,
236 num_bigint::Sign::Plus | num_bigint::Sign::NoSign => result,
237 })
238 }
239}
240impl<T: MontConfig<4>> From<[u64; 4]> for MontScalar<T> {
241 fn from(value: [u64; 4]) -> Self {
242 Self(Fp::new(ark_ff::BigInt(value)))
243 }
244}
245
246impl<T: MontConfig<4>> ark_std::UniformRand for MontScalar<T> {
247 fn rand<R: ark_std::rand::Rng + ?Sized>(rng: &mut R) -> Self {
248 Self(ark_ff::UniformRand::rand(rng))
249 }
250}
251
252impl<'a, T: MontConfig<4>> Sum<&'a Self> for MontScalar<T> {
253 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
254 Self(iter.map(|x| x.0).sum())
255 }
256}
257impl<T: MontConfig<4>> num_traits::One for MontScalar<T> {
258 fn one() -> Self {
259 Self(Fp::one())
260 }
261}
262impl<T: MontConfig<4>> num_traits::Zero for MontScalar<T> {
263 fn zero() -> Self {
264 Self(Fp::zero())
265 }
266 fn is_zero(&self) -> bool {
267 self.0.is_zero()
268 }
269}
270impl<T: MontConfig<4>> num_traits::Inv for MontScalar<T> {
271 type Output = Option<Self>;
272 fn inv(self) -> Option<Self> {
273 self.0.inverse().map(Self)
274 }
275}
276impl<T: MontConfig<4>> Serialize for MontScalar<T> {
277 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
278 let mut limbs: [u64; 4] = self.into();
279 limbs.reverse();
280 limbs.serialize(serializer)
281 }
282}
283impl<'de, T: MontConfig<4>> Deserialize<'de> for MontScalar<T> {
284 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
285 let mut limbs: [u64; 4] = Deserialize::deserialize(deserializer)?;
286 limbs.reverse();
287 Ok(limbs.into())
288 }
289}
290
291impl<T: MontConfig<4>> core::ops::Neg for &MontScalar<T> {
292 type Output = MontScalar<T>;
293 fn neg(self) -> Self::Output {
294 MontScalar(-self.0)
295 }
296}
297
298impl<T: MontConfig<4>> From<MontScalar<T>> for [u64; 4] {
299 fn from(value: MontScalar<T>) -> Self {
300 (&value).into()
301 }
302}
303
304impl<T: MontConfig<4>> From<&MontScalar<T>> for [u64; 4] {
305 fn from(value: &MontScalar<T>) -> Self {
306 value.0.into_bigint().0
307 }
308}
309
310impl<T: MontConfig<4>> Display for MontScalar<T> {
311 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
312 let sign = if f.sign_plus() {
313 let n = -self;
314 if self > &n {
315 Some(Some(n))
316 } else {
317 Some(None)
318 }
319 } else {
320 None
321 };
322 match (f.alternate(), sign) {
323 (false, None) => {
324 let data = self.0.into_bigint().0;
325 write!(
326 f,
327 "{:016X}{:016X}{:016X}{:016X}",
328 data[3], data[2], data[1], data[0],
329 )
330 }
331 (false, Some(None)) => {
332 let data = self.0.into_bigint().0;
333 write!(
334 f,
335 "+{:016X}{:016X}{:016X}{:016X}",
336 data[3], data[2], data[1], data[0],
337 )
338 }
339 (false, Some(Some(n))) => {
340 let data = n.0.into_bigint().0;
341 write!(
342 f,
343 "-{:016X}{:016X}{:016X}{:016X}",
344 data[3], data[2], data[1], data[0],
345 )
346 }
347 (true, None) => {
348 let data = self.to_bytes_le();
349 write!(
350 f,
351 "0x{:02X}{:02X}...{:02X}{:02X}",
352 data[31], data[30], data[1], data[0],
353 )
354 }
355 (true, Some(None)) => {
356 let data = self.to_bytes_le();
357 write!(
358 f,
359 "+0x{:02X}{:02X}...{:02X}{:02X}",
360 data[31], data[30], data[1], data[0],
361 )
362 }
363 (true, Some(Some(n))) => {
364 let data = n.to_bytes_le();
365 write!(
366 f,
367 "-0x{:02X}{:02X}...{:02X}{:02X}",
368 data[31], data[30], data[1], data[0],
369 )
370 }
371 }
372 }
373}
374
375impl<T> Scalar for MontScalar<T>
376where
377 T: MontConfig<4>,
378{
379 const MAX_SIGNED: Self = Self(Fp::new(T::MODULUS.divide_by_2_round_down()));
380 const ZERO: Self = Self(Fp::ZERO);
381 const ONE: Self = Self(Fp::ONE);
382 const TWO: Self = Self(Fp::new(ark_ff::BigInt([2, 0, 0, 0])));
383 const TEN: Self = Self(Fp::new(ark_ff::BigInt([10, 0, 0, 0])));
384 const TWO_POW_64: Self = Self(Fp::new(ark_ff::BigInt([0, 1, 0, 0])));
385 const CHALLENGE_MASK: U256 = {
386 assert!(
387 T::MODULUS.0[3].leading_zeros() < 64,
388 "modulus expected to be larger than 1 << (64*3)"
389 );
390 U256::from_digits([
391 u64::MAX,
392 u64::MAX,
393 u64::MAX,
394 u64::MAX >> (T::MODULUS.0[3].leading_zeros() + 1),
395 ])
396 };
397 #[expect(clippy::cast_possible_truncation)]
398 const MAX_BITS: u8 = {
399 assert!(
400 T::MODULUS.0[3].leading_zeros() < 64,
401 "modulus expected to be larger than 1 << (64*3)"
402 );
403 255 - T::MODULUS.0[3].leading_zeros() as u8
404 };
405}
406
407impl<T> TryFrom<MontScalar<T>> for bool
408where
409 T: MontConfig<4>,
410 MontScalar<T>: Scalar,
411{
412 type Error = ScalarConversionError;
413 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
414 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
415 (-1, (-value).into())
416 } else {
417 (1, value.into())
418 };
419 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
420 return Err(ScalarConversionError::Overflow {
421 error: format!("{value} is too large to fit in an i8"),
422 });
423 }
424 let val: i128 = sign * i128::from(abs[0]);
425 match val {
426 0 => Ok(false),
427 1 => Ok(true),
428 _ => Err(ScalarConversionError::Overflow {
429 error: format!("{value} is too large to fit in a bool"),
430 }),
431 }
432 }
433}
434
435impl<T> TryFrom<MontScalar<T>> for u8
436where
437 T: MontConfig<4>,
438 MontScalar<T>: Scalar,
439{
440 type Error = ScalarConversionError;
441
442 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
443 if value < MontScalar::<T>::ZERO {
444 return Err(ScalarConversionError::Overflow {
445 error: format!("{value} is negative and cannot fit in a u8"),
446 });
447 }
448
449 let abs: [u64; 4] = value.into();
450
451 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
452 return Err(ScalarConversionError::Overflow {
453 error: format!("{value} is too large to fit in a u8"),
454 });
455 }
456
457 abs[0]
458 .try_into()
459 .map_err(|_| ScalarConversionError::Overflow {
460 error: format!("{value} is too large to fit in a u8"),
461 })
462 }
463}
464
465impl<T> TryFrom<MontScalar<T>> for i8
466where
467 T: MontConfig<4>,
468 MontScalar<T>: Scalar,
469{
470 type Error = ScalarConversionError;
471 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
472 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
473 (-1, (-value).into())
474 } else {
475 (1, value.into())
476 };
477 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
478 return Err(ScalarConversionError::Overflow {
479 error: format!("{value} is too large to fit in an i8"),
480 });
481 }
482 let val: i128 = sign * i128::from(abs[0]);
483 val.try_into().map_err(|_| ScalarConversionError::Overflow {
484 error: format!("{value} is too large to fit in an i8"),
485 })
486 }
487}
488
489impl<T> TryFrom<MontScalar<T>> for i16
490where
491 T: MontConfig<4>,
492 MontScalar<T>: Scalar,
493{
494 type Error = ScalarConversionError;
495 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
496 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
497 (-1, (-value).into())
498 } else {
499 (1, value.into())
500 };
501 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
502 return Err(ScalarConversionError::Overflow {
503 error: format!("{value} is too large to fit in an i16"),
504 });
505 }
506 let val: i128 = sign * i128::from(abs[0]);
507 val.try_into().map_err(|_| ScalarConversionError::Overflow {
508 error: format!("{value} is too large to fit in an i16"),
509 })
510 }
511}
512
513impl<T> TryFrom<MontScalar<T>> for i32
514where
515 T: MontConfig<4>,
516 MontScalar<T>: Scalar,
517{
518 type Error = ScalarConversionError;
519 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
520 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
521 (-1, (-value).into())
522 } else {
523 (1, value.into())
524 };
525 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
526 return Err(ScalarConversionError::Overflow {
527 error: format!("{value} is too large to fit in an i32"),
528 });
529 }
530 let val: i128 = sign * i128::from(abs[0]);
531 val.try_into().map_err(|_| ScalarConversionError::Overflow {
532 error: format!("{value} is too large to fit in an i32"),
533 })
534 }
535}
536
537impl<T> TryFrom<MontScalar<T>> for i64
538where
539 T: MontConfig<4>,
540 MontScalar<T>: Scalar,
541{
542 type Error = ScalarConversionError;
543 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
544 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
545 (-1, (-value).into())
546 } else {
547 (1, value.into())
548 };
549 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
550 return Err(ScalarConversionError::Overflow {
551 error: format!("{value} is too large to fit in an i64"),
552 });
553 }
554 let val: i128 = sign * i128::from(abs[0]);
555 val.try_into().map_err(|_| ScalarConversionError::Overflow {
556 error: format!("{value} is too large to fit in an i64"),
557 })
558 }
559}
560
561impl<T> TryFrom<MontScalar<T>> for i128
562where
563 T: MontConfig<4>,
564 MontScalar<T>: Scalar,
565{
566 type Error = ScalarConversionError;
567
568 #[expect(clippy::cast_possible_wrap)]
569 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
570 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
571 (-1, (-value).into())
572 } else {
573 (1, value.into())
574 };
575 if abs[2] != 0 || abs[3] != 0 {
576 return Err(ScalarConversionError::Overflow {
577 error: format!("{value} is too large to fit in an i128"),
578 });
579 }
580 let val: u128 = (u128::from(abs[1]) << 64) | (u128::from(abs[0]));
581 match (sign, val) {
582 (1, v) if v <= i128::MAX as u128 => Ok(v as i128),
583 (-1, v) if v <= i128::MAX as u128 => Ok(-(v as i128)),
584 (-1, v) if v == i128::MAX as u128 + 1 => Ok(i128::MIN),
585 _ => Err(ScalarConversionError::Overflow {
586 error: format!("{value} is too large to fit in an i128"),
587 }),
588 }
589 }
590}
591
592impl<T> From<MontScalar<T>> for BigInt
593where
594 T: MontConfig<4>,
595 MontScalar<T>: Scalar,
596{
597 fn from(value: MontScalar<T>) -> Self {
598 let is_negative = value > <MontScalar<T>>::MAX_SIGNED;
600 let sign = if is_negative {
601 num_bigint::Sign::Minus
602 } else {
603 num_bigint::Sign::Plus
604 };
605 let value_abs: [u64; 4] = (if is_negative { -value } else { value }).into();
606 let bits: &[u8] = bytemuck::cast_slice(&value_abs);
607 BigInt::from_bytes_le(sign, bits)
608 }
609}