1use ark_bn254::Fr;
2use ark_ff::{BigInteger, Field, PrimeField};
3use ethers_core::types::U256;
4
5use crate::bjj::{u256_to_le_bytes, PublicKey, BASE8, BJJ_A, BJJ_D};
6use crate::error::CryptoError;
7use crate::field::{fr_to_u256, u256_to_fr};
8
9pub fn derive_public_key_from_sk(sk: U256) -> Result<(U256, U256), CryptoError> {
11 let sk_bytes = u256_to_le_bytes(sk);
12 let result = BASE8.mul_scalar(&sk_bytes)?;
13
14 let x_bytes = result.x().into_bigint().to_bytes_be();
15 let y_bytes = result.y().into_bigint().to_bytes_be();
16
17 Ok((
18 U256::from_big_endian(&x_bytes),
19 U256::from_big_endian(&y_bytes),
20 ))
21}
22
23pub fn derive_shared_secret_bjj(
28 ephemeral_sk: U256,
29 compliance_pk: (U256, U256),
30) -> Result<U256, CryptoError> {
31 let pk = PublicKey::from_coordinates(u256_to_fr(compliance_pk.0), u256_to_fr(compliance_pk.1))?;
32
33 let sk_bytes = u256_to_le_bytes(ephemeral_sk);
34 let shared_point = pk.mul_scalar(&sk_bytes)?;
35
36 Ok(fr_to_u256(shared_point.x()))
37}
38
39#[must_use]
41pub fn bjj_is_on_curve(x: Fr, y: Fr) -> bool {
42 let a = *BJJ_A;
43 let d = *BJJ_D;
44
45 let x2 = x.square();
46 let y2 = y.square();
47
48 let lhs = a * x2 + y2;
49 let rhs = Fr::from(1u64) + d * x2 * y2;
50
51 lhs == rhs
52}
53
54pub fn bjj_scalar_mul(scalar: U256, point: (U256, U256)) -> Result<(U256, U256), CryptoError> {
56 let x = u256_to_fr(point.0);
57 let y = u256_to_fr(point.1);
58
59 let pk = PublicKey::from_coordinates(x, y)?;
60
61 let scalar_bytes = u256_to_le_bytes(scalar);
62 let result = pk.mul_scalar(&scalar_bytes)?;
63
64 let x_bytes = result.x().into_bigint().to_bytes_be();
65 let y_bytes = result.y().into_bigint().to_bytes_be();
66
67 Ok((
68 U256::from_big_endian(&x_bytes),
69 U256::from_big_endian(&y_bytes),
70 ))
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use crate::bjj::SecretKey;
77
78 #[test]
79 fn test_derive_shared_secret_bjj_symmetry() {
80 let alice_sk = U256::from(12345u64);
81 let bob_sk = U256::from(67890u64);
82
83 let alice_sk_bytes = u256_to_le_bytes(alice_sk);
84 let bob_sk_bytes = u256_to_le_bytes(bob_sk);
85
86 let alice_pk = BASE8.mul_scalar(&alice_sk_bytes).expect("valid test key");
87 let bob_pk = BASE8.mul_scalar(&bob_sk_bytes).expect("valid test key");
88
89 let alice_pk_tuple = (fr_to_u256(alice_pk.x()), fr_to_u256(alice_pk.y()));
90 let bob_pk_tuple = (fr_to_u256(bob_pk.x()), fr_to_u256(bob_pk.y()));
91
92 let ss_from_alice = derive_shared_secret_bjj(alice_sk, bob_pk_tuple).unwrap();
93 let ss_from_bob = derive_shared_secret_bjj(bob_sk, alice_pk_tuple).unwrap();
94
95 assert_eq!(ss_from_alice, ss_from_bob);
96 }
97
98 #[test]
99 fn test_derive_pk_from_zero_sk() {
100 let (x, y) = derive_public_key_from_sk(U256::zero()).expect("zero sk");
101 assert!(x.is_zero());
102 assert_eq!(y, U256::from(1));
103 }
104
105 #[test]
106 fn test_derive_pk_from_one() {
107 let (x, y) = derive_public_key_from_sk(U256::from(1)).expect("sk=1");
108 let base8_x = fr_to_u256(BASE8.x());
109 let base8_y = fr_to_u256(BASE8.y());
110 assert_eq!(x, base8_x);
111 assert_eq!(y, base8_y);
112 }
113
114 #[test]
115 fn test_derive_pk_matches_secret_key() {
116 let mut rng = rand::thread_rng();
117 let sk = SecretKey::generate(&mut rng);
118 let pk = sk.public_key().expect("pk");
119
120 let sk_bytes = sk.0.into_bigint().to_bytes_be();
121 let sk_u256 = U256::from_big_endian(&sk_bytes);
122
123 let (x, y) = derive_public_key_from_sk(sk_u256).expect("derive pk");
124 assert_eq!(x, fr_to_u256(pk.x()));
125 assert_eq!(y, fr_to_u256(pk.y()));
126 }
127
128 #[test]
129 fn test_bjj_is_on_curve_base8() {
130 assert!(bjj_is_on_curve(BASE8.x(), BASE8.y()));
131 }
132
133 #[test]
134 fn test_bjj_is_on_curve_identity() {
135 assert!(bjj_is_on_curve(Fr::from(0), Fr::from(1)));
136 }
137
138 #[test]
139 fn test_bjj_is_on_curve_random_point() {
140 assert!(!bjj_is_on_curve(Fr::from(42), Fr::from(99)));
141 }
142
143 #[test]
144 fn test_bjj_is_on_curve_derived() {
145 let mut rng = rand::thread_rng();
146 let sk = SecretKey::generate(&mut rng);
147 let pk = sk.public_key().expect("pk");
148 assert!(bjj_is_on_curve(pk.x(), pk.y()));
149 }
150
151 #[test]
152 fn test_bjj_scalar_mul_identity() {
153 let base = (fr_to_u256(BASE8.x()), fr_to_u256(BASE8.y()));
154 let result = bjj_scalar_mul(U256::from(1), base).expect("scalar mul");
155 assert_eq!(result, base);
156 }
157
158 #[test]
159 fn test_bjj_scalar_mul_zero() {
160 let base = (fr_to_u256(BASE8.x()), fr_to_u256(BASE8.y()));
161 let (x, y) = bjj_scalar_mul(U256::zero(), base).expect("scalar mul 0");
162 assert!(x.is_zero());
163 assert_eq!(y, U256::from(1));
164 }
165
166 #[test]
167 fn test_bjj_scalar_mul_off_curve() {
168 let bad_point = (U256::from(1), U256::from(2));
169 let result = bjj_scalar_mul(U256::from(5), bad_point);
170 assert!(result.is_err());
171 }
172
173 #[test]
174 fn test_derive_shared_secret_rejects_off_curve() {
175 let result = derive_shared_secret_bjj(U256::from(42), (U256::from(1), U256::from(2)));
176 assert!(result.is_err());
177 }
178}