1use ark_bn254::Fr;
2use ark_ff::{BigInteger, PrimeField};
3use ethers_core::types::{Address, U256};
4
5use crate::error::CryptoError;
6use crate::poseidon::{IPoseidonHasher, NoxHasher};
7
8#[allow(clippy::must_use_candidate)]
10pub fn u256_to_fr(value: U256) -> Fr {
11 let mut bytes = [0u8; 32];
12 value.to_big_endian(&mut bytes);
13 Fr::from_be_bytes_mod_order(&bytes)
14}
15
16#[allow(clippy::must_use_candidate)]
18pub fn fr_to_u256(fr: Fr) -> U256 {
19 let bigint = fr.into_bigint();
20 let bytes = bigint.to_bytes_be();
21 U256::from_big_endian(&bytes)
22}
23
24#[allow(clippy::must_use_candidate)]
26pub fn to_noir_hex(value: U256) -> String {
27 format!("0x{value:064x}")
28}
29
30#[allow(clippy::must_use_candidate)]
32pub fn to_noir_decimal(value: U256) -> String {
33 value.to_string()
34}
35
36pub fn from_noir_hex(hex_str: &str) -> Result<U256, CryptoError> {
38 let clean = hex_str.trim().trim_start_matches("0x");
39 let padded = format!("{clean:0>64}");
40 let bytes = hex::decode(&padded).map_err(|_| CryptoError::FieldConversion)?;
41 Ok(U256::from_big_endian(&bytes))
42}
43
44#[must_use]
46pub fn poseidon_hash(inputs: &[U256]) -> U256 {
47 let hasher = NoxHasher::new();
48
49 let fr_inputs: Vec<_> = inputs.iter().map(|u| u256_to_fr(*u)).collect();
50 let result_fr = hasher.hash(&fr_inputs);
51 fr_to_u256(result_fr)
52}
53
54#[must_use]
56pub fn poseidon_hash_fr(inputs: &[Fr]) -> Fr {
57 let hasher = NoxHasher::new();
58 hasher.hash(inputs)
59}
60
61pub fn string_to_fr(text: &str) -> Result<U256, CryptoError> {
65 let bytes = text.as_bytes();
66 if bytes.len() > 32 {
67 return Err(CryptoError::InputTooLong {
68 max: 32,
69 got: bytes.len(),
70 });
71 }
72
73 let mut padded = [0u8; 32];
75 let start = 32 - bytes.len();
76 padded[start..].copy_from_slice(bytes);
77
78 let fr = Fr::from_be_bytes_mod_order(&padded);
79 let field_from_bytes = fr_to_u256(fr);
80 Ok(poseidon_hash(&[field_from_bytes]))
81}
82
83#[allow(clippy::must_use_candidate)]
85pub fn address_to_field(addr: Address) -> U256 {
86 U256::from_big_endian(addr.as_bytes())
87}
88
89#[allow(clippy::must_use_candidate)]
91pub fn field_to_address(field: U256) -> Address {
92 let mut bytes = [0u8; 32];
93 field.to_big_endian(&mut bytes);
94 Address::from_slice(&bytes[12..32])
95}
96
97#[must_use]
99pub fn random_field() -> U256 {
100 use rand::RngCore;
101 let mut rng = rand::rngs::OsRng;
102 let mut bytes = [0u8; 32];
103 rng.fill_bytes(&mut bytes);
104
105 let fr = Fr::from_be_bytes_mod_order(&bytes);
106 fr_to_u256(fr)
107}
108
109#[must_use]
112pub fn random_bjj_scalar() -> U256 {
113 use num_bigint::BigUint;
114 use rand::RngCore;
115
116 #[allow(clippy::expect_used)]
117 static SUBGROUP_ORDER_BIGINT: std::sync::LazyLock<BigUint> = std::sync::LazyLock::new(|| {
118 BigUint::parse_bytes(crate::SUBGROUP_ORDER.as_bytes(), 10)
119 .expect("SUBGROUP_ORDER is a compile-time decimal constant")
120 });
121
122 let mut rng = rand::rngs::OsRng;
123 let mut bytes = [0u8; 32];
124 rng.fill_bytes(&mut bytes);
125
126 let val = BigUint::from_bytes_be(&bytes);
127 let reduced = val % &*SUBGROUP_ORDER_BIGINT;
128 let be_bytes = reduced.to_bytes_be();
129 let mut padded = [0u8; 32];
130 let start = 32_usize.saturating_sub(be_bytes.len());
131 let copy_len = be_bytes.len().min(32);
132 padded[start..start + copy_len].copy_from_slice(&be_bytes[..copy_len]);
133 U256::from_big_endian(&padded)
134}
135
136pub fn serialize_fr<S>(field: &ark_bn254::Fr, serializer: S) -> Result<S::Ok, S::Error>
138where
139 S: serde::Serializer,
140{
141 let bytes = field.into_bigint().to_bytes_be();
142 serializer.serialize_str(&format!("0x{}", hex::encode(bytes)))
143}
144
145pub fn deserialize_fr<'de, D>(deserializer: D) -> Result<ark_bn254::Fr, D::Error>
147where
148 D: serde::Deserializer<'de>,
149{
150 use serde::Deserialize;
151 let s = String::deserialize(deserializer)?;
152 let clean_s = s.trim_start_matches("0x");
153 let bytes = hex::decode(clean_s).map_err(serde::de::Error::custom)?;
154
155 if bytes.len() > 32 {
156 return Err(serde::de::Error::custom("Field element exceeds 32 bytes"));
157 }
158
159 let mut padded = [0u8; 32];
160 padded[32 - bytes.len()..].copy_from_slice(&bytes);
161
162 let val = ark_bn254::Fr::from_be_bytes_mod_order(&padded);
163
164 let round_trip = val.into_bigint().to_bytes_be();
165 if round_trip != padded {
166 return Err(serde::de::Error::custom("Value exceeds field modulus"));
167 }
168
169 Ok(val)
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_u256_fr_roundtrip() {
178 let original = U256::from(12345u64);
179 let fr = u256_to_fr(original);
180 let back = fr_to_u256(fr);
181 assert_eq!(original, back);
182 }
183
184 #[test]
185 fn test_u256_fr_large_value() {
186 let bytes = [0xffu8; 32];
187 let original = U256::from_big_endian(&bytes);
188 let fr = u256_to_fr(original);
189 let back = fr_to_u256(fr);
190 assert!(back < original);
191 }
192
193 #[test]
194 fn test_poseidon_hash_deterministic() {
195 let inputs = [U256::from(1), U256::from(2), U256::from(3)];
196 let hash1 = poseidon_hash(&inputs);
197 let hash2 = poseidon_hash(&inputs);
198 assert_eq!(hash1, hash2);
199 assert!(!hash1.is_zero());
200 }
201
202 #[test]
203 fn test_random_field_valid() {
204 let a = random_field();
205 let b = random_field();
206 assert_ne!(a, b);
207
208 let fr = u256_to_fr(a);
209 let back = fr_to_u256(fr);
210 assert_eq!(a, back);
211 }
212
213 #[test]
214 fn test_string_to_fr_deterministic() {
215 let result1 = string_to_fr("hisoka.enc_key").unwrap();
216 let result2 = string_to_fr("hisoka.enc_key").unwrap();
217 assert_eq!(result1, result2);
218 }
219
220 #[test]
221 fn test_string_to_fr_different_inputs() {
222 let key = string_to_fr("hisoka.enc_key").unwrap();
223 let iv = string_to_fr("hisoka.enc_iv").unwrap();
224 assert_ne!(key, iv);
225 }
226
227 #[test]
228 fn test_string_to_fr_too_long() {
229 let long = "a".repeat(33);
230 let result = string_to_fr(&long);
231 assert!(result.is_err());
232 assert_eq!(
233 result.unwrap_err(),
234 CryptoError::InputTooLong { max: 32, got: 33 }
235 );
236 }
237
238 #[test]
239 fn test_string_to_fr_exactly_32() {
240 let exact = "a".repeat(32);
241 let result = string_to_fr(&exact);
242 assert!(result.is_ok());
243 }
244
245 #[test]
246 fn test_string_to_fr_empty() {
247 let result = string_to_fr("");
248 assert!(result.is_ok());
249 assert!(!result.unwrap().is_zero());
250 }
251
252 #[test]
253 fn test_to_noir_hex_format() {
254 let val = U256::from(255u64);
255 let hex = to_noir_hex(val);
256 assert!(hex.starts_with("0x"));
257 assert_eq!(hex.len(), 66);
258 assert!(hex.ends_with("ff"));
259 }
260
261 #[test]
262 fn test_to_noir_hex_zero() {
263 let hex = to_noir_hex(U256::zero());
264 assert_eq!(hex, format!("0x{}", "0".repeat(64)));
265 }
266
267 #[test]
268 fn test_noir_hex_roundtrip() {
269 let values = [
270 U256::zero(),
271 U256::from(1u64),
272 U256::from(u64::MAX),
273 U256::from(42u64),
274 ];
275 for val in values {
276 let hex = to_noir_hex(val);
277 let recovered = from_noir_hex(&hex).expect("roundtrip");
278 assert_eq!(val, recovered, "roundtrip mismatch for {val}");
279 }
280 }
281
282 #[test]
283 fn test_from_noir_hex_no_prefix() {
284 let result = from_noir_hex("ff");
285 assert!(result.is_ok());
286 assert_eq!(result.unwrap(), U256::from(255u64));
287 }
288
289 #[test]
290 fn test_from_noir_hex_invalid() {
291 let result = from_noir_hex("0xZZZZ");
292 assert!(result.is_err());
293 }
294
295 #[test]
296 fn test_to_noir_decimal() {
297 assert_eq!(to_noir_decimal(U256::from(42u64)), "42");
298 assert_eq!(to_noir_decimal(U256::zero()), "0");
299 }
300
301 #[test]
302 fn test_address_roundtrip() {
303 let addr = Address::from_slice(&[0xABu8; 20]);
304 let field = address_to_field(addr);
305 let recovered = field_to_address(field);
306 assert_eq!(addr, recovered);
307 }
308
309 #[test]
310 fn test_address_zero() {
311 let addr = Address::zero();
312 let field = address_to_field(addr);
313 assert!(field.is_zero());
314 let recovered = field_to_address(field);
315 assert_eq!(addr, recovered);
316 }
317
318 #[test]
319 fn test_address_known_value() {
320 let addr_bytes = hex::decode("d8dA6BF26964aF9D7eEd9e03E53415D37aA96045").unwrap();
321 let addr = Address::from_slice(&addr_bytes);
322 let field = address_to_field(addr);
323 let recovered = field_to_address(field);
324 assert_eq!(addr, recovered);
325 }
326
327 #[test]
328 fn test_serialize_deserialize_fr_roundtrip() {
329 use ark_ff::BigInteger;
330
331 let values = [
332 U256::zero(),
333 U256::from(1u64),
334 U256::from(999999u64),
335 U256::from(u64::MAX),
336 ];
337 for val in values {
338 let fr = u256_to_fr(val);
339
340 let bytes = fr.into_bigint().to_bytes_be();
341 let hex_str = format!("0x{}", hex::encode(&bytes));
342 assert!(hex_str.starts_with("0x"));
343 assert_eq!(hex_str.len(), 66);
344
345 let clean = hex_str.trim_start_matches("0x");
346 let decoded = hex::decode(clean).expect("valid hex");
347 let mut padded = [0u8; 32];
348 padded[32 - decoded.len()..].copy_from_slice(&decoded);
349 let recovered = ark_bn254::Fr::from_be_bytes_mod_order(&padded);
350
351 let round_trip = recovered.into_bigint().to_bytes_be();
352 assert_eq!(round_trip, padded, "unexpected reduction for U256={val}");
353 assert_eq!(fr, recovered, "Fr roundtrip mismatch for U256={val}");
354 }
355 }
356
357 #[test]
358 fn test_deserialize_fr_over_modulus() {
359 use ark_ff::BigInteger;
360
361 let bytes = [0xFFu8; 32];
362 let val = ark_bn254::Fr::from_be_bytes_mod_order(&bytes);
363
364 let round_trip = val.into_bigint().to_bytes_be();
365 assert_ne!(round_trip.as_slice(), &bytes[..]);
366 }
367
368 #[test]
369 fn test_u256_to_fr_zero() {
370 let fr = u256_to_fr(U256::zero());
371 let back = fr_to_u256(fr);
372 assert!(back.is_zero());
373 }
374
375 #[test]
376 fn test_u256_to_fr_max_valid() {
377 let modulus_minus_one = U256::from_dec_str(
378 "21888242871839275222246405745257275088548364400416034343698204186575808495616",
379 )
380 .unwrap();
381 let fr = u256_to_fr(modulus_minus_one);
382 let back = fr_to_u256(fr);
383 assert_eq!(modulus_minus_one, back);
384 }
385
386 #[test]
387 fn test_poseidon_hash_single_input() {
388 let hash = poseidon_hash(&[U256::from(42)]);
389 assert!(!hash.is_zero());
390 }
391
392 #[test]
393 fn test_poseidon_hash_collision_resistance() {
394 let h1 = poseidon_hash(&[U256::from(1)]);
395 let h2 = poseidon_hash(&[U256::from(2)]);
396 assert_ne!(h1, h2);
397 }
398
399 #[test]
400 fn test_poseidon_hash_multi_block() {
401 let inputs: Vec<_> = (0..10).map(U256::from).collect();
402 let hash = poseidon_hash(&inputs);
403 assert!(!hash.is_zero());
404
405 let mut reversed = inputs.clone();
406 reversed.reverse();
407 let hash2 = poseidon_hash(&reversed);
408 assert_ne!(hash, hash2);
409 }
410
411 #[test]
412 fn test_random_bjj_scalar_in_subgroup_order() {
413 use num_bigint::BigUint;
414
415 let subgroup_order = BigUint::parse_bytes(
416 b"2736030358979909402780800718157159386076813972158567259200215660948447373041",
417 10,
418 )
419 .expect("valid decimal constant");
420
421 for _ in 0..1000 {
422 let scalar = random_bjj_scalar();
423 let mut bytes = [0u8; 32];
424 scalar.to_big_endian(&mut bytes);
425 let val = BigUint::from_bytes_be(&bytes);
426 assert!(val < subgroup_order, "got {val} >= subgroup order");
427 }
428 }
429
430 #[test]
431 fn test_random_field_valid_multiple() {
432 let mut seen = std::collections::HashSet::new();
433 for _ in 0..20 {
434 let val = random_field();
435 let fr = u256_to_fr(val);
436 let back = fr_to_u256(fr);
437 assert_eq!(val, back);
438 seen.insert(val);
439 }
440 assert!(
441 seen.len() >= 15,
442 "only {} unique values in 20 draws",
443 seen.len()
444 );
445 }
446}