proof_of_sql/base/scalar/
mont_scalar.rs

1use crate::base::scalar::{Scalar, ScalarConversionError};
2use alloc::{
3    format,
4    string::{String, ToString},
5    vec::Vec,
6};
7use ark_ff::{AdditiveGroup, BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use bnum::types::U256;
10use bytemuck::TransparentWrapper;
11use core::{
12    cmp::Ordering,
13    fmt,
14    fmt::{Debug, Display, Formatter},
15    hash::{Hash, Hasher},
16    iter::{Product, Sum},
17    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
18};
19use num_bigint::BigInt;
20use num_traits::{Signed, Zero};
21use serde::{Deserialize, Deserializer, Serialize, Serializer};
22#[derive(CanonicalSerialize, CanonicalDeserialize, TransparentWrapper)]
23/// A wrapper struct around a `Fp256<MontBackend<T, 4>>` that can easily implement the `Scalar` trait.
24///
25/// Using the `Scalar` trait rather than this type is encouraged to allow for easier switching of the underlying field.
26#[repr(transparent)]
27pub struct MontScalar<T: MontConfig<4>>(pub Fp256<MontBackend<T, 4>>);
28
29// --------------------------------------------------------------------------------
30// replacement for #[derive(Add, Sub, Mul, AddAssign, SubAssign, MulAssign, Neg,
31//  Sum, Product, Clone, Copy, PartialOrd, PartialEq, Default, Debug, Eq, Hash, Ord)]
32// --------------------------------------------------------------------------------
33impl<T: MontConfig<4>> Add for MontScalar<T> {
34    type Output = Self;
35    fn add(self, rhs: Self) -> Self::Output {
36        Self(self.0 + rhs.0)
37    }
38}
39impl<T: MontConfig<4>> Sub for MontScalar<T> {
40    type Output = Self;
41    fn sub(self, rhs: Self) -> Self::Output {
42        Self(self.0 - rhs.0)
43    }
44}
45impl<T: MontConfig<4>> Mul for MontScalar<T> {
46    type Output = Self;
47    fn mul(self, rhs: Self) -> Self::Output {
48        Self(self.0 * rhs.0)
49    }
50}
51impl<T: MontConfig<4>> AddAssign for MontScalar<T> {
52    fn add_assign(&mut self, rhs: Self) {
53        self.0 += rhs.0;
54    }
55}
56impl<T: MontConfig<4>> SubAssign for MontScalar<T> {
57    fn sub_assign(&mut self, rhs: Self) {
58        self.0 -= rhs.0;
59    }
60}
61impl<T: MontConfig<4>> MulAssign for MontScalar<T> {
62    fn mul_assign(&mut self, rhs: Self) {
63        self.0 *= rhs.0;
64    }
65}
66impl<T: MontConfig<4>> Neg for MontScalar<T> {
67    type Output = Self;
68    fn neg(self) -> Self::Output {
69        Self(-self.0)
70    }
71}
72impl<T: MontConfig<4>> Sum for MontScalar<T> {
73    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
74        Self(iter.map(|x| x.0).sum())
75    }
76}
77impl<T: MontConfig<4>> Product for MontScalar<T> {
78    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
79        Self(iter.map(|x| x.0).product())
80    }
81}
82impl<T: MontConfig<4>> Clone for MontScalar<T> {
83    fn clone(&self) -> Self {
84        *self
85    }
86}
87impl<T: MontConfig<4>> Copy for MontScalar<T> {}
88impl<T: MontConfig<4>> PartialOrd for MontScalar<T> {
89    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
90        Some(self.cmp(other))
91    }
92}
93impl<T: MontConfig<4>> PartialEq for MontScalar<T> {
94    fn eq(&self, other: &Self) -> bool {
95        self.0 == other.0
96    }
97}
98impl<T: MontConfig<4>> Default for MontScalar<T> {
99    fn default() -> Self {
100        Self(Fp::default())
101    }
102}
103impl<T: MontConfig<4>> Debug for MontScalar<T> {
104    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
105        f.debug_tuple("MontScalar").field(&self.0).finish()
106    }
107}
108impl<T: MontConfig<4>> Eq for MontScalar<T> {}
109impl<T: MontConfig<4>> Hash for MontScalar<T> {
110    fn hash<H: Hasher>(&self, state: &mut H) {
111        self.0.hash(state);
112    }
113}
114impl<T: MontConfig<4>> Ord for MontScalar<T> {
115    fn cmp(&self, other: &Self) -> Ordering {
116        self.0.cmp(&other.0)
117    }
118}
119// --------------------------------------------------------------------------------
120// end replacement for #[derive(...)]
121// --------------------------------------------------------------------------------
122
123/// TODO: add docs
124macro_rules! impl_from_for_mont_scalar_for_type_supported_by_from {
125    ($tt:ty) => {
126        impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
127            fn from(x: $tt) -> Self {
128                Self(x.into())
129            }
130        }
131    };
132}
133
134/// Implement `From<&[u8]>` for `MontScalar`
135impl<T: MontConfig<4>> From<&[u8]> for MontScalar<T> {
136    fn from(x: &[u8]) -> Self {
137        if x.is_empty() {
138            return Self::zero();
139        }
140
141        let hash = blake3::hash(x);
142        let mut bytes: [u8; 32] = hash.into();
143        bytes[31] &= 0b0000_1111_u8;
144
145        Self::from_le_bytes_mod_order(&bytes)
146    }
147}
148
149/// TODO: add docs
150macro_rules! impl_from_for_mont_scalar_for_string {
151    ($tt:ty) => {
152        impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
153            fn from(x: $tt) -> Self {
154                x.as_bytes().into()
155            }
156        }
157    };
158}
159
160impl_from_for_mont_scalar_for_type_supported_by_from!(bool);
161impl_from_for_mont_scalar_for_type_supported_by_from!(u8);
162impl_from_for_mont_scalar_for_type_supported_by_from!(u16);
163impl_from_for_mont_scalar_for_type_supported_by_from!(u32);
164impl_from_for_mont_scalar_for_type_supported_by_from!(u64);
165impl_from_for_mont_scalar_for_type_supported_by_from!(u128);
166impl_from_for_mont_scalar_for_type_supported_by_from!(i8);
167impl_from_for_mont_scalar_for_type_supported_by_from!(i16);
168impl_from_for_mont_scalar_for_type_supported_by_from!(i32);
169impl_from_for_mont_scalar_for_type_supported_by_from!(i64);
170impl_from_for_mont_scalar_for_type_supported_by_from!(i128);
171impl_from_for_mont_scalar_for_string!(&str);
172impl_from_for_mont_scalar_for_string!(String);
173
174impl<F: MontConfig<4>, T> From<&T> for MontScalar<F>
175where
176    T: Into<MontScalar<F>> + Clone,
177{
178    fn from(x: &T) -> Self {
179        x.clone().into()
180    }
181}
182
183impl<T: MontConfig<4>> MontScalar<T> {
184    /// Convenience function for creating a new `MontScalar<T>` from the underlying `Fp256<MontBackend<T, 4>>`. Should only be used in tests.
185    #[cfg(test)]
186    #[must_use]
187    pub fn new(value: Fp256<MontBackend<T, 4>>) -> Self {
188        Self(value)
189    }
190
191    /// Create a new `MontScalar<T>` from a `[u64, 4]`. The array is expected to be in non-montgomery form.
192    ///
193    /// # Panics
194    ///
195    /// This method will panic if the provided `[u64; 4]` cannot be converted into a valid `BigInt` due to an overflow or invalid input. The method unwraps the result of `Fp::from_bigint`, which will panic if the `BigInt` does not represent a valid field element ("Invalid input" refers to an integer that is outside the valid range [0,p-1] for the prime field or cannot be represented as a canonical field element. It can also occur due to overflow or issues in the conversion process.).
196    #[must_use]
197    pub fn from_bigint(vals: [u64; 4]) -> Self {
198        Self(Fp::from_bigint(ark_ff::BigInt(vals)).unwrap())
199    }
200    /// Create a new `MontScalar<T>` from a `[u8]` modulus the field order. The array is expected to be in non-montgomery form.
201    #[must_use]
202    pub fn from_le_bytes_mod_order(bytes: &[u8]) -> Self {
203        Self(Fp::from_le_bytes_mod_order(bytes))
204    }
205    /// Create a `Vec<u8>` from a `MontScalar<T>`. The array will be in non-montgomery form.
206    #[must_use]
207    pub fn to_bytes_le(&self) -> Vec<u8> {
208        self.0.into_bigint().to_bytes_le()
209    }
210    /// Convenience function for converting a slice of `ark_curve25519::Fr` into a vector of `Curve25519Scalar`. Should not be used outside of tests.
211    #[cfg(test)]
212    pub fn wrap_slice(slice: &[Fp256<MontBackend<T, 4>>]) -> Vec<Self> {
213        slice.iter().copied().map(Self).collect()
214    }
215    /// Convenience function for converting a slice of `Curve25519Scalar` into a vector of `ark_curve25519::Fr`. Should not be used outside of tests.
216    #[cfg(test)]
217    #[must_use]
218    pub fn unwrap_slice(slice: &[Self]) -> Vec<Fp256<MontBackend<T, 4>>> {
219        slice.iter().map(|x| x.0).collect()
220    }
221}
222
223impl<T> TryFrom<BigInt> for MontScalar<T>
224where
225    T: MontConfig<4>,
226    MontScalar<T>: Scalar,
227{
228    type Error = ScalarConversionError;
229
230    fn try_from(value: BigInt) -> Result<Self, Self::Error> {
231        if value.abs() > BigInt::from(<MontScalar<T>>::MAX_SIGNED) {
232            return Err(ScalarConversionError::Overflow {
233                error: "BigInt too large for Scalar".to_string(),
234            });
235        }
236
237        let (sign, digits) = value.to_u64_digits();
238        assert!(digits.len() <= 4); // This should not happen if the above check is correct
239        let mut limbs = [0u64; 4];
240        limbs[..digits.len()].copy_from_slice(&digits);
241        let result = Self::from(limbs);
242        Ok(match sign {
243            num_bigint::Sign::Minus => -result,
244            num_bigint::Sign::Plus | num_bigint::Sign::NoSign => result,
245        })
246    }
247}
248impl<T: MontConfig<4>> From<[u64; 4]> for MontScalar<T> {
249    fn from(value: [u64; 4]) -> Self {
250        Self(Fp::new(ark_ff::BigInt(value)))
251    }
252}
253
254impl<T: MontConfig<4>> ark_std::UniformRand for MontScalar<T> {
255    fn rand<R: ark_std::rand::Rng + ?Sized>(rng: &mut R) -> Self {
256        Self(ark_ff::UniformRand::rand(rng))
257    }
258}
259
260impl<'a, T: MontConfig<4>> Sum<&'a Self> for MontScalar<T> {
261    fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
262        Self(iter.map(|x| x.0).sum())
263    }
264}
265impl<T: MontConfig<4>> num_traits::One for MontScalar<T> {
266    fn one() -> Self {
267        Self(Fp::one())
268    }
269}
270impl<T: MontConfig<4>> num_traits::Zero for MontScalar<T> {
271    fn zero() -> Self {
272        Self(Fp::zero())
273    }
274    fn is_zero(&self) -> bool {
275        self.0.is_zero()
276    }
277}
278impl<T: MontConfig<4>> num_traits::Inv for MontScalar<T> {
279    type Output = Option<Self>;
280    fn inv(self) -> Option<Self> {
281        self.0.inverse().map(Self)
282    }
283}
284impl<T: MontConfig<4>> Serialize for MontScalar<T> {
285    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
286        let mut limbs: [u64; 4] = self.into();
287        limbs.reverse();
288        limbs.serialize(serializer)
289    }
290}
291impl<'de, T: MontConfig<4>> Deserialize<'de> for MontScalar<T> {
292    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
293        let mut limbs: [u64; 4] = Deserialize::deserialize(deserializer)?;
294        limbs.reverse();
295        Ok(limbs.into())
296    }
297}
298
299impl<T: MontConfig<4>> core::ops::Neg for &MontScalar<T> {
300    type Output = MontScalar<T>;
301    fn neg(self) -> Self::Output {
302        MontScalar(-self.0)
303    }
304}
305
306impl<T: MontConfig<4>> From<MontScalar<T>> for [u64; 4] {
307    fn from(value: MontScalar<T>) -> Self {
308        (&value).into()
309    }
310}
311
312impl<T: MontConfig<4>> From<&MontScalar<T>> for [u64; 4] {
313    fn from(value: &MontScalar<T>) -> Self {
314        value.0.into_bigint().0
315    }
316}
317
318impl<T: MontConfig<4>> Display for MontScalar<T> {
319    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
320        let sign = if f.sign_plus() {
321            let n = -self;
322            if self > &n {
323                Some(Some(n))
324            } else {
325                Some(None)
326            }
327        } else {
328            None
329        };
330        match (f.alternate(), sign) {
331            (false, None) => {
332                let data = self.0.into_bigint().0;
333                write!(
334                    f,
335                    "{:016X}{:016X}{:016X}{:016X}",
336                    data[3], data[2], data[1], data[0],
337                )
338            }
339            (false, Some(None)) => {
340                let data = self.0.into_bigint().0;
341                write!(
342                    f,
343                    "+{:016X}{:016X}{:016X}{:016X}",
344                    data[3], data[2], data[1], data[0],
345                )
346            }
347            (false, Some(Some(n))) => {
348                let data = n.0.into_bigint().0;
349                write!(
350                    f,
351                    "-{:016X}{:016X}{:016X}{:016X}",
352                    data[3], data[2], data[1], data[0],
353                )
354            }
355            (true, None) => {
356                let data = self.to_bytes_le();
357                write!(
358                    f,
359                    "0x{:02X}{:02X}...{:02X}{:02X}",
360                    data[31], data[30], data[1], data[0],
361                )
362            }
363            (true, Some(None)) => {
364                let data = self.to_bytes_le();
365                write!(
366                    f,
367                    "+0x{:02X}{:02X}...{:02X}{:02X}",
368                    data[31], data[30], data[1], data[0],
369                )
370            }
371            (true, Some(Some(n))) => {
372                let data = n.to_bytes_le();
373                write!(
374                    f,
375                    "-0x{:02X}{:02X}...{:02X}{:02X}",
376                    data[31], data[30], data[1], data[0],
377                )
378            }
379        }
380    }
381}
382
383impl<T> Scalar for MontScalar<T>
384where
385    T: MontConfig<4>,
386{
387    const MAX_SIGNED: Self = Self(Fp::new(T::MODULUS.divide_by_2_round_down()));
388    const ZERO: Self = Self(Fp::ZERO);
389    const ONE: Self = Self(Fp::ONE);
390    const TWO: Self = Self(Fp::new(ark_ff::BigInt([2, 0, 0, 0])));
391    const TEN: Self = Self(Fp::new(ark_ff::BigInt([10, 0, 0, 0])));
392    const TWO_POW_64: Self = Self(Fp::new(ark_ff::BigInt([0, 1, 0, 0])));
393    const CHALLENGE_MASK: U256 = {
394        assert!(
395            T::MODULUS.0[3].leading_zeros() < 64,
396            "modulus expected to be larger than 1 << (64*3)"
397        );
398        U256::from_digits([
399            u64::MAX,
400            u64::MAX,
401            u64::MAX,
402            u64::MAX >> (T::MODULUS.0[3].leading_zeros() + 1),
403        ])
404    };
405    #[expect(clippy::cast_possible_truncation)]
406    const MAX_BITS: u8 = {
407        assert!(
408            T::MODULUS.0[3].leading_zeros() < 64,
409            "modulus expected to be larger than 1 << (64*3)"
410        );
411        255 - T::MODULUS.0[3].leading_zeros() as u8
412    };
413}
414
415impl<T> TryFrom<MontScalar<T>> for bool
416where
417    T: MontConfig<4>,
418    MontScalar<T>: Scalar,
419{
420    type Error = ScalarConversionError;
421    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
422        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
423            (-1, (-value).into())
424        } else {
425            (1, value.into())
426        };
427        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
428            return Err(ScalarConversionError::Overflow {
429                error: format!("{value} is too large to fit in an i8"),
430            });
431        }
432        let val: i128 = sign * i128::from(abs[0]);
433        match val {
434            0 => Ok(false),
435            1 => Ok(true),
436            _ => Err(ScalarConversionError::Overflow {
437                error: format!("{value} is too large to fit in a bool"),
438            }),
439        }
440    }
441}
442
443impl<T> TryFrom<MontScalar<T>> for u8
444where
445    T: MontConfig<4>,
446    MontScalar<T>: Scalar,
447{
448    type Error = ScalarConversionError;
449
450    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
451        if value < MontScalar::<T>::ZERO {
452            return Err(ScalarConversionError::Overflow {
453                error: format!("{value} is negative and cannot fit in a u8"),
454            });
455        }
456
457        let abs: [u64; 4] = value.into();
458
459        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
460            return Err(ScalarConversionError::Overflow {
461                error: format!("{value} is too large to fit in a u8"),
462            });
463        }
464
465        abs[0]
466            .try_into()
467            .map_err(|_| ScalarConversionError::Overflow {
468                error: format!("{value} is too large to fit in a u8"),
469            })
470    }
471}
472
473impl<T> TryFrom<MontScalar<T>> for i8
474where
475    T: MontConfig<4>,
476    MontScalar<T>: Scalar,
477{
478    type Error = ScalarConversionError;
479    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
480        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
481            (-1, (-value).into())
482        } else {
483            (1, value.into())
484        };
485        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
486            return Err(ScalarConversionError::Overflow {
487                error: format!("{value} is too large to fit in an i8"),
488            });
489        }
490        let val: i128 = sign * i128::from(abs[0]);
491        val.try_into().map_err(|_| ScalarConversionError::Overflow {
492            error: format!("{value} is too large to fit in an i8"),
493        })
494    }
495}
496
497impl<T> TryFrom<MontScalar<T>> for i16
498where
499    T: MontConfig<4>,
500    MontScalar<T>: Scalar,
501{
502    type Error = ScalarConversionError;
503    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
504        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
505            (-1, (-value).into())
506        } else {
507            (1, value.into())
508        };
509        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
510            return Err(ScalarConversionError::Overflow {
511                error: format!("{value} is too large to fit in an i16"),
512            });
513        }
514        let val: i128 = sign * i128::from(abs[0]);
515        val.try_into().map_err(|_| ScalarConversionError::Overflow {
516            error: format!("{value} is too large to fit in an i16"),
517        })
518    }
519}
520
521impl<T> TryFrom<MontScalar<T>> for i32
522where
523    T: MontConfig<4>,
524    MontScalar<T>: Scalar,
525{
526    type Error = ScalarConversionError;
527    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
528        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
529            (-1, (-value).into())
530        } else {
531            (1, value.into())
532        };
533        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
534            return Err(ScalarConversionError::Overflow {
535                error: format!("{value} is too large to fit in an i32"),
536            });
537        }
538        let val: i128 = sign * i128::from(abs[0]);
539        val.try_into().map_err(|_| ScalarConversionError::Overflow {
540            error: format!("{value} is too large to fit in an i32"),
541        })
542    }
543}
544
545impl<T> TryFrom<MontScalar<T>> for i64
546where
547    T: MontConfig<4>,
548    MontScalar<T>: Scalar,
549{
550    type Error = ScalarConversionError;
551    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
552        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
553            (-1, (-value).into())
554        } else {
555            (1, value.into())
556        };
557        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
558            return Err(ScalarConversionError::Overflow {
559                error: format!("{value} is too large to fit in an i64"),
560            });
561        }
562        let val: i128 = sign * i128::from(abs[0]);
563        val.try_into().map_err(|_| ScalarConversionError::Overflow {
564            error: format!("{value} is too large to fit in an i64"),
565        })
566    }
567}
568
569impl<T> TryFrom<MontScalar<T>> for i128
570where
571    T: MontConfig<4>,
572    MontScalar<T>: Scalar,
573{
574    type Error = ScalarConversionError;
575
576    #[expect(clippy::cast_possible_wrap)]
577    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
578        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
579            (-1, (-value).into())
580        } else {
581            (1, value.into())
582        };
583        if abs[2] != 0 || abs[3] != 0 {
584            return Err(ScalarConversionError::Overflow {
585                error: format!("{value} is too large to fit in an i128"),
586            });
587        }
588        let val: u128 = (u128::from(abs[1]) << 64) | (u128::from(abs[0]));
589        match (sign, val) {
590            (1, v) if v <= i128::MAX as u128 => Ok(v as i128),
591            (-1, v) if v <= i128::MAX as u128 => Ok(-(v as i128)),
592            (-1, v) if v == i128::MAX as u128 + 1 => Ok(i128::MIN),
593            _ => Err(ScalarConversionError::Overflow {
594                error: format!("{value} is too large to fit in an i128"),
595            }),
596        }
597    }
598}
599
600impl<T> From<MontScalar<T>> for BigInt
601where
602    T: MontConfig<4>,
603    MontScalar<T>: Scalar,
604{
605    fn from(value: MontScalar<T>) -> Self {
606        // Since we wrap around in finite fields anything greater than the max signed value is negative
607        let is_negative = value > <MontScalar<T>>::MAX_SIGNED;
608        let sign = if is_negative {
609            num_bigint::Sign::Minus
610        } else {
611            num_bigint::Sign::Plus
612        };
613        let value_abs: [u64; 4] = (if is_negative { -value } else { value }).into();
614        let bits: &[u8] = bytemuck::cast_slice(&value_abs);
615        BigInt::from_bytes_le(sign, bits)
616    }
617}