ark_client/key_provider.rs
1use crate::error::Error;
2use bitcoin::bip32::DerivationPath;
3use bitcoin::bip32::Xpriv;
4use bitcoin::key::Keypair;
5use bitcoin::secp256k1::Secp256k1;
6use std::sync::Arc;
7
8pub enum KeypairIndex {
9 /// Increments the index and returns a new keypair
10 New,
11 /// Returns the last unused address
12 LastUnused,
13}
14
15/// Provides keypairs for signing operations
16///
17/// This trait allows different key management strategies:
18/// - Static keypair (single key)
19/// - BIP32 HD wallet (hierarchical deterministic)
20/// - Hardware wallets (future)
21/// - Custom key derivation schemes
22pub trait KeyProvider: Send + Sync {
23 /// Get a keypair for receiving funds
24 ///
25 /// For static key providers, this always returns the same keypair regardless of the index.
26 /// For HD wallets, behavior depends on the `keypair_index` parameter.
27 ///
28 /// # Arguments
29 ///
30 /// * `keypair_index` - Controls which keypair to return:
31 /// - `KeypairIndex::New`: Increments the internal index and returns a new keypair
32 /// - `KeypairIndex::LastUnused`: Returns the last unused keypair without incrementing
33 ///
34 /// # Returns
35 ///
36 /// A keypair to use for receiving funds
37 fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error>;
38
39 /// Get a keypair for a specific BIP32 derivation path
40 ///
41 /// # Arguments
42 ///
43 /// * `path` - BIP32 derivation path as an array of child indexes
44 ///
45 /// # Returns
46 ///
47 /// A keypair derived at the specified path, or an error if derivation is not supported
48 fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error>;
49
50 /// Get a keypair for a specific public key
51 ///
52 /// This is essential for HD wallets where you need to find the correct keypair
53 /// for signing with a previously generated public key.
54 ///
55 /// # Arguments
56 ///
57 /// * `pk` - The X-only public key to find the keypair for
58 ///
59 /// # Returns
60 ///
61 /// The keypair corresponding to the public key, or an error if not found
62 fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error>;
63
64 /// Get all public keys that this provider currently knows about
65 ///
66 /// For static key providers, this returns the single keypair's public key.
67 /// For HD wallets, this returns all public keys that have been derived and cached
68 /// (i.e., keys generated via `get_next_keypair`).
69 ///
70 /// This is useful for determining which keys are available for signing operations
71 /// without having to search or derive new keys.
72 ///
73 /// # Returns
74 ///
75 /// A vector of X-only public keys known to this provider
76 fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error>;
77
78 /// Returns true if this provider supports key discovery
79 ///
80 /// HD wallets return true since they can derive and discover previously used keys.
81 /// Static key providers return false (single key, nothing to discover).
82 fn supports_discovery(&self) -> bool {
83 false
84 }
85
86 /// Derive a keypair at a specific index without caching
87 ///
88 /// This is used during discovery to check keys without affecting the provider's state.
89 /// Returns `None` if the provider doesn't support index-based derivation.
90 ///
91 /// # Arguments
92 ///
93 /// * `index` - The derivation index (appended to base path for HD wallets)
94 fn derive_at_discovery_index(&self, _index: u32) -> Result<Option<Keypair>, Error> {
95 Ok(None)
96 }
97
98 /// Cache a discovered keypair at the given index
99 ///
100 /// This is called after discovery determines a key is "used" (has VTXOs).
101 /// Also updates next_index if index >= current next_index to avoid collisions.
102 ///
103 /// No-op for providers that don't support discovery.
104 ///
105 /// # Arguments
106 ///
107 /// * `index` - The derivation index
108 /// * `kp` - The keypair to cache
109 fn cache_discovered_keypair(&self, _index: u32, _kp: Keypair) -> Result<(), Error> {
110 Ok(())
111 }
112
113 fn mark_as_used(&self, _pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
114 Ok(())
115 }
116}
117
118/// A simple key provider that uses a static keypair
119///
120/// This is the simplest implementation and is backward compatible with
121/// the original single-keypair design.
122#[derive(Clone)]
123pub struct StaticKeyProvider {
124 kp: Keypair,
125}
126
127impl StaticKeyProvider {
128 /// Create a new static key provider
129 pub fn new(kp: Keypair) -> Self {
130 Self { kp }
131 }
132}
133
134impl KeyProvider for StaticKeyProvider {
135 fn get_next_keypair(&self, _: KeypairIndex) -> Result<Keypair, Error> {
136 // Static provider always returns the same keypair
137 Ok(self.kp)
138 }
139
140 fn get_keypair_for_path(&self, _path: &[u32]) -> Result<Keypair, Error> {
141 // Static provider always returns the same keypair
142 Ok(self.kp)
143 }
144
145 fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
146 // Verify that the requested public key matches our keypair
147 let our_pk = self.kp.x_only_public_key().0;
148 if &our_pk == pk {
149 Ok(self.kp)
150 } else {
151 Err(Error::ad_hoc(format!(
152 "Public key mismatch: requested {pk}, but only have {our_pk}"
153 )))
154 }
155 }
156
157 fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
158 Ok(vec![self.kp.public_key().into()])
159 }
160}
161
162/// A BIP32 hierarchical deterministic key provider
163///
164/// This provider derives keypairs from a master extended private key
165/// using BIP32 derivation paths. It maintains an index counter for
166/// generating new receiving addresses.
167///
168/// ## Example
169///
170/// ```rust
171/// # use std::str::FromStr;
172/// # use bitcoin::bip32::{Xpriv, DerivationPath};
173/// # use bitcoin::Network;
174/// # use crate::ark_client::KeyProvider;
175/// # use ark_client::Bip32KeyProvider;
176/// # use ark_client::key_provider::KeypairIndex;
177///
178/// fn example() -> Result<(), Box<dyn std::error::Error>> {
179/// // Create from a master key with a base path (e.g., m/84'/0'/0'/0)
180/// let master_key = Xpriv::from_str("xprv...")?;
181/// let base_path = DerivationPath::from_str("m/84'/0'/0'/0")?;
182///
183/// // This will derive keys at m/84'/0'/0'/0/0, m/84'/0'/0'/0/1, etc.
184/// let provider = Bip32KeyProvider::new(master_key, base_path);
185///
186/// // Get the next receiving keypair (increments index)
187/// let kp1 = provider.get_next_keypair(KeypairIndex::New)?; // m/84'/0'/0'/0/0
188/// let kp2 = provider.get_next_keypair(KeypairIndex::New)?; // m/84'/0'/0'/0/1
189///
190/// // Or derive a specific keypair by path
191/// let custom_path = vec![84 + 0x8000_0000, 0x8000_0000, 0x8000_0000, 0, 5];
192/// let kp = provider.get_keypair_for_path(&custom_path)?;
193/// # Ok(())
194/// # }
195/// ```
196pub struct Bip32KeyProvider {
197 master_key: Xpriv,
198 base_path: DerivationPath,
199 // Using std::sync::Mutex for interior mutability across Send + Sync
200 next_index: Arc<std::sync::Mutex<u32>>,
201 // Cache of derived keys: pk -> (path_index, keypair, used)
202 // The `used` flag indicates whether this keypair has been used (has VTXOs)
203 key_cache:
204 Arc<std::sync::RwLock<std::collections::HashMap<bitcoin::XOnlyPublicKey, KeyCacheValue>>>,
205}
206
207#[derive(Clone, Copy)]
208pub struct KeyCacheValue {
209 path_index: u32,
210 kp: Keypair,
211 /// Indicates whether this keypair has been used (has VTXOs).
212 used: bool,
213}
214
215impl Bip32KeyProvider {
216 /// Create a new BIP32 key provider
217 ///
218 /// # Arguments
219 ///
220 /// * `master_key` - The master extended private key (xpriv)
221 /// * `base_path` - The base derivation path (e.g., m/84'/0'/0'/0). The provider will append
222 /// index numbers to this path.
223 pub fn new(master_key: Xpriv, base_path: DerivationPath) -> Self {
224 Self {
225 master_key,
226 base_path,
227 next_index: Arc::new(std::sync::Mutex::new(0)),
228 key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
229 }
230 }
231
232 /// Create a new BIP32 key provider starting from a specific index
233 ///
234 /// # Arguments
235 ///
236 /// * `master_key` - The master extended private key (xpriv)
237 /// * `base_path` - The base derivation path
238 /// * `start_index` - The starting index for key derivation
239 pub fn new_with_index(master_key: Xpriv, base_path: DerivationPath, start_index: u32) -> Self {
240 Self {
241 master_key,
242 base_path,
243 next_index: Arc::new(std::sync::Mutex::new(start_index)),
244 key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
245 }
246 }
247
248 /// Derive a keypair at the specified path
249 fn derive_keypair(&self, path: &DerivationPath) -> Result<Keypair, Error> {
250 let secp = Secp256k1::new();
251 let derived_key = self
252 .master_key
253 .derive_priv(&secp, path)
254 .map_err(|e| Error::ad_hoc(format!("BIP32 derivation failed: {e}")))?;
255
256 Ok(derived_key.to_keypair(&secp))
257 }
258
259 /// Derive a keypair at base_path/index
260 fn derive_at_index(&self, index: u32) -> Result<Keypair, Error> {
261 use bitcoin::bip32::ChildNumber;
262
263 let path = self.base_path.clone();
264 let path = path.extend([ChildNumber::Normal { index }]);
265
266 self.derive_keypair(&path)
267 }
268}
269
270impl KeyProvider for Bip32KeyProvider {
271 fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
272 match keypair_index {
273 KeypairIndex::New => {
274 // Get and increment the next index
275 let index = {
276 let mut next_index = self
277 .next_index
278 .lock()
279 .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
280 let current = *next_index;
281 *next_index = next_index
282 .checked_add(1)
283 .ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
284 current
285 };
286
287 // Derive the keypair at this index
288 let kp = self.derive_at_index(index)?;
289
290 // Cache it for later lookup (marked as unused)
291 let pk = kp.x_only_public_key().0;
292 {
293 let mut cache = self
294 .key_cache
295 .write()
296 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
297 cache.insert(
298 pk,
299 KeyCacheValue {
300 path_index: index,
301 kp,
302 used: false,
303 },
304 );
305 }
306
307 Ok(kp)
308 }
309 KeypairIndex::LastUnused => {
310 // First, try to find an unused keypair in the cache
311 {
312 let cache = self
313 .key_cache
314 .read()
315 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
316
317 // Find the unused keypair with the lowest index
318 let unused = cache
319 .values()
320 .filter(|KeyCacheValue { used, .. }| !used)
321 .min_by_key(|KeyCacheValue { path_index, .. }| *path_index);
322
323 if let Some(KeyCacheValue { kp, .. }) = unused {
324 return Ok(*kp);
325 }
326 }
327
328 // No unused keypair found, derive a new one
329 self.get_next_keypair(KeypairIndex::New)
330 }
331 }
332 }
333
334 fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
335 use bitcoin::bip32::ChildNumber;
336 let child_numbers: Vec<ChildNumber> = path
337 .iter()
338 .map(|&n| {
339 if n & 0x8000_0000 != 0 {
340 ChildNumber::Hardened {
341 index: n & 0x7FFF_FFFF,
342 }
343 } else {
344 ChildNumber::Normal { index: n }
345 }
346 })
347 .collect();
348 let derivation_path = DerivationPath::from(child_numbers);
349 self.derive_keypair(&derivation_path)
350 }
351
352 fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
353 // First check the cache
354 {
355 let cache = self
356 .key_cache
357 .read()
358 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
359 if let Some(KeyCacheValue { kp, .. }) = cache.get(pk) {
360 return Ok(*kp);
361 }
362 }
363
364 // If not in cache, we need to search. For now, we'll search up to the current index
365 let current_index = {
366 let next_index = self
367 .next_index
368 .lock()
369 .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
370 *next_index
371 };
372
373 // Search through derived keys up to current index
374 for i in 0..current_index {
375 let kp = self.derive_at_index(i)?;
376 let derived_pk = kp.x_only_public_key().0;
377
378 if &derived_pk == pk {
379 // Cache it for next time (assume used since we're looking it up for signing)
380 let mut cache = self
381 .key_cache
382 .write()
383 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
384 cache.insert(
385 derived_pk,
386 KeyCacheValue {
387 path_index: i,
388 kp,
389 used: true,
390 },
391 );
392 return Ok(kp);
393 }
394 }
395
396 Err(Error::ad_hoc(format!(
397 "Public key {pk} not found in HD wallet. \
398 Searched indices 0..{current_index}. \
399 The key may have been generated outside this provider."
400 )))
401 }
402
403 fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
404 let cache = self
405 .key_cache
406 .read()
407 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
408
409 Ok(cache.keys().copied().collect())
410 }
411
412 fn supports_discovery(&self) -> bool {
413 true
414 }
415
416 fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
417 self.derive_at_index(index).map(Some)
418 }
419
420 fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
421 let pk = kp.x_only_public_key().0;
422
423 // Add to cache (marked as used since it was discovered with VTXOs)
424 {
425 let mut cache = self
426 .key_cache
427 .write()
428 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
429 cache.insert(
430 pk,
431 KeyCacheValue {
432 path_index: index,
433 kp,
434 used: true,
435 },
436 );
437 }
438
439 // Update next_index if needed (set to index + 1 if >= current)
440 {
441 let mut next = self
442 .next_index
443 .lock()
444 .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
445 if index >= *next {
446 *next = index
447 .checked_add(1)
448 .ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
449 }
450 }
451
452 Ok(())
453 }
454
455 fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
456 // First check the cache
457 {
458 let maybe_kp = {
459 let cache = self
460 .key_cache
461 .read()
462 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
463 cache.get(pk).copied()
464 };
465
466 match maybe_kp {
467 Some(KeyCacheValue {
468 path_index,
469 kp,
470 used: false,
471 }) => {
472 let mut cache = self
473 .key_cache
474 .write()
475 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
476 cache.insert(
477 *pk,
478 KeyCacheValue {
479 path_index,
480 kp,
481 used: true,
482 },
483 );
484 return Ok(());
485 }
486 Some(KeyCacheValue { used: true, .. }) => {
487 // already marked as used
488 return Ok(());
489 }
490 _ => {
491 // no found
492 }
493 }
494 }
495
496 // If not in cache, we need to search. For now, we'll search up to the current index
497 let current_index = {
498 let next_index = self
499 .next_index
500 .lock()
501 .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
502 *next_index
503 };
504
505 // Search through derived keys up to current index
506 for i in 0..current_index {
507 let kp = self.derive_at_index(i)?;
508 let derived_pk = kp.x_only_public_key().0;
509
510 if &derived_pk == pk {
511 // Cache it for next time (assume used since we're looking it up for signing)
512 let mut cache = self
513 .key_cache
514 .write()
515 .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
516 cache.insert(
517 derived_pk,
518 KeyCacheValue {
519 path_index: i,
520 kp,
521 used: true,
522 },
523 );
524 return Ok(());
525 }
526 }
527
528 Err(Error::ad_hoc(format!(
529 "Public key {pk} not found in HD wallet. \
530 Searched indices 0..{current_index}. \
531 The key may have been generated outside this provider."
532 )))
533 }
534}
535
536// Implement KeyProvider for Arc<T> where T: KeyProvider
537impl<T: KeyProvider> KeyProvider for Arc<T> {
538 fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
539 (**self).get_next_keypair(keypair_index)
540 }
541
542 fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
543 (**self).get_keypair_for_path(path)
544 }
545
546 fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
547 (**self).get_keypair_for_pk(pk)
548 }
549
550 fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
551 (**self).get_cached_pks()
552 }
553
554 fn supports_discovery(&self) -> bool {
555 (**self).supports_discovery()
556 }
557
558 fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
559 (**self).derive_at_discovery_index(index)
560 }
561
562 fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
563 (**self).cache_discovered_keypair(index, kp)
564 }
565
566 fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
567 (**self).mark_as_used(pk)
568 }
569}