Skip to main content

ark_client/
key_provider.rs

1use crate::error::Error;
2use bitcoin::bip32::DerivationPath;
3use bitcoin::bip32::Xpriv;
4use bitcoin::key::Keypair;
5use bitcoin::secp256k1::Secp256k1;
6use std::sync::Arc;
7
8pub enum KeypairIndex {
9    /// Increments the index and returns a new keypair
10    New,
11    /// Returns the last unused address
12    LastUnused,
13}
14
15/// Provides keypairs for signing operations
16///
17/// This trait allows different key management strategies:
18/// - Static keypair (single key)
19/// - BIP32 HD wallet (hierarchical deterministic)
20/// - Hardware wallets (future)
21/// - Custom key derivation schemes
22pub trait KeyProvider: Send + Sync {
23    /// Get a keypair for receiving funds
24    ///
25    /// For static key providers, this always returns the same keypair regardless of the index.
26    /// For HD wallets, behavior depends on the `keypair_index` parameter.
27    ///
28    /// # Arguments
29    ///
30    /// * `keypair_index` - Controls which keypair to return:
31    ///   - `KeypairIndex::New`: Increments the internal index and returns a new keypair
32    ///   - `KeypairIndex::LastUnused`: Returns the last unused keypair without incrementing
33    ///
34    /// # Returns
35    ///
36    /// A keypair to use for receiving funds
37    fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error>;
38
39    /// Get a keypair for a specific BIP32 derivation path
40    ///
41    /// # Arguments
42    ///
43    /// * `path` - BIP32 derivation path as an array of child indexes
44    ///
45    /// # Returns
46    ///
47    /// A keypair derived at the specified path, or an error if derivation is not supported
48    fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error>;
49
50    /// Get a keypair for a specific public key
51    ///
52    /// This is essential for HD wallets where you need to find the correct keypair
53    /// for signing with a previously generated public key.
54    ///
55    /// # Arguments
56    ///
57    /// * `pk` - The X-only public key to find the keypair for
58    ///
59    /// # Returns
60    ///
61    /// The keypair corresponding to the public key, or an error if not found
62    fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error>;
63
64    /// Get all public keys that this provider currently knows about
65    ///
66    /// For static key providers, this returns the single keypair's public key.
67    /// For HD wallets, this returns all public keys that have been derived and cached
68    /// (i.e., keys generated via `get_next_keypair`).
69    ///
70    /// This is useful for determining which keys are available for signing operations
71    /// without having to search or derive new keys.
72    ///
73    /// # Returns
74    ///
75    /// A vector of X-only public keys known to this provider
76    fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error>;
77
78    /// Returns true if this provider supports key discovery
79    ///
80    /// HD wallets return true since they can derive and discover previously used keys.
81    /// Static key providers return false (single key, nothing to discover).
82    fn supports_discovery(&self) -> bool {
83        false
84    }
85
86    /// Derive a keypair at a specific index without caching
87    ///
88    /// This is used during discovery to check keys without affecting the provider's state.
89    /// Returns `None` if the provider doesn't support index-based derivation.
90    ///
91    /// # Arguments
92    ///
93    /// * `index` - The derivation index (appended to base path for HD wallets)
94    fn derive_at_discovery_index(&self, _index: u32) -> Result<Option<Keypair>, Error> {
95        Ok(None)
96    }
97
98    /// Cache a discovered keypair at the given index
99    ///
100    /// This is called after discovery determines a key is "used" (has VTXOs).
101    /// Also updates next_index if index >= current next_index to avoid collisions.
102    ///
103    /// No-op for providers that don't support discovery.
104    ///
105    /// # Arguments
106    ///
107    /// * `index` - The derivation index
108    /// * `kp` - The keypair to cache
109    fn cache_discovered_keypair(&self, _index: u32, _kp: Keypair) -> Result<(), Error> {
110        Ok(())
111    }
112
113    fn mark_as_used(&self, _pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
114        Ok(())
115    }
116}
117
118/// A simple key provider that uses a static keypair
119///
120/// This is the simplest implementation and is backward compatible with
121/// the original single-keypair design.
122#[derive(Clone)]
123pub struct StaticKeyProvider {
124    kp: Keypair,
125}
126
127impl StaticKeyProvider {
128    /// Create a new static key provider
129    pub fn new(kp: Keypair) -> Self {
130        Self { kp }
131    }
132}
133
134impl KeyProvider for StaticKeyProvider {
135    fn get_next_keypair(&self, _: KeypairIndex) -> Result<Keypair, Error> {
136        // Static provider always returns the same keypair
137        Ok(self.kp)
138    }
139
140    fn get_keypair_for_path(&self, _path: &[u32]) -> Result<Keypair, Error> {
141        // Static provider always returns the same keypair
142        Ok(self.kp)
143    }
144
145    fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
146        // Verify that the requested public key matches our keypair
147        let our_pk = self.kp.x_only_public_key().0;
148        if &our_pk == pk {
149            Ok(self.kp)
150        } else {
151            Err(Error::ad_hoc(format!(
152                "Public key mismatch: requested {pk}, but only have {our_pk}"
153            )))
154        }
155    }
156
157    fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
158        Ok(vec![self.kp.public_key().into()])
159    }
160}
161
162/// A BIP32 hierarchical deterministic key provider
163///
164/// This provider derives keypairs from a master extended private key
165/// using BIP32 derivation paths. It maintains an index counter for
166/// generating new receiving addresses.
167///
168/// ## Example
169///
170/// ```rust
171/// # use std::str::FromStr;
172/// # use bitcoin::bip32::{Xpriv, DerivationPath};
173/// # use bitcoin::Network;
174/// # use crate::ark_client::KeyProvider;
175/// # use ark_client::Bip32KeyProvider;
176/// # use ark_client::key_provider::KeypairIndex;
177///
178/// fn example() -> Result<(), Box<dyn std::error::Error>> {
179/// // Create from a master key with a base path (e.g., m/84'/0'/0'/0)
180/// let master_key = Xpriv::from_str("xprv...")?;
181/// let base_path = DerivationPath::from_str("m/84'/0'/0'/0")?;
182///
183/// // This will derive keys at m/84'/0'/0'/0/0, m/84'/0'/0'/0/1, etc.
184/// let provider = Bip32KeyProvider::new(master_key, base_path);
185///
186/// // Get the next receiving keypair (increments index)
187/// let kp1 = provider.get_next_keypair(KeypairIndex::New)?; // m/84'/0'/0'/0/0
188/// let kp2 = provider.get_next_keypair(KeypairIndex::New)?; // m/84'/0'/0'/0/1
189///
190/// // Or derive a specific keypair by path
191/// let custom_path = vec![84 + 0x8000_0000, 0x8000_0000, 0x8000_0000, 0, 5];
192/// let kp = provider.get_keypair_for_path(&custom_path)?;
193/// # Ok(())
194/// # }
195/// ```
196pub struct Bip32KeyProvider {
197    master_key: Xpriv,
198    base_path: DerivationPath,
199    // Using std::sync::Mutex for interior mutability across Send + Sync
200    next_index: Arc<std::sync::Mutex<u32>>,
201    // Cache of derived keys: pk -> (path_index, keypair, used)
202    // The `used` flag indicates whether this keypair has been used (has VTXOs)
203    key_cache:
204        Arc<std::sync::RwLock<std::collections::HashMap<bitcoin::XOnlyPublicKey, KeyCacheValue>>>,
205}
206
207#[derive(Clone, Copy)]
208pub struct KeyCacheValue {
209    path_index: u32,
210    kp: Keypair,
211    /// Indicates whether this keypair has been used (has VTXOs).
212    used: bool,
213}
214
215impl Bip32KeyProvider {
216    /// Create a new BIP32 key provider
217    ///
218    /// # Arguments
219    ///
220    /// * `master_key` - The master extended private key (xpriv)
221    /// * `base_path` - The base derivation path (e.g., m/84'/0'/0'/0). The provider will append
222    ///   index numbers to this path.
223    pub fn new(master_key: Xpriv, base_path: DerivationPath) -> Self {
224        Self {
225            master_key,
226            base_path,
227            next_index: Arc::new(std::sync::Mutex::new(0)),
228            key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
229        }
230    }
231
232    /// Create a new BIP32 key provider starting from a specific index
233    ///
234    /// # Arguments
235    ///
236    /// * `master_key` - The master extended private key (xpriv)
237    /// * `base_path` - The base derivation path
238    /// * `start_index` - The starting index for key derivation
239    pub fn new_with_index(master_key: Xpriv, base_path: DerivationPath, start_index: u32) -> Self {
240        Self {
241            master_key,
242            base_path,
243            next_index: Arc::new(std::sync::Mutex::new(start_index)),
244            key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
245        }
246    }
247
248    /// Derive a keypair at the specified path
249    fn derive_keypair(&self, path: &DerivationPath) -> Result<Keypair, Error> {
250        let secp = Secp256k1::new();
251        let derived_key = self
252            .master_key
253            .derive_priv(&secp, path)
254            .map_err(|e| Error::ad_hoc(format!("BIP32 derivation failed: {e}")))?;
255
256        Ok(derived_key.to_keypair(&secp))
257    }
258
259    /// Derive a keypair at base_path/index
260    fn derive_at_index(&self, index: u32) -> Result<Keypair, Error> {
261        use bitcoin::bip32::ChildNumber;
262
263        let path = self.base_path.clone();
264        let path = path.extend([ChildNumber::Normal { index }]);
265
266        self.derive_keypair(&path)
267    }
268}
269
270impl KeyProvider for Bip32KeyProvider {
271    fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
272        match keypair_index {
273            KeypairIndex::New => {
274                // Get and increment the next index
275                let index = {
276                    let mut next_index = self
277                        .next_index
278                        .lock()
279                        .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
280                    let current = *next_index;
281                    *next_index = next_index
282                        .checked_add(1)
283                        .ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
284                    current
285                };
286
287                // Derive the keypair at this index
288                let kp = self.derive_at_index(index)?;
289
290                // Cache it for later lookup (marked as unused)
291                let pk = kp.x_only_public_key().0;
292                {
293                    let mut cache = self
294                        .key_cache
295                        .write()
296                        .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
297                    cache.insert(
298                        pk,
299                        KeyCacheValue {
300                            path_index: index,
301                            kp,
302                            used: false,
303                        },
304                    );
305                }
306
307                Ok(kp)
308            }
309            KeypairIndex::LastUnused => {
310                // First, try to find an unused keypair in the cache
311                {
312                    let cache = self
313                        .key_cache
314                        .read()
315                        .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
316
317                    // Find the unused keypair with the lowest index
318                    let unused = cache
319                        .values()
320                        .filter(|KeyCacheValue { used, .. }| !used)
321                        .min_by_key(|KeyCacheValue { path_index, .. }| *path_index);
322
323                    if let Some(KeyCacheValue { kp, .. }) = unused {
324                        return Ok(*kp);
325                    }
326                }
327
328                // No unused keypair found, derive a new one
329                self.get_next_keypair(KeypairIndex::New)
330            }
331        }
332    }
333
334    fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
335        use bitcoin::bip32::ChildNumber;
336        let child_numbers: Vec<ChildNumber> = path
337            .iter()
338            .map(|&n| {
339                if n & 0x8000_0000 != 0 {
340                    ChildNumber::Hardened {
341                        index: n & 0x7FFF_FFFF,
342                    }
343                } else {
344                    ChildNumber::Normal { index: n }
345                }
346            })
347            .collect();
348        let derivation_path = DerivationPath::from(child_numbers);
349        self.derive_keypair(&derivation_path)
350    }
351
352    fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
353        // First check the cache
354        {
355            let cache = self
356                .key_cache
357                .read()
358                .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
359            if let Some(KeyCacheValue { kp, .. }) = cache.get(pk) {
360                return Ok(*kp);
361            }
362        }
363
364        // If not in cache, we need to search. For now, we'll search up to the current index
365        let current_index = {
366            let next_index = self
367                .next_index
368                .lock()
369                .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
370            *next_index
371        };
372
373        // Search through derived keys up to current index
374        for i in 0..current_index {
375            let kp = self.derive_at_index(i)?;
376            let derived_pk = kp.x_only_public_key().0;
377
378            if &derived_pk == pk {
379                // Cache it for next time (assume used since we're looking it up for signing)
380                let mut cache = self
381                    .key_cache
382                    .write()
383                    .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
384                cache.insert(
385                    derived_pk,
386                    KeyCacheValue {
387                        path_index: i,
388                        kp,
389                        used: true,
390                    },
391                );
392                return Ok(kp);
393            }
394        }
395
396        Err(Error::ad_hoc(format!(
397            "Public key {pk} not found in HD wallet. \
398            Searched indices 0..{current_index}. \
399            The key may have been generated outside this provider."
400        )))
401    }
402
403    fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
404        let cache = self
405            .key_cache
406            .read()
407            .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
408
409        Ok(cache.keys().copied().collect())
410    }
411
412    fn supports_discovery(&self) -> bool {
413        true
414    }
415
416    fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
417        self.derive_at_index(index).map(Some)
418    }
419
420    fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
421        let pk = kp.x_only_public_key().0;
422
423        // Add to cache (marked as used since it was discovered with VTXOs)
424        {
425            let mut cache = self
426                .key_cache
427                .write()
428                .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
429            cache.insert(
430                pk,
431                KeyCacheValue {
432                    path_index: index,
433                    kp,
434                    used: true,
435                },
436            );
437        }
438
439        // Update next_index if needed (set to index + 1 if >= current)
440        {
441            let mut next = self
442                .next_index
443                .lock()
444                .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
445            if index >= *next {
446                *next = index
447                    .checked_add(1)
448                    .ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
449            }
450        }
451
452        Ok(())
453    }
454
455    fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
456        // First check the cache
457        {
458            let maybe_kp = {
459                let cache = self
460                    .key_cache
461                    .read()
462                    .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
463                cache.get(pk).copied()
464            };
465
466            match maybe_kp {
467                Some(KeyCacheValue {
468                    path_index,
469                    kp,
470                    used: false,
471                }) => {
472                    let mut cache = self
473                        .key_cache
474                        .write()
475                        .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
476                    cache.insert(
477                        *pk,
478                        KeyCacheValue {
479                            path_index,
480                            kp,
481                            used: true,
482                        },
483                    );
484                    return Ok(());
485                }
486                Some(KeyCacheValue { used: true, .. }) => {
487                    // already marked as used
488                    return Ok(());
489                }
490                _ => {
491                    // no found
492                }
493            }
494        }
495
496        // If not in cache, we need to search. For now, we'll search up to the current index
497        let current_index = {
498            let next_index = self
499                .next_index
500                .lock()
501                .map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
502            *next_index
503        };
504
505        // Search through derived keys up to current index
506        for i in 0..current_index {
507            let kp = self.derive_at_index(i)?;
508            let derived_pk = kp.x_only_public_key().0;
509
510            if &derived_pk == pk {
511                // Cache it for next time (assume used since we're looking it up for signing)
512                let mut cache = self
513                    .key_cache
514                    .write()
515                    .map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
516                cache.insert(
517                    derived_pk,
518                    KeyCacheValue {
519                        path_index: i,
520                        kp,
521                        used: true,
522                    },
523                );
524                return Ok(());
525            }
526        }
527
528        Err(Error::ad_hoc(format!(
529            "Public key {pk} not found in HD wallet. \
530            Searched indices 0..{current_index}. \
531            The key may have been generated outside this provider."
532        )))
533    }
534}
535
536// Implement KeyProvider for Arc<T> where T: KeyProvider
537impl<T: KeyProvider> KeyProvider for Arc<T> {
538    fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
539        (**self).get_next_keypair(keypair_index)
540    }
541
542    fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
543        (**self).get_keypair_for_path(path)
544    }
545
546    fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
547        (**self).get_keypair_for_pk(pk)
548    }
549
550    fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
551        (**self).get_cached_pks()
552    }
553
554    fn supports_discovery(&self) -> bool {
555        (**self).supports_discovery()
556    }
557
558    fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
559        (**self).derive_at_discovery_index(index)
560    }
561
562    fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
563        (**self).cache_discovered_keypair(index, kp)
564    }
565
566    fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
567        (**self).mark_as_used(pk)
568    }
569}