proof_of_sql/base/scalar/
scalar_ext.rs

1use super::Scalar;
2use bnum::types::U256;
3use core::cmp::Ordering;
4use tiny_keccak::Hasher;
5
6/// Extension trait for blanket implementations for `Scalar` types.
7/// This trait is primarily to avoid cluttering the core `Scalar` implementation with default implementations
8/// and provides helper methods for `Scalar`.
9pub trait ScalarExt: Scalar {
10    /// Compute 10^exponent for the Scalar. Note that we do not check for overflow.
11    fn pow10(exponent: u8) -> Self {
12        itertools::repeat_n(Self::TEN, exponent as usize).product()
13    }
14    /// Compare two `Scalar`s as signed numbers.
15    fn signed_cmp(&self, other: &Self) -> Ordering {
16        match *self - *other {
17            x if x.is_zero() => Ordering::Equal,
18            x if x > Self::MAX_SIGNED => Ordering::Less,
19            _ => Ordering::Greater,
20        }
21    }
22
23    #[must_use]
24    /// Converts a U256 to Scalar, wrapping as needed
25    fn from_wrapping(value: U256) -> Self {
26        let value_as_limbs: [u64; 4] = value.into();
27        Self::from(value_as_limbs)
28    }
29
30    /// Converts a Scalar to U256. Note that any values above `MAX_SIGNED` shall remain positive, even if they are representative of negative values.
31    fn into_u256_wrapping(self) -> U256 {
32        U256::from(Into::<[u64; 4]>::into(self))
33    }
34
35    /// Converts a byte slice to a Scalar using a hash function, preventing collisions.
36    /// WARNING: Only up to 31 bytes (2^248 bits) are supported by `PoSQL` cryptographic
37    /// objects. This function masks off the last byte of the hash to ensure the result
38    /// fits in this range.
39    #[must_use]
40    fn from_byte_slice_via_hash(bytes: &[u8]) -> Self {
41        if bytes.is_empty() {
42            return Self::zero();
43        }
44
45        let mut hasher = tiny_keccak::Keccak::v256();
46        hasher.update(bytes);
47        let mut hashed_bytes = [0u8; 32];
48        hasher.finalize(&mut hashed_bytes);
49        let hashed_val =
50            U256::from_le_slice(&hashed_bytes).expect("32 bytes => guaranteed to parse as U256");
51        let masked_val = hashed_val & Self::CHALLENGE_MASK;
52        Self::from_wrapping(masked_val)
53    }
54}
55
56impl<S: Scalar> ScalarExt for S {}
57
58#[cfg(test)]
59pub(crate) fn test_scalar_constants<S: Scalar>() {
60    assert_eq!(S::from(0), S::ZERO);
61    assert_eq!(S::from(1), S::ONE);
62    assert_eq!(S::from(2), S::TWO);
63    // -1/2 == least upper bound
64    assert_eq!(-S::TWO.inv().unwrap(), S::MAX_SIGNED);
65    assert_eq!(S::from(10), S::TEN);
66
67    // Check the challenge mask
68    assert_eq!(
69        S::CHALLENGE_MASK,
70        U256::MAX >> S::CHALLENGE_MASK.leading_zeros()
71    );
72    assert!(S::MAX_SIGNED.into_u256_wrapping() < S::CHALLENGE_MASK);
73    assert!((-S::ONE).into_u256_wrapping() > S::CHALLENGE_MASK);
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use crate::base::scalar::{test_scalar::TestScalar, MontScalar};
80    use bytemuck::cast;
81
82    #[test]
83    fn we_can_get_zero_from_zero_bytes() {
84        assert_eq!(TestScalar::from_byte_slice_via_hash(&[]), TestScalar::ZERO);
85    }
86
87    #[test]
88    fn we_can_get_scalar_from_hashed_bytes() {
89        // Raw bytes of test string "abc" with 31st byte zeroed out:
90        let expected: [u8; 32] = [
91            0x4e, 0x03, 0x65, 0x7a, 0xea, 0x45, 0xa9, 0x4f, 0xc7, 0xd4, 0x7b, 0xa8, 0x26, 0xc8,
92            0xd6, 0x67, 0xc0, 0xd1, 0xe6, 0xe3, 0x3a, 0x64, 0xa0, 0x36, 0xec, 0x44, 0xf5, 0x8f,
93            0xa1, 0x2d, 0x6c, 0x05,
94        ];
95
96        let scalar_from_bytes: TestScalar = TestScalar::from_byte_slice_via_hash(b"abc");
97
98        let limbs_native: [u64; 4] = cast(expected);
99        let limbs_le = [
100            u64::from_le_bytes(limbs_native[0].to_le_bytes()),
101            u64::from_le_bytes(limbs_native[1].to_le_bytes()),
102            u64::from_le_bytes(limbs_native[2].to_le_bytes()),
103            u64::from_le_bytes(limbs_native[3].to_le_bytes()),
104        ];
105        let scalar_from_ref = TestScalar::from(limbs_le);
106
107        assert_eq!(
108            scalar_from_bytes, scalar_from_ref,
109            "The masked keccak v256 of 'abc' must match"
110        );
111    }
112
113    #[test]
114    fn we_can_compute_powers_of_10() {
115        for i in 0..=u128::MAX.ilog10() {
116            assert_eq!(
117                TestScalar::pow10(u8::try_from(i).unwrap()),
118                TestScalar::from(u128::pow(10, i))
119            );
120        }
121        assert_eq!(
122            TestScalar::pow10(76),
123            MontScalar(ark_ff::MontFp!(
124                "10000000000000000000000000000000000000000000000000000000000000000000000000000"
125            ))
126        );
127    }
128
129    #[test]
130    fn scalar_comparison_works() {
131        let zero = TestScalar::ZERO;
132        let one = TestScalar::ONE;
133        let two = TestScalar::TWO;
134        let max = TestScalar::MAX_SIGNED;
135        let min = max + one;
136        assert_eq!(max.signed_cmp(&one), Ordering::Greater);
137        assert_eq!(one.signed_cmp(&zero), Ordering::Greater);
138        assert_eq!(min.signed_cmp(&zero), Ordering::Less);
139        assert_eq!((two * max).signed_cmp(&zero), Ordering::Less);
140        assert_eq!(two * max + one, zero);
141    }
142}