1use std::sync::Arc;
10
11use bsv::primitives::private_key::PrivateKey;
12use bsv::wallet::cached_key_deriver::CachedKeyDeriver;
13use bsv::wallet::interfaces::{CreateActionArgs, CreateActionOutput, CreateActionResult};
14use bsv::wallet::types::{Counterparty, CounterpartyType, Protocol};
15
16use crate::error::{WalletError, WalletResult};
17use crate::monitor::Monitor;
18use crate::services::traits::WalletServices;
19use crate::storage::manager::WalletStorageManager;
20use crate::storage::StorageConfig;
21use crate::types::Chain;
22use crate::utility::script_template_brc29::ScriptTemplateBRC29;
23use crate::wallet::privileged::PrivilegedKeyManager;
24use crate::wallet::types::{KeyPair, WalletArgs};
25use crate::wallet::wallet::Wallet;
26
27pub struct SetupWallet {
37 pub wallet: Wallet,
39 pub chain: Chain,
41 pub key_deriver: Arc<CachedKeyDeriver>,
43 pub identity_key: String,
45 pub storage: WalletStorageManager,
47 pub services: Option<Arc<dyn WalletServices>>,
49 pub monitor: Option<Arc<Monitor>>,
51}
52
53enum StorageKind {
59 Sqlite(String),
61 #[allow(dead_code)]
63 Mysql(String),
64 #[allow(dead_code)]
66 Postgres(String),
67}
68
69pub struct WalletBuilder {
101 chain: Option<Chain>,
102 root_key: Option<PrivateKey>,
103 storage_config: Option<StorageKind>,
104 storage_identity_key: Option<String>,
105 services: Option<Arc<dyn WalletServices>>,
106 use_default_services: bool,
107 monitor_enabled: bool,
108 privileged_key_manager: Option<Arc<dyn PrivilegedKeyManager>>,
109 pool_max_connections: Option<u32>,
110 pool_min_connections: Option<u32>,
111 pool_idle_timeout: Option<std::time::Duration>,
112 pool_connect_timeout: Option<std::time::Duration>,
113}
114
115impl WalletBuilder {
116 pub fn new() -> Self {
118 Self {
119 chain: None,
120 root_key: None,
121 storage_config: None,
122 storage_identity_key: None,
123 services: None,
124 use_default_services: false,
125 monitor_enabled: false,
126 privileged_key_manager: None,
127 pool_max_connections: None,
128 pool_min_connections: None,
129 pool_idle_timeout: None,
130 pool_connect_timeout: None,
131 }
132 }
133
134 pub fn chain(mut self, chain: Chain) -> Self {
136 self.chain = Some(chain);
137 self
138 }
139
140 pub fn root_key(mut self, key: PrivateKey) -> Self {
142 self.root_key = Some(key);
143 self
144 }
145
146 pub fn with_sqlite(mut self, path: &str) -> Self {
148 self.storage_config = Some(StorageKind::Sqlite(path.to_string()));
149 self
150 }
151
152 pub fn with_sqlite_memory(mut self) -> Self {
154 self.storage_config = Some(StorageKind::Sqlite(":memory:".to_string()));
155 self
156 }
157
158 pub fn with_mysql(mut self, url: &str) -> Self {
160 self.storage_config = Some(StorageKind::Mysql(url.to_string()));
161 self
162 }
163
164 pub fn with_postgres(mut self, url: &str) -> Self {
166 self.storage_config = Some(StorageKind::Postgres(url.to_string()));
167 self
168 }
169
170 pub fn with_default_services(mut self) -> Self {
175 self.use_default_services = true;
176 self
177 }
178
179 pub fn with_services(mut self, services: Arc<dyn WalletServices>) -> Self {
181 self.services = Some(services);
182 self
183 }
184
185 pub fn with_monitor(mut self) -> Self {
187 self.monitor_enabled = true;
188 self
189 }
190
191 pub fn with_storage_identity_key(mut self, key: String) -> Self {
196 self.storage_identity_key = Some(key);
197 self
198 }
199
200 pub fn with_privileged_key_manager(mut self, pkm: Arc<dyn PrivilegedKeyManager>) -> Self {
202 self.privileged_key_manager = Some(pkm);
203 self
204 }
205
206 pub fn with_max_connections(mut self, max: u32) -> Self {
211 self.pool_max_connections = Some(max);
212 self
213 }
214
215 pub fn with_min_connections(mut self, min: u32) -> Self {
219 self.pool_min_connections = Some(min);
220 self
221 }
222
223 pub fn with_pool_idle_timeout(mut self, timeout: std::time::Duration) -> Self {
227 self.pool_idle_timeout = Some(timeout);
228 self
229 }
230
231 pub fn with_pool_connect_timeout(mut self, timeout: std::time::Duration) -> Self {
235 self.pool_connect_timeout = Some(timeout);
236 self
237 }
238
239 pub async fn build(self) -> WalletResult<SetupWallet> {
246 let chain = self
248 .chain
249 .ok_or_else(|| WalletError::MissingParameter("chain".to_string()))?;
250 let root_key = self
251 .root_key
252 .ok_or_else(|| WalletError::MissingParameter("root_key".to_string()))?;
253 let storage_kind = self.storage_config.ok_or_else(|| {
254 WalletError::MissingParameter(
255 "storage (call with_sqlite, with_sqlite_memory, with_mysql, or with_postgres)"
256 .to_string(),
257 )
258 })?;
259
260 let key_deriver = Arc::new(CachedKeyDeriver::new(root_key, None));
262 let identity_key_hex = key_deriver.identity_key().to_der_hex();
263
264 let pool_max = self.pool_max_connections;
266 let pool_min = self.pool_min_connections;
267 let pool_idle = self.pool_idle_timeout;
268 let pool_connect = self.pool_connect_timeout;
269 let apply_pool_overrides = |config: &mut StorageConfig| {
270 if let Some(max) = pool_max {
271 config.max_connections = max;
272 }
273 if let Some(min) = pool_min {
274 config.min_connections = min;
275 }
276 if let Some(timeout) = pool_idle {
277 config.idle_timeout = timeout;
278 }
279 if let Some(timeout) = pool_connect {
280 config.connect_timeout = timeout;
281 }
282 };
283
284 use crate::storage::traits::wallet_provider::WalletStorageProvider;
289 let provider: Arc<dyn WalletStorageProvider> = match storage_kind {
290 StorageKind::Sqlite(path) => {
291 let url = if path == ":memory:" {
292 "sqlite::memory:".to_string()
293 } else {
294 format!("sqlite:{}", path)
295 };
296 let mut config = StorageConfig {
297 url,
298 ..StorageConfig::default()
299 };
300 apply_pool_overrides(&mut config);
301 #[cfg(feature = "sqlite")]
302 {
303 let storage =
304 crate::storage::sqlx_impl::SqliteStorage::new_sqlite(config, chain.clone())
305 .await?;
306 Arc::new(storage) as Arc<dyn WalletStorageProvider>
307 }
308 #[cfg(not(feature = "sqlite"))]
309 {
310 let _ = config;
311 return Err(WalletError::InvalidOperation(
312 "SQLite feature not enabled. Add `sqlite` feature to Cargo.toml."
313 .to_string(),
314 ));
315 }
316 }
317 StorageKind::Mysql(url) => {
318 let mut config = StorageConfig {
319 url,
320 ..StorageConfig::default()
321 };
322 apply_pool_overrides(&mut config);
323 #[cfg(feature = "mysql")]
324 {
325 let mut storage =
326 crate::storage::sqlx_impl::MysqlStorage::new_mysql(config, chain.clone())
327 .await?;
328 if let Some(ref sik) = self.storage_identity_key {
329 storage.storage_identity_key = sik.clone();
330 }
331 Arc::new(storage) as Arc<dyn WalletStorageProvider>
332 }
333 #[cfg(not(feature = "mysql"))]
334 {
335 let _ = config;
336 return Err(WalletError::InvalidOperation(
337 "MySQL feature not enabled. Add `mysql` feature to Cargo.toml.".to_string(),
338 ));
339 }
340 }
341 StorageKind::Postgres(url) => {
342 let mut config = StorageConfig {
343 url,
344 ..StorageConfig::default()
345 };
346 apply_pool_overrides(&mut config);
347 #[cfg(feature = "postgres")]
348 {
349 let storage =
350 crate::storage::sqlx_impl::PgStorage::new_postgres(config, chain.clone())
351 .await?;
352 Arc::new(storage) as Arc<dyn WalletStorageProvider>
353 }
354 #[cfg(not(feature = "postgres"))]
355 {
356 let _ = config;
357 return Err(WalletError::InvalidOperation(
358 "PostgreSQL feature not enabled. Add `postgres` feature to Cargo.toml."
359 .to_string(),
360 ));
361 }
362 }
363 };
364
365 provider.migrate("setup", "").await?;
367
368 let make_manager = |key: String, p: Arc<dyn WalletStorageProvider>| {
372 WalletStorageManager::new(key, Some(p), vec![])
373 };
374
375 let storage = make_manager(identity_key_hex.clone(), provider.clone());
377 storage.make_available().await?;
378
379 let services: Option<Arc<dyn WalletServices>> = if let Some(svc) = self.services {
381 Some(svc)
382 } else if self.use_default_services {
383 Some(Arc::new(crate::services::services::Services::from_chain(
384 chain.clone(),
385 )))
386 } else {
387 None
388 };
389
390 let wallet_storage = make_manager(identity_key_hex.clone(), provider.clone());
392 wallet_storage.make_available().await?;
393
394 let wallet_args = WalletArgs {
395 chain: chain.clone(),
396 key_deriver: key_deriver.clone(),
397 storage: wallet_storage,
398 services: services.clone(),
399 monitor: None, privileged_key_manager: self.privileged_key_manager,
401 settings_manager: None,
402 lookup_resolver: None,
403 };
404
405 let wallet = Wallet::new(wallet_args)?;
407
408 let monitor = if self.monitor_enabled {
410 if let Some(ref svc) = services {
411 let monitor_storage = make_manager(identity_key_hex.clone(), provider.clone());
412 monitor_storage.make_available().await?;
413 let monitor = crate::monitor::Monitor::builder()
414 .chain(chain.clone())
415 .storage(monitor_storage)
416 .services(svc.clone())
417 .default_tasks()
418 .build()?;
419 Some(Arc::new(monitor))
420 } else {
421 None
423 }
424 } else {
425 None
426 };
427
428 Ok(SetupWallet {
429 wallet,
430 chain,
431 key_deriver,
432 identity_key: identity_key_hex,
433 storage,
434 services,
435 monitor,
436 })
437 }
438}
439
440impl Default for WalletBuilder {
441 fn default() -> Self {
442 Self::new()
443 }
444}
445
446pub fn get_key_pair(
454 key_deriver: &CachedKeyDeriver,
455 protocol_id: &str,
456 key_id: &str,
457 counterparty: &str,
458) -> WalletResult<KeyPair> {
459 let protocol = parse_protocol(protocol_id)?;
460 let cp = parse_counterparty(counterparty)?;
461
462 let private_key = key_deriver
463 .derive_private_key(&protocol, key_id, &cp)
464 .map_err(|e| WalletError::Internal(format!("Key derivation failed: {}", e)))?;
465
466 let public_key = private_key.to_public_key();
467
468 Ok(KeyPair {
469 private_key: private_key.to_hex(),
470 public_key: public_key.to_der_hex(),
471 })
472}
473
474pub fn get_lock_p2pkh(
479 key_deriver: &CachedKeyDeriver,
480 protocol_id: &str,
481 key_id: &str,
482 counterparty: &str,
483) -> WalletResult<Vec<u8>> {
484 let protocol = parse_protocol(protocol_id)?;
485 let cp = parse_counterparty(counterparty)?;
486
487 let derived_pub = key_deriver
488 .derive_public_key(&protocol, key_id, &cp, false)
489 .map_err(|e| WalletError::Internal(format!("Public key derivation failed: {}", e)))?;
490
491 use bsv::script::templates::p2pkh::P2PKH;
492 use bsv::script::templates::ScriptTemplateLock;
493
494 let hash_vec = derived_pub.to_hash();
496 let mut hash = [0u8; 20];
497 hash.copy_from_slice(&hash_vec);
498
499 let p2pkh = P2PKH::from_public_key_hash(hash);
500 let locking_script = p2pkh
501 .lock()
502 .map_err(|e| WalletError::Internal(format!("P2PKH lock failed: {}", e)))?;
503 Ok(locking_script.to_binary())
504}
505
506pub fn create_p2pkh_outputs(
511 key_deriver: &CachedKeyDeriver,
512 count: usize,
513 satoshis: u64,
514) -> WalletResult<Vec<CreateActionOutput>> {
515 let mut outputs = Vec::with_capacity(count);
516 let root_key = key_deriver.root_key();
517 let identity_pub = key_deriver.identity_key();
518
519 for i in 0..count {
520 let derivation_prefix = random_hex_string();
522 let derivation_suffix = random_hex_string();
523
524 let tmpl = ScriptTemplateBRC29::new(derivation_prefix, derivation_suffix);
525 let locking_script = tmpl.lock(root_key, &identity_pub)?;
526
527 outputs.push(CreateActionOutput {
528 locking_script: Some(locking_script),
529 satoshis,
530 output_description: format!("p2pkh {}", i),
531 basket: None,
532 custom_instructions: None,
533 tags: vec![],
534 });
535 }
536
537 Ok(outputs)
538}
539
540pub async fn create_p2pkh_outputs_action(
545 wallet: &Wallet,
546 count: usize,
547 satoshis: u64,
548 description: &str,
549) -> WalletResult<CreateActionResult> {
550 let outputs = create_p2pkh_outputs(&wallet.key_deriver, count, satoshis)?;
551
552 use bsv::wallet::interfaces::WalletInterface;
553 let result = wallet
554 .create_action(
555 CreateActionArgs {
556 description: description.to_string(),
557 inputs: vec![],
558 outputs,
559 lock_time: None,
560 version: None,
561 labels: vec![],
562 options: None,
563 input_beef: None,
564 reference: None,
565 },
566 None,
567 )
568 .await
569 .map_err(|e| WalletError::Internal(format!("create_action failed: {}", e)))?;
570
571 Ok(result)
572}
573
574fn parse_protocol(protocol_id: &str) -> WalletResult<Protocol> {
580 if let Some((level_str, name)) = protocol_id.split_once('.') {
581 let security_level: u8 = level_str
582 .parse()
583 .map_err(|_| WalletError::InvalidParameter {
584 parameter: "protocol_id".to_string(),
585 must_be: "in format 'security_level.protocol_name' (e.g., '2.3241645161d8')"
586 .to_string(),
587 })?;
588 Ok(Protocol {
589 security_level,
590 protocol: name.to_string(),
591 })
592 } else {
593 Ok(Protocol {
595 security_level: 2,
596 protocol: protocol_id.to_string(),
597 })
598 }
599}
600
601fn parse_counterparty(counterparty: &str) -> WalletResult<Counterparty> {
603 match counterparty {
604 "self" => Ok(Counterparty {
605 counterparty_type: CounterpartyType::Self_,
606 public_key: None,
607 }),
608 "anyone" => Ok(Counterparty {
609 counterparty_type: CounterpartyType::Anyone,
610 public_key: None,
611 }),
612 hex_str => {
613 let pk = bsv::primitives::public_key::PublicKey::from_string(hex_str).map_err(|e| {
614 WalletError::InvalidParameter {
615 parameter: "counterparty".to_string(),
616 must_be: format!("'self', 'anyone', or a valid public key hex: {}", e),
617 }
618 })?;
619 Ok(Counterparty {
620 counterparty_type: CounterpartyType::Other,
621 public_key: Some(pk),
622 })
623 }
624 }
625}
626
627fn random_hex_string() -> String {
629 use std::time::{SystemTime, UNIX_EPOCH};
630 let now = SystemTime::now()
631 .duration_since(UNIX_EPOCH)
632 .unwrap_or_default();
633 let nanos = now.as_nanos();
634 let random_val: u64 = (nanos as u64) ^ (nanos.wrapping_shr(64) as u64);
636 format!("{:016x}", random_val)
637}
638
639#[cfg(test)]
644mod tests {
645 use super::*;
646
647 #[test]
648 fn test_parse_protocol_with_level() {
649 let p = parse_protocol("2.3241645161d8").unwrap();
650 assert_eq!(p.security_level, 2);
651 assert_eq!(p.protocol, "3241645161d8");
652 }
653
654 #[test]
655 fn test_parse_protocol_without_level() {
656 let p = parse_protocol("3241645161d8").unwrap();
657 assert_eq!(p.security_level, 2);
658 assert_eq!(p.protocol, "3241645161d8");
659 }
660
661 #[test]
662 fn test_parse_counterparty_self() {
663 let cp = parse_counterparty("self").unwrap();
664 assert_eq!(cp.counterparty_type, CounterpartyType::Self_);
665 assert!(cp.public_key.is_none());
666 }
667
668 #[test]
669 fn test_parse_counterparty_anyone() {
670 let cp = parse_counterparty("anyone").unwrap();
671 assert_eq!(cp.counterparty_type, CounterpartyType::Anyone);
672 assert!(cp.public_key.is_none());
673 }
674
675 #[test]
676 fn test_wallet_builder_validates_chain() {
677 let rt = tokio::runtime::Builder::new_current_thread()
678 .enable_all()
679 .build()
680 .unwrap();
681 let result = rt.block_on(WalletBuilder::new().build());
682 match result {
683 Err(e) => {
684 let err = e.to_string();
685 assert!(err.contains("chain"), "Expected chain error, got: {}", err);
686 }
687 Ok(_) => panic!("Expected error for missing chain"),
688 }
689 }
690
691 #[test]
692 fn test_wallet_builder_validates_root_key() {
693 let rt = tokio::runtime::Builder::new_current_thread()
694 .enable_all()
695 .build()
696 .unwrap();
697 let result = rt.block_on(WalletBuilder::new().chain(Chain::Test).build());
698 match result {
699 Err(e) => {
700 let err = e.to_string();
701 assert!(
702 err.contains("root_key"),
703 "Expected root_key error, got: {}",
704 err
705 );
706 }
707 Ok(_) => panic!("Expected error for missing root_key"),
708 }
709 }
710
711 #[test]
712 fn test_wallet_builder_validates_storage() {
713 let rt = tokio::runtime::Builder::new_current_thread()
714 .enable_all()
715 .build()
716 .unwrap();
717 let root_key = PrivateKey::from_hex("aa").unwrap();
718 let result = rt.block_on(
719 WalletBuilder::new()
720 .chain(Chain::Test)
721 .root_key(root_key)
722 .build(),
723 );
724 match result {
725 Err(e) => {
726 let err = e.to_string();
727 assert!(
728 err.contains("storage"),
729 "Expected storage error, got: {}",
730 err
731 );
732 }
733 Ok(_) => panic!("Expected error for missing storage"),
734 }
735 }
736
737 #[test]
738 fn test_get_key_pair_self() {
739 let priv_key = PrivateKey::from_hex("aa").unwrap();
740 let key_deriver = CachedKeyDeriver::new(priv_key, None);
741 let kp = get_key_pair(&key_deriver, "2.3241645161d8", "test_key", "self").unwrap();
742 assert!(!kp.private_key.is_empty());
743 assert!(!kp.public_key.is_empty());
744 assert_eq!(kp.public_key.len(), 66);
746 }
747
748 #[test]
749 fn test_get_lock_p2pkh_produces_25_byte_script() {
750 let priv_key = PrivateKey::from_hex("aa").unwrap();
751 let key_deriver = CachedKeyDeriver::new(priv_key, None);
752 let script = get_lock_p2pkh(&key_deriver, "2.3241645161d8", "test_key", "self").unwrap();
753 assert_eq!(script.len(), 25);
755 }
756
757 #[test]
758 fn test_create_p2pkh_outputs_count() {
759 let priv_key = PrivateKey::from_hex("aa").unwrap();
760 let key_deriver = CachedKeyDeriver::new(priv_key, None);
761 let outputs = create_p2pkh_outputs(&key_deriver, 3, 1000).unwrap();
762 assert_eq!(outputs.len(), 3);
763 for (i, o) in outputs.iter().enumerate() {
764 assert_eq!(o.satoshis, 1000);
765 assert!(o.locking_script.is_some());
766 assert_eq!(o.output_description, format!("p2pkh {}", i));
767 }
768 }
769
770 #[test]
771 fn test_random_hex_string_length() {
772 let s = random_hex_string();
773 assert_eq!(s.len(), 16);
774 }
775}