#![allow(missing_docs)]
use std::{
collections::HashMap,
fs::File,
io::Write,
path::{Path, PathBuf},
};
use primitive_types::H160;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::{
neo_builder::{Transaction, TransactionBuilder, Witness},
neo_clients::{APITrait, JsonRpcProvider, ProviderError, RpcClient},
neo_config::NeoConstants,
neo_crypto::{CryptoError, HashableForVec, KeyPair, Secp256r1Signature},
neo_protocol::{Account, AccountTrait, UnclaimedGas},
neo_types::{
script_hash::ScriptHashExtension,
serde_with_utils::{
deserialize_hash_map_h160_account, deserialize_script_hash,
serialize_hash_map_h160_account, serialize_script_hash,
},
ScryptParamsDef,
},
neo_wallets::{NEP6Account, Nep6Wallet, WalletError, WalletTrait},
};
use scrypt::Params;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Wallet {
pub name: String,
pub version: String,
pub scrypt_params: ScryptParamsDef,
#[serde(deserialize_with = "deserialize_hash_map_h160_account")]
#[serde(serialize_with = "serialize_hash_map_h160_account")]
pub accounts: HashMap<H160, Account>,
#[serde(deserialize_with = "deserialize_script_hash")]
#[serde(serialize_with = "serialize_script_hash")]
pub(crate) default_account: H160,
#[serde(skip_serializing_if = "Option::is_none")]
pub extra: Option<HashMap<String, String>>,
}
impl Default for Wallet {
fn default() -> Self {
Self {
name: Wallet::DEFAULT_WALLET_NAME.to_string(),
version: Wallet::CURRENT_VERSION.to_string(),
scrypt_params: ScryptParamsDef::default(),
accounts: HashMap::new(),
default_account: H160::default(),
extra: None,
}
}
}
impl WalletTrait for Wallet {
type Account = Account;
fn name(&self) -> &String {
&self.name
}
fn version(&self) -> &String {
&self.version
}
fn scrypt_params(&self) -> &ScryptParamsDef {
&self.scrypt_params
}
fn accounts(&self) -> Vec<Self::Account> {
self.accounts.values().cloned().collect::<Vec<Self::Account>>()
}
fn default_account(&self) -> Option<&Account> {
self.accounts.get(&self.default_account)
}
fn set_name(&mut self, name: String) {
self.name = name;
}
fn set_version(&mut self, version: String) {
self.version = version;
}
fn set_scrypt_params(&mut self, params: ScryptParamsDef) {
self.scrypt_params = params;
}
fn set_default_account(&mut self, default_account: H160) {
self.default_account = default_account;
self.sync_default_account_flags();
}
fn add_account(&mut self, account: Self::Account) {
let script_hash = account.get_script_hash();
let was_empty = self.accounts.is_empty();
self.accounts.insert(script_hash, account);
if was_empty {
self.default_account = script_hash;
}
self.sync_default_account_flags();
}
fn remove_account(&mut self, hash: &H160) -> Option<Self::Account> {
let removed = self.accounts.remove(hash);
if removed.is_some() {
if self.default_account == *hash {
self.promote_default_account();
} else {
self.sync_default_account_flags();
}
}
removed
}
}
impl Wallet {
pub const DEFAULT_WALLET_NAME: &'static str = "NeoWallet";
pub const CURRENT_VERSION: &'static str = "1.0";
fn sync_default_account_flags(&mut self) {
for (hash, account) in self.accounts.iter_mut() {
account.is_default = *hash == self.default_account;
}
}
fn promote_default_account(&mut self) {
self.default_account = self
.accounts
.keys()
.copied()
.min_by_key(|hash| hash.to_fixed_bytes())
.unwrap_or_default();
self.sync_default_account_flags();
}
fn effective_scrypt_params(&self) -> Params {
Params::new(self.scrypt_params.log_n, self.scrypt_params.r, self.scrypt_params.p, 32)
.unwrap_or_else(|e| {
tracing::warn!(
error = %e,
log_n = self.scrypt_params.log_n,
r = self.scrypt_params.r,
p = self.scrypt_params.p,
"Invalid scrypt params; falling back to Neo defaults"
);
Params::new(
NeoConstants::SCRYPT_LOG_N,
NeoConstants::SCRYPT_R,
NeoConstants::SCRYPT_P,
32,
)
.unwrap_or_else(|e| {
tracing::error!(
error = %e,
"Neo default scrypt parameters are invalid; falling back to scrypt recommended params"
);
Params::recommended()
})
})
}
pub fn new() -> Self {
Self::try_new_with_account_factory(Account::create).unwrap_or_else(|e| {
panic!("failed to create default wallet; use Wallet::try_new for fallible handling: {e}")
})
}
pub fn try_new() -> Result<Self, WalletError> {
Self::try_new_with_account_factory(Account::create)
}
fn try_new_with_account_factory<F>(create_account: F) -> Result<Self, WalletError>
where
F: FnOnce() -> Result<Account, ProviderError>,
{
let mut account = create_account().map_err(WalletError::ProviderError)?;
account.is_default = true;
let default_account_hash = account.address_or_scripthash.script_hash();
let mut accounts = HashMap::new();
accounts.insert(default_account_hash, account);
Ok(Self {
name: Self::DEFAULT_WALLET_NAME.to_string(),
version: Self::CURRENT_VERSION.to_string(),
scrypt_params: ScryptParamsDef::default(),
accounts,
default_account: default_account_hash,
extra: None,
})
}
pub fn to_nep6(&self) -> Result<Nep6Wallet, WalletError> {
let accounts = self
.accounts
.values()
.map(NEP6Account::from_account)
.collect::<Result<Vec<NEP6Account>, _>>()?;
Ok(Nep6Wallet {
name: self.name.clone(),
version: self.version.clone(),
scrypt: self.scrypt_params.clone(),
accounts,
extra: None,
})
}
pub fn from_nep6(nep6: Nep6Wallet) -> Result<Self, WalletError> {
let accounts = nep6
.accounts()
.iter()
.map(NEP6Account::to_account)
.collect::<Result<Vec<_>, _>>()?;
if accounts.is_empty() {
tracing::warn!("No accounts found in NEP6 wallet");
return Err(WalletError::NoAccounts);
}
let default_account =
if let Some(account) = accounts.iter().find(|account| account.is_default) {
account.get_script_hash()
} else {
tracing::warn!("No default account found, using first account");
accounts[0].get_script_hash()
};
let mut wallet = Self {
name: nep6.name().clone(),
version: nep6.version().clone(),
scrypt_params: nep6.scrypt().clone(),
accounts: accounts
.into_iter()
.map(|account| (account.get_script_hash(), account))
.collect(),
default_account,
extra: nep6.extra.clone(),
};
wallet.sync_default_account_flags();
Ok(wallet)
}
pub fn from_account(account: &Account) -> Result<Wallet, WalletError> {
let mut wallet: Wallet = Wallet::default();
wallet.add_account(account.clone());
Ok(wallet)
}
pub fn from_accounts(accounts: Vec<Account>) -> Result<Wallet, WalletError> {
let mut wallet: Wallet = Wallet::default();
for account in &accounts {
wallet.add_account(account.clone());
}
if let Some(first_account) = accounts.first() {
wallet.set_default_account(first_account.get_script_hash());
} else {
return Err(WalletError::NoAccounts);
}
Ok(wallet)
}
pub fn save_to_file(&self, path: PathBuf) -> Result<(), WalletError> {
let nep6 = self.to_nep6()?;
let json = serde_json::to_string(&nep6).map_err(|e| {
WalletError::AccountState(format!("Failed to serialize wallet to JSON: {e}"))
})?;
let mut file = File::create(path)
.map_err(|e| WalletError::FileError(format!("Failed to create wallet file: {e}")))?;
file.write_all(json.as_bytes())
.map_err(|e| WalletError::FileError(format!("Failed to write wallet file: {e}")))?;
Ok(())
}
pub fn get_account(&self, script_hash: &H160) -> Option<&Account> {
self.accounts.get(script_hash)
}
pub fn remove_account(&mut self, script_hash: &H160) -> bool {
WalletTrait::remove_account(self, script_hash).is_some()
}
pub fn encrypt_accounts(&mut self, password: &str) {
let params = self.effective_scrypt_params();
for account in self.accounts.values_mut() {
if account.key_pair().is_some() {
if let Err(e) = account.encrypt_private_key_with_params(password, ¶ms) {
tracing::warn!(
address = %account.get_address(),
error = %e,
"Failed to encrypt private key for account"
);
}
}
}
}
pub fn encrypt_accounts_parallel(&mut self, password: &str) {
let params = self.effective_scrypt_params();
let errors: Vec<(String, String)> = self
.accounts
.par_iter_mut()
.filter_map(|(_, account)| {
if account.key_pair().is_some() {
match account.encrypt_private_key_with_params(password, ¶ms) {
Err(e) => Some((account.get_address(), e.to_string())),
Ok(_) => None,
}
} else {
None
}
})
.collect();
for (address, error) in errors {
tracing::warn!(address = %address, error = %error, "Failed to encrypt private key");
}
}
pub fn encrypt_accounts_parallel_with_threads(&mut self, password: &str, num_threads: usize) {
let pool = match rayon::ThreadPoolBuilder::new().num_threads(num_threads).build() {
Ok(pool) => pool,
Err(err) => {
tracing::warn!(
threads = num_threads,
error = %err,
"Failed to build custom rayon thread pool; falling back to default pool"
);
self.encrypt_accounts_parallel(password);
return;
},
};
pool.install(|| {
self.encrypt_accounts_parallel(password);
});
}
pub fn encrypt_accounts_batch_parallel(&mut self, password: &str, batch_size: usize) {
use std::sync::{Arc, Mutex};
let params = self.effective_scrypt_params();
let accounts_to_encrypt: Vec<(H160, Account)> = self
.accounts
.iter()
.filter(|(_, account)| account.key_pair().is_some())
.map(|(hash, account)| (*hash, account.clone()))
.collect();
let results: Arc<Mutex<Vec<(H160, Result<Account, String>)>>> =
Arc::new(Mutex::new(Vec::new()));
accounts_to_encrypt.par_chunks(batch_size).for_each(|batch| {
let batch_results: Vec<(H160, Result<Account, String>)> = batch
.iter()
.map(|(hash, account)| {
let mut account_clone = account.clone();
match account_clone.encrypt_private_key_with_params(password, ¶ms) {
Ok(_) => (*hash, Ok(account_clone)),
Err(e) => (*hash, Err(format!("{}: {}", account.get_address(), e))),
}
})
.collect();
results.lock().unwrap_or_else(|e| e.into_inner()).extend(batch_results);
});
let results = match Arc::try_unwrap(results) {
Ok(mutex) => mutex.into_inner().unwrap_or_else(|e| e.into_inner()),
Err(arc) => arc.lock().unwrap_or_else(|e| e.into_inner()).clone(),
};
for (hash, result) in results {
match result {
Ok(encrypted_account) => {
self.accounts.insert(hash, encrypted_account);
},
Err(error_msg) => {
tracing::warn!("Failed to encrypt private key for account {}", error_msg);
},
}
}
}
#[deprecated(since = "0.1.0", note = "Please use `create_wallet` instead")]
pub fn create(path: &Path, password: &str) -> Result<Self, WalletError> {
Self::create_wallet(path, password)
}
#[deprecated(since = "0.1.0", note = "Please use `open_wallet` instead")]
pub fn open(path: &Path, password: &str) -> Result<Self, WalletError> {
Self::open_wallet(path, password)
}
pub fn get_accounts(&self) -> Vec<&Account> {
self.accounts.values().collect()
}
pub fn create_account(&mut self) -> Result<&Account, WalletError> {
let account = Account::create()?;
self.add_account(account.clone());
self.get_account(&account.get_script_hash()).ok_or_else(|| {
WalletError::AccountState("Account was added but could not be retrieved".to_string())
})
}
pub fn import_private_key(&mut self, wif: &str) -> Result<&Account, WalletError> {
let key_pair = KeyPair::from_wif(wif)
.map_err(|e| WalletError::AccountState(format!("Failed to import private key: {e}")))?;
let account =
Account::from_key_pair(key_pair, None, None).map_err(WalletError::ProviderError)?;
self.add_account(account.clone());
self.get_account(&account.get_script_hash()).ok_or_else(|| {
WalletError::AccountState("Account was added but could not be retrieved".to_string())
})
}
pub fn verify_password(&self, password: &str) -> bool {
if self.accounts.is_empty() {
return false;
}
let params = self.effective_scrypt_params();
for account in self.accounts.values() {
if account.encrypted_private_key().is_none() {
continue;
}
if account.key_pair().is_some() {
continue;
}
let mut account_clone = account.clone();
match account_clone.decrypt_private_key_with_params(password, ¶ms) {
Ok(_) => return true, Err(_) => continue, }
}
false
}
pub fn change_password(
&mut self,
current_password: &str,
new_password: &str,
) -> Result<(), WalletError> {
if !self.verify_password(current_password) {
return Err(WalletError::AccountState("Invalid password".to_string()));
}
let params = self.effective_scrypt_params();
for account in self.accounts.values_mut() {
if account.encrypted_private_key().is_some() && account.key_pair().is_none() {
if let Err(e) = account.decrypt_private_key_with_params(current_password, ¶ms) {
return Err(WalletError::DecryptionError(format!(
"Failed to decrypt account {}: {}",
account.get_address(),
e
)));
}
}
}
self.encrypt_accounts(new_password);
Ok(())
}
pub fn change_password_parallel(
&mut self,
current_password: &str,
new_password: &str,
) -> Result<(), WalletError> {
if !self.verify_password(current_password) {
return Err(WalletError::AccountState("Invalid password".to_string()));
}
let params = self.effective_scrypt_params();
let accounts_to_decrypt: Vec<(H160, Account)> = self
.accounts
.iter()
.filter(|(_, account)| {
account.encrypted_private_key().is_some() && account.key_pair().is_none()
})
.map(|(hash, account)| (*hash, account.clone()))
.collect();
let decrypted_results: Vec<(H160, Result<Account, String>)> = accounts_to_decrypt
.into_par_iter()
.map(|(hash, account)| {
let mut account_clone = account.clone();
match account_clone.decrypt_private_key_with_params(current_password, ¶ms) {
Ok(_) => (hash, Ok(account_clone)),
Err(e) => (hash, Err(format!("{}: {}", account.get_address(), e))),
}
})
.collect();
for (_, result) in &decrypted_results {
if let Err(error_msg) = result {
return Err(WalletError::DecryptionError(format!(
"Failed to decrypt account {}",
error_msg
)));
}
}
for (hash, result) in decrypted_results {
if let Ok(decrypted_account) = result {
self.accounts.insert(hash, decrypted_account);
}
}
self.encrypt_accounts_parallel(new_password);
Ok(())
}
pub async fn get_unclaimed_gas<P>(&self, rpc_client: &P) -> Result<UnclaimedGas, WalletError>
where
P: JsonRpcProvider + APITrait + 'static,
<P as APITrait>::Error: Into<ProviderError>,
{
let mut total_unclaimed = UnclaimedGas::default();
for account in self.get_accounts() {
let script_hash = account.get_script_hash();
let unclaimed = rpc_client
.get_unclaimed_gas(script_hash)
.await
.map_err(|e| WalletError::ProviderError(e.into()))?;
total_unclaimed += unclaimed;
}
Ok(total_unclaimed)
}
}
impl Wallet {
pub async fn sign_message<S: Send + Sync + AsRef<[u8]>>(
&self,
message: S,
) -> Result<Secp256r1Signature, WalletError> {
let message = message.as_ref();
let binding = message.hash256();
let message_hash = binding.as_slice();
self.default_account()
.ok_or(WalletError::NoDefaultAccount)?
.key_pair()
.clone()
.ok_or(WalletError::NoKeyPair)?
.private_key()
.sign_tx(message_hash)
.map_err(|_e| WalletError::NoKeyPair)
}
pub async fn get_witness<'a, P: JsonRpcProvider + 'static>(
&self,
tx: &Transaction<'a, P>,
) -> Result<Witness, WalletError> {
let tx_with_chain = tx;
if tx_with_chain.network().is_none() {
}
let account = self.default_account().ok_or(WalletError::NoDefaultAccount)?;
let key_pair = account.key_pair.clone().ok_or(WalletError::NoKeyPair)?;
Witness::create(tx.get_hash_data().await?, &key_pair).map_err(|_e| WalletError::NoKeyPair)
}
pub async fn sign_transaction<'a, P>(
&self,
tx_builder: &'a mut TransactionBuilder<'a, P>,
account_address: &str,
password: &str,
) -> Result<Transaction<'a, P>, WalletError>
where
P: JsonRpcProvider + 'static,
{
let script_hash = H160::from_address(account_address)
.map_err(|e| WalletError::AccountState(format!("Invalid address: {e}")))?;
let account = self.get_account(&script_hash).ok_or_else(|| {
WalletError::AccountState(format!("Account not found: {account_address}"))
})?;
let key_pair = match account.key_pair() {
Some(kp) => kp.clone(),
None => {
let mut account_clone = account.clone();
account_clone.decrypt_private_key(password).map_err(|e| {
WalletError::DecryptionError(format!("Failed to decrypt account: {e}"))
})?;
match account_clone.key_pair() {
Some(kp) => kp.clone(),
None => return Err(WalletError::NoKeyPair),
}
},
};
let mut tx = tx_builder.get_unsigned_tx().await?;
let witness = Witness::create(tx.get_hash_data().await?, &key_pair)
.map_err(|e| WalletError::SigningError(format!("Failed to create witness: {e}")))?;
tx.add_witness(witness);
Ok(tx)
}
#[allow(dead_code)]
fn address(&self) -> String {
if let Some(account) = self.get_account(&self.default_account) {
account.address_or_scripthash.address()
} else {
H160::default().to_address()
}
}
pub fn create_wallet(path: &Path, password: &str) -> Result<Self, WalletError> {
let mut wallet = Wallet::default();
let account = Account::create().map_err(WalletError::ProviderError)?;
wallet.add_account(account);
wallet.encrypt_accounts(password);
wallet.save_to_file(path.to_path_buf())?;
Ok(wallet)
}
pub fn open_wallet(path: &Path, password: &str) -> Result<Self, WalletError> {
let wallet_json = std::fs::read_to_string(path)
.map_err(|e| WalletError::FileError(format!("Failed to read wallet file: {e}")))?;
let nep6_wallet: Nep6Wallet = serde_json::from_str(&wallet_json).map_err(|e| {
WalletError::DeserializationError(format!("Failed to parse wallet JSON: {e}"))
})?;
let wallet = Wallet::from_nep6(nep6_wallet)?;
let can_decrypt = wallet.verify_password(password);
if !can_decrypt {
return Err(WalletError::CryptoError(CryptoError::InvalidPassphrase(
"Invalid password".to_string(),
)));
}
Ok(wallet)
}
pub fn get_all_accounts(&self) -> Vec<&Account> {
self.accounts.values().collect()
}
pub fn create_new_account(&mut self) -> Result<&Account, WalletError> {
let account = Account::create().map_err(WalletError::ProviderError)?;
let script_hash = account.address_or_scripthash.script_hash();
self.add_account(account);
self.get_account(&script_hash).ok_or_else(|| {
WalletError::AccountState("Account was added but could not be retrieved".to_string())
})
}
pub fn import_from_wif(&mut self, private_key: &str) -> Result<&Account, WalletError> {
let key_pair = KeyPair::from_wif(private_key).map_err(WalletError::CryptoError)?;
let account = Account::from_key_pair(key_pair, None, None)
.map_err(|e| WalletError::AccountState(format!("Failed to create account: {e}")))?;
let script_hash = account.address_or_scripthash.script_hash();
self.add_account(account);
self.get_account(&script_hash).ok_or_else(|| {
WalletError::AccountState("Account was added but could not be retrieved".to_string())
})
}
pub async fn get_unclaimed_gas_as_float<P>(
&self,
rpc_client: &RpcClient<P>,
) -> Result<f64, WalletError>
where
P: JsonRpcProvider + 'static,
{
let mut total_unclaimed = 0.0;
for account in self.accounts.values() {
let script_hash = account.address_or_scripthash.script_hash();
let unclaimed = rpc_client
.get_unclaimed_gas(script_hash)
.await
.map_err(WalletError::ProviderError)?;
total_unclaimed += unclaimed.unclaimed.parse::<f64>().unwrap_or(0.0);
}
Ok(total_unclaimed)
}
#[allow(dead_code)]
fn network(&self) -> u32 {
self.extra
.as_ref()
.and_then(|extra| {
extra
.get("network")
.map(|n| n.parse::<u32>().unwrap_or(NeoConstants::MAGIC_NUMBER_MAINNET))
})
.unwrap_or(NeoConstants::MAGIC_NUMBER_MAINNET)
}
pub fn with_network(mut self, network: u32) -> Self {
let mut extra = self.extra.unwrap_or_default();
extra.insert("network".to_string(), network.to_string());
self.extra = Some(extra);
self
}
}
#[cfg(test)]
mod tests {
use crate::{
neo_clients::ProviderError,
neo_config::TestConstants,
neo_protocol::{Account, AccountTrait},
neo_wallets::{NEP6Account, Nep6Wallet, Wallet, WalletError, WalletTrait},
ScryptParamsDef,
};
use primitive_types::H160;
use tempfile::tempdir;
fn apply_fast_scrypt(wallet: &mut Wallet) {
wallet.set_scrypt_params(ScryptParamsDef { log_n: 10, r: 8, p: 1 });
}
#[test]
fn test_is_default() {
let account = Account::from_address(TestConstants::DEFAULT_ACCOUNT_ADDRESS)
.expect("Should be able to create account from valid address in test");
let mut wallet: Wallet = Wallet::new();
wallet.add_account(account.clone());
assert!(!account.is_default);
let hash = account.address_or_scripthash.script_hash();
wallet.set_default_account(hash);
assert!(wallet.get_account(&hash).expect("Account should exist in wallet").is_default);
}
#[test]
fn test_create_default_wallet() {
let wallet: Wallet = Wallet::default();
assert_eq!(&wallet.name, "NeoWallet");
assert_eq!(&wallet.version, Wallet::CURRENT_VERSION);
assert_eq!(wallet.accounts.len(), 0usize);
}
#[test]
fn test_try_new_creates_single_default_account() {
let wallet = Wallet::try_new().expect("Should create wallet with default account");
assert_eq!(wallet.accounts.len(), 1);
let account = wallet.default_account().expect("Wallet should have a default account");
assert!(account.is_default);
}
#[test]
#[should_panic(expected = "failed to create default wallet; use Wallet::try_new for fallible handling")]
fn test_new_panics_when_account_creation_fails() {
let _ = Wallet::try_new_with_account_factory(|| {
Err(ProviderError::CustomError("boom".to_string()))
})
.unwrap_or_else(|e| {
panic!("failed to create default wallet; use Wallet::try_new for fallible handling: {e}")
});
}
#[test]
fn test_create_wallet_with_accounts() {
let account1 = Account::create().expect("Should be able to create account in test");
let account2 = Account::create().expect("Should be able to create account in test");
let wallet = Wallet::from_accounts(vec![account1.clone(), account2.clone()])
.expect("Should be able to create wallet from accounts in test");
assert_eq!(wallet.default_account(), Some(&account1));
assert_eq!(wallet.accounts.len(), 2);
assert!(wallet
.accounts
.values()
.any(|a| a.get_script_hash() == account1.address_or_scripthash.script_hash()));
assert!(wallet
.accounts
.values()
.any(|a| a.get_script_hash() == account2.address_or_scripthash.script_hash()));
}
#[test]
fn test_from_account_keeps_only_supplied_account() {
let account = Account::create().expect("Should be able to create account in test");
let wallet =
Wallet::from_account(&account).expect("Should be able to create wallet from account");
assert_eq!(wallet.accounts.len(), 1);
assert_eq!(wallet.default_account(), Some(&account));
assert_eq!(wallet.get_account(&account.get_script_hash()), Some(&account));
}
#[test]
fn test_add_account_to_empty_wallet_sets_default_account() {
let account = Account::create().expect("Should be able to create account in test");
let mut wallet = Wallet::default();
wallet.add_account(account.clone());
assert_eq!(wallet.default_account(), Some(&account));
assert!(
wallet
.get_account(&account.get_script_hash())
.expect("Account should exist")
.is_default
);
}
#[test]
fn test_set_default_account_with_unknown_hash_leaves_no_default() {
let account1 = Account::create().expect("Should be able to create account in test");
let account2 = Account::create().expect("Should be able to create account in test");
let mut wallet = Wallet::from_accounts(vec![account1.clone(), account2.clone()])
.expect("Should be able to create wallet from accounts in test");
wallet.set_default_account(H160::repeat_byte(0xff));
assert_eq!(wallet.default_account(), None);
assert!(!wallet.accounts.values().any(|account| account.is_default));
}
#[test]
fn test_is_default_account() {
let account = Account::create().expect("Should be able to create account in test");
let wallet = Wallet::from_accounts(vec![account.clone()])
.expect("Should be able to create wallet from accounts in test");
assert_eq!(wallet.default_account, account.get_script_hash());
}
#[test]
fn test_add_account() {
let account = Account::create().expect("Should be able to create account in test");
let mut wallet: Wallet = Wallet::new();
wallet.add_account(account.clone());
assert_eq!(wallet.accounts.len(), 2);
assert_eq!(
wallet.get_account(&account.address_or_scripthash.script_hash()),
Some(&account)
);
}
#[test]
fn test_encrypt_wallet() {
let mut wallet: Wallet = Wallet::new();
apply_fast_scrypt(&mut wallet);
wallet.add_account(Account::create().expect("Should be able to create account in test"));
assert!(wallet.accounts()[0].key_pair().is_some());
assert!(wallet.accounts()[1].key_pair().is_some());
wallet.encrypt_accounts("pw");
assert!(wallet.accounts()[0].key_pair().is_none());
assert!(wallet.accounts()[1].key_pair().is_none());
}
#[test]
fn test_encrypt_wallet_parallel() {
let mut wallet: Wallet = Wallet::new();
apply_fast_scrypt(&mut wallet);
for _ in 0..5 {
wallet
.add_account(Account::create().expect("Should be able to create account in test"));
}
for account in wallet.accounts() {
assert!(account.key_pair().is_some());
}
wallet.encrypt_accounts_parallel("parallel_password");
for account in wallet.accounts() {
assert!(account.key_pair().is_none());
assert!(account.encrypted_private_key().is_some());
}
}
#[test]
fn test_encrypt_wallet_batch_parallel() {
let mut wallet: Wallet = Wallet::new();
apply_fast_scrypt(&mut wallet);
for _ in 0..10 {
wallet
.add_account(Account::create().expect("Should be able to create account in test"));
}
for account in wallet.accounts() {
assert!(account.key_pair().is_some());
}
wallet.encrypt_accounts_batch_parallel("batch_password", 3);
for account in wallet.accounts() {
assert!(account.key_pair().is_none());
assert!(account.encrypted_private_key().is_some());
}
}
#[test]
fn test_change_password_parallel() {
let mut wallet = Wallet::new();
apply_fast_scrypt(&mut wallet);
for _ in 0..5 {
wallet
.add_account(Account::create().expect("Should be able to create account in test"));
}
let old_password = "old_password";
let new_password = "new_password";
wallet.encrypt_accounts(old_password);
assert!(wallet.verify_password(old_password));
assert!(!wallet.verify_password(new_password));
wallet
.change_password_parallel(old_password, new_password)
.expect("Password change should succeed");
assert!(!wallet.verify_password(old_password));
assert!(wallet.verify_password(new_password));
}
#[test]
fn test_to_nep6_rejects_unencrypted_accounts_instead_of_dropping_them() {
let wallet = Wallet::try_new().expect("Should create wallet with default account");
let result = wallet.to_nep6();
assert!(matches!(
result,
Err(WalletError::AccountState(message))
if message.contains("not encrypted")
));
}
#[test]
fn test_save_to_file_rejects_unencrypted_wallet() {
let temp_dir = tempdir().expect("Should create temp dir");
let path = temp_dir.path().join("wallet.json");
let wallet = Wallet::try_new().expect("Should create wallet with default account");
let result = wallet.save_to_file(path.clone());
assert!(matches!(
result,
Err(WalletError::AccountState(message))
if message.contains("not encrypted")
));
assert!(!path.exists());
}
#[test]
fn test_from_nep6_rejects_empty_wallet() {
let nep6_wallet = Nep6Wallet::new(
"Empty".to_string(),
Wallet::CURRENT_VERSION.to_string(),
ScryptParamsDef::default(),
vec![],
None,
);
let err = Wallet::from_nep6(nep6_wallet).unwrap_err();
assert!(matches!(err, WalletError::NoAccounts));
}
#[test]
fn test_from_nep6_surfaces_invalid_account_errors() {
let nep6_wallet = Nep6Wallet::new(
"Invalid".to_string(),
Wallet::CURRENT_VERSION.to_string(),
ScryptParamsDef::default(),
vec![NEP6Account::new(String::new(), None, true, false, None, None, None)],
None,
);
let err = Wallet::from_nep6(nep6_wallet).unwrap_err();
assert!(matches!(
err,
WalletError::AccountState(message)
if message.contains("missing both address and verification script")
));
}
#[test]
fn test_create_wallet_creates_single_encrypted_default_account() {
let temp_dir = tempdir().expect("Should create temp dir");
let path = temp_dir.path().join("wallet.json");
let wallet =
Wallet::create_wallet(&path, "password123").expect("Should create wallet on disk");
assert!(path.exists());
assert_eq!(wallet.accounts.len(), 1);
let account = wallet.default_account().expect("Wallet should have a default account");
assert!(account.is_default);
assert!(account.key_pair().is_none());
assert!(account.encrypted_private_key().is_some());
}
#[test]
fn test_verify_password() {
let mut wallet = Wallet::new();
apply_fast_scrypt(&mut wallet);
let account = Account::create().unwrap();
wallet.add_account(account.clone());
assert!(!wallet.verify_password("password123"));
wallet.encrypt_accounts("password123");
assert!(wallet.verify_password("password123"));
assert!(!wallet.verify_password("wrong_password"));
}
#[test]
fn test_remove_default_account_promotes_deterministic_remaining_account() {
let account1 = Account::create().expect("Should be able to create account in test");
let account2 = Account::create().expect("Should be able to create account in test");
let account3 = Account::create().expect("Should be able to create account in test");
let mut wallet =
Wallet::from_accounts(vec![account1.clone(), account2.clone(), account3.clone()])
.expect("Should be able to create wallet from accounts in test");
let expected_hash = [account2.get_script_hash(), account3.get_script_hash()]
.into_iter()
.min_by_key(|hash| hash.to_fixed_bytes())
.expect("remaining accounts should not be empty");
assert!(wallet.remove_account(&account1.get_script_hash()));
assert_eq!(
wallet.default_account().map(|account| account.get_script_hash()),
Some(expected_hash)
);
assert!(wallet.get_account(&expected_hash).expect("Account should exist").is_default);
}
#[test]
fn test_remove_default_account_promotes_remaining_account() {
let account1 = Account::create().expect("Should be able to create account in test");
let account2 = Account::create().expect("Should be able to create account in test");
let mut wallet = Wallet::from_accounts(vec![account1.clone(), account2.clone()])
.expect("Should be able to create wallet from accounts in test");
assert!(wallet.remove_account(&account1.get_script_hash()));
assert_eq!(wallet.accounts.len(), 1);
assert_eq!(wallet.default_account(), Some(&account2));
assert!(
wallet
.get_account(&account2.get_script_hash())
.expect("Account should exist")
.is_default
);
}
}