Skip to main content

darkpool_crypto/
bjj.rs

1use crate::error::CryptoError;
2use crate::field::{deserialize_fr, serialize_fr};
3use ark_bn254::Fr;
4use ark_ff::{BigInteger, Field, One, PrimeField, Zero};
5use rand::{CryptoRng, Rng};
6use serde::{Deserialize, Serialize};
7use std::str::FromStr;
8use zeroize::Zeroize;
9
10/// `BabyJubJub` twisted Edwards parameter A = 168700.
11#[allow(clippy::expect_used)]
12pub(crate) static BJJ_A: std::sync::LazyLock<Fr> =
13    std::sync::LazyLock::new(|| Fr::from_str("168700").expect("BJJ parameter A is valid"));
14
15/// `BabyJubJub` twisted Edwards parameter D = 168696 (non-square, addition law is complete).
16#[allow(clippy::expect_used)]
17pub(crate) static BJJ_D: std::sync::LazyLock<Fr> =
18    std::sync::LazyLock::new(|| Fr::from_str("168696").expect("BJJ parameter D is valid"));
19
20pub const BASE8_X: &str =
21    "5299619240641551281634865583518297030282874472190772894086521144482721001553";
22pub const BASE8_Y: &str =
23    "16950150798460657717958625567821834550301663161624707787222815936182638968203";
24
25/// `BabyJubJub` subgroup order (decimal). Canonical source for all crates.
26pub const SUBGROUP_ORDER: &str =
27    "2736030358979909402780800718157159386076813972158567259200215660948447373041";
28
29#[allow(clippy::expect_used)]
30pub static BASE8: std::sync::LazyLock<PublicKey> = std::sync::LazyLock::new(|| PublicKey {
31    x: Fr::from_str(BASE8_X).expect("BASE8_X is valid"),
32    y: Fr::from_str(BASE8_Y).expect("BASE8_Y is valid"),
33});
34
35#[allow(clippy::expect_used)]
36static SUBGROUP_ORDER_LE: std::sync::LazyLock<Vec<u8>> = std::sync::LazyLock::new(|| {
37    let order = Fr::from_str(SUBGROUP_ORDER).expect("BJJ subgroup order is valid");
38    order.into_bigint().to_bytes_le()
39});
40
41#[allow(clippy::expect_used)]
42static HALF_MODULUS: std::sync::LazyLock<Fr> = std::sync::LazyLock::new(|| {
43    Fr::from_bigint(<Fr as PrimeField>::MODULUS_MINUS_ONE_DIV_TWO)
44        .expect("MODULUS_MINUS_ONE_DIV_TWO is valid")
45});
46
47/// Convert a U256 value to 32-byte little-endian representation for `mul_scalar`.
48#[must_use]
49pub fn u256_to_le_bytes(value: ethers_core::types::U256) -> [u8; 32] {
50    let mut bytes = [0u8; 32];
51    value.to_little_endian(&mut bytes);
52    bytes
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)]
56pub struct PublicKey {
57    #[serde(serialize_with = "serialize_fr", deserialize_with = "deserialize_fr")]
58    x: Fr,
59    #[serde(serialize_with = "serialize_fr", deserialize_with = "deserialize_fr")]
60    y: Fr,
61}
62
63impl PublicKey {
64    #[inline]
65    #[must_use]
66    pub fn x(&self) -> Fr {
67        self.x
68    }
69
70    #[inline]
71    #[must_use]
72    pub fn y(&self) -> Fr {
73        self.y
74    }
75
76    /// Caller must ensure the point is on the BJJ curve and in the prime-order subgroup.
77    #[inline]
78    #[must_use]
79    pub fn new_unchecked(x: Fr, y: Fr) -> Self {
80        Self { x, y }
81    }
82}
83
84/// `BabyJubJub` secret key. Zeroized on drop; no `Copy` to prevent silent duplication.
85#[derive(Clone, PartialEq, Eq)]
86pub struct SecretKey(pub ark_ed_on_bn254::Fr);
87
88impl std::fmt::Debug for SecretKey {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_tuple("SecretKey").field(&"[REDACTED]").finish()
91    }
92}
93
94impl SecretKey {
95    pub fn generate<R: CryptoRng + Rng>(rng: &mut R) -> Self {
96        use ark_std::UniformRand;
97        Self(ark_ed_on_bn254::Fr::rand(rng))
98    }
99
100    pub fn from_hex(hex_str: &str) -> Result<Self, CryptoError> {
101        let bytes = hex::decode(hex_str).map_err(|_| CryptoError::InvalidKey)?;
102        if bytes.len() != 32 {
103            return Err(CryptoError::InvalidKey);
104        }
105        let val = ark_ed_on_bn254::Fr::from_be_bytes_mod_order(&bytes);
106
107        // Reject over-modulus values that get silently reduced
108        let round_trip = val.into_bigint().to_bytes_be();
109        let mut padded = [0u8; 32];
110        padded.copy_from_slice(&bytes);
111        if round_trip != padded {
112            return Err(CryptoError::InvalidKey);
113        }
114
115        Ok(Self(val))
116    }
117
118    pub fn public_key(&self) -> Result<PublicKey, CryptoError> {
119        BASE8.mul_scalar(&self.0.into_bigint().to_bytes_le())
120    }
121
122    pub fn derive_shared_secret(&self, peer_pk: &PublicKey) -> Result<SharedSecret, CryptoError> {
123        let point = peer_pk.mul_scalar(&self.0.into_bigint().to_bytes_le())?;
124        Ok(SharedSecret {
125            x: point.x,
126            y: point.y,
127        })
128    }
129
130    #[must_use]
131    pub fn to_hex(&self) -> String {
132        let bytes = self.0.into_bigint().to_bytes_be();
133        hex::encode(bytes)
134    }
135}
136
137impl Zeroize for SecretKey {
138    fn zeroize(&mut self) {
139        self.0.zeroize();
140    }
141}
142
143impl Drop for SecretKey {
144    fn drop(&mut self) {
145        self.zeroize();
146    }
147}
148
149#[derive(Clone)]
150pub struct SharedSecret {
151    x: Fr,
152    y: Fr,
153}
154
155impl SharedSecret {
156    #[inline]
157    #[must_use]
158    pub fn x(&self) -> Fr {
159        self.x
160    }
161
162    #[inline]
163    #[must_use]
164    pub fn y(&self) -> Fr {
165        self.y
166    }
167}
168
169impl std::fmt::Debug for SharedSecret {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        f.debug_struct("SharedSecret")
172            .field("x", &"[REDACTED]")
173            .field("y", &"[REDACTED]")
174            .finish()
175    }
176}
177
178impl SharedSecret {
179    #[must_use]
180    pub fn to_symmetric_key(&self) -> [u8; 32] {
181        let packed = PublicKey {
182            x: self.x,
183            y: self.y,
184        }
185        .to_bytes();
186        use sha2::{Digest, Sha256};
187        let mut hasher = Sha256::new();
188        hasher.update(packed);
189        hasher.finalize().into()
190    }
191}
192
193impl Zeroize for SharedSecret {
194    fn zeroize(&mut self) {
195        self.x.zeroize();
196        self.y.zeroize();
197    }
198}
199
200impl Drop for SharedSecret {
201    fn drop(&mut self) {
202        self.zeroize();
203    }
204}
205
206impl PublicKey {
207    /// Construct from coordinates with on-curve and subgroup membership validation.
208    pub fn from_coordinates(x: Fr, y: Fr) -> Result<Self, CryptoError> {
209        let point = Self { x, y };
210
211        if !point.is_on_curve() {
212            return Err(CryptoError::InvalidPoint);
213        }
214
215        let check = point.mul_scalar(&SUBGROUP_ORDER_LE)?;
216        if !check.x.is_zero() || !check.y.is_one() {
217            return Err(CryptoError::SubgroupCheckFailed);
218        }
219
220        Ok(point)
221    }
222
223    /// Check: `A*x^2 + y^2 == 1 + D*x^2*y^2`
224    #[must_use]
225    pub fn is_on_curve(&self) -> bool {
226        let a = *BJJ_A;
227        let d = *BJJ_D;
228        let x2 = self.x.square();
229        let y2 = self.y.square();
230        let lhs = a * x2 + y2;
231        let rhs = Fr::one() + d * x2 * y2;
232        lhs == rhs
233    }
234
235    /// Twisted Edwards point addition (complete -- no exceptional cases).
236    pub fn add(&self, other: &Self) -> Result<Self, CryptoError> {
237        let a = *BJJ_A;
238        let d = *BJJ_D;
239
240        let beta = self.x * other.y;
241        let gamma = self.y * other.x;
242        let tau = beta * gamma;
243        let dtau = d * tau;
244
245        let x3_denom_inv = (Fr::one() + dtau)
246            .inverse()
247            .ok_or(CryptoError::InvalidOperation)?;
248        let x3 = (beta + gamma) * x3_denom_inv;
249
250        let delta = (self.y - (a * self.x)) * (other.x + other.y);
251        let y3_denom_inv = (Fr::one() - dtau)
252            .inverse()
253            .ok_or(CryptoError::InvalidOperation)?;
254        let y3 = (delta + (a * beta) - gamma) * y3_denom_inv;
255
256        Ok(Self { x: x3, y: y3 })
257    }
258
259    /// Constant-time double-and-add scalar multiplication. Scalar is little-endian bytes.
260    pub fn mul_scalar(&self, scalar_le_bytes: &[u8]) -> Result<Self, CryptoError> {
261        use subtle::Choice;
262
263        let mut res = Self {
264            x: Fr::zero(),
265            y: Fr::one(),
266        };
267        let mut temp = *self;
268
269        for byte in scalar_le_bytes {
270            let mut b = *byte;
271            for _ in 0..8 {
272                let candidate = res.add(&temp)?;
273                let bit = Choice::from(b & 1);
274                res = Self::ct_select(&res, &candidate, bit);
275                temp = temp.add(&temp)?;
276                b >>= 1;
277            }
278        }
279        Ok(res)
280    }
281
282    /// Constant-time point selection using u64 limbs (4 per coordinate).
283    fn ct_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self {
284        use subtle::ConditionallySelectable;
285
286        let a_x_limbs = a.x.into_bigint().0;
287        let b_x_limbs = b.x.into_bigint().0;
288        let a_y_limbs = a.y.into_bigint().0;
289        let b_y_limbs = b.y.into_bigint().0;
290
291        let mut x_bytes = [0u8; 32];
292        let mut y_bytes = [0u8; 32];
293
294        for i in 0..4 {
295            let x_limb = u64::conditional_select(&a_x_limbs[i], &b_x_limbs[i], choice);
296            let y_limb = u64::conditional_select(&a_y_limbs[i], &b_y_limbs[i], choice);
297            x_bytes[i * 8..(i + 1) * 8].copy_from_slice(&x_limb.to_le_bytes());
298            y_bytes[i * 8..(i + 1) * 8].copy_from_slice(&y_limb.to_le_bytes());
299        }
300
301        Self {
302            x: Fr::from_le_bytes_mod_order(&x_bytes),
303            y: Fr::from_le_bytes_mod_order(&y_bytes),
304        }
305    }
306
307    /// Compressed encoding: LE y-coordinate with sign bit in MSB.
308    #[must_use]
309    #[allow(clippy::expect_used)]
310    pub fn to_bytes(&self) -> [u8; 32] {
311        let mut y_bytes = self.y.into_bigint().to_bytes_le();
312        if self.x > *HALF_MODULUS {
313            y_bytes[31] |= 0x80;
314        }
315        // SAFETY: BN254 field elements are always 32 bytes in LE representation
316        y_bytes.try_into().expect("y_bytes is 32 bytes")
317    }
318
319    #[must_use]
320    pub fn to_hex(&self) -> String {
321        hex::encode(self.to_bytes())
322    }
323
324    /// Decompress from hex with full validation (curve check + subgroup membership).
325    pub fn from_hex(hex_str: &str) -> Result<Self, CryptoError> {
326        let bytes = hex::decode(hex_str).map_err(|_| CryptoError::InvalidKey)?;
327        if bytes.len() != 32 {
328            return Err(CryptoError::InvalidKey);
329        }
330
331        let mut y_bytes = bytes.clone();
332        let sign = (y_bytes[31] & 0x80) != 0;
333        y_bytes[31] &= 0x7F;
334
335        let y = Fr::from_le_bytes_mod_order(&y_bytes);
336        let a = *BJJ_A;
337        let d = *BJJ_D;
338        let y2 = y.square();
339
340        let x2 = (Fr::one() - y2) * (a - (d * y2)).inverse().ok_or(CryptoError::InvalidPoint)?;
341        let mut x = x2.sqrt().ok_or(CryptoError::InvalidPoint)?;
342
343        if (x > *HALF_MODULUS) != sign {
344            x = -x;
345        }
346
347        let point = Self { x, y };
348
349        let check = point.mul_scalar(&SUBGROUP_ORDER_LE)?;
350        if !check.x.is_zero() || !check.y.is_one() {
351            return Err(CryptoError::SubgroupCheckFailed);
352        }
353
354        Ok(point)
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_identity_element() {
364        let scalar_zero = [0u8; 32];
365        let result = BASE8
366            .mul_scalar(&scalar_zero)
367            .expect("mul by zero should succeed");
368        assert_eq!(result.x, Fr::from(0));
369        assert_eq!(result.y, Fr::from(1));
370    }
371
372    #[test]
373    fn test_generator_mul_one() {
374        let mut scalar_one = [0u8; 32];
375        scalar_one[0] = 1;
376        let result = BASE8
377            .mul_scalar(&scalar_one)
378            .expect("mul by one should succeed");
379        assert_eq!(result.x, BASE8.x);
380        assert_eq!(result.y, BASE8.y);
381    }
382
383    #[test]
384    fn test_public_key_generation() {
385        let mut rng = rand::thread_rng();
386        let sk = SecretKey::generate(&mut rng);
387        let pk = sk
388            .public_key()
389            .expect("public key derivation should succeed");
390        // Public key should not be the identity element
391        assert!(pk.x != Fr::from(0) || pk.y != Fr::from(1));
392    }
393
394    #[test]
395    fn test_public_key_serialization_roundtrip() {
396        let mut rng = rand::thread_rng();
397        let sk = SecretKey::generate(&mut rng);
398        let pk = sk
399            .public_key()
400            .expect("public key derivation should succeed");
401        let hex_str = pk.to_hex();
402        let pk2 = PublicKey::from_hex(&hex_str).expect("from_hex should succeed");
403        assert_eq!(pk, pk2);
404    }
405
406    #[test]
407    fn test_secret_key_roundtrip() {
408        let mut rng = rand::thread_rng();
409        let sk = SecretKey::generate(&mut rng);
410        let hex_str = sk.to_hex();
411        let sk2 = SecretKey::from_hex(&hex_str).expect("from_hex should succeed");
412        assert_eq!(sk, sk2);
413    }
414
415    #[test]
416    fn test_shared_secret_commutativity() {
417        let mut rng = rand::thread_rng();
418        let alice_sk = SecretKey::generate(&mut rng);
419        let bob_sk = SecretKey::generate(&mut rng);
420
421        let alice_pk = alice_sk.public_key().expect("alice pk");
422        let bob_pk = bob_sk.public_key().expect("bob pk");
423
424        let ss_alice = alice_sk.derive_shared_secret(&bob_pk).expect("alice ECDH");
425        let ss_bob = bob_sk.derive_shared_secret(&alice_pk).expect("bob ECDH");
426
427        assert_eq!(ss_alice.x, ss_bob.x);
428        assert_eq!(ss_alice.y, ss_bob.y);
429    }
430
431    #[test]
432    fn test_mul_scalar_returns_result() {
433        let scalar = [0x42u8; 32];
434        let result = BASE8.mul_scalar(&scalar);
435        assert!(result.is_ok());
436    }
437
438    #[test]
439    fn test_from_hex_rejects_over_modulus() {
440        let over_modulus = "060c89ce5c263405370a08b6d0302b0bab3eedb83920ee0a677297dc392126f2";
441        let result = SecretKey::from_hex(over_modulus);
442        assert!(result.is_err());
443    }
444
445    #[test]
446    fn test_from_hex_rejects_wrong_length() {
447        let short = "00".repeat(31);
448        assert!(SecretKey::from_hex(&short).is_err());
449
450        let long = "00".repeat(33);
451        assert!(SecretKey::from_hex(&long).is_err());
452    }
453
454    #[test]
455    fn test_point_addition_identity() {
456        let identity = PublicKey {
457            x: Fr::from(0),
458            y: Fr::from(1),
459        };
460        let result = BASE8
461            .add(&identity)
462            .expect("adding identity should succeed");
463        assert_eq!(result.x, BASE8.x);
464        assert_eq!(result.y, BASE8.y);
465    }
466
467    #[test]
468    fn test_from_coordinates_valid_point() {
469        let pk = PublicKey::from_coordinates(BASE8.x, BASE8.y);
470        assert!(pk.is_ok());
471        assert_eq!(pk.unwrap(), *BASE8);
472    }
473
474    #[test]
475    fn test_from_coordinates_off_curve() {
476        let result = PublicKey::from_coordinates(Fr::from(1), Fr::from(2));
477        assert_eq!(result, Err(CryptoError::InvalidPoint));
478    }
479
480    #[test]
481    fn test_from_coordinates_zero_zero() {
482        let result = PublicKey::from_coordinates(Fr::from(0), Fr::from(0));
483        assert_eq!(result, Err(CryptoError::InvalidPoint));
484    }
485
486    #[test]
487    fn test_from_coordinates_identity() {
488        let result = PublicKey::from_coordinates(Fr::from(0), Fr::from(1));
489        assert!(result.is_ok());
490    }
491
492    #[test]
493    fn test_is_on_curve_base8() {
494        assert!(BASE8.is_on_curve());
495    }
496
497    #[test]
498    fn test_is_on_curve_identity() {
499        let id = PublicKey {
500            x: Fr::from(0),
501            y: Fr::from(1),
502        };
503        assert!(id.is_on_curve());
504    }
505
506    #[test]
507    fn test_is_on_curve_invalid() {
508        let bad = PublicKey {
509            x: Fr::from(42),
510            y: Fr::from(99),
511        };
512        assert!(!bad.is_on_curve());
513    }
514
515    #[test]
516    fn test_derived_pk_is_on_curve() {
517        let mut rng = rand::thread_rng();
518        let sk = SecretKey::generate(&mut rng);
519        let pk = sk.public_key().expect("pk derivation");
520        assert!(pk.is_on_curve());
521    }
522
523    #[test]
524    fn test_point_negation() {
525        let neg_base8 = PublicKey {
526            x: -BASE8.x,
527            y: BASE8.y,
528        };
529        let result = BASE8.add(&neg_base8).expect("P + (-P) should succeed");
530        assert_eq!(result.x, Fr::from(0));
531        assert_eq!(result.y, Fr::from(1));
532    }
533
534    #[test]
535    fn test_mul_by_subgroup_order() {
536        let result = BASE8
537            .mul_scalar(&SUBGROUP_ORDER_LE)
538            .expect("mul by subgroup order should succeed");
539        assert_eq!(result.x, Fr::from(0));
540        assert_eq!(result.y, Fr::from(1));
541    }
542
543    #[test]
544    fn test_point_addition_commutative() {
545        let mut rng = rand::thread_rng();
546        let sk1 = SecretKey::generate(&mut rng);
547        let sk2 = SecretKey::generate(&mut rng);
548        let p = sk1.public_key().expect("pk1");
549        let q = sk2.public_key().expect("pk2");
550
551        let pq = p.add(&q).expect("P + Q");
552        let qp = q.add(&p).expect("Q + P");
553        assert_eq!(pq, qp);
554    }
555
556    #[test]
557    fn test_scalar_mul_double() {
558        let mut scalar_two = [0u8; 32];
559        scalar_two[0] = 2;
560        let doubled = BASE8.mul_scalar(&scalar_two).expect("[2]*G");
561        let added = BASE8.add(&*BASE8).expect("G + G");
562        assert_eq!(doubled, added);
563    }
564
565    #[test]
566    fn test_from_hex_empty() {
567        let result = PublicKey::from_hex("");
568        assert!(result.is_err());
569    }
570
571    #[test]
572    fn test_from_hex_wrong_length() {
573        let result = PublicKey::from_hex("aabb");
574        assert!(result.is_err());
575    }
576
577    #[test]
578    fn test_from_hex_invalid_chars() {
579        let result = PublicKey::from_hex(&"zz".repeat(32));
580        assert!(result.is_err());
581    }
582
583    #[test]
584    fn test_from_hex_all_zeros() {
585        let result = PublicKey::from_hex(&"00".repeat(32));
586        assert!(result.is_err());
587    }
588
589    #[test]
590    fn test_from_hex_all_ff() {
591        let result = PublicKey::from_hex(&"ff".repeat(32));
592        assert!(result.is_err());
593    }
594
595    #[test]
596    fn test_public_key_roundtrip_stress() {
597        let mut rng = rand::thread_rng();
598        for _ in 0..10 {
599            let sk = SecretKey::generate(&mut rng);
600            let pk = sk.public_key().expect("pk");
601            let hex = pk.to_hex();
602            let recovered = PublicKey::from_hex(&hex).expect("roundtrip");
603            assert_eq!(pk, recovered);
604        }
605    }
606
607    #[test]
608    fn test_secret_key_zero() {
609        let zero = "00".repeat(32);
610        let sk = SecretKey::from_hex(&zero);
611        assert!(sk.is_ok());
612    }
613
614    #[test]
615    fn test_secret_key_one() {
616        let mut hex = "00".repeat(31);
617        hex.push_str("01");
618        let sk = SecretKey::from_hex(&hex).expect("one should be valid");
619        let pk = sk.public_key().expect("pk from sk=1");
620        assert_eq!(pk, *BASE8);
621    }
622
623    #[test]
624    fn test_secret_key_invalid_hex() {
625        let result = SecretKey::from_hex("not_valid_hex_at_all_!");
626        assert!(result.is_err());
627    }
628
629    #[test]
630    fn test_shared_secret_to_symmetric_key() {
631        let mut rng = rand::thread_rng();
632        let alice = SecretKey::generate(&mut rng);
633        let bob = SecretKey::generate(&mut rng);
634        let bob_pk = bob.public_key().expect("bob pk");
635
636        let ss = alice.derive_shared_secret(&bob_pk).expect("ECDH");
637        let key1 = ss.to_symmetric_key();
638        let key2 = ss.to_symmetric_key();
639
640        assert_eq!(key1.len(), 32);
641        assert_eq!(key1, key2);
642    }
643
644    #[test]
645    fn test_shared_secret_different_keys() {
646        let mut rng = rand::thread_rng();
647        let alice = SecretKey::generate(&mut rng);
648        let bob = SecretKey::generate(&mut rng);
649        let carol = SecretKey::generate(&mut rng);
650
651        let bob_pk = bob.public_key().expect("bob pk");
652        let carol_pk = carol.public_key().expect("carol pk");
653
654        let ss_bob = alice.derive_shared_secret(&bob_pk).expect("ECDH bob");
655        let ss_carol = alice.derive_shared_secret(&carol_pk).expect("ECDH carol");
656
657        assert_ne!(ss_bob.to_symmetric_key(), ss_carol.to_symmetric_key());
658    }
659
660    #[test]
661    fn test_u256_to_le_bytes_small() {
662        let val = ethers_core::types::U256::from(0x0102u64);
663        let le = u256_to_le_bytes(val);
664        assert_eq!(le[0], 0x02);
665        assert_eq!(le[1], 0x01);
666        assert!(le[2..].iter().all(|&b| b == 0));
667    }
668
669    #[test]
670    fn test_u256_to_le_bytes_length() {
671        let val = ethers_core::types::U256::MAX;
672        let le = u256_to_le_bytes(val);
673        assert_eq!(le.len(), 32);
674        assert!(le.iter().all(|&b| b == 0xff));
675    }
676}