jub_jub/
fp.rs

1use crate::error::Error;
2use bls_12_381::Fr;
3use serde::{Deserialize, Serialize};
4use zkstd::arithmetic::bits_256::*;
5use zkstd::common::*;
6use zkstd::macros::field::*;
7
8const MODULUS: [u64; 4] = [
9    0xd0970e5ed6f72cb7,
10    0xa6682093ccc81082,
11    0x06673b0101343b00,
12    0x0e7db4ea6533afa9,
13];
14
15const GENERATOR: [u64; 4] = [2, 0, 0, 0];
16
17// weird if this is problem
18const MULTIPLICATIVE_GENERATOR: Fp = Fp([2, 0, 0, 0]);
19
20/// R = 2^256 mod r
21const R: [u64; 4] = [
22    0x25f80bb3b99607d9,
23    0xf315d62f66b6e750,
24    0x932514eeeb8814f4,
25    0x09a6fc6f479155c6,
26];
27
28/// R^2 = 2^512 mod r
29const R2: [u64; 4] = [
30    0x67719aa495e57731,
31    0x51b0cef09ce3fc26,
32    0x69dab7fac026e9a5,
33    0x04f6547b8d127688,
34];
35
36/// R^3 = 2^768 mod r
37const R3: [u64; 4] = [
38    0xe0d6c6563d830544,
39    0x323e3883598d0f85,
40    0xf0fea3004c2e2ba8,
41    0x05874f84946737ec,
42];
43
44const INV: u64 = 0x1ba3a358ef788ef9;
45
46const S: usize = 1;
47
48const ROOT_OF_UNITY: Fp = Fp([
49    0xaa9f02ab1d6124de,
50    0xb3524a6466112932,
51    0x7342261215ac260b,
52    0x4d6b87b1da259e2,
53]);
54
55/// Twisted Edwards curve Jubjub base field
56#[derive(Clone, Copy, Decode, Encode, Serialize, Deserialize)]
57pub struct Fp(pub(crate) [u64; 4]);
58
59impl SigUtils<32> for Fp {
60    fn to_bytes(self) -> [u8; Self::LENGTH] {
61        let tmp = self.montgomery_reduce();
62
63        let mut res = [0; Self::LENGTH];
64        res[0..8].copy_from_slice(&tmp[0].to_le_bytes());
65        res[8..16].copy_from_slice(&tmp[1].to_le_bytes());
66        res[16..24].copy_from_slice(&tmp[2].to_le_bytes());
67        res[24..32].copy_from_slice(&tmp[3].to_le_bytes());
68
69        res
70    }
71
72    fn from_bytes(bytes: [u8; Self::LENGTH]) -> Option<Self> {
73        // SBP-M1 review: apply proper error handling instead of `unwrap`
74        let l0 = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
75        let l1 = u64::from_le_bytes(bytes[8..16].try_into().unwrap());
76        let l2 = u64::from_le_bytes(bytes[16..24].try_into().unwrap());
77        let l3 = u64::from_le_bytes(bytes[24..32].try_into().unwrap());
78
79        let (_, borrow) = sbb(l0, MODULUS[0], 0);
80        let (_, borrow) = sbb(l1, MODULUS[1], borrow);
81        let (_, borrow) = sbb(l2, MODULUS[2], borrow);
82        let (_, borrow) = sbb(l3, MODULUS[3], borrow);
83
84        if borrow & 1 == 1 {
85            Some(Self([l0, l1, l2, l3]) * Self(R2))
86        } else {
87            None
88        }
89    }
90}
91
92impl Fp {
93    pub const fn to_mont_form(val: [u64; 4]) -> Self {
94        Self(to_mont_form(val, R2, MODULUS, INV))
95    }
96
97    pub(crate) const fn montgomery_reduce(self) -> [u64; 4] {
98        mont(
99            [self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0],
100            MODULUS,
101            INV,
102        )
103    }
104
105    pub fn from_hex(hex: &str) -> Result<Fp, Error> {
106        let max_len = 64;
107        let hex = hex.strip_prefix("0x").unwrap_or(hex);
108        let length = hex.len();
109        if length > max_len {
110            return Err(Error::HexStringTooLong);
111        }
112        let hex_bytes = hex.as_bytes();
113
114        let mut hex: [[u8; 16]; 4] = [[0; 16]; 4];
115        for i in 0..max_len {
116            hex[i / 16][i % 16] = if i >= length {
117                0
118            } else {
119                match hex_bytes[length - i - 1] {
120                    48..=57 => hex_bytes[length - i - 1] - 48,
121                    65..=70 => hex_bytes[length - i - 1] - 55,
122                    97..=102 => hex_bytes[length - i - 1] - 87,
123                    _ => return Err(Error::HexStringInvalid),
124                }
125            };
126        }
127        let mut limbs: [u64; 4] = [0; 4];
128        for i in 0..hex.len() {
129            limbs[i] = Fp::bytes_to_u64(&hex[i]).unwrap();
130        }
131        Ok(Fp(mul(limbs, R2, MODULUS, INV)))
132    }
133
134    fn bytes_to_u64(bytes: &[u8; 16]) -> Result<u64, Error> {
135        let mut res: u64 = 0;
136        for (i, byte) in bytes.iter().enumerate() {
137            res += match byte {
138                0..=15 => 16u64.pow(i as u32) * (*byte as u64),
139                _ => return Err(Error::BytesInvalid),
140            }
141        }
142        Ok(res)
143    }
144
145    pub fn reduce(&self) -> Self {
146        Self(self.montgomery_reduce())
147    }
148
149    pub fn is_even(&self) -> bool {
150        self.0[0] % 2 == 0
151    }
152
153    pub fn from_hash(hash: &[u8; 64]) -> Self {
154        let d0 = Fp([
155            u64::from_le_bytes(hash[0..8].try_into().unwrap()),
156            u64::from_le_bytes(hash[8..16].try_into().unwrap()),
157            u64::from_le_bytes(hash[16..24].try_into().unwrap()),
158            u64::from_le_bytes(hash[24..32].try_into().unwrap()),
159        ]);
160        let d1 = Fp([
161            u64::from_le_bytes(hash[32..40].try_into().unwrap()),
162            u64::from_le_bytes(hash[40..48].try_into().unwrap()),
163            u64::from_le_bytes(hash[48..56].try_into().unwrap()),
164            u64::from_le_bytes(hash[56..64].try_into().unwrap()),
165        ]);
166        d0 * Fp(R2) + d1 * Fp(R3)
167    }
168
169    /// Compute the result from `Scalar (mod 2^k)`.
170    ///
171    /// # Panics
172    ///
173    /// If the given k is > 32 (5 bits) as the value gets
174    /// greater than the limb.
175    pub fn mod_2_pow_k(&self, k: u8) -> u8 {
176        (self.0[0] & ((1 << k) - 1)) as u8
177    }
178
179    /// Compute the result from `Scalar (mods k)`.
180    ///
181    /// # Panics
182    ///
183    /// If the given `k > 32 (5 bits)` || `k == 0` as the value gets
184    /// greater than the limb.
185    pub fn mods_2_pow_k(&self, w: u8) -> i8 {
186        assert!(w < 32u8);
187        let modulus = self.mod_2_pow_k(w) as i8;
188        let two_pow_w_minus_one = 1i8 << (w - 1);
189
190        match modulus >= two_pow_w_minus_one {
191            false => modulus,
192            true => modulus - ((1u8 << w) as i8),
193        }
194    }
195}
196
197impl From<i8> for Fp {
198    fn from(val: i8) -> Fp {
199        match (val >= 0, val < 0) {
200            (true, false) => Fp([val.unsigned_abs() as u64, 0u64, 0u64, 0u64]),
201            (false, true) => -Fp([val.unsigned_abs() as u64, 0u64, 0u64, 0u64]),
202            (_, _) => unreachable!(),
203        }
204    }
205}
206
207impl From<Fr> for Fp {
208    fn from(scalar: Fr) -> Fp {
209        let bls_scalar = Fp::from_bytes(scalar.to_bytes());
210
211        assert!(
212            bls_scalar.is_some(),
213            "Failed to convert a Scalar from Bls to Jubjub"
214        );
215
216        bls_scalar.unwrap()
217    }
218}
219
220impl From<Fp> for Fr {
221    fn from(scalar: Fp) -> Fr {
222        let bls_scalar = Fr::from_bytes(scalar.to_bytes());
223
224        assert!(
225            bls_scalar.is_some(),
226            "Failed to convert a Scalar from JubJub to BLS"
227        );
228
229        bls_scalar.unwrap()
230    }
231}
232
233/// wNAF expression computation over field
234pub fn compute_windowed_naf<F: FftField>(scalar: F, width: u8) -> [i8; 256] {
235    let mut k = scalar.reduce();
236    let mut i = 0;
237    let one = F::one().reduce();
238    let mut res = [0i8; 256];
239
240    while k >= one {
241        if !k.is_even() {
242            let ki = k.mods_2_pow_k(width);
243            res[i] = ki;
244            let k_ = match (ki >= 0, ki < 0) {
245                (true, false) => F::from([ki.unsigned_abs() as u64, 0u64, 0u64, 0u64]),
246                (false, true) => -F::from([ki.unsigned_abs() as u64, 0u64, 0u64, 0u64]),
247                (_, _) => unreachable!(),
248            };
249            k -= k_;
250        } else {
251            res[i] = 0i8;
252        };
253
254        k.divn(1u32);
255        i += 1;
256    }
257    res
258}
259
260fft_field_operation!(
261    Fp,
262    MODULUS,
263    GENERATOR,
264    MULTIPLICATIVE_GENERATOR,
265    INV,
266    ROOT_OF_UNITY,
267    R,
268    R2,
269    R3,
270    S
271);
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use paste::paste;
277    use rand_core::OsRng;
278
279    field_test!(fp_field, Fp, 1000);
280
281    #[test]
282    fn test_from_hex() {
283        let a = Fp::from_hex("0x64774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab")
284            .unwrap();
285        assert_eq!(
286            a,
287            Fp([
288                0x4ddc8f91e171cd75,
289                0x9b925835a7d203fb,
290                0x0cdb538ead47e463,
291                0x01a19f85f00d79b8,
292            ])
293        )
294    }
295}