Skip to main content

ark_client/
key_provider.rs

1use crate::error::Error;
2use bitcoin::bip32::DerivationPath;
3use bitcoin::bip32::Xpriv;
4use bitcoin::key::Keypair;
5use bitcoin::secp256k1::Secp256k1;
6use std::sync::Arc;
7
8pub enum KeypairIndex {
9    /// Increments the index and returns a new keypair
10    New,
11    /// Returns the last unused address
12    LastUnused,
13}
14
15/// Provides keypairs for signing operations
16///
17/// This trait allows different key management strategies:
18/// - Static keypair (single key)
19/// - BIP32 HD wallet (hierarchical deterministic)
20/// - Hardware wallets (future)
21/// - Custom key derivation schemes
22pub trait KeyProvider: Send + Sync {
23    /// Get a keypair for receiving funds
24    ///
25    /// For static key providers, this always returns the same keypair regardless of the index.
26    /// For HD wallets, behavior depends on the `keypair_index` parameter.
27    ///
28    /// # Arguments
29    ///
30    /// * `keypair_index` - Controls which keypair to return:
31    ///   - `KeypairIndex::New`: Increments the internal index and returns a new keypair
32    ///   - `KeypairIndex::LastUnused`: Returns the last unused keypair without incrementing
33    ///
34    /// # Returns
35    ///
36    /// A keypair to use for receiving funds
37    fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error>;
38
39    /// Get a keypair for a specific BIP32 derivation path
40    ///
41    /// # Arguments
42    ///
43    /// * `path` - BIP32 derivation path as an array of child indexes
44    ///
45    /// # Returns
46    ///
47    /// A keypair derived at the specified path, or an error if derivation is not supported
48    fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error>;
49
50    /// Get a keypair for a specific public key
51    ///
52    /// This is essential for HD wallets where you need to find the correct keypair
53    /// for signing with a previously generated public key.
54    ///
55    /// # Arguments
56    ///
57    /// * `pk` - The X-only public key to find the keypair for
58    ///
59    /// # Returns
60    ///
61    /// The keypair corresponding to the public key, or an error if not found
62    fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error>;
63
64    /// Get all public keys that this provider currently knows about
65    ///
66    /// For static key providers, this returns the single keypair's public key.
67    /// For HD wallets, this returns all public keys that have been derived and cached
68    /// (i.e., keys generated via `get_next_keypair`).
69    ///
70    /// This is useful for determining which keys are available for signing operations
71    /// without having to search or derive new keys.
72    ///
73    /// # Returns
74    ///
75    /// A vector of X-only public keys known to this provider
76    fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error>;
77
78    /// Returns true if this provider supports key discovery
79    ///
80    /// HD wallets return true since they can derive and discover previously used keys.
81    /// Static key providers return false (single key, nothing to discover).
82    fn supports_discovery(&self) -> bool {
83        false
84    }
85
86    /// Get the derivation index for a cached public key.
87    ///
88    /// Returns `None` if the provider doesn't support index-based derivation
89    /// or if the key is not in the cache.
90    fn get_derivation_index_for_pk(&self, _pk: &bitcoin::XOnlyPublicKey) -> Option<u32> {
91        None
92    }
93
94    /// Derive a keypair at a specific index without caching
95    ///
96    /// This is used during discovery to check keys without affecting the provider's state.
97    /// Returns `None` if the provider doesn't support index-based derivation.
98    ///
99    /// # Arguments
100    ///
101    /// * `index` - The derivation index (appended to base path for HD wallets)
102    fn derive_at_discovery_index(&self, _index: u32) -> Result<Option<Keypair>, Error> {
103        Ok(None)
104    }
105
106    /// Cache a discovered keypair at the given index
107    ///
108    /// This is called after discovery determines a key is "used" (has VTXOs).
109    /// Also updates next_index if index >= current next_index to avoid collisions.
110    ///
111    /// No-op for providers that don't support discovery.
112    ///
113    /// # Arguments
114    ///
115    /// * `index` - The derivation index
116    /// * `kp` - The keypair to cache
117    fn cache_discovered_keypair(&self, _index: u32, _kp: Keypair) -> Result<(), Error> {
118        Ok(())
119    }
120
121    fn mark_as_used(&self, _pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
122        Ok(())
123    }
124}
125
126/// A simple key provider that uses a static keypair
127///
128/// This is the simplest implementation and is backward compatible with
129/// the original single-keypair design.
130#[derive(Clone)]
131pub struct StaticKeyProvider {
132    kp: Keypair,
133}
134
135impl StaticKeyProvider {
136    /// Create a new static key provider
137    pub fn new(kp: Keypair) -> Self {
138        Self { kp }
139    }
140}
141
142impl KeyProvider for StaticKeyProvider {
143    fn get_next_keypair(&self, _: KeypairIndex) -> Result<Keypair, Error> {
144        // Static provider always returns the same keypair
145        Ok(self.kp)
146    }
147
148    fn get_keypair_for_path(&self, _path: &[u32]) -> Result<Keypair, Error> {
149        // Static provider always returns the same keypair
150        Ok(self.kp)
151    }
152
153    fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
154        // Verify that the requested public key matches our keypair
155        let our_pk = self.kp.x_only_public_key().0;
156        if &our_pk == pk {
157            Ok(self.kp)
158        } else {
159            Err(Error::ad_hoc(format!(
160                "Public key mismatch: requested {pk}, but only have {our_pk}"
161            )))
162        }
163    }
164
165    fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
166        Ok(vec![self.kp.public_key().into()])
167    }
168}
169
170/// A BIP32 hierarchical deterministic key provider
171///
172/// This provider derives keypairs from a master extended private key
173/// using BIP32 derivation paths. It maintains an index counter for
174/// generating new receiving addresses.
175///
176/// ## Example
177///
178/// ```rust
179/// # use std::str::FromStr;
180/// # use bitcoin::bip32::{Xpriv, DerivationPath};
181/// # use bitcoin::Network;
182/// # use crate::ark_client::KeyProvider;
183/// # use ark_client::Bip32KeyProvider;
184/// # use ark_client::key_provider::KeypairIndex;
185///
186/// fn example() -> Result<(), Box<dyn std::error::Error>> {
187/// // Create from a master key with a base path (e.g., m/84'/0'/0'/0)
188/// let master_key = Xpriv::from_str("xprv...")?;
189/// let base_path = DerivationPath::from_str("m/84'/0'/0'/0")?;
190///
191/// // This will derive keys at m/84'/0'/0'/0/0, m/84'/0'/0'/0/1, etc.
192/// let provider = Bip32KeyProvider::new(master_key, base_path);
193///
194/// // Get the next receiving keypair (increments index)
195/// let kp1 = provider.get_next_keypair(KeypairIndex::New)?; // m/84'/0'/0'/0/0
196/// let kp2 = provider.get_next_keypair(KeypairIndex::New)?; // m/84'/0'/0'/0/1
197///
198/// // Or derive a specific keypair by path
199/// let custom_path = vec![84 + 0x8000_0000, 0x8000_0000, 0x8000_0000, 0, 5];
200/// let kp = provider.get_keypair_for_path(&custom_path)?;
201/// # Ok(())
202/// # }
203/// ```
204pub struct Bip32KeyProvider {
205    master_key: Xpriv,
206    base_path: DerivationPath,
207    // Using std::sync::Mutex for interior mutability across Send + Sync
208    next_index: Arc<std::sync::Mutex<u32>>,
209    // Cache of derived keys: pk -> (path_index, keypair, used)
210    // The `used` flag indicates whether this keypair has been used (has VTXOs)
211    key_cache:
212        Arc<std::sync::RwLock<std::collections::HashMap<bitcoin::XOnlyPublicKey, KeyCacheValue>>>,
213}
214
215#[derive(Clone, Copy)]
216pub struct KeyCacheValue {
217    path_index: u32,
218    kp: Keypair,
219    /// Indicates whether this keypair has been used (has VTXOs).
220    used: bool,
221}
222
223impl Bip32KeyProvider {
224    /// Create a new BIP32 key provider
225    ///
226    /// # Arguments
227    ///
228    /// * `master_key` - The master extended private key (xpriv)
229    /// * `base_path` - The base derivation path (e.g., m/84'/0'/0'/0). The provider will append
230    ///   index numbers to this path.
231    pub fn new(master_key: Xpriv, base_path: DerivationPath) -> Self {
232        Self {
233            master_key,
234            base_path,
235            next_index: Arc::new(std::sync::Mutex::new(0)),
236            key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
237        }
238    }
239
240    /// Create a new BIP32 key provider starting from a specific index
241    ///
242    /// # Arguments
243    ///
244    /// * `master_key` - The master extended private key (xpriv)
245    /// * `base_path` - The base derivation path
246    /// * `start_index` - The starting index for key derivation
247    pub fn new_with_index(master_key: Xpriv, base_path: DerivationPath, start_index: u32) -> Self {
248        Self {
249            master_key,
250            base_path,
251            next_index: Arc::new(std::sync::Mutex::new(start_index)),
252            key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
253        }
254    }
255
256    /// Derive a keypair at the specified path
257    fn derive_keypair(&self, path: &DerivationPath) -> Result<Keypair, Error> {
258        let secp = Secp256k1::new();
259        let derived_key = self
260            .master_key
261            .derive_priv(&secp, path)
262            .map_err(|e| Error::ad_hoc(format!("BIP32 derivation failed: {e}")))?;
263
264        Ok(derived_key.to_keypair(&secp))
265    }
266
267    /// Derive a keypair at base_path/index
268    fn derive_at_index(&self, index: u32) -> Result<Keypair, Error> {
269        use bitcoin::bip32::ChildNumber;
270
271        let path = self.base_path.clone();
272        let path = path.extend([ChildNumber::Normal { index }]);
273
274        self.derive_keypair(&path)
275    }
276}
277
278impl KeyProvider for Bip32KeyProvider {
279    fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
280        match keypair_index {
281            KeypairIndex::New => {
282                // Get and increment the next index
283                let index = {
284                    let mut next_index = self
285                        .next_index
286                        .lock()
287                        .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
288                    let current = *next_index;
289                    *next_index = next_index
290                        .checked_add(1)
291                        .ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
292                    current
293                };
294
295                // Derive the keypair at this index
296                let kp = self.derive_at_index(index)?;
297
298                // Cache it for later lookup (marked as unused)
299                let pk = kp.x_only_public_key().0;
300                {
301                    let mut cache = self
302                        .key_cache
303                        .write()
304                        .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
305                    cache.insert(
306                        pk,
307                        KeyCacheValue {
308                            path_index: index,
309                            kp,
310                            used: false,
311                        },
312                    );
313                }
314
315                Ok(kp)
316            }
317            KeypairIndex::LastUnused => {
318                // First, try to find an unused keypair in the cache
319                {
320                    let cache = self
321                        .key_cache
322                        .read()
323                        .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
324
325                    // Find the unused keypair with the lowest index
326                    let unused = cache
327                        .values()
328                        .filter(|KeyCacheValue { used, .. }| !used)
329                        .min_by_key(|KeyCacheValue { path_index, .. }| *path_index);
330
331                    if let Some(KeyCacheValue { kp, .. }) = unused {
332                        return Ok(*kp);
333                    }
334                }
335
336                // No unused keypair found, derive a new one
337                self.get_next_keypair(KeypairIndex::New)
338            }
339        }
340    }
341
342    fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
343        use bitcoin::bip32::ChildNumber;
344        let child_numbers: Vec<ChildNumber> = path
345            .iter()
346            .map(|&n| {
347                if n & 0x8000_0000 != 0 {
348                    ChildNumber::Hardened {
349                        index: n & 0x7FFF_FFFF,
350                    }
351                } else {
352                    ChildNumber::Normal { index: n }
353                }
354            })
355            .collect();
356        let derivation_path = DerivationPath::from(child_numbers);
357        self.derive_keypair(&derivation_path)
358    }
359
360    fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
361        // First check the cache
362        {
363            let cache = self
364                .key_cache
365                .read()
366                .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
367            if let Some(KeyCacheValue { kp, .. }) = cache.get(pk) {
368                return Ok(*kp);
369            }
370        }
371
372        // If not in cache, we need to search. For now, we'll search up to the current index
373        let current_index = {
374            let next_index = self
375                .next_index
376                .lock()
377                .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
378            *next_index
379        };
380
381        // Search through derived keys up to current index
382        for i in 0..current_index {
383            let kp = self.derive_at_index(i)?;
384            let derived_pk = kp.x_only_public_key().0;
385
386            if &derived_pk == pk {
387                // Cache it for next time (assume used since we're looking it up for signing)
388                let mut cache = self
389                    .key_cache
390                    .write()
391                    .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
392                cache.insert(
393                    derived_pk,
394                    KeyCacheValue {
395                        path_index: i,
396                        kp,
397                        used: true,
398                    },
399                );
400                return Ok(kp);
401            }
402        }
403
404        Err(Error::ad_hoc(format!(
405            "Public key {pk} not found in HD wallet. \
406            Searched indices 0..{current_index}. \
407            The key may have been generated outside this provider."
408        )))
409    }
410
411    fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
412        let cache = self
413            .key_cache
414            .read()
415            .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
416
417        Ok(cache.keys().copied().collect())
418    }
419
420    fn supports_discovery(&self) -> bool {
421        true
422    }
423
424    fn get_derivation_index_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Option<u32> {
425        let cache = self.key_cache.read().ok()?;
426        cache.get(pk).map(|v| v.path_index)
427    }
428
429    fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
430        self.derive_at_index(index).map(Some)
431    }
432
433    fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
434        let pk = kp.x_only_public_key().0;
435
436        // Add to cache (marked as used since it was discovered with VTXOs)
437        {
438            let mut cache = self
439                .key_cache
440                .write()
441                .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
442            cache.insert(
443                pk,
444                KeyCacheValue {
445                    path_index: index,
446                    kp,
447                    used: true,
448                },
449            );
450        }
451
452        // Update next_index if needed (set to index + 1 if >= current)
453        {
454            let mut next = self
455                .next_index
456                .lock()
457                .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
458            if index >= *next {
459                *next = index
460                    .checked_add(1)
461                    .ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
462            }
463        }
464
465        Ok(())
466    }
467
468    fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
469        // First check the cache
470        {
471            let maybe_kp = {
472                let cache = self
473                    .key_cache
474                    .read()
475                    .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
476                cache.get(pk).copied()
477            };
478
479            match maybe_kp {
480                Some(KeyCacheValue {
481                    path_index,
482                    kp,
483                    used: false,
484                }) => {
485                    let mut cache = self
486                        .key_cache
487                        .write()
488                        .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
489                    cache.insert(
490                        *pk,
491                        KeyCacheValue {
492                            path_index,
493                            kp,
494                            used: true,
495                        },
496                    );
497                    return Ok(());
498                }
499                Some(KeyCacheValue { used: true, .. }) => {
500                    // already marked as used
501                    return Ok(());
502                }
503                _ => {
504                    // no found
505                }
506            }
507        }
508
509        // If not in cache, we need to search. For now, we'll search up to the current index
510        let current_index = {
511            let next_index = self
512                .next_index
513                .lock()
514                .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
515            *next_index
516        };
517
518        // Search through derived keys up to current index
519        for i in 0..current_index {
520            let kp = self.derive_at_index(i)?;
521            let derived_pk = kp.x_only_public_key().0;
522
523            if &derived_pk == pk {
524                // Cache it for next time (assume used since we're looking it up for signing)
525                let mut cache = self
526                    .key_cache
527                    .write()
528                    .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
529                cache.insert(
530                    derived_pk,
531                    KeyCacheValue {
532                        path_index: i,
533                        kp,
534                        used: true,
535                    },
536                );
537                return Ok(());
538            }
539        }
540
541        Err(Error::ad_hoc(format!(
542            "Public key {pk} not found in HD wallet. \
543            Searched indices 0..{current_index}. \
544            The key may have been generated outside this provider."
545        )))
546    }
547}
548
549// Implement KeyProvider for Arc<T> where T: KeyProvider
550impl<T: KeyProvider> KeyProvider for Arc<T> {
551    fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
552        (**self).get_next_keypair(keypair_index)
553    }
554
555    fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
556        (**self).get_keypair_for_path(path)
557    }
558
559    fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
560        (**self).get_keypair_for_pk(pk)
561    }
562
563    fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
564        (**self).get_cached_pks()
565    }
566
567    fn supports_discovery(&self) -> bool {
568        (**self).supports_discovery()
569    }
570
571    fn get_derivation_index_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Option<u32> {
572        (**self).get_derivation_index_for_pk(pk)
573    }
574
575    fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
576        (**self).derive_at_discovery_index(index)
577    }
578
579    fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
580        (**self).cache_discovered_keypair(index, kp)
581    }
582
583    fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
584        (**self).mark_as_used(pk)
585    }
586}