1use crate::base::scalar::{Scalar, ScalarConversionError};
2use alloc::{
3 format,
4 string::{String, ToString},
5 vec::Vec,
6};
7use ark_ff::{AdditiveGroup, BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use bnum::types::U256;
10use bytemuck::TransparentWrapper;
11use core::{
12 cmp::Ordering,
13 fmt,
14 fmt::{Debug, Display, Formatter},
15 hash::{Hash, Hasher},
16 iter::{Product, Sum},
17 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
18};
19use num_bigint::BigInt;
20use num_traits::{Signed, Zero};
21use serde::{Deserialize, Deserializer, Serialize, Serializer};
22#[derive(CanonicalSerialize, CanonicalDeserialize, TransparentWrapper)]
23#[repr(transparent)]
27pub struct MontScalar<T: MontConfig<4>>(pub Fp256<MontBackend<T, 4>>);
28
29impl<T: MontConfig<4>> Add for MontScalar<T> {
34 type Output = Self;
35 fn add(self, rhs: Self) -> Self::Output {
36 Self(self.0 + rhs.0)
37 }
38}
39impl<T: MontConfig<4>> Sub for MontScalar<T> {
40 type Output = Self;
41 fn sub(self, rhs: Self) -> Self::Output {
42 Self(self.0 - rhs.0)
43 }
44}
45impl<T: MontConfig<4>> Mul for MontScalar<T> {
46 type Output = Self;
47 fn mul(self, rhs: Self) -> Self::Output {
48 Self(self.0 * rhs.0)
49 }
50}
51impl<T: MontConfig<4>> AddAssign for MontScalar<T> {
52 fn add_assign(&mut self, rhs: Self) {
53 self.0 += rhs.0;
54 }
55}
56impl<T: MontConfig<4>> SubAssign for MontScalar<T> {
57 fn sub_assign(&mut self, rhs: Self) {
58 self.0 -= rhs.0;
59 }
60}
61impl<T: MontConfig<4>> MulAssign for MontScalar<T> {
62 fn mul_assign(&mut self, rhs: Self) {
63 self.0 *= rhs.0;
64 }
65}
66impl<T: MontConfig<4>> Neg for MontScalar<T> {
67 type Output = Self;
68 fn neg(self) -> Self::Output {
69 Self(-self.0)
70 }
71}
72impl<T: MontConfig<4>> Sum for MontScalar<T> {
73 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
74 Self(iter.map(|x| x.0).sum())
75 }
76}
77impl<T: MontConfig<4>> Product for MontScalar<T> {
78 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
79 Self(iter.map(|x| x.0).product())
80 }
81}
82impl<T: MontConfig<4>> Clone for MontScalar<T> {
83 fn clone(&self) -> Self {
84 *self
85 }
86}
87impl<T: MontConfig<4>> Copy for MontScalar<T> {}
88impl<T: MontConfig<4>> PartialOrd for MontScalar<T> {
89 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
90 Some(self.cmp(other))
91 }
92}
93impl<T: MontConfig<4>> PartialEq for MontScalar<T> {
94 fn eq(&self, other: &Self) -> bool {
95 self.0 == other.0
96 }
97}
98impl<T: MontConfig<4>> Default for MontScalar<T> {
99 fn default() -> Self {
100 Self(Fp::default())
101 }
102}
103impl<T: MontConfig<4>> Debug for MontScalar<T> {
104 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
105 f.debug_tuple("MontScalar").field(&self.0).finish()
106 }
107}
108impl<T: MontConfig<4>> Eq for MontScalar<T> {}
109impl<T: MontConfig<4>> Hash for MontScalar<T> {
110 fn hash<H: Hasher>(&self, state: &mut H) {
111 self.0.hash(state);
112 }
113}
114impl<T: MontConfig<4>> Ord for MontScalar<T> {
115 fn cmp(&self, other: &Self) -> Ordering {
116 self.0.cmp(&other.0)
117 }
118}
119macro_rules! impl_from_for_mont_scalar_for_type_supported_by_from {
125 ($tt:ty) => {
126 impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
127 fn from(x: $tt) -> Self {
128 Self(x.into())
129 }
130 }
131 };
132}
133
134impl<T: MontConfig<4>> From<&[u8]> for MontScalar<T> {
136 fn from(x: &[u8]) -> Self {
137 if x.is_empty() {
138 return Self::zero();
139 }
140
141 let hash = blake3::hash(x);
142 let mut bytes: [u8; 32] = hash.into();
143 bytes[31] &= 0b0000_1111_u8;
144
145 Self::from_le_bytes_mod_order(&bytes)
146 }
147}
148
149macro_rules! impl_from_for_mont_scalar_for_string {
151 ($tt:ty) => {
152 impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
153 fn from(x: $tt) -> Self {
154 x.as_bytes().into()
155 }
156 }
157 };
158}
159
160impl_from_for_mont_scalar_for_type_supported_by_from!(bool);
161impl_from_for_mont_scalar_for_type_supported_by_from!(u8);
162impl_from_for_mont_scalar_for_type_supported_by_from!(u16);
163impl_from_for_mont_scalar_for_type_supported_by_from!(u32);
164impl_from_for_mont_scalar_for_type_supported_by_from!(u64);
165impl_from_for_mont_scalar_for_type_supported_by_from!(u128);
166impl_from_for_mont_scalar_for_type_supported_by_from!(i8);
167impl_from_for_mont_scalar_for_type_supported_by_from!(i16);
168impl_from_for_mont_scalar_for_type_supported_by_from!(i32);
169impl_from_for_mont_scalar_for_type_supported_by_from!(i64);
170impl_from_for_mont_scalar_for_type_supported_by_from!(i128);
171impl_from_for_mont_scalar_for_string!(&str);
172impl_from_for_mont_scalar_for_string!(String);
173
174impl<F: MontConfig<4>, T> From<&T> for MontScalar<F>
175where
176 T: Into<MontScalar<F>> + Clone,
177{
178 fn from(x: &T) -> Self {
179 x.clone().into()
180 }
181}
182
183impl<T: MontConfig<4>> MontScalar<T> {
184 #[cfg(test)]
186 #[must_use]
187 pub fn new(value: Fp256<MontBackend<T, 4>>) -> Self {
188 Self(value)
189 }
190
191 #[must_use]
197 pub fn from_bigint(vals: [u64; 4]) -> Self {
198 Self(Fp::from_bigint(ark_ff::BigInt(vals)).unwrap())
199 }
200 #[must_use]
202 pub fn from_le_bytes_mod_order(bytes: &[u8]) -> Self {
203 Self(Fp::from_le_bytes_mod_order(bytes))
204 }
205 #[must_use]
207 pub fn to_bytes_le(&self) -> Vec<u8> {
208 self.0.into_bigint().to_bytes_le()
209 }
210 #[cfg(test)]
212 pub fn wrap_slice(slice: &[Fp256<MontBackend<T, 4>>]) -> Vec<Self> {
213 slice.iter().copied().map(Self).collect()
214 }
215 #[cfg(test)]
217 #[must_use]
218 pub fn unwrap_slice(slice: &[Self]) -> Vec<Fp256<MontBackend<T, 4>>> {
219 slice.iter().map(|x| x.0).collect()
220 }
221}
222
223impl<T> TryFrom<BigInt> for MontScalar<T>
224where
225 T: MontConfig<4>,
226 MontScalar<T>: Scalar,
227{
228 type Error = ScalarConversionError;
229
230 fn try_from(value: BigInt) -> Result<Self, Self::Error> {
231 if value.abs() > BigInt::from(<MontScalar<T>>::MAX_SIGNED) {
232 return Err(ScalarConversionError::Overflow {
233 error: "BigInt too large for Scalar".to_string(),
234 });
235 }
236
237 let (sign, digits) = value.to_u64_digits();
238 assert!(digits.len() <= 4); let mut limbs = [0u64; 4];
240 limbs[..digits.len()].copy_from_slice(&digits);
241 let result = Self::from(limbs);
242 Ok(match sign {
243 num_bigint::Sign::Minus => -result,
244 num_bigint::Sign::Plus | num_bigint::Sign::NoSign => result,
245 })
246 }
247}
248impl<T: MontConfig<4>> From<[u64; 4]> for MontScalar<T> {
249 fn from(value: [u64; 4]) -> Self {
250 Self(Fp::new(ark_ff::BigInt(value)))
251 }
252}
253
254impl<T: MontConfig<4>> ark_std::UniformRand for MontScalar<T> {
255 fn rand<R: ark_std::rand::Rng + ?Sized>(rng: &mut R) -> Self {
256 Self(ark_ff::UniformRand::rand(rng))
257 }
258}
259
260impl<'a, T: MontConfig<4>> Sum<&'a Self> for MontScalar<T> {
261 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
262 Self(iter.map(|x| x.0).sum())
263 }
264}
265impl<T: MontConfig<4>> num_traits::One for MontScalar<T> {
266 fn one() -> Self {
267 Self(Fp::one())
268 }
269}
270impl<T: MontConfig<4>> num_traits::Zero for MontScalar<T> {
271 fn zero() -> Self {
272 Self(Fp::zero())
273 }
274 fn is_zero(&self) -> bool {
275 self.0.is_zero()
276 }
277}
278impl<T: MontConfig<4>> num_traits::Inv for MontScalar<T> {
279 type Output = Option<Self>;
280 fn inv(self) -> Option<Self> {
281 self.0.inverse().map(Self)
282 }
283}
284impl<T: MontConfig<4>> Serialize for MontScalar<T> {
285 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
286 let mut limbs: [u64; 4] = self.into();
287 limbs.reverse();
288 limbs.serialize(serializer)
289 }
290}
291impl<'de, T: MontConfig<4>> Deserialize<'de> for MontScalar<T> {
292 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
293 let mut limbs: [u64; 4] = Deserialize::deserialize(deserializer)?;
294 limbs.reverse();
295 Ok(limbs.into())
296 }
297}
298
299impl<T: MontConfig<4>> core::ops::Neg for &MontScalar<T> {
300 type Output = MontScalar<T>;
301 fn neg(self) -> Self::Output {
302 MontScalar(-self.0)
303 }
304}
305
306impl<T: MontConfig<4>> From<MontScalar<T>> for [u64; 4] {
307 fn from(value: MontScalar<T>) -> Self {
308 (&value).into()
309 }
310}
311
312impl<T: MontConfig<4>> From<&MontScalar<T>> for [u64; 4] {
313 fn from(value: &MontScalar<T>) -> Self {
314 value.0.into_bigint().0
315 }
316}
317
318impl<T: MontConfig<4>> Display for MontScalar<T> {
319 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
320 let sign = if f.sign_plus() {
321 let n = -self;
322 if self > &n {
323 Some(Some(n))
324 } else {
325 Some(None)
326 }
327 } else {
328 None
329 };
330 match (f.alternate(), sign) {
331 (false, None) => {
332 let data = self.0.into_bigint().0;
333 write!(
334 f,
335 "{:016X}{:016X}{:016X}{:016X}",
336 data[3], data[2], data[1], data[0],
337 )
338 }
339 (false, Some(None)) => {
340 let data = self.0.into_bigint().0;
341 write!(
342 f,
343 "+{:016X}{:016X}{:016X}{:016X}",
344 data[3], data[2], data[1], data[0],
345 )
346 }
347 (false, Some(Some(n))) => {
348 let data = n.0.into_bigint().0;
349 write!(
350 f,
351 "-{:016X}{:016X}{:016X}{:016X}",
352 data[3], data[2], data[1], data[0],
353 )
354 }
355 (true, None) => {
356 let data = self.to_bytes_le();
357 write!(
358 f,
359 "0x{:02X}{:02X}...{:02X}{:02X}",
360 data[31], data[30], data[1], data[0],
361 )
362 }
363 (true, Some(None)) => {
364 let data = self.to_bytes_le();
365 write!(
366 f,
367 "+0x{:02X}{:02X}...{:02X}{:02X}",
368 data[31], data[30], data[1], data[0],
369 )
370 }
371 (true, Some(Some(n))) => {
372 let data = n.to_bytes_le();
373 write!(
374 f,
375 "-0x{:02X}{:02X}...{:02X}{:02X}",
376 data[31], data[30], data[1], data[0],
377 )
378 }
379 }
380 }
381}
382
383impl<T> Scalar for MontScalar<T>
384where
385 T: MontConfig<4>,
386{
387 const MAX_SIGNED: Self = Self(Fp::new(T::MODULUS.divide_by_2_round_down()));
388 const ZERO: Self = Self(Fp::ZERO);
389 const ONE: Self = Self(Fp::ONE);
390 const TWO: Self = Self(Fp::new(ark_ff::BigInt([2, 0, 0, 0])));
391 const TEN: Self = Self(Fp::new(ark_ff::BigInt([10, 0, 0, 0])));
392 const TWO_POW_64: Self = Self(Fp::new(ark_ff::BigInt([0, 1, 0, 0])));
393 const CHALLENGE_MASK: U256 = {
394 assert!(
395 T::MODULUS.0[3].leading_zeros() < 64,
396 "modulus expected to be larger than 1 << (64*3)"
397 );
398 U256::from_digits([
399 u64::MAX,
400 u64::MAX,
401 u64::MAX,
402 u64::MAX >> (T::MODULUS.0[3].leading_zeros() + 1),
403 ])
404 };
405 #[expect(clippy::cast_possible_truncation)]
406 const MAX_BITS: u8 = {
407 assert!(
408 T::MODULUS.0[3].leading_zeros() < 64,
409 "modulus expected to be larger than 1 << (64*3)"
410 );
411 255 - T::MODULUS.0[3].leading_zeros() as u8
412 };
413}
414
415impl<T> TryFrom<MontScalar<T>> for bool
416where
417 T: MontConfig<4>,
418 MontScalar<T>: Scalar,
419{
420 type Error = ScalarConversionError;
421 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
422 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
423 (-1, (-value).into())
424 } else {
425 (1, value.into())
426 };
427 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
428 return Err(ScalarConversionError::Overflow {
429 error: format!("{value} is too large to fit in an i8"),
430 });
431 }
432 let val: i128 = sign * i128::from(abs[0]);
433 match val {
434 0 => Ok(false),
435 1 => Ok(true),
436 _ => Err(ScalarConversionError::Overflow {
437 error: format!("{value} is too large to fit in a bool"),
438 }),
439 }
440 }
441}
442
443impl<T> TryFrom<MontScalar<T>> for u8
444where
445 T: MontConfig<4>,
446 MontScalar<T>: Scalar,
447{
448 type Error = ScalarConversionError;
449
450 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
451 if value < MontScalar::<T>::ZERO {
452 return Err(ScalarConversionError::Overflow {
453 error: format!("{value} is negative and cannot fit in a u8"),
454 });
455 }
456
457 let abs: [u64; 4] = value.into();
458
459 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
460 return Err(ScalarConversionError::Overflow {
461 error: format!("{value} is too large to fit in a u8"),
462 });
463 }
464
465 abs[0]
466 .try_into()
467 .map_err(|_| ScalarConversionError::Overflow {
468 error: format!("{value} is too large to fit in a u8"),
469 })
470 }
471}
472
473impl<T> TryFrom<MontScalar<T>> for i8
474where
475 T: MontConfig<4>,
476 MontScalar<T>: Scalar,
477{
478 type Error = ScalarConversionError;
479 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
480 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
481 (-1, (-value).into())
482 } else {
483 (1, value.into())
484 };
485 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
486 return Err(ScalarConversionError::Overflow {
487 error: format!("{value} is too large to fit in an i8"),
488 });
489 }
490 let val: i128 = sign * i128::from(abs[0]);
491 val.try_into().map_err(|_| ScalarConversionError::Overflow {
492 error: format!("{value} is too large to fit in an i8"),
493 })
494 }
495}
496
497impl<T> TryFrom<MontScalar<T>> for i16
498where
499 T: MontConfig<4>,
500 MontScalar<T>: Scalar,
501{
502 type Error = ScalarConversionError;
503 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
504 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
505 (-1, (-value).into())
506 } else {
507 (1, value.into())
508 };
509 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
510 return Err(ScalarConversionError::Overflow {
511 error: format!("{value} is too large to fit in an i16"),
512 });
513 }
514 let val: i128 = sign * i128::from(abs[0]);
515 val.try_into().map_err(|_| ScalarConversionError::Overflow {
516 error: format!("{value} is too large to fit in an i16"),
517 })
518 }
519}
520
521impl<T> TryFrom<MontScalar<T>> for i32
522where
523 T: MontConfig<4>,
524 MontScalar<T>: Scalar,
525{
526 type Error = ScalarConversionError;
527 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
528 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
529 (-1, (-value).into())
530 } else {
531 (1, value.into())
532 };
533 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
534 return Err(ScalarConversionError::Overflow {
535 error: format!("{value} is too large to fit in an i32"),
536 });
537 }
538 let val: i128 = sign * i128::from(abs[0]);
539 val.try_into().map_err(|_| ScalarConversionError::Overflow {
540 error: format!("{value} is too large to fit in an i32"),
541 })
542 }
543}
544
545impl<T> TryFrom<MontScalar<T>> for i64
546where
547 T: MontConfig<4>,
548 MontScalar<T>: Scalar,
549{
550 type Error = ScalarConversionError;
551 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
552 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
553 (-1, (-value).into())
554 } else {
555 (1, value.into())
556 };
557 if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
558 return Err(ScalarConversionError::Overflow {
559 error: format!("{value} is too large to fit in an i64"),
560 });
561 }
562 let val: i128 = sign * i128::from(abs[0]);
563 val.try_into().map_err(|_| ScalarConversionError::Overflow {
564 error: format!("{value} is too large to fit in an i64"),
565 })
566 }
567}
568
569impl<T> TryFrom<MontScalar<T>> for i128
570where
571 T: MontConfig<4>,
572 MontScalar<T>: Scalar,
573{
574 type Error = ScalarConversionError;
575
576 #[expect(clippy::cast_possible_wrap)]
577 fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
578 let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
579 (-1, (-value).into())
580 } else {
581 (1, value.into())
582 };
583 if abs[2] != 0 || abs[3] != 0 {
584 return Err(ScalarConversionError::Overflow {
585 error: format!("{value} is too large to fit in an i128"),
586 });
587 }
588 let val: u128 = (u128::from(abs[1]) << 64) | (u128::from(abs[0]));
589 match (sign, val) {
590 (1, v) if v <= i128::MAX as u128 => Ok(v as i128),
591 (-1, v) if v <= i128::MAX as u128 => Ok(-(v as i128)),
592 (-1, v) if v == i128::MAX as u128 + 1 => Ok(i128::MIN),
593 _ => Err(ScalarConversionError::Overflow {
594 error: format!("{value} is too large to fit in an i128"),
595 }),
596 }
597 }
598}
599
600impl<T> From<MontScalar<T>> for BigInt
601where
602 T: MontConfig<4>,
603 MontScalar<T>: Scalar,
604{
605 fn from(value: MontScalar<T>) -> Self {
606 let is_negative = value > <MontScalar<T>>::MAX_SIGNED;
608 let sign = if is_negative {
609 num_bigint::Sign::Minus
610 } else {
611 num_bigint::Sign::Plus
612 };
613 let value_abs: [u64; 4] = (if is_negative { -value } else { value }).into();
614 let bits: &[u8] = bytemuck::cast_slice(&value_abs);
615 BigInt::from_bytes_le(sign, bits)
616 }
617}