Skip to main content

darkpool_crypto/
ecdh.rs

1use ark_bn254::Fr;
2use ark_ff::{BigInteger, Field, PrimeField};
3use ethers_core::types::U256;
4
5use crate::bjj::{u256_to_le_bytes, PublicKey, BASE8, BJJ_A, BJJ_D};
6use crate::error::CryptoError;
7use crate::field::{fr_to_u256, u256_to_fr};
8
9/// `PK = sk * Base8`
10pub fn derive_public_key_from_sk(sk: U256) -> Result<(U256, U256), CryptoError> {
11    let sk_bytes = u256_to_le_bytes(sk);
12    let result = BASE8.mul_scalar(&sk_bytes)?;
13
14    let x_bytes = result.x().into_bigint().to_bytes_be();
15    let y_bytes = result.y().into_bigint().to_bytes_be();
16
17    Ok((
18        U256::from_big_endian(&x_bytes),
19        U256::from_big_endian(&y_bytes),
20    ))
21}
22
23/// ECDH shared secret derivation. Returns X coordinate of `sk * pk`.
24///
25/// Uses `mul_scalar` directly (not `SecretKey::from_hex`) because the scalar
26/// may exceed the BJJ subgroup order when sourced from KDF values.
27pub fn derive_shared_secret_bjj(
28    ephemeral_sk: U256,
29    compliance_pk: (U256, U256),
30) -> Result<U256, CryptoError> {
31    let pk = PublicKey::from_coordinates(u256_to_fr(compliance_pk.0), u256_to_fr(compliance_pk.1))?;
32
33    let sk_bytes = u256_to_le_bytes(ephemeral_sk);
34    let shared_point = pk.mul_scalar(&sk_bytes)?;
35
36    Ok(fr_to_u256(shared_point.x()))
37}
38
39/// Check `a*x^2 + y^2 == 1 + d*x^2*y^2`
40#[must_use]
41pub fn bjj_is_on_curve(x: Fr, y: Fr) -> bool {
42    let a = *BJJ_A;
43    let d = *BJJ_D;
44
45    let x2 = x.square();
46    let y2 = y.square();
47
48    let lhs = a * x2 + y2;
49    let rhs = Fr::from(1u64) + d * x2 * y2;
50
51    lhs == rhs
52}
53
54/// `scalar * point` with full validation (on-curve + subgroup check).
55pub fn bjj_scalar_mul(scalar: U256, point: (U256, U256)) -> Result<(U256, U256), CryptoError> {
56    let x = u256_to_fr(point.0);
57    let y = u256_to_fr(point.1);
58
59    let pk = PublicKey::from_coordinates(x, y)?;
60
61    let scalar_bytes = u256_to_le_bytes(scalar);
62    let result = pk.mul_scalar(&scalar_bytes)?;
63
64    let x_bytes = result.x().into_bigint().to_bytes_be();
65    let y_bytes = result.y().into_bigint().to_bytes_be();
66
67    Ok((
68        U256::from_big_endian(&x_bytes),
69        U256::from_big_endian(&y_bytes),
70    ))
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use crate::bjj::SecretKey;
77
78    #[test]
79    fn test_derive_shared_secret_bjj_symmetry() {
80        let alice_sk = U256::from(12345u64);
81        let bob_sk = U256::from(67890u64);
82
83        let alice_sk_bytes = u256_to_le_bytes(alice_sk);
84        let bob_sk_bytes = u256_to_le_bytes(bob_sk);
85
86        let alice_pk = BASE8.mul_scalar(&alice_sk_bytes).expect("valid test key");
87        let bob_pk = BASE8.mul_scalar(&bob_sk_bytes).expect("valid test key");
88
89        let alice_pk_tuple = (fr_to_u256(alice_pk.x()), fr_to_u256(alice_pk.y()));
90        let bob_pk_tuple = (fr_to_u256(bob_pk.x()), fr_to_u256(bob_pk.y()));
91
92        let ss_from_alice = derive_shared_secret_bjj(alice_sk, bob_pk_tuple).unwrap();
93        let ss_from_bob = derive_shared_secret_bjj(bob_sk, alice_pk_tuple).unwrap();
94
95        assert_eq!(ss_from_alice, ss_from_bob);
96    }
97
98    #[test]
99    fn test_derive_pk_from_zero_sk() {
100        let (x, y) = derive_public_key_from_sk(U256::zero()).expect("zero sk");
101        assert!(x.is_zero());
102        assert_eq!(y, U256::from(1));
103    }
104
105    #[test]
106    fn test_derive_pk_from_one() {
107        let (x, y) = derive_public_key_from_sk(U256::from(1)).expect("sk=1");
108        let base8_x = fr_to_u256(BASE8.x());
109        let base8_y = fr_to_u256(BASE8.y());
110        assert_eq!(x, base8_x);
111        assert_eq!(y, base8_y);
112    }
113
114    #[test]
115    fn test_derive_pk_matches_secret_key() {
116        let mut rng = rand::thread_rng();
117        let sk = SecretKey::generate(&mut rng);
118        let pk = sk.public_key().expect("pk");
119
120        let sk_bytes = sk.0.into_bigint().to_bytes_be();
121        let sk_u256 = U256::from_big_endian(&sk_bytes);
122
123        let (x, y) = derive_public_key_from_sk(sk_u256).expect("derive pk");
124        assert_eq!(x, fr_to_u256(pk.x()));
125        assert_eq!(y, fr_to_u256(pk.y()));
126    }
127
128    #[test]
129    fn test_bjj_is_on_curve_base8() {
130        assert!(bjj_is_on_curve(BASE8.x(), BASE8.y()));
131    }
132
133    #[test]
134    fn test_bjj_is_on_curve_identity() {
135        assert!(bjj_is_on_curve(Fr::from(0), Fr::from(1)));
136    }
137
138    #[test]
139    fn test_bjj_is_on_curve_random_point() {
140        assert!(!bjj_is_on_curve(Fr::from(42), Fr::from(99)));
141    }
142
143    #[test]
144    fn test_bjj_is_on_curve_derived() {
145        let mut rng = rand::thread_rng();
146        let sk = SecretKey::generate(&mut rng);
147        let pk = sk.public_key().expect("pk");
148        assert!(bjj_is_on_curve(pk.x(), pk.y()));
149    }
150
151    #[test]
152    fn test_bjj_scalar_mul_identity() {
153        let base = (fr_to_u256(BASE8.x()), fr_to_u256(BASE8.y()));
154        let result = bjj_scalar_mul(U256::from(1), base).expect("scalar mul");
155        assert_eq!(result, base);
156    }
157
158    #[test]
159    fn test_bjj_scalar_mul_zero() {
160        let base = (fr_to_u256(BASE8.x()), fr_to_u256(BASE8.y()));
161        let (x, y) = bjj_scalar_mul(U256::zero(), base).expect("scalar mul 0");
162        assert!(x.is_zero());
163        assert_eq!(y, U256::from(1));
164    }
165
166    #[test]
167    fn test_bjj_scalar_mul_off_curve() {
168        let bad_point = (U256::from(1), U256::from(2));
169        let result = bjj_scalar_mul(U256::from(5), bad_point);
170        assert!(result.is_err());
171    }
172
173    #[test]
174    fn test_derive_shared_secret_rejects_off_curve() {
175        let result = derive_shared_secret_bjj(U256::from(42), (U256::from(1), U256::from(2)));
176        assert!(result.is_err());
177    }
178}