1use std::collections::HashMap;
6
7use ethers::types::U256;
8use num_bigint::BigUint;
9use tracing::warn;
10
11use darkpool_crypto::SUBGROUP_ORDER;
12
13use crate::identity::DarkAccount;
14
15#[allow(clippy::expect_used)]
16static SUBGROUP_ORDER_BIGINT: std::sync::LazyLock<BigUint> = std::sync::LazyLock::new(|| {
17 BigUint::parse_bytes(SUBGROUP_ORDER.as_bytes(), 10)
18 .expect("SUBGROUP_ORDER is a compile-time constant")
19});
20
21pub const DEFAULT_LOOKAHEAD: u64 = 20;
22
23#[derive(Debug, Clone)]
24pub struct KeyRepository {
25 account: DarkAccount,
26 compliance_pk: (U256, U256),
27 ephemeral_index: u64,
28 incoming_index: u64,
29 next_ephemeral_nonce: u64,
30 ephemeral_key_map: HashMap<String, (U256, u64)>,
32 recipient_key_map: HashMap<String, (U256, u64)>,
34}
35
36impl KeyRepository {
37 #[must_use]
38 pub fn new(account: DarkAccount, compliance_pk: (U256, U256)) -> Self {
39 Self {
40 account,
41 compliance_pk,
42 ephemeral_index: 0,
43 incoming_index: 0,
44 next_ephemeral_nonce: 0,
45 ephemeral_key_map: HashMap::new(),
46 recipient_key_map: HashMap::new(),
47 }
48 }
49
50 #[must_use]
51 pub fn ephemeral_index(&self) -> u64 {
52 self.ephemeral_index
53 }
54
55 #[must_use]
56 pub fn incoming_index(&self) -> u64 {
57 self.incoming_index
58 }
59
60 pub fn get_public_incoming_key(
61 &mut self,
62 ) -> Result<(U256, U256), darkpool_crypto::CryptoError> {
63 self.account.get_public_incoming_key(self.incoming_index)
64 }
65
66 pub fn next_ephemeral_params(&mut self) -> (U256, U256) {
68 let idx = self.next_ephemeral_nonce;
69 self.next_ephemeral_nonce += 1;
70
71 let nonce = U256::from(idx);
72 let sk = self.account.get_ephemeral_outgoing_key(idx);
73 self.register_ephemeral_key(idx);
74 (sk, nonce)
75 }
76
77 pub fn advance_ephemeral_keys(&mut self, count: u64) {
78 for _ in 0..count {
79 self.register_ephemeral_key(self.ephemeral_index);
80 self.ephemeral_index += 1;
81 }
82 }
83
84 pub fn advance_incoming_keys(&mut self, count: u64) {
85 for _ in 0..count {
86 self.register_incoming_key(self.incoming_index);
87 self.incoming_index += 1;
88 }
89 }
90
91 #[must_use]
92 pub fn try_match_deposit(&self, epk_x: U256, epk_y: U256) -> Option<(U256, u64)> {
93 let key = Self::format_point_key(epk_x, epk_y);
94 self.ephemeral_key_map.get(&key).copied()
95 }
96
97 #[must_use]
98 pub fn try_match_transfer(&self, tag_px: U256) -> Option<(U256, u64)> {
99 let key = tag_px.to_string();
100 self.recipient_key_map.get(&key).copied()
101 }
102
103 #[must_use]
104 pub fn get_all_tags(&self) -> Vec<String> {
105 self.recipient_key_map.keys().cloned().collect()
106 }
107
108 fn register_ephemeral_key(&mut self, index: u64) -> bool {
109 let (epk_x, epk_y) = match self.account.get_public_ephemeral_key(index) {
110 Ok(pk) => pk,
111 Err(e) => {
112 warn!(index, "Failed to derive ephemeral public key: {}", e);
113 return false;
114 }
115 };
116 let lookup_key = Self::format_point_key(epk_x, epk_y);
117
118 if !self.ephemeral_key_map.contains_key(&lookup_key) {
119 let eph_sk = self.account.get_ephemeral_outgoing_key(index);
120 self.ephemeral_key_map.insert(lookup_key, (eph_sk, index));
121 }
122 true
123 }
124
125 fn register_incoming_key(&mut self, index: u64) -> bool {
126 let recipient_sk = self.account.get_incoming_viewing_key(index);
127 let recipient_sk_mod = Self::reduce_mod_subgroup(recipient_sk);
128 let p = match Self::scalar_mul_point(recipient_sk_mod, self.compliance_pk) {
129 Ok(point) => point,
130 Err(e) => {
131 warn!(
132 index,
133 "Failed to compute transfer tag for incoming key: {}", e
134 );
135 return false;
136 }
137 };
138
139 let tag_key = p.0.to_string();
140
141 self.recipient_key_map
142 .entry(tag_key)
143 .or_insert((recipient_sk_mod, index));
144 true
145 }
146
147 fn format_point_key(x: U256, y: U256) -> String {
148 format!("{x}_{y}")
149 }
150
151 fn reduce_mod_subgroup(value: U256) -> U256 {
152 let mut bytes = [0u8; 32];
153 value.to_big_endian(&mut bytes);
154 let bigint = BigUint::from_bytes_be(&bytes);
155 let reduced = bigint % &*SUBGROUP_ORDER_BIGINT;
156 let mut result_bytes = reduced.to_bytes_be();
157 while result_bytes.len() < 32 {
158 result_bytes.insert(0, 0);
159 }
160
161 U256::from_big_endian(&result_bytes)
162 }
163
164 fn scalar_mul_point(
165 scalar: U256,
166 point: (U256, U256),
167 ) -> Result<(U256, U256), darkpool_crypto::CryptoError> {
168 use ark_ff::{BigInteger, PrimeField};
169 use darkpool_crypto::PublicKey;
170
171 use crate::crypto_helpers::u256_to_fr;
172
173 let pk = PublicKey::from_coordinates(u256_to_fr(point.0), u256_to_fr(point.1))?;
174
175 let mut scalar_bytes = [0u8; 32];
176 scalar.to_big_endian(&mut scalar_bytes);
177 scalar_bytes.reverse(); let result = pk.mul_scalar(&scalar_bytes)?;
180 let x_bytes = result.x().into_bigint().to_bytes_be();
181 let y_bytes = result.y().into_bigint().to_bytes_be();
182
183 Ok((
184 U256::from_big_endian(&x_bytes),
185 U256::from_big_endian(&y_bytes),
186 ))
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 fn create_test_repo() -> KeyRepository {
195 use crate::crypto_helpers::fr_to_u256;
196 use darkpool_crypto::BASE8;
197
198 let account = DarkAccount::from_seed(b"test_key_repository_seed");
199 let compliance_sk_bytes = [0x42u8; 32];
200 let compliance_pk_point = BASE8
201 .mul_scalar(&compliance_sk_bytes)
202 .expect("valid test key");
203 let compliance_pk = (
204 fr_to_u256(compliance_pk_point.x()),
205 fr_to_u256(compliance_pk_point.y()),
206 );
207 KeyRepository::new(account, compliance_pk)
208 }
209
210 #[test]
211 fn test_key_repository_creation() {
212 let repo = create_test_repo();
213 assert_eq!(repo.ephemeral_index(), 0);
214 assert_eq!(repo.incoming_index(), 0);
215 }
216
217 #[test]
218 fn test_next_ephemeral_params() {
219 let mut repo = create_test_repo();
220
221 let (sk1, nonce1) = repo.next_ephemeral_params();
222 let (sk2, nonce2) = repo.next_ephemeral_params();
223 let (sk3, nonce3) = repo.next_ephemeral_params();
224
225 assert_eq!(nonce1, U256::from(0));
226 assert_eq!(nonce2, U256::from(1));
227 assert_eq!(nonce3, U256::from(2));
228 assert_ne!(sk1, sk2);
229 assert_ne!(sk2, sk3);
230 }
231
232 #[test]
233 fn test_advance_ephemeral_keys() {
234 let mut repo = create_test_repo();
235
236 repo.advance_ephemeral_keys(5);
237 assert_eq!(repo.ephemeral_index(), 5);
238
239 repo.advance_ephemeral_keys(3);
240 assert_eq!(repo.ephemeral_index(), 8);
241 }
242
243 #[test]
244 fn test_advance_incoming_keys() {
245 let mut repo = create_test_repo();
246
247 repo.advance_incoming_keys(5);
248 assert_eq!(repo.incoming_index(), 5);
249
250 repo.advance_incoming_keys(3);
251 assert_eq!(repo.incoming_index(), 8);
252 }
253
254 #[test]
255 fn test_try_match_deposit() {
256 let mut repo = create_test_repo();
257
258 repo.advance_ephemeral_keys(3);
259 let (epk_x, epk_y) = {
260 let mut account = DarkAccount::from_seed(b"test_key_repository_seed");
261 account.get_public_ephemeral_key(0).unwrap()
262 };
263
264 let (sk, idx) = repo.try_match_deposit(epk_x, epk_y).unwrap();
265 assert_eq!(idx, 0);
266 assert!(!sk.is_zero());
267 assert!(repo
268 .try_match_deposit(U256::from(999), U256::from(888))
269 .is_none());
270 }
271
272 #[test]
273 fn test_try_match_transfer() {
274 let mut repo = create_test_repo();
275
276 repo.advance_incoming_keys(3);
277 let tags = repo.get_all_tags();
278 assert_eq!(tags.len(), 3);
279
280 for tag in &tags {
281 let tag_u256 = U256::from_dec_str(tag).unwrap();
282 let result = repo.try_match_transfer(tag_u256);
283 assert!(result.is_some(), "Tag {} should be matchable", tag);
284 }
285
286 assert!(repo.try_match_transfer(U256::from(12345)).is_none());
287 }
288
289 #[test]
290 fn test_deterministic_key_registration() {
291 let mut repo1 = create_test_repo();
292 let mut repo2 = create_test_repo();
293 repo1.advance_ephemeral_keys(5);
294 repo2.advance_ephemeral_keys(5);
295
296 let (epk_x, epk_y) = {
297 let mut account = DarkAccount::from_seed(b"test_key_repository_seed");
298 account.get_public_ephemeral_key(2).unwrap()
299 };
300
301 assert_eq!(
302 repo1.try_match_deposit(epk_x, epk_y),
303 repo2.try_match_deposit(epk_x, epk_y)
304 );
305 }
306}