1use crate::base::scalar::{Scalar, ScalarConversionError, ScalarExt};
2use alloc::{
3 format,
4 string::{String, ToString},
5 vec::Vec,
6};
7use ark_ff::{AdditiveGroup, BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use bnum::types::U256;
10use bytemuck::TransparentWrapper;
11use core::{
12 cmp::Ordering,
13 fmt,
14 fmt::{Debug, Display, Formatter},
15 hash::{Hash, Hasher},
16 iter::{Product, Sum},
17 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
18};
19use num_bigint::BigInt;
20use num_traits::Signed;
21use serde::{Deserialize, Deserializer, Serialize, Serializer};
22#[derive(CanonicalSerialize, CanonicalDeserialize, TransparentWrapper)]
23#[repr(transparent)]
27pub struct MontScalar<T: MontConfig<4>>(pub Fp256<MontBackend<T, 4>>);
28
29impl<T: MontConfig<4>> Add for MontScalar<T> {
34 type Output = Self;
35 fn add(self, rhs: Self) -> Self::Output {
36 Self(self.0 + rhs.0)
37 }
38}
39impl<T: MontConfig<4>> Sub for MontScalar<T> {
40 type Output = Self;
41 fn sub(self, rhs: Self) -> Self::Output {
42 Self(self.0 - rhs.0)
43 }
44}
45impl<T: MontConfig<4>> Mul for MontScalar<T> {
46 type Output = Self;
47 fn mul(self, rhs: Self) -> Self::Output {
48 Self(self.0 * rhs.0)
49 }
50}
51impl<T: MontConfig<4>> AddAssign for MontScalar<T> {
52 fn add_assign(&mut self, rhs: Self) {
53 self.0 += rhs.0;
54 }
55}
56impl<T: MontConfig<4>> SubAssign for MontScalar<T> {
57 fn sub_assign(&mut self, rhs: Self) {
58 self.0 -= rhs.0;
59 }
60}
61impl<T: MontConfig<4>> MulAssign for MontScalar<T> {
62 fn mul_assign(&mut self, rhs: Self) {
63 self.0 *= rhs.0;
64 }
65}
66impl<T: MontConfig<4>> Neg for MontScalar<T> {
67 type Output = Self;
68 fn neg(self) -> Self::Output {
69 Self(-self.0)
70 }
71}
72impl<T: MontConfig<4>> Sum for MontScalar<T> {
73 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
74 Self(iter.map(|x| x.0).sum())
75 }
76}
77impl<T: MontConfig<4>> Product for MontScalar<T> {
78 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
79 Self(iter.map(|x| x.0).product())
80 }
81}
82impl<T: MontConfig<4>> Clone for MontScalar<T> {
83 fn clone(&self) -> Self {
84 *self
85 }
86}
87impl<T: MontConfig<4>> Copy for MontScalar<T> {}
88impl<T: MontConfig<4>> PartialOrd for MontScalar<T> {
89 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
90 Some(self.cmp(other))
91 }
92}
93impl<T: MontConfig<4>> PartialEq for MontScalar<T> {
94 fn eq(&self, other: &Self) -> bool {
95 self.0 == other.0
96 }
97}
98impl<T: MontConfig<4>> Default for MontScalar<T> {
99 fn default() -> Self {
100 Self(Fp::default())
101 }
102}
103impl<T: MontConfig<4>> Debug for MontScalar<T> {
104 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
105 f.debug_tuple("MontScalar").field(&self.0).finish()
106 }
107}
108impl<T: MontConfig<4>> Eq for MontScalar<T> {}
109impl<T: MontConfig<4>> Hash for MontScalar<T> {
110 fn hash<H: Hasher>(&self, state: &mut H) {
111 self.0.hash(state);
112 }
113}
114impl<T: MontConfig<4>> Ord for MontScalar<T> {
115 fn cmp(&self, other: &Self) -> Ordering {
116 self.0.cmp(&other.0)
117 }
118}
119macro_rules! impl_from_for_mont_scalar_for_type_supported_by_from {
125 ($tt:ty) => {
126 impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
127 fn from(x: $tt) -> Self {
128 Self(x.into())
129 }
130 }
131 };
132}
133
134impl<T: MontConfig<4>> From<&[u8]> for MontScalar<T> {
136 fn from(x: &[u8]) -> Self {
137 ScalarExt::from_byte_slice_via_hash(x)
138 }
139}
140
141macro_rules! impl_from_for_mont_scalar_for_string {
143 ($tt:ty) => {
144 impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
145 fn from(x: $tt) -> Self {
146 x.as_bytes().into()
147 }
148 }
149 };
150}
151
152impl_from_for_mont_scalar_for_type_supported_by_from!(bool);
153impl_from_for_mont_scalar_for_type_supported_by_from!(u8);
154impl_from_for_mont_scalar_for_type_supported_by_from!(u16);
155impl_from_for_mont_scalar_for_type_supported_by_from!(u32);
156impl_from_for_mont_scalar_for_type_supported_by_from!(u64);
157impl_from_for_mont_scalar_for_type_supported_by_from!(u128);
158impl_from_for_mont_scalar_for_type_supported_by_from!(i8);
159impl_from_for_mont_scalar_for_type_supported_by_from!(i16);
160impl_from_for_mont_scalar_for_type_supported_by_from!(i32);
161impl_from_for_mont_scalar_for_type_supported_by_from!(i64);
162impl_from_for_mont_scalar_for_type_supported_by_from!(i128);
163impl_from_for_mont_scalar_for_string!(&str);
164impl_from_for_mont_scalar_for_string!(String);
165
166impl<F: MontConfig<4>, T> From<&T> for MontScalar<F>
167where
168 T: Into<MontScalar<F>> + Clone,
169{
170 fn from(x: &T) -> Self {
171 x.clone().into()
172 }
173}
174
175impl<T: MontConfig<4>> MontScalar<T> {
176 #[cfg(test)]
178 #[must_use]
179 pub fn new(value: Fp256<MontBackend<T, 4>>) -> Self {
180 Self(value)
181 }
182
183 #[must_use]
189 pub fn from_bigint(vals: [u64; 4]) -> Self {
190 Self(Fp::from_bigint(ark_ff::BigInt(vals)).unwrap())
191 }
192 #[must_use]
194 pub fn from_le_bytes_mod_order(bytes: &[u8]) -> Self {
195 Self(Fp::from_le_bytes_mod_order(bytes))
196 }
197 #[must_use]
199 pub fn to_bytes_le(&self) -> Vec<u8> {
200 self.0.into_bigint().to_bytes_le()
201 }
202 #[cfg(test)]
204 pub fn wrap_slice(slice: &[Fp256<MontBackend<T, 4>>]) -> Vec<Self> {
205 slice.iter().copied().map(Self).collect()
206 }
207 #[cfg(test)]
209 #[must_use]
210 pub fn unwrap_slice(slice: &[Self]) -> Vec<Fp256<MontBackend<T, 4>>> {
211 slice.iter().map(|x| x.0).collect()
212 }
213}
214
215impl<T> TryFrom<BigInt> for MontScalar<T>
216where
217 T: MontConfig<4>,
218 MontScalar<T>: Scalar,
219{
220 type Error = ScalarConversionError;
221
222 fn try_from(value: BigInt) -> Result<Self, Self::Error> {
223 if value.abs() > BigInt::from(<MontScalar<T>>::MAX_SIGNED) {
224 return Err(ScalarConversionError::Overflow {
225 error: "BigInt too large for Scalar".to_string(),
226 });
227 }
228
229 let (sign, digits) = value.to_u64_digits();
230 assert!(digits.len() <= 4); let mut limbs = [0u64; 4];
232 limbs[..digits.len()].copy_from_slice(&digits);
233 let result = Self::from(limbs);
234 Ok(match sign {
235 num_bigint::Sign::Minus => -result,
236 num_bigint::Sign::Plus | num_bigint::Sign::NoSign => result,
237 })
238 }
239}
240impl<T: MontConfig<4>> From<[u64; 4]> for MontScalar<T> {
241 fn from(value: [u64; 4]) -> Self {
242 Self(Fp::new(ark_ff::BigInt(value)))
243 }
244}
245
246impl<T: MontConfig<4>> ark_std::UniformRand for MontScalar<T> {
247 fn rand<R: ark_std::rand::Rng + ?Sized>(rng: &mut R) -> Self {
248 Self(ark_ff::UniformRand::rand(rng))
249 }
250}
251
252impl<'a, T: MontConfig<4>> Sum<&'a Self> for MontScalar<T> {
253 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
254 Self(iter.map(|x| x.0).sum())
255 }
256}
257impl<T: MontConfig<4>> num_traits::One for MontScalar<T> {
258 fn one() -> Self {
259 Self(Fp::one())
260 }
261}
262impl<T: MontConfig<4>> num_traits::Zero for MontScalar<T> {
263 fn zero() -> Self {
264 Self(Fp::zero())
265 }
266 fn is_zero(&self) -> bool {
267 self.0.is_zero()
268 }
269}
270impl<T: MontConfig<4>> num_traits::Inv for MontScalar<T> {
271 type Output = Option<Self>;
272 fn inv(self) -> Option<Self> {
273 self.0.inverse().map(Self)
274 }
275}
276impl<T: MontConfig<4>> Serialize for MontScalar<T> {
277 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
278 let mut limbs: [u64; 4] = self.into();
279 limbs.reverse();
280 limbs.serialize(serializer)
281 }
282}
283impl<'de, T: MontConfig<4>> Deserialize<'de> for MontScalar<T> {
284 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
285 let mut limbs: [u64; 4] = Deserialize::deserialize(deserializer)?;
286 limbs.reverse();
287 Ok(limbs.into())
288 }
289}
290
291impl<T: MontConfig<4>> core::ops::Neg for &MontScalar<T> {
292 type Output = MontScalar<T>;
293 fn neg(self) -> Self::Output {
294 MontScalar(-self.0)
295 }
296}
297
298impl<T: MontConfig<4>> From<MontScalar<T>> for [u64; 4] {
299 fn from(value: MontScalar<T>) -> Self {
300 (&value).into()
301 }
302}
303
304impl<T: MontConfig<4>> From<&MontScalar<T>> for [u64; 4] {
305 fn from(value: &MontScalar<T>) -> Self {
306 value.0.into_bigint().0
307 }
308}
309
310impl<T: MontConfig<4>> Display for MontScalar<T> {
311 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
312 let sign = if f.sign_plus() {
313 let n = -self;
314 if self > &n {
315 Some(Some(n))
316 } else {
317 Some(None)
318 }
319 } else {
320 None
321 };
322 match (f.alternate(), sign) {
323 (false, None) => {
324 let data = self.0.into_bigint().0;
325 write!(
326 f,
327 "{:016X}{:016X}{:016X}{:016X}",
328 data[3], data[2], data[1], data[0],
329 )
330 }
331 (false, Some(None)) => {
332 let data = self.0.into_bigint().0;
333 write!(
334 f,
335 "+{:016X}{:016X}{:016X}{:016X}",
336 data[3], data[2], data[1], data[0],
337 )
338 }
339 (false, Some(Some(n))) => {
340 let data = n.0.into_bigint().0;
341 write!(
342 f,
343 "-{:016X}{:016X}{:016X}{:016X}",
344 data[3], data[2], data[1], data[0],
345 )
346 }
347 (true, None) => {
348 let data = self.to_bytes_le();
349 write!(
350 f,
351 "0x{:02X}{:02X}...{:02X}{:02X}",
352 data[31], data[30], data[1], data[0],
353 )
354 }
355 (true, Some(None)) => {
356 let data = self.to_bytes_le();
357 write!(
358 f,
359 "+0x{:02X}{:02X}...{:02X}{:02X}",
360 data[31], data[30], data[1], data[0],
361 )
362 }
363 (true, Some(Some(n))) => {
364 let data = n.to_bytes_le();
365 write!(
366 f,
367 "-0x{:02X}{:02X}...{:02X}{:02X}",
368 data[31], data[30], data[1], data[0],
369 )
370 }
371 }
372 }
373}
374
375impl<T> Scalar for MontScalar<T>
376where
377 T: MontConfig<4>,
378{
379 const MAX_SIGNED: Self = Self(Fp::new(T::MODULUS.divide_by_2_round_down()));
380 const ZERO: Self = Self(Fp::ZERO);
381 const ONE: Self = Self(Fp::ONE);
382 const TWO: Self = Self(Fp::new(ark_ff::BigInt([2, 0, 0, 0])));
383 const TEN: Self = Self(Fp::new(ark_ff::BigInt([10, 0, 0, 0])));
384 const TWO_POW_64: Self = Self(Fp::new(ark_ff::BigInt([0, 1, 0, 0])));
385 const CHALLENGE_MASK: U256 = {
386 assert!(
387 T::MODULUS.0[3].leading_zeros() < 64,
388 "modulus expected to be larger than 1 << (64*3)"
389 );
390 U256::from_digits([
391 u64::MAX,
392 u64::MAX,
393 u64::MAX,
394 u64::MAX >> (T::MODULUS.0[3].leading_zeros() + 1),
395 ])
396 };
397 #[expect(clippy::cast_possible_truncation)]
398 const MAX_BITS: u8 = {
399 assert!(
400 T::MODULUS.0[3].leading_zeros() < 64,
401 "modulus expected to be larger than 1 << (64*3)"
402 );
403 255 - T::MODULUS.0[3].leading_zeros() as u8
404 };
405 const MAX_SIGNED_U256: U256 = U256::from_digits(T::MODULUS.divide_by_2_round_down().0);
406}
407
408impl<T> TryFrom<MontScalar<T>> for bool
409where
410 T: MontConfig<4>,
411 MontScalar<T>: Scalar,
412{
413 type Error = ScalarConversionError;
414 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
415 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
416 (-1, (-value).into())
417 } else {
418 (1, value.into())
419 };
420 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
421 return Err(ScalarConversionError::Overflow {
422 error: format!("{value} is too large to fit in an i8"),
423 });
424 }
425 let val: i128 = sign * i128::from(abs[0]);
426 match val {
427 0 => Ok(false),
428 1 => Ok(true),
429 _ => Err(ScalarConversionError::Overflow {
430 error: format!("{value} is too large to fit in a bool"),
431 }),
432 }
433 }
434}
435
436impl<T> TryFrom<MontScalar<T>> for u8
437where
438 T: MontConfig<4>,
439 MontScalar<T>: Scalar,
440{
441 type Error = ScalarConversionError;
442
443 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
444 if value < MontScalar::<T>::ZERO {
445 return Err(ScalarConversionError::Overflow {
446 error: format!("{value} is negative and cannot fit in a u8"),
447 });
448 }
449
450 let abs: [u64; 4] = value.into();
451
452 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
453 return Err(ScalarConversionError::Overflow {
454 error: format!("{value} is too large to fit in a u8"),
455 });
456 }
457
458 abs[0]
459 .try_into()
460 .map_err(|_| ScalarConversionError::Overflow {
461 error: format!("{value} is too large to fit in a u8"),
462 })
463 }
464}
465
466impl<T> TryFrom<MontScalar<T>> for i8
467where
468 T: MontConfig<4>,
469 MontScalar<T>: Scalar,
470{
471 type Error = ScalarConversionError;
472 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
473 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
474 (-1, (-value).into())
475 } else {
476 (1, value.into())
477 };
478 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
479 return Err(ScalarConversionError::Overflow {
480 error: format!("{value} is too large to fit in an i8"),
481 });
482 }
483 let val: i128 = sign * i128::from(abs[0]);
484 val.try_into().map_err(|_| ScalarConversionError::Overflow {
485 error: format!("{value} is too large to fit in an i8"),
486 })
487 }
488}
489
490impl<T> TryFrom<MontScalar<T>> for i16
491where
492 T: MontConfig<4>,
493 MontScalar<T>: Scalar,
494{
495 type Error = ScalarConversionError;
496 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
497 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
498 (-1, (-value).into())
499 } else {
500 (1, value.into())
501 };
502 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
503 return Err(ScalarConversionError::Overflow {
504 error: format!("{value} is too large to fit in an i16"),
505 });
506 }
507 let val: i128 = sign * i128::from(abs[0]);
508 val.try_into().map_err(|_| ScalarConversionError::Overflow {
509 error: format!("{value} is too large to fit in an i16"),
510 })
511 }
512}
513
514impl<T> TryFrom<MontScalar<T>> for i32
515where
516 T: MontConfig<4>,
517 MontScalar<T>: Scalar,
518{
519 type Error = ScalarConversionError;
520 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
521 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
522 (-1, (-value).into())
523 } else {
524 (1, value.into())
525 };
526 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
527 return Err(ScalarConversionError::Overflow {
528 error: format!("{value} is too large to fit in an i32"),
529 });
530 }
531 let val: i128 = sign * i128::from(abs[0]);
532 val.try_into().map_err(|_| ScalarConversionError::Overflow {
533 error: format!("{value} is too large to fit in an i32"),
534 })
535 }
536}
537
538impl<T> TryFrom<MontScalar<T>> for i64
539where
540 T: MontConfig<4>,
541 MontScalar<T>: Scalar,
542{
543 type Error = ScalarConversionError;
544 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
545 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
546 (-1, (-value).into())
547 } else {
548 (1, value.into())
549 };
550 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
551 return Err(ScalarConversionError::Overflow {
552 error: format!("{value} is too large to fit in an i64"),
553 });
554 }
555 let val: i128 = sign * i128::from(abs[0]);
556 val.try_into().map_err(|_| ScalarConversionError::Overflow {
557 error: format!("{value} is too large to fit in an i64"),
558 })
559 }
560}
561
562impl<T> TryFrom<MontScalar<T>> for i128
563where
564 T: MontConfig<4>,
565 MontScalar<T>: Scalar,
566{
567 type Error = ScalarConversionError;
568
569 #[expect(clippy::cast_possible_wrap)]
570 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
571 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
572 (-1, (-value).into())
573 } else {
574 (1, value.into())
575 };
576 if abs[2] != 0 || abs[3] != 0 {
577 return Err(ScalarConversionError::Overflow {
578 error: format!("{value} is too large to fit in an i128"),
579 });
580 }
581 let val: u128 = (u128::from(abs[1]) << 64) | (u128::from(abs[0]));
582 match (sign, val) {
583 (1, v) if v <= i128::MAX as u128 => Ok(v as i128),
584 (-1, v) if v <= i128::MAX as u128 => Ok(-(v as i128)),
585 (-1, v) if v == i128::MAX as u128 + 1 => Ok(i128::MIN),
586 _ => Err(ScalarConversionError::Overflow {
587 error: format!("{value} is too large to fit in an i128"),
588 }),
589 }
590 }
591}
592
593impl<T> From<MontScalar<T>> for BigInt
594where
595 T: MontConfig<4>,
596 MontScalar<T>: Scalar,
597{
598 fn from(value: MontScalar<T>) -> Self {
599 let is_negative = value > <MontScalar<T>>::MAX_SIGNED;
601 let sign = if is_negative {
602 num_bigint::Sign::Minus
603 } else {
604 num_bigint::Sign::Plus
605 };
606 let value_abs: [u64; 4] = (if is_negative { -value } else { value }).into();
607 let bits: &[u8] = bytemuck::cast_slice(&value_abs);
608 BigInt::from_bytes_le(sign, bits)
609 }
610}