Skip to main content

darkpool_crypto/
field.rs

1use ark_bn254::Fr;
2use ark_ff::{BigInteger, PrimeField};
3use ethers_core::types::{Address, U256};
4
5use crate::error::CryptoError;
6use crate::poseidon::{IPoseidonHasher, NoxHasher};
7
8/// Convert ethers U256 to `ark_bn254::Fr`. Values >= modulus are silently reduced.
9#[allow(clippy::must_use_candidate)]
10pub fn u256_to_fr(value: U256) -> Fr {
11    let mut bytes = [0u8; 32];
12    value.to_big_endian(&mut bytes);
13    Fr::from_be_bytes_mod_order(&bytes)
14}
15
16/// Convert `ark_bn254::Fr` to ethers U256.
17#[allow(clippy::must_use_candidate)]
18pub fn fr_to_u256(fr: Fr) -> U256 {
19    let bigint = fr.into_bigint();
20    let bytes = bigint.to_bytes_be();
21    U256::from_big_endian(&bytes)
22}
23
24/// Convert U256 to Noir-compatible hex string (0x-prefixed, 64 chars, lowercase).
25#[allow(clippy::must_use_candidate)]
26pub fn to_noir_hex(value: U256) -> String {
27    format!("0x{value:064x}")
28}
29
30/// Convert U256 to Noir decimal string.
31#[allow(clippy::must_use_candidate)]
32pub fn to_noir_decimal(value: U256) -> String {
33    value.to_string()
34}
35
36/// Parse a Noir hex string (0x-prefixed or raw) back to U256.
37pub fn from_noir_hex(hex_str: &str) -> Result<U256, CryptoError> {
38    let clean = hex_str.trim().trim_start_matches("0x");
39    let padded = format!("{clean:0>64}");
40    let bytes = hex::decode(&padded).map_err(|_| CryptoError::FieldConversion)?;
41    Ok(U256::from_big_endian(&bytes))
42}
43
44/// Poseidon2 hash over U256 values. Output matches `std::hash::poseidon2` in Noir.
45#[must_use]
46pub fn poseidon_hash(inputs: &[U256]) -> U256 {
47    let hasher = NoxHasher::new();
48
49    let fr_inputs: Vec<_> = inputs.iter().map(|u| u256_to_fr(*u)).collect();
50    let result_fr = hasher.hash(&fr_inputs);
51    fr_to_u256(result_fr)
52}
53
54/// Poseidon hash on `Fr` directly, avoiding U256 roundtrips.
55#[must_use]
56pub fn poseidon_hash_fr(inputs: &[Fr]) -> Fr {
57    let hasher = NoxHasher::new();
58    hasher.hash(inputs)
59}
60
61/// Convert string to field element: left-pad to 32 bytes, interpret as Fr, then Poseidon hash.
62///
63/// Must match TypeScript `stringToFr` exactly for KDF domain separation.
64pub fn string_to_fr(text: &str) -> Result<U256, CryptoError> {
65    let bytes = text.as_bytes();
66    if bytes.len() > 32 {
67        return Err(CryptoError::InputTooLong {
68            max: 32,
69            got: bytes.len(),
70        });
71    }
72
73    // Left-pad to 32 bytes
74    let mut padded = [0u8; 32];
75    let start = 32 - bytes.len();
76    padded[start..].copy_from_slice(bytes);
77
78    let fr = Fr::from_be_bytes_mod_order(&padded);
79    let field_from_bytes = fr_to_u256(fr);
80    Ok(poseidon_hash(&[field_from_bytes]))
81}
82
83/// Convert an Ethereum address to a field element.
84#[allow(clippy::must_use_candidate)]
85pub fn address_to_field(addr: Address) -> U256 {
86    U256::from_big_endian(addr.as_bytes())
87}
88
89/// Convert a field element back to an Ethereum address (last 20 bytes).
90#[allow(clippy::must_use_candidate)]
91pub fn field_to_address(field: U256) -> Address {
92    let mut bytes = [0u8; 32];
93    field.to_big_endian(&mut bytes);
94    Address::from_slice(&bytes[12..32])
95}
96
97/// Generate a random BN254 scalar field element.
98#[must_use]
99pub fn random_field() -> U256 {
100    use rand::RngCore;
101    let mut rng = rand::rngs::OsRng;
102    let mut bytes = [0u8; 32];
103    rng.fill_bytes(&mut bytes);
104
105    let fr = Fr::from_be_bytes_mod_order(&bytes);
106    fr_to_u256(fr)
107}
108
109/// Random BJJ scalar (mod subgroup order L, ~2^251). Required for Noir circuits
110/// where `ScalarField::<63>` needs values < 2^252; `random_field()` would fail ~67% of the time.
111#[must_use]
112pub fn random_bjj_scalar() -> U256 {
113    use num_bigint::BigUint;
114    use rand::RngCore;
115
116    #[allow(clippy::expect_used)]
117    static SUBGROUP_ORDER_BIGINT: std::sync::LazyLock<BigUint> = std::sync::LazyLock::new(|| {
118        BigUint::parse_bytes(crate::SUBGROUP_ORDER.as_bytes(), 10)
119            .expect("SUBGROUP_ORDER is a compile-time decimal constant")
120    });
121
122    let mut rng = rand::rngs::OsRng;
123    let mut bytes = [0u8; 32];
124    rng.fill_bytes(&mut bytes);
125
126    let val = BigUint::from_bytes_be(&bytes);
127    let reduced = val % &*SUBGROUP_ORDER_BIGINT;
128    let be_bytes = reduced.to_bytes_be();
129    let mut padded = [0u8; 32];
130    let start = 32_usize.saturating_sub(be_bytes.len());
131    let copy_len = be_bytes.len().min(32);
132    padded[start..start + copy_len].copy_from_slice(&be_bytes[..copy_len]);
133    U256::from_big_endian(&padded)
134}
135
136/// Serialize `Fr` as `"0x"` + 64 lowercase hex chars (big-endian).
137pub fn serialize_fr<S>(field: &ark_bn254::Fr, serializer: S) -> Result<S::Ok, S::Error>
138where
139    S: serde::Serializer,
140{
141    let bytes = field.into_bigint().to_bytes_be();
142    serializer.serialize_str(&format!("0x{}", hex::encode(bytes)))
143}
144
145/// Deserialize `Fr` from hex string. Rejects over-modulus values via round-trip check.
146pub fn deserialize_fr<'de, D>(deserializer: D) -> Result<ark_bn254::Fr, D::Error>
147where
148    D: serde::Deserializer<'de>,
149{
150    use serde::Deserialize;
151    let s = String::deserialize(deserializer)?;
152    let clean_s = s.trim_start_matches("0x");
153    let bytes = hex::decode(clean_s).map_err(serde::de::Error::custom)?;
154
155    if bytes.len() > 32 {
156        return Err(serde::de::Error::custom("Field element exceeds 32 bytes"));
157    }
158
159    let mut padded = [0u8; 32];
160    padded[32 - bytes.len()..].copy_from_slice(&bytes);
161
162    let val = ark_bn254::Fr::from_be_bytes_mod_order(&padded);
163
164    let round_trip = val.into_bigint().to_bytes_be();
165    if round_trip != padded {
166        return Err(serde::de::Error::custom("Value exceeds field modulus"));
167    }
168
169    Ok(val)
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_u256_fr_roundtrip() {
178        let original = U256::from(12345u64);
179        let fr = u256_to_fr(original);
180        let back = fr_to_u256(fr);
181        assert_eq!(original, back);
182    }
183
184    #[test]
185    fn test_u256_fr_large_value() {
186        let bytes = [0xffu8; 32];
187        let original = U256::from_big_endian(&bytes);
188        let fr = u256_to_fr(original);
189        let back = fr_to_u256(fr);
190        assert!(back < original);
191    }
192
193    #[test]
194    fn test_poseidon_hash_deterministic() {
195        let inputs = [U256::from(1), U256::from(2), U256::from(3)];
196        let hash1 = poseidon_hash(&inputs);
197        let hash2 = poseidon_hash(&inputs);
198        assert_eq!(hash1, hash2);
199        assert!(!hash1.is_zero());
200    }
201
202    #[test]
203    fn test_random_field_valid() {
204        let a = random_field();
205        let b = random_field();
206        assert_ne!(a, b);
207
208        let fr = u256_to_fr(a);
209        let back = fr_to_u256(fr);
210        assert_eq!(a, back);
211    }
212
213    #[test]
214    fn test_string_to_fr_deterministic() {
215        let result1 = string_to_fr("hisoka.enc_key").unwrap();
216        let result2 = string_to_fr("hisoka.enc_key").unwrap();
217        assert_eq!(result1, result2);
218    }
219
220    #[test]
221    fn test_string_to_fr_different_inputs() {
222        let key = string_to_fr("hisoka.enc_key").unwrap();
223        let iv = string_to_fr("hisoka.enc_iv").unwrap();
224        assert_ne!(key, iv);
225    }
226
227    #[test]
228    fn test_string_to_fr_too_long() {
229        let long = "a".repeat(33);
230        let result = string_to_fr(&long);
231        assert!(result.is_err());
232        assert_eq!(
233            result.unwrap_err(),
234            CryptoError::InputTooLong { max: 32, got: 33 }
235        );
236    }
237
238    #[test]
239    fn test_string_to_fr_exactly_32() {
240        let exact = "a".repeat(32);
241        let result = string_to_fr(&exact);
242        assert!(result.is_ok());
243    }
244
245    #[test]
246    fn test_string_to_fr_empty() {
247        let result = string_to_fr("");
248        assert!(result.is_ok());
249        assert!(!result.unwrap().is_zero());
250    }
251
252    #[test]
253    fn test_to_noir_hex_format() {
254        let val = U256::from(255u64);
255        let hex = to_noir_hex(val);
256        assert!(hex.starts_with("0x"));
257        assert_eq!(hex.len(), 66);
258        assert!(hex.ends_with("ff"));
259    }
260
261    #[test]
262    fn test_to_noir_hex_zero() {
263        let hex = to_noir_hex(U256::zero());
264        assert_eq!(hex, format!("0x{}", "0".repeat(64)));
265    }
266
267    #[test]
268    fn test_noir_hex_roundtrip() {
269        let values = [
270            U256::zero(),
271            U256::from(1u64),
272            U256::from(u64::MAX),
273            U256::from(42u64),
274        ];
275        for val in values {
276            let hex = to_noir_hex(val);
277            let recovered = from_noir_hex(&hex).expect("roundtrip");
278            assert_eq!(val, recovered, "roundtrip mismatch for {val}");
279        }
280    }
281
282    #[test]
283    fn test_from_noir_hex_no_prefix() {
284        let result = from_noir_hex("ff");
285        assert!(result.is_ok());
286        assert_eq!(result.unwrap(), U256::from(255u64));
287    }
288
289    #[test]
290    fn test_from_noir_hex_invalid() {
291        let result = from_noir_hex("0xZZZZ");
292        assert!(result.is_err());
293    }
294
295    #[test]
296    fn test_to_noir_decimal() {
297        assert_eq!(to_noir_decimal(U256::from(42u64)), "42");
298        assert_eq!(to_noir_decimal(U256::zero()), "0");
299    }
300
301    #[test]
302    fn test_address_roundtrip() {
303        let addr = Address::from_slice(&[0xABu8; 20]);
304        let field = address_to_field(addr);
305        let recovered = field_to_address(field);
306        assert_eq!(addr, recovered);
307    }
308
309    #[test]
310    fn test_address_zero() {
311        let addr = Address::zero();
312        let field = address_to_field(addr);
313        assert!(field.is_zero());
314        let recovered = field_to_address(field);
315        assert_eq!(addr, recovered);
316    }
317
318    #[test]
319    fn test_address_known_value() {
320        let addr_bytes = hex::decode("d8dA6BF26964aF9D7eEd9e03E53415D37aA96045").unwrap();
321        let addr = Address::from_slice(&addr_bytes);
322        let field = address_to_field(addr);
323        let recovered = field_to_address(field);
324        assert_eq!(addr, recovered);
325    }
326
327    #[test]
328    fn test_serialize_deserialize_fr_roundtrip() {
329        use ark_ff::BigInteger;
330
331        let values = [
332            U256::zero(),
333            U256::from(1u64),
334            U256::from(999999u64),
335            U256::from(u64::MAX),
336        ];
337        for val in values {
338            let fr = u256_to_fr(val);
339
340            let bytes = fr.into_bigint().to_bytes_be();
341            let hex_str = format!("0x{}", hex::encode(&bytes));
342            assert!(hex_str.starts_with("0x"));
343            assert_eq!(hex_str.len(), 66);
344
345            let clean = hex_str.trim_start_matches("0x");
346            let decoded = hex::decode(clean).expect("valid hex");
347            let mut padded = [0u8; 32];
348            padded[32 - decoded.len()..].copy_from_slice(&decoded);
349            let recovered = ark_bn254::Fr::from_be_bytes_mod_order(&padded);
350
351            let round_trip = recovered.into_bigint().to_bytes_be();
352            assert_eq!(round_trip, padded, "unexpected reduction for U256={val}");
353            assert_eq!(fr, recovered, "Fr roundtrip mismatch for U256={val}");
354        }
355    }
356
357    #[test]
358    fn test_deserialize_fr_over_modulus() {
359        use ark_ff::BigInteger;
360
361        let bytes = [0xFFu8; 32];
362        let val = ark_bn254::Fr::from_be_bytes_mod_order(&bytes);
363
364        let round_trip = val.into_bigint().to_bytes_be();
365        assert_ne!(round_trip.as_slice(), &bytes[..]);
366    }
367
368    #[test]
369    fn test_u256_to_fr_zero() {
370        let fr = u256_to_fr(U256::zero());
371        let back = fr_to_u256(fr);
372        assert!(back.is_zero());
373    }
374
375    #[test]
376    fn test_u256_to_fr_max_valid() {
377        let modulus_minus_one = U256::from_dec_str(
378            "21888242871839275222246405745257275088548364400416034343698204186575808495616",
379        )
380        .unwrap();
381        let fr = u256_to_fr(modulus_minus_one);
382        let back = fr_to_u256(fr);
383        assert_eq!(modulus_minus_one, back);
384    }
385
386    #[test]
387    fn test_poseidon_hash_single_input() {
388        let hash = poseidon_hash(&[U256::from(42)]);
389        assert!(!hash.is_zero());
390    }
391
392    #[test]
393    fn test_poseidon_hash_collision_resistance() {
394        let h1 = poseidon_hash(&[U256::from(1)]);
395        let h2 = poseidon_hash(&[U256::from(2)]);
396        assert_ne!(h1, h2);
397    }
398
399    #[test]
400    fn test_poseidon_hash_multi_block() {
401        let inputs: Vec<_> = (0..10).map(U256::from).collect();
402        let hash = poseidon_hash(&inputs);
403        assert!(!hash.is_zero());
404
405        let mut reversed = inputs.clone();
406        reversed.reverse();
407        let hash2 = poseidon_hash(&reversed);
408        assert_ne!(hash, hash2);
409    }
410
411    #[test]
412    fn test_random_bjj_scalar_in_subgroup_order() {
413        use num_bigint::BigUint;
414
415        let subgroup_order = BigUint::parse_bytes(
416            b"2736030358979909402780800718157159386076813972158567259200215660948447373041",
417            10,
418        )
419        .expect("valid decimal constant");
420
421        for _ in 0..1000 {
422            let scalar = random_bjj_scalar();
423            let mut bytes = [0u8; 32];
424            scalar.to_big_endian(&mut bytes);
425            let val = BigUint::from_bytes_be(&bytes);
426            assert!(val < subgroup_order, "got {val} >= subgroup order");
427        }
428    }
429
430    #[test]
431    fn test_random_field_valid_multiple() {
432        let mut seen = std::collections::HashSet::new();
433        for _ in 0..20 {
434            let val = random_field();
435            let fr = u256_to_fr(val);
436            let back = fr_to_u256(fr);
437            assert_eq!(val, back);
438            seen.insert(val);
439        }
440        assert!(
441            seen.len() >= 15,
442            "only {} unique values in 20 draws",
443            seen.len()
444        );
445    }
446}