ark_client/key_provider.rs
1use crate::error::Error;
2use bitcoin::bip32::DerivationPath;
3use bitcoin::bip32::Xpriv;
4use bitcoin::key::Keypair;
5use bitcoin::secp256k1::Secp256k1;
6use std::sync::Arc;
7
8pub enum KeypairIndex {
9 /// Increments the index and returns a new keypair
10 New,
11 /// Returns the last unused address
12 LastUnused,
13}
14
15/// Provides keypairs for signing operations
16///
17/// This trait allows different key management strategies:
18/// - Static keypair (single key)
19/// - BIP32 HD wallet (hierarchical deterministic)
20/// - Hardware wallets (future)
21/// - Custom key derivation schemes
22pub trait KeyProvider: Send + Sync {
23 /// Get a keypair for receiving funds
24 ///
25 /// For static key providers, this always returns the same keypair regardless of the index.
26 /// For HD wallets, behavior depends on the `keypair_index` parameter.
27 ///
28 /// # Arguments
29 ///
30 /// * `keypair_index` - Controls which keypair to return:
31 /// - `KeypairIndex::New`: Increments the internal index and returns a new keypair
32 /// - `KeypairIndex::LastUnused`: Returns the last unused keypair without incrementing
33 ///
34 /// # Returns
35 ///
36 /// A keypair to use for receiving funds
37 fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error>;
38
39 /// Get a keypair for a specific BIP32 derivation path
40 ///
41 /// # Arguments
42 ///
43 /// * `path` - BIP32 derivation path as an array of child indexes
44 ///
45 /// # Returns
46 ///
47 /// A keypair derived at the specified path, or an error if derivation is not supported
48 fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error>;
49
50 /// Get a keypair for a specific public key
51 ///
52 /// This is essential for HD wallets where you need to find the correct keypair
53 /// for signing with a previously generated public key.
54 ///
55 /// # Arguments
56 ///
57 /// * `pk` - The X-only public key to find the keypair for
58 ///
59 /// # Returns
60 ///
61 /// The keypair corresponding to the public key, or an error if not found
62 fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error>;
63
64 /// Get all public keys that this provider currently knows about
65 ///
66 /// For static key providers, this returns the single keypair's public key.
67 /// For HD wallets, this returns all public keys that have been derived and cached
68 /// (i.e., keys generated via `get_next_keypair`).
69 ///
70 /// This is useful for determining which keys are available for signing operations
71 /// without having to search or derive new keys.
72 ///
73 /// # Returns
74 ///
75 /// A vector of X-only public keys known to this provider
76 fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error>;
77
78 /// Returns true if this provider supports key discovery
79 ///
80 /// HD wallets return true since they can derive and discover previously used keys.
81 /// Static key providers return false (single key, nothing to discover).
82 fn supports_discovery(&self) -> bool {
83 false
84 }
85
86 /// Get the derivation index for a cached public key.
87 ///
88 /// Returns `None` if the provider doesn't support index-based derivation
89 /// or if the key is not in the cache.
90 fn get_derivation_index_for_pk(&self, _pk: &bitcoin::XOnlyPublicKey) -> Option<u32> {
91 None
92 }
93
94 /// Derive a keypair at a specific index without caching
95 ///
96 /// This is used during discovery to check keys without affecting the provider's state.
97 /// Returns `None` if the provider doesn't support index-based derivation.
98 ///
99 /// # Arguments
100 ///
101 /// * `index` - The derivation index (appended to base path for HD wallets)
102 fn derive_at_discovery_index(&self, _index: u32) -> Result<Option<Keypair>, Error> {
103 Ok(None)
104 }
105
106 /// Cache a discovered keypair at the given index
107 ///
108 /// This is called after discovery determines a key is "used" (has VTXOs).
109 /// Also updates next_index if index >= current next_index to avoid collisions.
110 ///
111 /// No-op for providers that don't support discovery.
112 ///
113 /// # Arguments
114 ///
115 /// * `index` - The derivation index
116 /// * `kp` - The keypair to cache
117 fn cache_discovered_keypair(&self, _index: u32, _kp: Keypair) -> Result<(), Error> {
118 Ok(())
119 }
120
121 fn mark_as_used(&self, _pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
122 Ok(())
123 }
124}
125
126/// A simple key provider that uses a static keypair
127///
128/// This is the simplest implementation and is backward compatible with
129/// the original single-keypair design.
130#[derive(Clone)]
131pub struct StaticKeyProvider {
132 kp: Keypair,
133}
134
135impl StaticKeyProvider {
136 /// Create a new static key provider
137 pub fn new(kp: Keypair) -> Self {
138 Self { kp }
139 }
140}
141
142impl KeyProvider for StaticKeyProvider {
143 fn get_next_keypair(&self, _: KeypairIndex) -> Result<Keypair, Error> {
144 // Static provider always returns the same keypair
145 Ok(self.kp)
146 }
147
148 fn get_keypair_for_path(&self, _path: &[u32]) -> Result<Keypair, Error> {
149 // Static provider always returns the same keypair
150 Ok(self.kp)
151 }
152
153 fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
154 // Verify that the requested public key matches our keypair
155 let our_pk = self.kp.x_only_public_key().0;
156 if &our_pk == pk {
157 Ok(self.kp)
158 } else {
159 Err(Error::ad_hoc(format!(
160 "Public key mismatch: requested {pk}, but only have {our_pk}"
161 )))
162 }
163 }
164
165 fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
166 Ok(vec![self.kp.public_key().into()])
167 }
168}
169
170/// A BIP32 hierarchical deterministic key provider
171///
172/// This provider derives keypairs from a master extended private key
173/// using BIP32 derivation paths. It maintains an index counter for
174/// generating new receiving addresses.
175///
176/// ## Example
177///
178/// ```rust
179/// # use std::str::FromStr;
180/// # use bitcoin::bip32::{Xpriv, DerivationPath};
181/// # use bitcoin::Network;
182/// # use crate::ark_client::KeyProvider;
183/// # use ark_client::Bip32KeyProvider;
184/// # use ark_client::key_provider::KeypairIndex;
185///
186/// fn example() -> Result<(), Box<dyn std::error::Error>> {
187/// // Create from a master key with a base path (e.g., m/84'/0'/0'/0)
188/// let master_key = Xpriv::from_str("xprv...")?;
189/// let base_path = DerivationPath::from_str("m/84'/0'/0'/0")?;
190///
191/// // This will derive keys at m/84'/0'/0'/0/0, m/84'/0'/0'/0/1, etc.
192/// let provider = Bip32KeyProvider::new(master_key, base_path);
193///
194/// // Get the next receiving keypair (increments index)
195/// let kp1 = provider.get_next_keypair(KeypairIndex::New)?; // m/84'/0'/0'/0/0
196/// let kp2 = provider.get_next_keypair(KeypairIndex::New)?; // m/84'/0'/0'/0/1
197///
198/// // Or derive a specific keypair by path
199/// let custom_path = vec![84 + 0x8000_0000, 0x8000_0000, 0x8000_0000, 0, 5];
200/// let kp = provider.get_keypair_for_path(&custom_path)?;
201/// # Ok(())
202/// # }
203/// ```
204pub struct Bip32KeyProvider {
205 master_key: Xpriv,
206 base_path: DerivationPath,
207 // Using std::sync::Mutex for interior mutability across Send + Sync
208 next_index: Arc<std::sync::Mutex<u32>>,
209 // Cache of derived keys: pk -> (path_index, keypair, used)
210 // The `used` flag indicates whether this keypair has been used (has VTXOs)
211 key_cache:
212 Arc<std::sync::RwLock<std::collections::HashMap<bitcoin::XOnlyPublicKey, KeyCacheValue>>>,
213}
214
215#[derive(Clone, Copy)]
216pub struct KeyCacheValue {
217 path_index: u32,
218 kp: Keypair,
219 /// Indicates whether this keypair has been used (has VTXOs).
220 used: bool,
221}
222
223impl Bip32KeyProvider {
224 /// Create a new BIP32 key provider
225 ///
226 /// # Arguments
227 ///
228 /// * `master_key` - The master extended private key (xpriv)
229 /// * `base_path` - The base derivation path (e.g., m/84'/0'/0'/0). The provider will append
230 /// index numbers to this path.
231 pub fn new(master_key: Xpriv, base_path: DerivationPath) -> Self {
232 Self {
233 master_key,
234 base_path,
235 next_index: Arc::new(std::sync::Mutex::new(0)),
236 key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
237 }
238 }
239
240 /// Create a new BIP32 key provider starting from a specific index
241 ///
242 /// # Arguments
243 ///
244 /// * `master_key` - The master extended private key (xpriv)
245 /// * `base_path` - The base derivation path
246 /// * `start_index` - The starting index for key derivation
247 pub fn new_with_index(master_key: Xpriv, base_path: DerivationPath, start_index: u32) -> Self {
248 Self {
249 master_key,
250 base_path,
251 next_index: Arc::new(std::sync::Mutex::new(start_index)),
252 key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
253 }
254 }
255
256 /// Derive a keypair at the specified path
257 fn derive_keypair(&self, path: &DerivationPath) -> Result<Keypair, Error> {
258 let secp = Secp256k1::new();
259 let derived_key = self
260 .master_key
261 .derive_priv(&secp, path)
262 .map_err(|e| Error::ad_hoc(format!("BIP32 derivation failed: {e}")))?;
263
264 Ok(derived_key.to_keypair(&secp))
265 }
266
267 /// Derive a keypair at base_path/index
268 fn derive_at_index(&self, index: u32) -> Result<Keypair, Error> {
269 use bitcoin::bip32::ChildNumber;
270
271 let path = self.base_path.clone();
272 let path = path.extend([ChildNumber::Normal { index }]);
273
274 self.derive_keypair(&path)
275 }
276}
277
278impl KeyProvider for Bip32KeyProvider {
279 fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
280 match keypair_index {
281 KeypairIndex::New => {
282 // Get and increment the next index
283 let index = {
284 let mut next_index = self
285 .next_index
286 .lock()
287 .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
288 let current = *next_index;
289 *next_index = next_index
290 .checked_add(1)
291 .ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
292 current
293 };
294
295 // Derive the keypair at this index
296 let kp = self.derive_at_index(index)?;
297
298 // Cache it for later lookup (marked as unused)
299 let pk = kp.x_only_public_key().0;
300 {
301 let mut cache = self
302 .key_cache
303 .write()
304 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
305 cache.insert(
306 pk,
307 KeyCacheValue {
308 path_index: index,
309 kp,
310 used: false,
311 },
312 );
313 }
314
315 Ok(kp)
316 }
317 KeypairIndex::LastUnused => {
318 // First, try to find an unused keypair in the cache
319 {
320 let cache = self
321 .key_cache
322 .read()
323 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
324
325 // Find the unused keypair with the lowest index
326 let unused = cache
327 .values()
328 .filter(|KeyCacheValue { used, .. }| !used)
329 .min_by_key(|KeyCacheValue { path_index, .. }| *path_index);
330
331 if let Some(KeyCacheValue { kp, .. }) = unused {
332 return Ok(*kp);
333 }
334 }
335
336 // No unused keypair found, derive a new one
337 self.get_next_keypair(KeypairIndex::New)
338 }
339 }
340 }
341
342 fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
343 use bitcoin::bip32::ChildNumber;
344 let child_numbers: Vec<ChildNumber> = path
345 .iter()
346 .map(|&n| {
347 if n & 0x8000_0000 != 0 {
348 ChildNumber::Hardened {
349 index: n & 0x7FFF_FFFF,
350 }
351 } else {
352 ChildNumber::Normal { index: n }
353 }
354 })
355 .collect();
356 let derivation_path = DerivationPath::from(child_numbers);
357 self.derive_keypair(&derivation_path)
358 }
359
360 fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
361 // First check the cache
362 {
363 let cache = self
364 .key_cache
365 .read()
366 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
367 if let Some(KeyCacheValue { kp, .. }) = cache.get(pk) {
368 return Ok(*kp);
369 }
370 }
371
372 // If not in cache, we need to search. For now, we'll search up to the current index
373 let current_index = {
374 let next_index = self
375 .next_index
376 .lock()
377 .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
378 *next_index
379 };
380
381 // Search through derived keys up to current index
382 for i in 0..current_index {
383 let kp = self.derive_at_index(i)?;
384 let derived_pk = kp.x_only_public_key().0;
385
386 if &derived_pk == pk {
387 // Cache it for next time (assume used since we're looking it up for signing)
388 let mut cache = self
389 .key_cache
390 .write()
391 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
392 cache.insert(
393 derived_pk,
394 KeyCacheValue {
395 path_index: i,
396 kp,
397 used: true,
398 },
399 );
400 return Ok(kp);
401 }
402 }
403
404 Err(Error::ad_hoc(format!(
405 "Public key {pk} not found in HD wallet. \
406 Searched indices 0..{current_index}. \
407 The key may have been generated outside this provider."
408 )))
409 }
410
411 fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
412 let cache = self
413 .key_cache
414 .read()
415 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
416
417 Ok(cache.keys().copied().collect())
418 }
419
420 fn supports_discovery(&self) -> bool {
421 true
422 }
423
424 fn get_derivation_index_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Option<u32> {
425 let cache = self.key_cache.read().ok()?;
426 cache.get(pk).map(|v| v.path_index)
427 }
428
429 fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
430 self.derive_at_index(index).map(Some)
431 }
432
433 fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
434 let pk = kp.x_only_public_key().0;
435
436 // Add to cache (marked as used since it was discovered with VTXOs)
437 {
438 let mut cache = self
439 .key_cache
440 .write()
441 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
442 cache.insert(
443 pk,
444 KeyCacheValue {
445 path_index: index,
446 kp,
447 used: true,
448 },
449 );
450 }
451
452 // Update next_index if needed (set to index + 1 if >= current)
453 {
454 let mut next = self
455 .next_index
456 .lock()
457 .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
458 if index >= *next {
459 *next = index
460 .checked_add(1)
461 .ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
462 }
463 }
464
465 Ok(())
466 }
467
468 fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
469 // First check the cache
470 {
471 let maybe_kp = {
472 let cache = self
473 .key_cache
474 .read()
475 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
476 cache.get(pk).copied()
477 };
478
479 match maybe_kp {
480 Some(KeyCacheValue {
481 path_index,
482 kp,
483 used: false,
484 }) => {
485 let mut cache = self
486 .key_cache
487 .write()
488 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
489 cache.insert(
490 *pk,
491 KeyCacheValue {
492 path_index,
493 kp,
494 used: true,
495 },
496 );
497 return Ok(());
498 }
499 Some(KeyCacheValue { used: true, .. }) => {
500 // already marked as used
501 return Ok(());
502 }
503 _ => {
504 // no found
505 }
506 }
507 }
508
509 // If not in cache, we need to search. For now, we'll search up to the current index
510 let current_index = {
511 let next_index = self
512 .next_index
513 .lock()
514 .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
515 *next_index
516 };
517
518 // Search through derived keys up to current index
519 for i in 0..current_index {
520 let kp = self.derive_at_index(i)?;
521 let derived_pk = kp.x_only_public_key().0;
522
523 if &derived_pk == pk {
524 // Cache it for next time (assume used since we're looking it up for signing)
525 let mut cache = self
526 .key_cache
527 .write()
528 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
529 cache.insert(
530 derived_pk,
531 KeyCacheValue {
532 path_index: i,
533 kp,
534 used: true,
535 },
536 );
537 return Ok(());
538 }
539 }
540
541 Err(Error::ad_hoc(format!(
542 "Public key {pk} not found in HD wallet. \
543 Searched indices 0..{current_index}. \
544 The key may have been generated outside this provider."
545 )))
546 }
547}
548
549// Implement KeyProvider for Arc<T> where T: KeyProvider
550impl<T: KeyProvider> KeyProvider for Arc<T> {
551 fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
552 (**self).get_next_keypair(keypair_index)
553 }
554
555 fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
556 (**self).get_keypair_for_path(path)
557 }
558
559 fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
560 (**self).get_keypair_for_pk(pk)
561 }
562
563 fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
564 (**self).get_cached_pks()
565 }
566
567 fn supports_discovery(&self) -> bool {
568 (**self).supports_discovery()
569 }
570
571 fn get_derivation_index_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Option<u32> {
572 (**self).get_derivation_index_for_pk(pk)
573 }
574
575 fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
576 (**self).derive_at_discovery_index(index)
577 }
578
579 fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
580 (**self).cache_discovered_keypair(index, kp)
581 }
582
583 fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
584 (**self).mark_as_used(pk)
585 }
586}