proof_of_sql/base/arrow/
scalar_and_i256_conversions.rs

1use crate::base::{math, scalar::Scalar};
2use arrow::datatypes::i256;
3
4const MIN_SUPPORTED_I256: i256 = i256::from_parts(
5    326_411_208_032_252_286_695_448_638_536_326_387_210,
6    -10_633_823_966_279_326_983_230_456_482_242_756_609,
7);
8const MAX_SUPPORTED_I256: i256 = i256::from_parts(
9    13_871_158_888_686_176_767_925_968_895_441_824_246,
10    10_633_823_966_279_326_983_230_456_482_242_756_608,
11);
12
13/// Converts a type implementing [Scalar] into an arrow i256
14pub fn convert_scalar_to_i256<S: Scalar>(val: &S) -> i256 {
15    let is_negative = val > &S::MAX_SIGNED;
16    let abs_scalar = if is_negative { -*val } else { *val };
17    let limbs: [u64; 4] = abs_scalar.into();
18
19    let low = u128::from(limbs[0]) | (u128::from(limbs[1]) << 64);
20    let high = i128::from(limbs[2]) | (i128::from(limbs[3]) << 64);
21
22    let abs_i256 = i256::from_parts(low, high);
23    if is_negative {
24        i256::wrapping_neg(abs_i256)
25    } else {
26        abs_i256
27    }
28}
29
30#[expect(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
31/// Converts an arrow i256 into limbed representation and then
32/// into a type implementing [Scalar]
33#[must_use]
34pub fn convert_i256_to_scalar<S: Scalar>(value: &i256) -> Option<S> {
35    // Check if value is within the bounds
36    if value < &MIN_SUPPORTED_I256 || value > &MAX_SUPPORTED_I256 {
37        None
38    } else {
39        // Prepare the absolute value for conversion
40        let abs_value = if value.is_negative() { -*value } else { *value };
41        let (low, high) = abs_value.to_parts();
42        let limbs = [
43            low as u64,
44            (low >> 64) as u64,
45            high as u64,
46            (high >> 64) as u64,
47        ];
48
49        // Convert limbs to Scalar and adjust for sign
50        let scalar: S = limbs.into();
51        Some(if value.is_negative() { -scalar } else { scalar })
52    }
53}
54
55#[expect(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
56impl From<i256> for math::i256::I256 {
57    fn from(value: i256) -> Self {
58        let (low, high) = value.to_parts();
59        Self::new([
60            low as u64,
61            (low >> 64) as u64,
62            high as u64,
63            (high >> 64) as u64,
64        ])
65    }
66}
67
68#[cfg(test)]
69mod tests {
70
71    use super::*;
72    use crate::base::scalar::{test_scalar::TestScalar, Scalar};
73    use num_traits::Zero;
74    use rand::RngCore;
75
76    /// Generate a random i256 within a supported range. Values generated by this function will
77    /// fit into the i256 but will not exceed 252 bits of width.
78    fn random_i256<R: RngCore + ?Sized>(rng: &mut R) -> i256 {
79        use rand::Rng;
80        let max_signed_as_parts: (u128, i128) =
81            convert_scalar_to_i256(&TestScalar::MAX_SIGNED).to_parts();
82
83        // Generate a random high part
84        let high: i128 = rng.gen_range(-max_signed_as_parts.1..=max_signed_as_parts.1);
85
86        // Generate a random low part, adjusted based on the high part
87        let low: u128 = if high < max_signed_as_parts.1 {
88            rng.gen()
89        } else {
90            rng.gen_range(0..=max_signed_as_parts.0)
91        };
92
93        i256::from_parts(low, high)
94    }
95
96    impl TryFrom<i256> for TestScalar {
97        type Error = ();
98
99        // Must fit inside 252 bits and so requires fallible
100        fn try_from(value: i256) -> Result<Self, ()> {
101            convert_i256_to_scalar(&value).ok_or(())
102        }
103    }
104
105    impl From<TestScalar> for i256 {
106        fn from(value: TestScalar) -> Self {
107            convert_scalar_to_i256(&value)
108        }
109    }
110
111    #[test]
112    fn test_testscalar_to_i256_conversion() {
113        let positive_scalar = TestScalar::from(12345);
114        let expected_i256 = i256::from(12345);
115        assert_eq!(i256::from(positive_scalar), expected_i256);
116
117        let negative_scalar = TestScalar::from(-12345);
118        let expected_i256 = i256::from(-12345);
119        assert_eq!(i256::from(negative_scalar), expected_i256);
120
121        let max_scalar = TestScalar::MAX_SIGNED;
122        let expected_max = i256::from(TestScalar::MAX_SIGNED);
123        assert_eq!(i256::from(max_scalar), expected_max);
124
125        let min_scalar = TestScalar::from(0);
126        let expected_min = i256::from(TestScalar::from(0));
127        assert_eq!(i256::from(min_scalar), expected_min);
128    }
129
130    #[test]
131    fn test_testscalar_i256_overflow_and_underflow() {
132        // 2^256 overflows
133        assert!(TestScalar::try_from(i256::MAX).is_err());
134
135        // MAX_SIGNED + 1 overflows
136        assert!(TestScalar::try_from(MAX_SUPPORTED_I256 + i256::from(1)).is_err());
137
138        // -2^255 underflows
139        assert!(i256::MIN < -(i256::from(TestScalar::MAX_SIGNED)));
140        assert!(TestScalar::try_from(i256::MIN).is_err());
141
142        // -MAX-SIGNED - 1 underflows
143        assert!(TestScalar::try_from(MIN_SUPPORTED_I256 - i256::from(1)).is_err());
144    }
145
146    #[test]
147    fn test_i256_testscalar_negative() {
148        // Test conversion from i256(-1) to TestScalar
149        let neg_one_i256_testscalar = TestScalar::try_from(i256::from(-1));
150        assert!(neg_one_i256_testscalar.is_ok());
151        let neg_one_testscalar = TestScalar::from(-1);
152        assert_eq!(neg_one_i256_testscalar.unwrap(), neg_one_testscalar);
153    }
154
155    #[test]
156    fn test_i256_testscalar_zero() {
157        // Test conversion from i256(0) to TestScalar
158        let zero_i256_testscalar = TestScalar::try_from(i256::from(0));
159        assert!(zero_i256_testscalar.is_ok());
160        let zero_testscalar = TestScalar::zero();
161        assert_eq!(zero_i256_testscalar.unwrap(), zero_testscalar);
162    }
163
164    #[test]
165    fn test_i256_testscalar_positive() {
166        // Test conversion from i256(42) to TestScalar
167        let forty_two_i256_testscalar = TestScalar::try_from(i256::from(42));
168        let forty_two_testscalar = TestScalar::from(42);
169        assert_eq!(forty_two_i256_testscalar.unwrap(), forty_two_testscalar);
170    }
171
172    #[test]
173    fn test_i256_testscalar_max_signed() {
174        let max_signed = MAX_SUPPORTED_I256;
175        // max signed value
176        let max_signed_scalar = TestScalar::MAX_SIGNED;
177        // Test conversion from i256 to TestScalar
178        let i256_testscalar = TestScalar::try_from(max_signed);
179        assert!(i256_testscalar.is_ok());
180        assert_eq!(i256_testscalar.unwrap(), max_signed_scalar);
181    }
182
183    #[test]
184    fn test_i256_testscalar_min_signed() {
185        let min_signed = MIN_SUPPORTED_I256;
186        let i256_testscalar = TestScalar::try_from(min_signed);
187        // -MAX_SIGNED is ok
188        assert!(i256_testscalar.is_ok());
189        assert_eq!(
190            i256_testscalar.unwrap(),
191            TestScalar::MAX_SIGNED + TestScalar::from(1)
192        );
193    }
194
195    #[test]
196    fn test_i256_testscalar_random() {
197        let mut rng = rand::thread_rng();
198        for _ in 0..1000 {
199            let i256_value = random_i256(&mut rng);
200            let curve25519_scalar = TestScalar::try_from(i256_value).expect("Conversion failed");
201            let back_to_i256 = i256::from(curve25519_scalar);
202            assert_eq!(i256_value, back_to_i256, "Round-trip conversion failed");
203        }
204    }
205
206    #[expect(clippy::cast_sign_loss)]
207    #[test]
208    fn test_arrow_i256_to_posql_i256_conversion() {
209        // Test zero
210        assert_eq!(
211            math::i256::I256::from(i256::ZERO),
212            math::i256::I256::new([0, 0, 0, 0])
213        );
214
215        // Test positive values
216        assert_eq!(
217            math::i256::I256::from(i256::from(1)),
218            math::i256::I256::new([1, 0, 0, 0])
219        );
220        assert_eq!(
221            math::i256::I256::from(i256::from(2)),
222            math::i256::I256::new([2, 0, 0, 0])
223        );
224
225        // Test negative values
226        assert_eq!(
227            math::i256::I256::from(i256::from(-1)),
228            math::i256::I256::new([u64::MAX; 4])
229        );
230        assert_eq!(
231            math::i256::I256::from(i256::from(-2)),
232            math::i256::I256::new([u64::MAX - 1, u64::MAX, u64::MAX, u64::MAX])
233        );
234
235        // Test some boundary values
236        assert_eq!(
237            math::i256::I256::from(i256::MAX),
238            math::i256::I256::new([u64::MAX, u64::MAX, u64::MAX, i64::MAX as u64])
239        );
240        assert_eq!(
241            math::i256::I256::from(i256::MIN),
242            math::i256::I256::new([0, 0, 0, i64::MIN as u64])
243        );
244
245        // Test other values
246        assert_eq!(
247            math::i256::I256::from(i256::from_parts(40, 20)),
248            math::i256::I256::new([40, 0, 20, 0])
249        );
250        assert_eq!(
251            math::i256::I256::from(i256::from_parts(20, -20)),
252            math::i256::I256::new([20, 0, u64::MAX - 19, u64::MAX])
253        );
254    }
255}