use crate::error::Error;
use bitcoin::bip32::DerivationPath;
use bitcoin::bip32::Xpriv;
use bitcoin::key::Keypair;
use bitcoin::secp256k1::Secp256k1;
use std::sync::Arc;
pub enum KeypairIndex {
New,
LastUnused,
}
pub trait KeyProvider: Send + Sync {
fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error>;
fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error>;
fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error>;
fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error>;
fn supports_discovery(&self) -> bool {
false
}
fn get_derivation_index_for_pk(&self, _pk: &bitcoin::XOnlyPublicKey) -> Option<u32> {
None
}
fn derive_at_discovery_index(&self, _index: u32) -> Result<Option<Keypair>, Error> {
Ok(None)
}
fn cache_discovered_keypair(&self, _index: u32, _kp: Keypair) -> Result<(), Error> {
Ok(())
}
fn mark_as_used(&self, _pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
Ok(())
}
}
#[derive(Clone)]
pub struct StaticKeyProvider {
kp: Keypair,
}
impl StaticKeyProvider {
pub fn new(kp: Keypair) -> Self {
Self { kp }
}
}
impl KeyProvider for StaticKeyProvider {
fn get_next_keypair(&self, _: KeypairIndex) -> Result<Keypair, Error> {
Ok(self.kp)
}
fn get_keypair_for_path(&self, _path: &[u32]) -> Result<Keypair, Error> {
Ok(self.kp)
}
fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
let our_pk = self.kp.x_only_public_key().0;
if &our_pk == pk {
Ok(self.kp)
} else {
Err(Error::ad_hoc(format!(
"Public key mismatch: requested {pk}, but only have {our_pk}"
)))
}
}
fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
Ok(vec![self.kp.public_key().into()])
}
}
pub struct Bip32KeyProvider {
master_key: Xpriv,
base_path: DerivationPath,
next_index: Arc<std::sync::Mutex<u32>>,
key_cache:
Arc<std::sync::RwLock<std::collections::HashMap<bitcoin::XOnlyPublicKey, KeyCacheValue>>>,
}
#[derive(Clone, Copy)]
pub struct KeyCacheValue {
path_index: u32,
kp: Keypair,
used: bool,
}
impl Bip32KeyProvider {
pub fn new(master_key: Xpriv, base_path: DerivationPath) -> Self {
Self {
master_key,
base_path,
next_index: Arc::new(std::sync::Mutex::new(0)),
key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
}
}
pub fn new_with_index(master_key: Xpriv, base_path: DerivationPath, start_index: u32) -> Self {
Self {
master_key,
base_path,
next_index: Arc::new(std::sync::Mutex::new(start_index)),
key_cache: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
}
}
fn derive_keypair(&self, path: &DerivationPath) -> Result<Keypair, Error> {
let secp = Secp256k1::new();
let derived_key = self
.master_key
.derive_priv(&secp, path)
.map_err(|e| Error::ad_hoc(format!("BIP32 derivation failed: {e}")))?;
Ok(derived_key.to_keypair(&secp))
}
fn derive_at_index(&self, index: u32) -> Result<Keypair, Error> {
use bitcoin::bip32::ChildNumber;
let path = self.base_path.clone();
let path = path.extend([ChildNumber::Normal { index }]);
self.derive_keypair(&path)
}
}
impl KeyProvider for Bip32KeyProvider {
fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
match keypair_index {
KeypairIndex::New => {
let index = {
let mut next_index = self
.next_index
.lock()
.map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
let current = *next_index;
*next_index = next_index
.checked_add(1)
.ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
current
};
let kp = self.derive_at_index(index)?;
let pk = kp.x_only_public_key().0;
{
let mut cache = self
.key_cache
.write()
.map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
cache.insert(
pk,
KeyCacheValue {
path_index: index,
kp,
used: false,
},
);
}
Ok(kp)
}
KeypairIndex::LastUnused => {
{
let cache = self
.key_cache
.read()
.map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
let unused = cache
.values()
.filter(|KeyCacheValue { used, .. }| !used)
.min_by_key(|KeyCacheValue { path_index, .. }| *path_index);
if let Some(KeyCacheValue { kp, .. }) = unused {
return Ok(*kp);
}
}
self.get_next_keypair(KeypairIndex::New)
}
}
}
fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
use bitcoin::bip32::ChildNumber;
let child_numbers: Vec<ChildNumber> = path
.iter()
.map(|&n| {
if n & 0x8000_0000 != 0 {
ChildNumber::Hardened {
index: n & 0x7FFF_FFFF,
}
} else {
ChildNumber::Normal { index: n }
}
})
.collect();
let derivation_path = DerivationPath::from(child_numbers);
self.derive_keypair(&derivation_path)
}
fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
{
let cache = self
.key_cache
.read()
.map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
if let Some(KeyCacheValue { kp, .. }) = cache.get(pk) {
return Ok(*kp);
}
}
let current_index = {
let next_index = self
.next_index
.lock()
.map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
*next_index
};
for i in 0..current_index {
let kp = self.derive_at_index(i)?;
let derived_pk = kp.x_only_public_key().0;
if &derived_pk == pk {
let mut cache = self
.key_cache
.write()
.map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
cache.insert(
derived_pk,
KeyCacheValue {
path_index: i,
kp,
used: true,
},
);
return Ok(kp);
}
}
Err(Error::ad_hoc(format!(
"Public key {pk} not found in HD wallet. \
Searched indices 0..{current_index}. \
The key may have been generated outside this provider."
)))
}
fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
let cache = self
.key_cache
.read()
.map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
Ok(cache.keys().copied().collect())
}
fn supports_discovery(&self) -> bool {
true
}
fn get_derivation_index_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Option<u32> {
let cache = self.key_cache.read().ok()?;
cache.get(pk).map(|v| v.path_index)
}
fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
self.derive_at_index(index).map(Some)
}
fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
let pk = kp.x_only_public_key().0;
{
let mut cache = self
.key_cache
.write()
.map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
cache.insert(
pk,
KeyCacheValue {
path_index: index,
kp,
used: true,
},
);
}
{
let mut next = self
.next_index
.lock()
.map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
if index >= *next {
*next = index
.checked_add(1)
.ok_or_else(|| Error::ad_hoc("Key derivation index overflow"))?;
}
}
Ok(())
}
fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
{
let maybe_kp = {
let cache = self
.key_cache
.read()
.map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
cache.get(pk).copied()
};
match maybe_kp {
Some(KeyCacheValue {
path_index,
kp,
used: false,
}) => {
let mut cache = self
.key_cache
.write()
.map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
cache.insert(
*pk,
KeyCacheValue {
path_index,
kp,
used: true,
},
);
return Ok(());
}
Some(KeyCacheValue { used: true, .. }) => {
return Ok(());
}
_ => {
}
}
}
let current_index = {
let next_index = self
.next_index
.lock()
.map_err(|e| Error::ad_hoc(format!("Failed to lock next_index: {e}")))?;
*next_index
};
for i in 0..current_index {
let kp = self.derive_at_index(i)?;
let derived_pk = kp.x_only_public_key().0;
if &derived_pk == pk {
let mut cache = self
.key_cache
.write()
.map_err(|e| Error::ad_hoc(format!("Failed to lock key_cache: {e}")))?;
cache.insert(
derived_pk,
KeyCacheValue {
path_index: i,
kp,
used: true,
},
);
return Ok(());
}
}
Err(Error::ad_hoc(format!(
"Public key {pk} not found in HD wallet. \
Searched indices 0..{current_index}. \
The key may have been generated outside this provider."
)))
}
}
impl<T: KeyProvider> KeyProvider for Arc<T> {
fn get_next_keypair(&self, keypair_index: KeypairIndex) -> Result<Keypair, Error> {
(**self).get_next_keypair(keypair_index)
}
fn get_keypair_for_path(&self, path: &[u32]) -> Result<Keypair, Error> {
(**self).get_keypair_for_path(path)
}
fn get_keypair_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<Keypair, Error> {
(**self).get_keypair_for_pk(pk)
}
fn get_cached_pks(&self) -> Result<Vec<bitcoin::XOnlyPublicKey>, Error> {
(**self).get_cached_pks()
}
fn supports_discovery(&self) -> bool {
(**self).supports_discovery()
}
fn get_derivation_index_for_pk(&self, pk: &bitcoin::XOnlyPublicKey) -> Option<u32> {
(**self).get_derivation_index_for_pk(pk)
}
fn derive_at_discovery_index(&self, index: u32) -> Result<Option<Keypair>, Error> {
(**self).derive_at_discovery_index(index)
}
fn cache_discovered_keypair(&self, index: u32, kp: Keypair) -> Result<(), Error> {
(**self).cache_discovered_keypair(index, kp)
}
fn mark_as_used(&self, pk: &bitcoin::XOnlyPublicKey) -> Result<(), Error> {
(**self).mark_as_used(pk)
}
}