1use crate::error::CryptoError;
2use crate::field::{deserialize_fr, serialize_fr};
3use ark_bn254::Fr;
4use ark_ff::{BigInteger, Field, One, PrimeField, Zero};
5use rand::{CryptoRng, Rng};
6use serde::{Deserialize, Serialize};
7use std::str::FromStr;
8use zeroize::Zeroize;
9
10#[allow(clippy::expect_used)]
12pub(crate) static BJJ_A: std::sync::LazyLock<Fr> =
13 std::sync::LazyLock::new(|| Fr::from_str("168700").expect("BJJ parameter A is valid"));
14
15#[allow(clippy::expect_used)]
17pub(crate) static BJJ_D: std::sync::LazyLock<Fr> =
18 std::sync::LazyLock::new(|| Fr::from_str("168696").expect("BJJ parameter D is valid"));
19
20pub const BASE8_X: &str =
21 "5299619240641551281634865583518297030282874472190772894086521144482721001553";
22pub const BASE8_Y: &str =
23 "16950150798460657717958625567821834550301663161624707787222815936182638968203";
24
25pub const SUBGROUP_ORDER: &str =
27 "2736030358979909402780800718157159386076813972158567259200215660948447373041";
28
29#[allow(clippy::expect_used)]
30pub static BASE8: std::sync::LazyLock<PublicKey> = std::sync::LazyLock::new(|| PublicKey {
31 x: Fr::from_str(BASE8_X).expect("BASE8_X is valid"),
32 y: Fr::from_str(BASE8_Y).expect("BASE8_Y is valid"),
33});
34
35#[allow(clippy::expect_used)]
36static SUBGROUP_ORDER_LE: std::sync::LazyLock<Vec<u8>> = std::sync::LazyLock::new(|| {
37 let order = Fr::from_str(SUBGROUP_ORDER).expect("BJJ subgroup order is valid");
38 order.into_bigint().to_bytes_le()
39});
40
41#[allow(clippy::expect_used)]
42static HALF_MODULUS: std::sync::LazyLock<Fr> = std::sync::LazyLock::new(|| {
43 Fr::from_bigint(<Fr as PrimeField>::MODULUS_MINUS_ONE_DIV_TWO)
44 .expect("MODULUS_MINUS_ONE_DIV_TWO is valid")
45});
46
47#[must_use]
49pub fn u256_to_le_bytes(value: ethers_core::types::U256) -> [u8; 32] {
50 let mut bytes = [0u8; 32];
51 value.to_little_endian(&mut bytes);
52 bytes
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)]
56pub struct PublicKey {
57 #[serde(serialize_with = "serialize_fr", deserialize_with = "deserialize_fr")]
58 x: Fr,
59 #[serde(serialize_with = "serialize_fr", deserialize_with = "deserialize_fr")]
60 y: Fr,
61}
62
63impl PublicKey {
64 #[inline]
65 #[must_use]
66 pub fn x(&self) -> Fr {
67 self.x
68 }
69
70 #[inline]
71 #[must_use]
72 pub fn y(&self) -> Fr {
73 self.y
74 }
75
76 #[inline]
78 #[must_use]
79 pub fn new_unchecked(x: Fr, y: Fr) -> Self {
80 Self { x, y }
81 }
82}
83
84#[derive(Clone, PartialEq, Eq)]
86pub struct SecretKey(pub ark_ed_on_bn254::Fr);
87
88impl std::fmt::Debug for SecretKey {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_tuple("SecretKey").field(&"[REDACTED]").finish()
91 }
92}
93
94impl SecretKey {
95 pub fn generate<R: CryptoRng + Rng>(rng: &mut R) -> Self {
96 use ark_std::UniformRand;
97 Self(ark_ed_on_bn254::Fr::rand(rng))
98 }
99
100 pub fn from_hex(hex_str: &str) -> Result<Self, CryptoError> {
101 let bytes = hex::decode(hex_str).map_err(|_| CryptoError::InvalidKey)?;
102 if bytes.len() != 32 {
103 return Err(CryptoError::InvalidKey);
104 }
105 let val = ark_ed_on_bn254::Fr::from_be_bytes_mod_order(&bytes);
106
107 let round_trip = val.into_bigint().to_bytes_be();
109 let mut padded = [0u8; 32];
110 padded.copy_from_slice(&bytes);
111 if round_trip != padded {
112 return Err(CryptoError::InvalidKey);
113 }
114
115 Ok(Self(val))
116 }
117
118 pub fn public_key(&self) -> Result<PublicKey, CryptoError> {
119 BASE8.mul_scalar(&self.0.into_bigint().to_bytes_le())
120 }
121
122 pub fn derive_shared_secret(&self, peer_pk: &PublicKey) -> Result<SharedSecret, CryptoError> {
123 let point = peer_pk.mul_scalar(&self.0.into_bigint().to_bytes_le())?;
124 Ok(SharedSecret {
125 x: point.x,
126 y: point.y,
127 })
128 }
129
130 #[must_use]
131 pub fn to_hex(&self) -> String {
132 let bytes = self.0.into_bigint().to_bytes_be();
133 hex::encode(bytes)
134 }
135}
136
137impl Zeroize for SecretKey {
138 fn zeroize(&mut self) {
139 self.0.zeroize();
140 }
141}
142
143impl Drop for SecretKey {
144 fn drop(&mut self) {
145 self.zeroize();
146 }
147}
148
149#[derive(Clone)]
150pub struct SharedSecret {
151 x: Fr,
152 y: Fr,
153}
154
155impl SharedSecret {
156 #[inline]
157 #[must_use]
158 pub fn x(&self) -> Fr {
159 self.x
160 }
161
162 #[inline]
163 #[must_use]
164 pub fn y(&self) -> Fr {
165 self.y
166 }
167}
168
169impl std::fmt::Debug for SharedSecret {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_struct("SharedSecret")
172 .field("x", &"[REDACTED]")
173 .field("y", &"[REDACTED]")
174 .finish()
175 }
176}
177
178impl SharedSecret {
179 #[must_use]
180 pub fn to_symmetric_key(&self) -> [u8; 32] {
181 let packed = PublicKey {
182 x: self.x,
183 y: self.y,
184 }
185 .to_bytes();
186 use sha2::{Digest, Sha256};
187 let mut hasher = Sha256::new();
188 hasher.update(packed);
189 hasher.finalize().into()
190 }
191}
192
193impl Zeroize for SharedSecret {
194 fn zeroize(&mut self) {
195 self.x.zeroize();
196 self.y.zeroize();
197 }
198}
199
200impl Drop for SharedSecret {
201 fn drop(&mut self) {
202 self.zeroize();
203 }
204}
205
206impl PublicKey {
207 pub fn from_coordinates(x: Fr, y: Fr) -> Result<Self, CryptoError> {
209 let point = Self { x, y };
210
211 if !point.is_on_curve() {
212 return Err(CryptoError::InvalidPoint);
213 }
214
215 let check = point.mul_scalar(&SUBGROUP_ORDER_LE)?;
216 if !check.x.is_zero() || !check.y.is_one() {
217 return Err(CryptoError::SubgroupCheckFailed);
218 }
219
220 Ok(point)
221 }
222
223 #[must_use]
225 pub fn is_on_curve(&self) -> bool {
226 let a = *BJJ_A;
227 let d = *BJJ_D;
228 let x2 = self.x.square();
229 let y2 = self.y.square();
230 let lhs = a * x2 + y2;
231 let rhs = Fr::one() + d * x2 * y2;
232 lhs == rhs
233 }
234
235 pub fn add(&self, other: &Self) -> Result<Self, CryptoError> {
237 let a = *BJJ_A;
238 let d = *BJJ_D;
239
240 let beta = self.x * other.y;
241 let gamma = self.y * other.x;
242 let tau = beta * gamma;
243 let dtau = d * tau;
244
245 let x3_denom_inv = (Fr::one() + dtau)
246 .inverse()
247 .ok_or(CryptoError::InvalidOperation)?;
248 let x3 = (beta + gamma) * x3_denom_inv;
249
250 let delta = (self.y - (a * self.x)) * (other.x + other.y);
251 let y3_denom_inv = (Fr::one() - dtau)
252 .inverse()
253 .ok_or(CryptoError::InvalidOperation)?;
254 let y3 = (delta + (a * beta) - gamma) * y3_denom_inv;
255
256 Ok(Self { x: x3, y: y3 })
257 }
258
259 pub fn mul_scalar(&self, scalar_le_bytes: &[u8]) -> Result<Self, CryptoError> {
261 use subtle::Choice;
262
263 let mut res = Self {
264 x: Fr::zero(),
265 y: Fr::one(),
266 };
267 let mut temp = *self;
268
269 for byte in scalar_le_bytes {
270 let mut b = *byte;
271 for _ in 0..8 {
272 let candidate = res.add(&temp)?;
273 let bit = Choice::from(b & 1);
274 res = Self::ct_select(&res, &candidate, bit);
275 temp = temp.add(&temp)?;
276 b >>= 1;
277 }
278 }
279 Ok(res)
280 }
281
282 fn ct_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self {
284 use subtle::ConditionallySelectable;
285
286 let a_x_limbs = a.x.into_bigint().0;
287 let b_x_limbs = b.x.into_bigint().0;
288 let a_y_limbs = a.y.into_bigint().0;
289 let b_y_limbs = b.y.into_bigint().0;
290
291 let mut x_bytes = [0u8; 32];
292 let mut y_bytes = [0u8; 32];
293
294 for i in 0..4 {
295 let x_limb = u64::conditional_select(&a_x_limbs[i], &b_x_limbs[i], choice);
296 let y_limb = u64::conditional_select(&a_y_limbs[i], &b_y_limbs[i], choice);
297 x_bytes[i * 8..(i + 1) * 8].copy_from_slice(&x_limb.to_le_bytes());
298 y_bytes[i * 8..(i + 1) * 8].copy_from_slice(&y_limb.to_le_bytes());
299 }
300
301 Self {
302 x: Fr::from_le_bytes_mod_order(&x_bytes),
303 y: Fr::from_le_bytes_mod_order(&y_bytes),
304 }
305 }
306
307 #[must_use]
309 #[allow(clippy::expect_used)]
310 pub fn to_bytes(&self) -> [u8; 32] {
311 let mut y_bytes = self.y.into_bigint().to_bytes_le();
312 if self.x > *HALF_MODULUS {
313 y_bytes[31] |= 0x80;
314 }
315 y_bytes.try_into().expect("y_bytes is 32 bytes")
317 }
318
319 #[must_use]
320 pub fn to_hex(&self) -> String {
321 hex::encode(self.to_bytes())
322 }
323
324 pub fn from_hex(hex_str: &str) -> Result<Self, CryptoError> {
326 let bytes = hex::decode(hex_str).map_err(|_| CryptoError::InvalidKey)?;
327 if bytes.len() != 32 {
328 return Err(CryptoError::InvalidKey);
329 }
330
331 let mut y_bytes = bytes.clone();
332 let sign = (y_bytes[31] & 0x80) != 0;
333 y_bytes[31] &= 0x7F;
334
335 let y = Fr::from_le_bytes_mod_order(&y_bytes);
336 let a = *BJJ_A;
337 let d = *BJJ_D;
338 let y2 = y.square();
339
340 let x2 = (Fr::one() - y2) * (a - (d * y2)).inverse().ok_or(CryptoError::InvalidPoint)?;
341 let mut x = x2.sqrt().ok_or(CryptoError::InvalidPoint)?;
342
343 if (x > *HALF_MODULUS) != sign {
344 x = -x;
345 }
346
347 let point = Self { x, y };
348
349 let check = point.mul_scalar(&SUBGROUP_ORDER_LE)?;
350 if !check.x.is_zero() || !check.y.is_one() {
351 return Err(CryptoError::SubgroupCheckFailed);
352 }
353
354 Ok(point)
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_identity_element() {
364 let scalar_zero = [0u8; 32];
365 let result = BASE8
366 .mul_scalar(&scalar_zero)
367 .expect("mul by zero should succeed");
368 assert_eq!(result.x, Fr::from(0));
369 assert_eq!(result.y, Fr::from(1));
370 }
371
372 #[test]
373 fn test_generator_mul_one() {
374 let mut scalar_one = [0u8; 32];
375 scalar_one[0] = 1;
376 let result = BASE8
377 .mul_scalar(&scalar_one)
378 .expect("mul by one should succeed");
379 assert_eq!(result.x, BASE8.x);
380 assert_eq!(result.y, BASE8.y);
381 }
382
383 #[test]
384 fn test_public_key_generation() {
385 let mut rng = rand::thread_rng();
386 let sk = SecretKey::generate(&mut rng);
387 let pk = sk
388 .public_key()
389 .expect("public key derivation should succeed");
390 assert!(pk.x != Fr::from(0) || pk.y != Fr::from(1));
392 }
393
394 #[test]
395 fn test_public_key_serialization_roundtrip() {
396 let mut rng = rand::thread_rng();
397 let sk = SecretKey::generate(&mut rng);
398 let pk = sk
399 .public_key()
400 .expect("public key derivation should succeed");
401 let hex_str = pk.to_hex();
402 let pk2 = PublicKey::from_hex(&hex_str).expect("from_hex should succeed");
403 assert_eq!(pk, pk2);
404 }
405
406 #[test]
407 fn test_secret_key_roundtrip() {
408 let mut rng = rand::thread_rng();
409 let sk = SecretKey::generate(&mut rng);
410 let hex_str = sk.to_hex();
411 let sk2 = SecretKey::from_hex(&hex_str).expect("from_hex should succeed");
412 assert_eq!(sk, sk2);
413 }
414
415 #[test]
416 fn test_shared_secret_commutativity() {
417 let mut rng = rand::thread_rng();
418 let alice_sk = SecretKey::generate(&mut rng);
419 let bob_sk = SecretKey::generate(&mut rng);
420
421 let alice_pk = alice_sk.public_key().expect("alice pk");
422 let bob_pk = bob_sk.public_key().expect("bob pk");
423
424 let ss_alice = alice_sk.derive_shared_secret(&bob_pk).expect("alice ECDH");
425 let ss_bob = bob_sk.derive_shared_secret(&alice_pk).expect("bob ECDH");
426
427 assert_eq!(ss_alice.x, ss_bob.x);
428 assert_eq!(ss_alice.y, ss_bob.y);
429 }
430
431 #[test]
432 fn test_mul_scalar_returns_result() {
433 let scalar = [0x42u8; 32];
434 let result = BASE8.mul_scalar(&scalar);
435 assert!(result.is_ok());
436 }
437
438 #[test]
439 fn test_from_hex_rejects_over_modulus() {
440 let over_modulus = "060c89ce5c263405370a08b6d0302b0bab3eedb83920ee0a677297dc392126f2";
441 let result = SecretKey::from_hex(over_modulus);
442 assert!(result.is_err());
443 }
444
445 #[test]
446 fn test_from_hex_rejects_wrong_length() {
447 let short = "00".repeat(31);
448 assert!(SecretKey::from_hex(&short).is_err());
449
450 let long = "00".repeat(33);
451 assert!(SecretKey::from_hex(&long).is_err());
452 }
453
454 #[test]
455 fn test_point_addition_identity() {
456 let identity = PublicKey {
457 x: Fr::from(0),
458 y: Fr::from(1),
459 };
460 let result = BASE8
461 .add(&identity)
462 .expect("adding identity should succeed");
463 assert_eq!(result.x, BASE8.x);
464 assert_eq!(result.y, BASE8.y);
465 }
466
467 #[test]
468 fn test_from_coordinates_valid_point() {
469 let pk = PublicKey::from_coordinates(BASE8.x, BASE8.y);
470 assert!(pk.is_ok());
471 assert_eq!(pk.unwrap(), *BASE8);
472 }
473
474 #[test]
475 fn test_from_coordinates_off_curve() {
476 let result = PublicKey::from_coordinates(Fr::from(1), Fr::from(2));
477 assert_eq!(result, Err(CryptoError::InvalidPoint));
478 }
479
480 #[test]
481 fn test_from_coordinates_zero_zero() {
482 let result = PublicKey::from_coordinates(Fr::from(0), Fr::from(0));
483 assert_eq!(result, Err(CryptoError::InvalidPoint));
484 }
485
486 #[test]
487 fn test_from_coordinates_identity() {
488 let result = PublicKey::from_coordinates(Fr::from(0), Fr::from(1));
489 assert!(result.is_ok());
490 }
491
492 #[test]
493 fn test_is_on_curve_base8() {
494 assert!(BASE8.is_on_curve());
495 }
496
497 #[test]
498 fn test_is_on_curve_identity() {
499 let id = PublicKey {
500 x: Fr::from(0),
501 y: Fr::from(1),
502 };
503 assert!(id.is_on_curve());
504 }
505
506 #[test]
507 fn test_is_on_curve_invalid() {
508 let bad = PublicKey {
509 x: Fr::from(42),
510 y: Fr::from(99),
511 };
512 assert!(!bad.is_on_curve());
513 }
514
515 #[test]
516 fn test_derived_pk_is_on_curve() {
517 let mut rng = rand::thread_rng();
518 let sk = SecretKey::generate(&mut rng);
519 let pk = sk.public_key().expect("pk derivation");
520 assert!(pk.is_on_curve());
521 }
522
523 #[test]
524 fn test_point_negation() {
525 let neg_base8 = PublicKey {
526 x: -BASE8.x,
527 y: BASE8.y,
528 };
529 let result = BASE8.add(&neg_base8).expect("P + (-P) should succeed");
530 assert_eq!(result.x, Fr::from(0));
531 assert_eq!(result.y, Fr::from(1));
532 }
533
534 #[test]
535 fn test_mul_by_subgroup_order() {
536 let result = BASE8
537 .mul_scalar(&SUBGROUP_ORDER_LE)
538 .expect("mul by subgroup order should succeed");
539 assert_eq!(result.x, Fr::from(0));
540 assert_eq!(result.y, Fr::from(1));
541 }
542
543 #[test]
544 fn test_point_addition_commutative() {
545 let mut rng = rand::thread_rng();
546 let sk1 = SecretKey::generate(&mut rng);
547 let sk2 = SecretKey::generate(&mut rng);
548 let p = sk1.public_key().expect("pk1");
549 let q = sk2.public_key().expect("pk2");
550
551 let pq = p.add(&q).expect("P + Q");
552 let qp = q.add(&p).expect("Q + P");
553 assert_eq!(pq, qp);
554 }
555
556 #[test]
557 fn test_scalar_mul_double() {
558 let mut scalar_two = [0u8; 32];
559 scalar_two[0] = 2;
560 let doubled = BASE8.mul_scalar(&scalar_two).expect("[2]*G");
561 let added = BASE8.add(&*BASE8).expect("G + G");
562 assert_eq!(doubled, added);
563 }
564
565 #[test]
566 fn test_from_hex_empty() {
567 let result = PublicKey::from_hex("");
568 assert!(result.is_err());
569 }
570
571 #[test]
572 fn test_from_hex_wrong_length() {
573 let result = PublicKey::from_hex("aabb");
574 assert!(result.is_err());
575 }
576
577 #[test]
578 fn test_from_hex_invalid_chars() {
579 let result = PublicKey::from_hex(&"zz".repeat(32));
580 assert!(result.is_err());
581 }
582
583 #[test]
584 fn test_from_hex_all_zeros() {
585 let result = PublicKey::from_hex(&"00".repeat(32));
586 assert!(result.is_err());
587 }
588
589 #[test]
590 fn test_from_hex_all_ff() {
591 let result = PublicKey::from_hex(&"ff".repeat(32));
592 assert!(result.is_err());
593 }
594
595 #[test]
596 fn test_public_key_roundtrip_stress() {
597 let mut rng = rand::thread_rng();
598 for _ in 0..10 {
599 let sk = SecretKey::generate(&mut rng);
600 let pk = sk.public_key().expect("pk");
601 let hex = pk.to_hex();
602 let recovered = PublicKey::from_hex(&hex).expect("roundtrip");
603 assert_eq!(pk, recovered);
604 }
605 }
606
607 #[test]
608 fn test_secret_key_zero() {
609 let zero = "00".repeat(32);
610 let sk = SecretKey::from_hex(&zero);
611 assert!(sk.is_ok());
612 }
613
614 #[test]
615 fn test_secret_key_one() {
616 let mut hex = "00".repeat(31);
617 hex.push_str("01");
618 let sk = SecretKey::from_hex(&hex).expect("one should be valid");
619 let pk = sk.public_key().expect("pk from sk=1");
620 assert_eq!(pk, *BASE8);
621 }
622
623 #[test]
624 fn test_secret_key_invalid_hex() {
625 let result = SecretKey::from_hex("not_valid_hex_at_all_!");
626 assert!(result.is_err());
627 }
628
629 #[test]
630 fn test_shared_secret_to_symmetric_key() {
631 let mut rng = rand::thread_rng();
632 let alice = SecretKey::generate(&mut rng);
633 let bob = SecretKey::generate(&mut rng);
634 let bob_pk = bob.public_key().expect("bob pk");
635
636 let ss = alice.derive_shared_secret(&bob_pk).expect("ECDH");
637 let key1 = ss.to_symmetric_key();
638 let key2 = ss.to_symmetric_key();
639
640 assert_eq!(key1.len(), 32);
641 assert_eq!(key1, key2);
642 }
643
644 #[test]
645 fn test_shared_secret_different_keys() {
646 let mut rng = rand::thread_rng();
647 let alice = SecretKey::generate(&mut rng);
648 let bob = SecretKey::generate(&mut rng);
649 let carol = SecretKey::generate(&mut rng);
650
651 let bob_pk = bob.public_key().expect("bob pk");
652 let carol_pk = carol.public_key().expect("carol pk");
653
654 let ss_bob = alice.derive_shared_secret(&bob_pk).expect("ECDH bob");
655 let ss_carol = alice.derive_shared_secret(&carol_pk).expect("ECDH carol");
656
657 assert_ne!(ss_bob.to_symmetric_key(), ss_carol.to_symmetric_key());
658 }
659
660 #[test]
661 fn test_u256_to_le_bytes_small() {
662 let val = ethers_core::types::U256::from(0x0102u64);
663 let le = u256_to_le_bytes(val);
664 assert_eq!(le[0], 0x02);
665 assert_eq!(le[1], 0x01);
666 assert!(le[2..].iter().all(|&b| b == 0));
667 }
668
669 #[test]
670 fn test_u256_to_le_bytes_length() {
671 let val = ethers_core::types::U256::MAX;
672 let le = u256_to_le_bytes(val);
673 assert_eq!(le.len(), 32);
674 assert!(le.iter().all(|&b| b == 0xff));
675 }
676}