use std::sync::Arc;
use bsv::primitives::private_key::PrivateKey;
use bsv::wallet::cached_key_deriver::CachedKeyDeriver;
use bsv::wallet::interfaces::{CreateActionArgs, CreateActionOutput, CreateActionResult};
use bsv::wallet::types::{Counterparty, CounterpartyType, Protocol};
use crate::error::{WalletError, WalletResult};
use crate::monitor::Monitor;
use crate::services::traits::WalletServices;
use crate::storage::manager::WalletStorageManager;
use crate::storage::StorageConfig;
use crate::types::Chain;
use crate::utility::script_template_brc29::ScriptTemplateBRC29;
use crate::wallet::privileged::PrivilegedKeyManager;
use crate::wallet::types::{KeyPair, WalletArgs};
use crate::wallet::wallet::Wallet;
pub struct SetupWallet {
pub wallet: Wallet,
pub chain: Chain,
pub key_deriver: Arc<CachedKeyDeriver>,
pub identity_key: String,
pub storage: Arc<WalletStorageManager>,
pub services: Option<Arc<dyn WalletServices>>,
pub monitor: Option<Arc<Monitor>>,
}
enum StorageKind {
Sqlite(String),
#[allow(dead_code)]
Mysql(String),
#[allow(dead_code)]
Postgres(String),
}
pub struct WalletBuilder {
chain: Option<Chain>,
root_key: Option<PrivateKey>,
storage_config: Option<StorageKind>,
storage_identity_key: Option<String>,
services: Option<Arc<dyn WalletServices>>,
use_default_services: bool,
monitor_enabled: bool,
privileged_key_manager: Option<Arc<dyn PrivilegedKeyManager>>,
pool_max_connections: Option<u32>,
pool_min_connections: Option<u32>,
pool_idle_timeout: Option<std::time::Duration>,
pool_connect_timeout: Option<std::time::Duration>,
}
impl WalletBuilder {
pub fn new() -> Self {
Self {
chain: None,
root_key: None,
storage_config: None,
storage_identity_key: None,
services: None,
use_default_services: false,
monitor_enabled: false,
privileged_key_manager: None,
pool_max_connections: None,
pool_min_connections: None,
pool_idle_timeout: None,
pool_connect_timeout: None,
}
}
pub fn chain(mut self, chain: Chain) -> Self {
self.chain = Some(chain);
self
}
pub fn root_key(mut self, key: PrivateKey) -> Self {
self.root_key = Some(key);
self
}
pub fn with_sqlite(mut self, path: &str) -> Self {
self.storage_config = Some(StorageKind::Sqlite(path.to_string()));
self
}
pub fn with_sqlite_memory(mut self) -> Self {
self.storage_config = Some(StorageKind::Sqlite(":memory:".to_string()));
self
}
pub fn with_mysql(mut self, url: &str) -> Self {
self.storage_config = Some(StorageKind::Mysql(url.to_string()));
self
}
pub fn with_postgres(mut self, url: &str) -> Self {
self.storage_config = Some(StorageKind::Postgres(url.to_string()));
self
}
pub fn with_default_services(mut self) -> Self {
self.use_default_services = true;
self
}
pub fn with_services(mut self, services: Arc<dyn WalletServices>) -> Self {
self.services = Some(services);
self
}
pub fn with_monitor(mut self) -> Self {
self.monitor_enabled = true;
self
}
pub fn with_storage_identity_key(mut self, key: String) -> Self {
self.storage_identity_key = Some(key);
self
}
pub fn with_privileged_key_manager(mut self, pkm: Arc<dyn PrivilegedKeyManager>) -> Self {
self.privileged_key_manager = Some(pkm);
self
}
pub fn with_max_connections(mut self, max: u32) -> Self {
self.pool_max_connections = Some(max);
self
}
pub fn with_min_connections(mut self, min: u32) -> Self {
self.pool_min_connections = Some(min);
self
}
pub fn with_pool_idle_timeout(mut self, timeout: std::time::Duration) -> Self {
self.pool_idle_timeout = Some(timeout);
self
}
pub fn with_pool_connect_timeout(mut self, timeout: std::time::Duration) -> Self {
self.pool_connect_timeout = Some(timeout);
self
}
pub async fn build(self) -> WalletResult<SetupWallet> {
let chain = self
.chain
.ok_or_else(|| WalletError::MissingParameter("chain".to_string()))?;
let root_key = self
.root_key
.ok_or_else(|| WalletError::MissingParameter("root_key".to_string()))?;
let storage_kind = self.storage_config.ok_or_else(|| {
WalletError::MissingParameter(
"storage (call with_sqlite, with_sqlite_memory, with_mysql, or with_postgres)"
.to_string(),
)
})?;
let key_deriver = Arc::new(CachedKeyDeriver::new(root_key, None));
let identity_key_hex = key_deriver.identity_key().to_der_hex();
let pool_max = self.pool_max_connections;
let pool_min = self.pool_min_connections;
let pool_idle = self.pool_idle_timeout;
let pool_connect = self.pool_connect_timeout;
let apply_pool_overrides = |config: &mut StorageConfig| {
if let Some(max) = pool_max {
config.max_connections = max;
}
if let Some(min) = pool_min {
config.min_connections = min;
}
if let Some(timeout) = pool_idle {
config.idle_timeout = timeout;
}
if let Some(timeout) = pool_connect {
config.connect_timeout = timeout;
}
};
use crate::storage::traits::wallet_provider::WalletStorageProvider;
let provider: Arc<dyn WalletStorageProvider> = match storage_kind {
StorageKind::Sqlite(path) => {
let url = if path == ":memory:" {
"sqlite::memory:".to_string()
} else {
format!("sqlite:{}", path)
};
let mut config = StorageConfig {
url,
..StorageConfig::default()
};
apply_pool_overrides(&mut config);
#[cfg(feature = "sqlite")]
{
let storage =
crate::storage::sqlx_impl::SqliteStorage::new_sqlite(config, chain.clone())
.await?;
Arc::new(storage) as Arc<dyn WalletStorageProvider>
}
#[cfg(not(feature = "sqlite"))]
{
let _ = config;
return Err(WalletError::InvalidOperation(
"SQLite feature not enabled. Add `sqlite` feature to Cargo.toml."
.to_string(),
));
}
}
StorageKind::Mysql(url) => {
let mut config = StorageConfig {
url,
..StorageConfig::default()
};
apply_pool_overrides(&mut config);
#[cfg(feature = "mysql")]
{
let mut storage =
crate::storage::sqlx_impl::MysqlStorage::new_mysql(config, chain.clone())
.await?;
if let Some(ref sik) = self.storage_identity_key {
storage.storage_identity_key = sik.clone();
}
Arc::new(storage) as Arc<dyn WalletStorageProvider>
}
#[cfg(not(feature = "mysql"))]
{
let _ = config;
return Err(WalletError::InvalidOperation(
"MySQL feature not enabled. Add `mysql` feature to Cargo.toml.".to_string(),
));
}
}
StorageKind::Postgres(url) => {
let mut config = StorageConfig {
url,
..StorageConfig::default()
};
apply_pool_overrides(&mut config);
#[cfg(feature = "postgres")]
{
let storage =
crate::storage::sqlx_impl::PgStorage::new_postgres(config, chain.clone())
.await?;
Arc::new(storage) as Arc<dyn WalletStorageProvider>
}
#[cfg(not(feature = "postgres"))]
{
let _ = config;
return Err(WalletError::InvalidOperation(
"PostgreSQL feature not enabled. Add `postgres` feature to Cargo.toml."
.to_string(),
));
}
}
};
provider.migrate("setup", "").await?;
let storage = Arc::new(WalletStorageManager::new(
identity_key_hex.clone(),
Some(provider.clone()),
vec![],
));
storage.make_available().await?;
let services: Option<Arc<dyn WalletServices>> = if let Some(svc) = self.services {
Some(svc)
} else if self.use_default_services {
Some(Arc::new(crate::services::services::Services::from_chain(
chain.clone(),
)))
} else {
None
};
let wallet_args = WalletArgs {
chain: chain.clone(),
key_deriver: key_deriver.clone(),
storage: storage.clone(),
services: services.clone(),
monitor: None, privileged_key_manager: self.privileged_key_manager,
settings_manager: None,
lookup_resolver: None,
};
let wallet = Wallet::new(wallet_args)?;
let monitor = if self.monitor_enabled {
if let Some(ref svc) = services {
let monitor = crate::monitor::Monitor::builder()
.chain(chain.clone())
.storage(storage.clone())
.services(svc.clone())
.default_tasks()
.build()?;
Some(Arc::new(monitor))
} else {
None
}
} else {
None
};
Ok(SetupWallet {
wallet,
chain,
key_deriver,
identity_key: identity_key_hex,
storage,
services,
monitor,
})
}
}
impl Default for WalletBuilder {
fn default() -> Self {
Self::new()
}
}
pub fn get_key_pair(
key_deriver: &CachedKeyDeriver,
protocol_id: &str,
key_id: &str,
counterparty: &str,
) -> WalletResult<KeyPair> {
let protocol = parse_protocol(protocol_id)?;
let cp = parse_counterparty(counterparty)?;
let private_key = key_deriver
.derive_private_key(&protocol, key_id, &cp)
.map_err(|e| WalletError::Internal(format!("Key derivation failed: {}", e)))?;
let public_key = private_key.to_public_key();
Ok(KeyPair {
private_key: private_key.to_hex(),
public_key: public_key.to_der_hex(),
})
}
pub fn get_lock_p2pkh(
key_deriver: &CachedKeyDeriver,
protocol_id: &str,
key_id: &str,
counterparty: &str,
) -> WalletResult<Vec<u8>> {
let protocol = parse_protocol(protocol_id)?;
let cp = parse_counterparty(counterparty)?;
let derived_pub = key_deriver
.derive_public_key(&protocol, key_id, &cp, false)
.map_err(|e| WalletError::Internal(format!("Public key derivation failed: {}", e)))?;
use bsv::script::templates::p2pkh::P2PKH;
use bsv::script::templates::ScriptTemplateLock;
let hash_vec = derived_pub.to_hash();
let mut hash = [0u8; 20];
hash.copy_from_slice(&hash_vec);
let p2pkh = P2PKH::from_public_key_hash(hash);
let locking_script = p2pkh
.lock()
.map_err(|e| WalletError::Internal(format!("P2PKH lock failed: {}", e)))?;
Ok(locking_script.to_binary())
}
pub fn create_p2pkh_outputs(
key_deriver: &CachedKeyDeriver,
count: usize,
satoshis: u64,
) -> WalletResult<Vec<CreateActionOutput>> {
let mut outputs = Vec::with_capacity(count);
let root_key = key_deriver.root_key();
let identity_pub = key_deriver.identity_key();
for i in 0..count {
let derivation_prefix = random_hex_string();
let derivation_suffix = random_hex_string();
let tmpl = ScriptTemplateBRC29::new(derivation_prefix, derivation_suffix);
let locking_script = tmpl.lock(root_key, &identity_pub)?;
outputs.push(CreateActionOutput {
locking_script: Some(locking_script),
satoshis,
output_description: format!("p2pkh {}", i),
basket: None,
custom_instructions: None,
tags: vec![],
});
}
Ok(outputs)
}
pub async fn create_p2pkh_outputs_action(
wallet: &Wallet,
count: usize,
satoshis: u64,
description: &str,
) -> WalletResult<CreateActionResult> {
let outputs = create_p2pkh_outputs(&wallet.key_deriver, count, satoshis)?;
use bsv::wallet::interfaces::WalletInterface;
let result = wallet
.create_action(
CreateActionArgs {
description: description.to_string(),
inputs: vec![],
outputs,
lock_time: None,
version: None,
labels: vec![],
options: None,
input_beef: None,
reference: None,
},
None,
)
.await
.map_err(|e| WalletError::Internal(format!("create_action failed: {}", e)))?;
Ok(result)
}
fn parse_protocol(protocol_id: &str) -> WalletResult<Protocol> {
if let Some((level_str, name)) = protocol_id.split_once('.') {
let security_level: u8 = level_str
.parse()
.map_err(|_| WalletError::InvalidParameter {
parameter: "protocol_id".to_string(),
must_be: "in format 'security_level.protocol_name' (e.g., '2.3241645161d8')"
.to_string(),
})?;
Ok(Protocol {
security_level,
protocol: name.to_string(),
})
} else {
Ok(Protocol {
security_level: 2,
protocol: protocol_id.to_string(),
})
}
}
fn parse_counterparty(counterparty: &str) -> WalletResult<Counterparty> {
match counterparty {
"self" => Ok(Counterparty {
counterparty_type: CounterpartyType::Self_,
public_key: None,
}),
"anyone" => Ok(Counterparty {
counterparty_type: CounterpartyType::Anyone,
public_key: None,
}),
hex_str => {
let pk = bsv::primitives::public_key::PublicKey::from_string(hex_str).map_err(|e| {
WalletError::InvalidParameter {
parameter: "counterparty".to_string(),
must_be: format!("'self', 'anyone', or a valid public key hex: {}", e),
}
})?;
Ok(Counterparty {
counterparty_type: CounterpartyType::Other,
public_key: Some(pk),
})
}
}
}
fn random_hex_string() -> String {
use base64::Engine;
use rand::RngCore;
let mut buf = [0u8; 8];
rand::thread_rng().fill_bytes(&mut buf);
base64::engine::general_purpose::STANDARD.encode(buf)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_protocol_with_level() {
let p = parse_protocol("2.3241645161d8").unwrap();
assert_eq!(p.security_level, 2);
assert_eq!(p.protocol, "3241645161d8");
}
#[test]
fn test_parse_protocol_without_level() {
let p = parse_protocol("3241645161d8").unwrap();
assert_eq!(p.security_level, 2);
assert_eq!(p.protocol, "3241645161d8");
}
#[test]
fn test_parse_counterparty_self() {
let cp = parse_counterparty("self").unwrap();
assert_eq!(cp.counterparty_type, CounterpartyType::Self_);
assert!(cp.public_key.is_none());
}
#[test]
fn test_parse_counterparty_anyone() {
let cp = parse_counterparty("anyone").unwrap();
assert_eq!(cp.counterparty_type, CounterpartyType::Anyone);
assert!(cp.public_key.is_none());
}
#[test]
fn test_wallet_builder_validates_chain() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let result = rt.block_on(WalletBuilder::new().build());
match result {
Err(e) => {
let err = e.to_string();
assert!(err.contains("chain"), "Expected chain error, got: {}", err);
}
Ok(_) => panic!("Expected error for missing chain"),
}
}
#[test]
fn test_wallet_builder_validates_root_key() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let result = rt.block_on(WalletBuilder::new().chain(Chain::Test).build());
match result {
Err(e) => {
let err = e.to_string();
assert!(
err.contains("root_key"),
"Expected root_key error, got: {}",
err
);
}
Ok(_) => panic!("Expected error for missing root_key"),
}
}
#[test]
fn test_wallet_builder_validates_storage() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let root_key = PrivateKey::from_hex("aa").unwrap();
let result = rt.block_on(
WalletBuilder::new()
.chain(Chain::Test)
.root_key(root_key)
.build(),
);
match result {
Err(e) => {
let err = e.to_string();
assert!(
err.contains("storage"),
"Expected storage error, got: {}",
err
);
}
Ok(_) => panic!("Expected error for missing storage"),
}
}
#[test]
fn test_get_key_pair_self() {
let priv_key = PrivateKey::from_hex("aa").unwrap();
let key_deriver = CachedKeyDeriver::new(priv_key, None);
let kp = get_key_pair(&key_deriver, "2.3241645161d8", "test_key", "self").unwrap();
assert!(!kp.private_key.is_empty());
assert!(!kp.public_key.is_empty());
assert_eq!(kp.public_key.len(), 66);
}
#[test]
fn test_get_lock_p2pkh_produces_25_byte_script() {
let priv_key = PrivateKey::from_hex("aa").unwrap();
let key_deriver = CachedKeyDeriver::new(priv_key, None);
let script = get_lock_p2pkh(&key_deriver, "2.3241645161d8", "test_key", "self").unwrap();
assert_eq!(script.len(), 25);
}
#[test]
fn test_create_p2pkh_outputs_count() {
let priv_key = PrivateKey::from_hex("aa").unwrap();
let key_deriver = CachedKeyDeriver::new(priv_key, None);
let outputs = create_p2pkh_outputs(&key_deriver, 3, 1000).unwrap();
assert_eq!(outputs.len(), 3);
for (i, o) in outputs.iter().enumerate() {
assert_eq!(o.satoshis, 1000);
assert!(o.locking_script.is_some());
assert_eq!(o.output_description, format!("p2pkh {}", i));
}
}
#[test]
fn test_random_hex_string_length() {
let s = random_hex_string();
assert_eq!(s.len(), 12);
use base64::Engine;
assert!(base64::engine::general_purpose::STANDARD.decode(&s).is_ok());
}
}