proof_of_sql/base/scalar/
mont_scalar.rs

1use crate::base::scalar::{Scalar, ScalarConversionError, ScalarExt};
2use alloc::{
3    format,
4    string::{String, ToString},
5    vec::Vec,
6};
7use ark_ff::{AdditiveGroup, BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField};
8use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
9use bnum::types::U256;
10use bytemuck::TransparentWrapper;
11use core::{
12    cmp::Ordering,
13    fmt,
14    fmt::{Debug, Display, Formatter},
15    hash::{Hash, Hasher},
16    iter::{Product, Sum},
17    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
18};
19use num_bigint::BigInt;
20use num_traits::Signed;
21use serde::{Deserialize, Deserializer, Serialize, Serializer};
22#[derive(CanonicalSerialize, CanonicalDeserialize, TransparentWrapper)]
23/// A wrapper struct around a `Fp256<MontBackend<T, 4>>` that can easily implement the `Scalar` trait.
24///
25/// Using the `Scalar` trait rather than this type is encouraged to allow for easier switching of the underlying field.
26#[repr(transparent)]
27pub struct MontScalar<T: MontConfig<4>>(pub Fp256<MontBackend<T, 4>>);
28
29// --------------------------------------------------------------------------------
30// replacement for #[derive(Add, Sub, Mul, AddAssign, SubAssign, MulAssign, Neg,
31//  Sum, Product, Clone, Copy, PartialOrd, PartialEq, Default, Debug, Eq, Hash, Ord)]
32// --------------------------------------------------------------------------------
33impl<T: MontConfig<4>> Add for MontScalar<T> {
34    type Output = Self;
35    fn add(self, rhs: Self) -> Self::Output {
36        Self(self.0 + rhs.0)
37    }
38}
39impl<T: MontConfig<4>> Sub for MontScalar<T> {
40    type Output = Self;
41    fn sub(self, rhs: Self) -> Self::Output {
42        Self(self.0 - rhs.0)
43    }
44}
45impl<T: MontConfig<4>> Mul for MontScalar<T> {
46    type Output = Self;
47    fn mul(self, rhs: Self) -> Self::Output {
48        Self(self.0 * rhs.0)
49    }
50}
51impl<T: MontConfig<4>> AddAssign for MontScalar<T> {
52    fn add_assign(&mut self, rhs: Self) {
53        self.0 += rhs.0;
54    }
55}
56impl<T: MontConfig<4>> SubAssign for MontScalar<T> {
57    fn sub_assign(&mut self, rhs: Self) {
58        self.0 -= rhs.0;
59    }
60}
61impl<T: MontConfig<4>> MulAssign for MontScalar<T> {
62    fn mul_assign(&mut self, rhs: Self) {
63        self.0 *= rhs.0;
64    }
65}
66impl<T: MontConfig<4>> Neg for MontScalar<T> {
67    type Output = Self;
68    fn neg(self) -> Self::Output {
69        Self(-self.0)
70    }
71}
72impl<T: MontConfig<4>> Sum for MontScalar<T> {
73    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
74        Self(iter.map(|x| x.0).sum())
75    }
76}
77impl<T: MontConfig<4>> Product for MontScalar<T> {
78    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
79        Self(iter.map(|x| x.0).product())
80    }
81}
82impl<T: MontConfig<4>> Clone for MontScalar<T> {
83    fn clone(&self) -> Self {
84        *self
85    }
86}
87impl<T: MontConfig<4>> Copy for MontScalar<T> {}
88impl<T: MontConfig<4>> PartialOrd for MontScalar<T> {
89    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
90        Some(self.cmp(other))
91    }
92}
93impl<T: MontConfig<4>> PartialEq for MontScalar<T> {
94    fn eq(&self, other: &Self) -> bool {
95        self.0 == other.0
96    }
97}
98impl<T: MontConfig<4>> Default for MontScalar<T> {
99    fn default() -> Self {
100        Self(Fp::default())
101    }
102}
103impl<T: MontConfig<4>> Debug for MontScalar<T> {
104    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
105        f.debug_tuple("MontScalar").field(&self.0).finish()
106    }
107}
108impl<T: MontConfig<4>> Eq for MontScalar<T> {}
109impl<T: MontConfig<4>> Hash for MontScalar<T> {
110    fn hash<H: Hasher>(&self, state: &mut H) {
111        self.0.hash(state);
112    }
113}
114impl<T: MontConfig<4>> Ord for MontScalar<T> {
115    fn cmp(&self, other: &Self) -> Ordering {
116        self.0.cmp(&other.0)
117    }
118}
119// --------------------------------------------------------------------------------
120// end replacement for #[derive(...)]
121// --------------------------------------------------------------------------------
122
123/// TODO: add docs
124macro_rules! impl_from_for_mont_scalar_for_type_supported_by_from {
125    ($tt:ty) => {
126        impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
127            fn from(x: $tt) -> Self {
128                Self(x.into())
129            }
130        }
131    };
132}
133
134/// Implement `From<&[u8]>` for `MontScalar`
135impl<T: MontConfig<4>> From<&[u8]> for MontScalar<T> {
136    fn from(x: &[u8]) -> Self {
137        ScalarExt::from_byte_slice_via_hash(x)
138    }
139}
140
141/// TODO: add docs
142macro_rules! impl_from_for_mont_scalar_for_string {
143    ($tt:ty) => {
144        impl<T: MontConfig<4>> From<$tt> for MontScalar<T> {
145            fn from(x: $tt) -> Self {
146                x.as_bytes().into()
147            }
148        }
149    };
150}
151
152impl_from_for_mont_scalar_for_type_supported_by_from!(bool);
153impl_from_for_mont_scalar_for_type_supported_by_from!(u8);
154impl_from_for_mont_scalar_for_type_supported_by_from!(u16);
155impl_from_for_mont_scalar_for_type_supported_by_from!(u32);
156impl_from_for_mont_scalar_for_type_supported_by_from!(u64);
157impl_from_for_mont_scalar_for_type_supported_by_from!(u128);
158impl_from_for_mont_scalar_for_type_supported_by_from!(i8);
159impl_from_for_mont_scalar_for_type_supported_by_from!(i16);
160impl_from_for_mont_scalar_for_type_supported_by_from!(i32);
161impl_from_for_mont_scalar_for_type_supported_by_from!(i64);
162impl_from_for_mont_scalar_for_type_supported_by_from!(i128);
163impl_from_for_mont_scalar_for_string!(&str);
164impl_from_for_mont_scalar_for_string!(String);
165
166impl<F: MontConfig<4>, T> From<&T> for MontScalar<F>
167where
168    T: Into<MontScalar<F>> + Clone,
169{
170    fn from(x: &T) -> Self {
171        x.clone().into()
172    }
173}
174
175impl<T: MontConfig<4>> MontScalar<T> {
176    /// Convenience function for creating a new `MontScalar<T>` from the underlying `Fp256<MontBackend<T, 4>>`. Should only be used in tests.
177    #[cfg(test)]
178    #[must_use]
179    pub fn new(value: Fp256<MontBackend<T, 4>>) -> Self {
180        Self(value)
181    }
182
183    /// Create a new `MontScalar<T>` from a `[u64, 4]`. The array is expected to be in non-montgomery form.
184    ///
185    /// # Panics
186    ///
187    /// This method will panic if the provided `[u64; 4]` cannot be converted into a valid `BigInt` due to an overflow or invalid input. The method unwraps the result of `Fp::from_bigint`, which will panic if the `BigInt` does not represent a valid field element ("Invalid input" refers to an integer that is outside the valid range [0,p-1] for the prime field or cannot be represented as a canonical field element. It can also occur due to overflow or issues in the conversion process.).
188    #[must_use]
189    pub fn from_bigint(vals: [u64; 4]) -> Self {
190        Self(Fp::from_bigint(ark_ff::BigInt(vals)).unwrap())
191    }
192    /// Create a new `MontScalar<T>` from a `[u8]` modulus the field order. The array is expected to be in non-montgomery form.
193    #[must_use]
194    pub fn from_le_bytes_mod_order(bytes: &[u8]) -> Self {
195        Self(Fp::from_le_bytes_mod_order(bytes))
196    }
197    /// Create a `Vec<u8>` from a `MontScalar<T>`. The array will be in non-montgomery form.
198    #[must_use]
199    pub fn to_bytes_le(&self) -> Vec<u8> {
200        self.0.into_bigint().to_bytes_le()
201    }
202    /// Convenience function for converting a slice of `ark_curve25519::Fr` into a vector of `Curve25519Scalar`. Should not be used outside of tests.
203    #[cfg(test)]
204    pub fn wrap_slice(slice: &[Fp256<MontBackend<T, 4>>]) -> Vec<Self> {
205        slice.iter().copied().map(Self).collect()
206    }
207    /// Convenience function for converting a slice of `Curve25519Scalar` into a vector of `ark_curve25519::Fr`. Should not be used outside of tests.
208    #[cfg(test)]
209    #[must_use]
210    pub fn unwrap_slice(slice: &[Self]) -> Vec<Fp256<MontBackend<T, 4>>> {
211        slice.iter().map(|x| x.0).collect()
212    }
213}
214
215impl<T> TryFrom<BigInt> for MontScalar<T>
216where
217    T: MontConfig<4>,
218    MontScalar<T>: Scalar,
219{
220    type Error = ScalarConversionError;
221
222    fn try_from(value: BigInt) -> Result<Self, Self::Error> {
223        if value.abs() > BigInt::from(<MontScalar<T>>::MAX_SIGNED) {
224            return Err(ScalarConversionError::Overflow {
225                error: "BigInt too large for Scalar".to_string(),
226            });
227        }
228
229        let (sign, digits) = value.to_u64_digits();
230        assert!(digits.len() <= 4); // This should not happen if the above check is correct
231        let mut limbs = [0u64; 4];
232        limbs[..digits.len()].copy_from_slice(&digits);
233        let result = Self::from(limbs);
234        Ok(match sign {
235            num_bigint::Sign::Minus => -result,
236            num_bigint::Sign::Plus | num_bigint::Sign::NoSign => result,
237        })
238    }
239}
240impl<T: MontConfig<4>> From<[u64; 4]> for MontScalar<T> {
241    fn from(value: [u64; 4]) -> Self {
242        Self(Fp::new(ark_ff::BigInt(value)))
243    }
244}
245
246impl<T: MontConfig<4>> ark_std::UniformRand for MontScalar<T> {
247    fn rand<R: ark_std::rand::Rng + ?Sized>(rng: &mut R) -> Self {
248        Self(ark_ff::UniformRand::rand(rng))
249    }
250}
251
252impl<'a, T: MontConfig<4>> Sum<&'a Self> for MontScalar<T> {
253    fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
254        Self(iter.map(|x| x.0).sum())
255    }
256}
257impl<T: MontConfig<4>> num_traits::One for MontScalar<T> {
258    fn one() -> Self {
259        Self(Fp::one())
260    }
261}
262impl<T: MontConfig<4>> num_traits::Zero for MontScalar<T> {
263    fn zero() -> Self {
264        Self(Fp::zero())
265    }
266    fn is_zero(&self) -> bool {
267        self.0.is_zero()
268    }
269}
270impl<T: MontConfig<4>> num_traits::Inv for MontScalar<T> {
271    type Output = Option<Self>;
272    fn inv(self) -> Option<Self> {
273        self.0.inverse().map(Self)
274    }
275}
276impl<T: MontConfig<4>> Serialize for MontScalar<T> {
277    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
278        let mut limbs: [u64; 4] = self.into();
279        limbs.reverse();
280        limbs.serialize(serializer)
281    }
282}
283impl<'de, T: MontConfig<4>> Deserialize<'de> for MontScalar<T> {
284    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
285        let mut limbs: [u64; 4] = Deserialize::deserialize(deserializer)?;
286        limbs.reverse();
287        Ok(limbs.into())
288    }
289}
290
291impl<T: MontConfig<4>> core::ops::Neg for &MontScalar<T> {
292    type Output = MontScalar<T>;
293    fn neg(self) -> Self::Output {
294        MontScalar(-self.0)
295    }
296}
297
298impl<T: MontConfig<4>> From<MontScalar<T>> for [u64; 4] {
299    fn from(value: MontScalar<T>) -> Self {
300        (&value).into()
301    }
302}
303
304impl<T: MontConfig<4>> From<&MontScalar<T>> for [u64; 4] {
305    fn from(value: &MontScalar<T>) -> Self {
306        value.0.into_bigint().0
307    }
308}
309
310impl<T: MontConfig<4>> Display for MontScalar<T> {
311    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
312        let sign = if f.sign_plus() {
313            let n = -self;
314            if self > &n {
315                Some(Some(n))
316            } else {
317                Some(None)
318            }
319        } else {
320            None
321        };
322        match (f.alternate(), sign) {
323            (false, None) => {
324                let data = self.0.into_bigint().0;
325                write!(
326                    f,
327                    "{:016X}{:016X}{:016X}{:016X}",
328                    data[3], data[2], data[1], data[0],
329                )
330            }
331            (false, Some(None)) => {
332                let data = self.0.into_bigint().0;
333                write!(
334                    f,
335                    "+{:016X}{:016X}{:016X}{:016X}",
336                    data[3], data[2], data[1], data[0],
337                )
338            }
339            (false, Some(Some(n))) => {
340                let data = n.0.into_bigint().0;
341                write!(
342                    f,
343                    "-{:016X}{:016X}{:016X}{:016X}",
344                    data[3], data[2], data[1], data[0],
345                )
346            }
347            (true, None) => {
348                let data = self.to_bytes_le();
349                write!(
350                    f,
351                    "0x{:02X}{:02X}...{:02X}{:02X}",
352                    data[31], data[30], data[1], data[0],
353                )
354            }
355            (true, Some(None)) => {
356                let data = self.to_bytes_le();
357                write!(
358                    f,
359                    "+0x{:02X}{:02X}...{:02X}{:02X}",
360                    data[31], data[30], data[1], data[0],
361                )
362            }
363            (true, Some(Some(n))) => {
364                let data = n.to_bytes_le();
365                write!(
366                    f,
367                    "-0x{:02X}{:02X}...{:02X}{:02X}",
368                    data[31], data[30], data[1], data[0],
369                )
370            }
371        }
372    }
373}
374
375impl<T> Scalar for MontScalar<T>
376where
377    T: MontConfig<4>,
378{
379    const MAX_SIGNED: Self = Self(Fp::new(T::MODULUS.divide_by_2_round_down()));
380    const ZERO: Self = Self(Fp::ZERO);
381    const ONE: Self = Self(Fp::ONE);
382    const TWO: Self = Self(Fp::new(ark_ff::BigInt([2, 0, 0, 0])));
383    const TEN: Self = Self(Fp::new(ark_ff::BigInt([10, 0, 0, 0])));
384    const TWO_POW_64: Self = Self(Fp::new(ark_ff::BigInt([0, 1, 0, 0])));
385    const CHALLENGE_MASK: U256 = {
386        assert!(
387            T::MODULUS.0[3].leading_zeros() < 64,
388            "modulus expected to be larger than 1 << (64*3)"
389        );
390        U256::from_digits([
391            u64::MAX,
392            u64::MAX,
393            u64::MAX,
394            u64::MAX >> (T::MODULUS.0[3].leading_zeros() + 1),
395        ])
396    };
397    #[expect(clippy::cast_possible_truncation)]
398    const MAX_BITS: u8 = {
399        assert!(
400            T::MODULUS.0[3].leading_zeros() < 64,
401            "modulus expected to be larger than 1 << (64*3)"
402        );
403        255 - T::MODULUS.0[3].leading_zeros() as u8
404    };
405    const MAX_SIGNED_U256: U256 = U256::from_digits(T::MODULUS.divide_by_2_round_down().0);
406}
407
408impl<T> TryFrom<MontScalar<T>> for bool
409where
410    T: MontConfig<4>,
411    MontScalar<T>: Scalar,
412{
413    type Error = ScalarConversionError;
414    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
415        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
416            (-1, (-value).into())
417        } else {
418            (1, value.into())
419        };
420        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
421            return Err(ScalarConversionError::Overflow {
422                error: format!("{value} is too large to fit in an i8"),
423            });
424        }
425        let val: i128 = sign * i128::from(abs[0]);
426        match val {
427            0 => Ok(false),
428            1 => Ok(true),
429            _ => Err(ScalarConversionError::Overflow {
430                error: format!("{value} is too large to fit in a bool"),
431            }),
432        }
433    }
434}
435
436impl<T> TryFrom<MontScalar<T>> for u8
437where
438    T: MontConfig<4>,
439    MontScalar<T>: Scalar,
440{
441    type Error = ScalarConversionError;
442
443    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
444        if value < MontScalar::<T>::ZERO {
445            return Err(ScalarConversionError::Overflow {
446                error: format!("{value} is negative and cannot fit in a u8"),
447            });
448        }
449
450        let abs: [u64; 4] = value.into();
451
452        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
453            return Err(ScalarConversionError::Overflow {
454                error: format!("{value} is too large to fit in a u8"),
455            });
456        }
457
458        abs[0]
459            .try_into()
460            .map_err(|_| ScalarConversionError::Overflow {
461                error: format!("{value} is too large to fit in a u8"),
462            })
463    }
464}
465
466impl<T> TryFrom<MontScalar<T>> for i8
467where
468    T: MontConfig<4>,
469    MontScalar<T>: Scalar,
470{
471    type Error = ScalarConversionError;
472    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
473        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
474            (-1, (-value).into())
475        } else {
476            (1, value.into())
477        };
478        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
479            return Err(ScalarConversionError::Overflow {
480                error: format!("{value} is too large to fit in an i8"),
481            });
482        }
483        let val: i128 = sign * i128::from(abs[0]);
484        val.try_into().map_err(|_| ScalarConversionError::Overflow {
485            error: format!("{value} is too large to fit in an i8"),
486        })
487    }
488}
489
490impl<T> TryFrom<MontScalar<T>> for i16
491where
492    T: MontConfig<4>,
493    MontScalar<T>: Scalar,
494{
495    type Error = ScalarConversionError;
496    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
497        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
498            (-1, (-value).into())
499        } else {
500            (1, value.into())
501        };
502        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
503            return Err(ScalarConversionError::Overflow {
504                error: format!("{value} is too large to fit in an i16"),
505            });
506        }
507        let val: i128 = sign * i128::from(abs[0]);
508        val.try_into().map_err(|_| ScalarConversionError::Overflow {
509            error: format!("{value} is too large to fit in an i16"),
510        })
511    }
512}
513
514impl<T> TryFrom<MontScalar<T>> for i32
515where
516    T: MontConfig<4>,
517    MontScalar<T>: Scalar,
518{
519    type Error = ScalarConversionError;
520    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
521        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
522            (-1, (-value).into())
523        } else {
524            (1, value.into())
525        };
526        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
527            return Err(ScalarConversionError::Overflow {
528                error: format!("{value} is too large to fit in an i32"),
529            });
530        }
531        let val: i128 = sign * i128::from(abs[0]);
532        val.try_into().map_err(|_| ScalarConversionError::Overflow {
533            error: format!("{value} is too large to fit in an i32"),
534        })
535    }
536}
537
538impl<T> TryFrom<MontScalar<T>> for i64
539where
540    T: MontConfig<4>,
541    MontScalar<T>: Scalar,
542{
543    type Error = ScalarConversionError;
544    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
545        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
546            (-1, (-value).into())
547        } else {
548            (1, value.into())
549        };
550        if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
551            return Err(ScalarConversionError::Overflow {
552                error: format!("{value} is too large to fit in an i64"),
553            });
554        }
555        let val: i128 = sign * i128::from(abs[0]);
556        val.try_into().map_err(|_| ScalarConversionError::Overflow {
557            error: format!("{value} is too large to fit in an i64"),
558        })
559    }
560}
561
562impl<T> TryFrom<MontScalar<T>> for i128
563where
564    T: MontConfig<4>,
565    MontScalar<T>: Scalar,
566{
567    type Error = ScalarConversionError;
568
569    #[expect(clippy::cast_possible_wrap)]
570    fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
571        let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
572            (-1, (-value).into())
573        } else {
574            (1, value.into())
575        };
576        if abs[2] != 0 || abs[3] != 0 {
577            return Err(ScalarConversionError::Overflow {
578                error: format!("{value} is too large to fit in an i128"),
579            });
580        }
581        let val: u128 = (u128::from(abs[1]) << 64) | (u128::from(abs[0]));
582        match (sign, val) {
583            (1, v) if v <= i128::MAX as u128 => Ok(v as i128),
584            (-1, v) if v <= i128::MAX as u128 => Ok(-(v as i128)),
585            (-1, v) if v == i128::MAX as u128 + 1 => Ok(i128::MIN),
586            _ => Err(ScalarConversionError::Overflow {
587                error: format!("{value} is too large to fit in an i128"),
588            }),
589        }
590    }
591}
592
593impl<T> From<MontScalar<T>> for BigInt
594where
595    T: MontConfig<4>,
596    MontScalar<T>: Scalar,
597{
598    fn from(value: MontScalar<T>) -> Self {
599        // Since we wrap around in finite fields anything greater than the max signed value is negative
600        let is_negative = value > <MontScalar<T>>::MAX_SIGNED;
601        let sign = if is_negative {
602            num_bigint::Sign::Minus
603        } else {
604            num_bigint::Sign::Plus
605        };
606        let value_abs: [u64; 4] = (if is_negative { -value } else { value }).into();
607        let bits: &[u8] = bytemuck::cast_slice(&value_abs);
608        BigInt::from_bytes_le(sign, bits)
609    }
610}