use crate::bip39::Mnemonic;
use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit, generic_array::GenericArray};
use rand::RngCore;
use unicode_normalization::UnicodeNormalization;
const DEFAULT_SALT: &str = "Thanks Satoshi!";
const DERIVE_PATH: &str = "m/0'/0'";
type Bip38Err = super::Bip38Error;
trait Derivation {
fn derive_secret_key(passphrase: &str, salt: &[u8]) -> Result<[u8; 64], Bip38Err> {
let pass: String = passphrase.nfc().collect();
let argon_salt = {
let scrypt_salt = [DEFAULT_SALT.as_bytes(), salt].concat();
let params = scrypt::Params::new(20, 8, 8, 64)?;
let mut result = [0u8; 64];
scrypt::scrypt(pass.as_bytes(), &scrypt_salt, ¶ms, &mut result)?;
let (half1, half2) = result.split_at_mut(32);
half1[..32].xor(&half2[..32]);
half1[..32].to_vec()
};
let argon = argon2::Argon2::default();
let mut secret_key = [0u8; 64];
argon.hash_password_into(pass.as_bytes(), &argon_salt, &mut secret_key)?;
Ok(secret_key)
}
fn derive_path_address(mnemonic: &Mnemonic, path: &str) -> Result<String, Bip38Err> {
use bitcoin::bip32::{DerivationPath, Xpriv};
use bitcoin::{Address, Network, secp256k1::Secp256k1};
use pbkdf2::pbkdf2_hmac;
let seed = {
let mnemonic = mnemonic.to_string();
let salt = format!("mnemonic{DEFAULT_SALT}").into_bytes();
let mut seed = [0u8; 64];
pbkdf2_hmac::<sha2::Sha512>(mnemonic.as_bytes(), &salt, u32::pow(2, 11), &mut seed);
seed
};
let root = Xpriv::new_master(Network::Bitcoin, &seed)?;
let address = {
let derive_path: DerivationPath = path.parse()?;
let xpriv = root.derive_priv(&Secp256k1::default(), &derive_path)?;
let pub_key = xpriv.to_priv().public_key(&Secp256k1::default());
Address::p2pkh(&pub_key, Network::Bitcoin).to_string()
};
Ok(address)
}
}
trait Encryption: Derivation + Sized {
fn encrypt_extend(&self, passphrase: &str, salt: &[u8]) -> Result<(Self, String), Bip38Err>;
fn decrypt_extend(&self, passphrase: &str, verify: &str) -> Result<Self, Bip38Err>;
}
impl Derivation for Mnemonic {}
impl Encryption for Mnemonic {
fn encrypt_extend(&self, passphrase: &str, salt: &[u8]) -> Result<(Self, String), Bip38Err> {
let result_bytes = self.count() / 3 * 4 + salt.len();
debug_assert!(matches!(result_bytes, 16 | 20 | 24 | 28 | 32));
let secret_key = Self::derive_secret_key(passphrase, salt)?;
let (mask, aes_key) = secret_key.split_at(32);
let entropy = &mut self.entropy();
{
entropy.resize(32, 0);
entropy[..32].xor(&mask[..32]);
let (part1, part2) = entropy.split_at_mut(16);
let cipher = aes::Aes256::new(GenericArray::from_slice(aes_key));
cipher.encrypt_block(GenericArray::from_mut_slice(part1));
if self.count() == 24 {
cipher.encrypt_block(GenericArray::from_mut_slice(part2));
}
entropy.resize(self.count() / 3 * 4, 0);
entropy.extend_from_slice(salt);
}
let new_mnemonic = Mnemonic::new(entropy, self.language())?;
let verify_word = {
let address = Self::derive_path_address(&self, DERIVE_PATH)?;
let checksum: u16 = address.as_bytes().sha256_n(2)[0] as u16;
let count_flag: u16 = 8 - self.count() as u16 / 3; let verify_idx = (count_flag << 8 | checksum) as usize;
debug_assert!(verify_idx < 2048);
self.language().word_at(verify_idx).unwrap().to_string()
};
Ok((new_mnemonic, verify_word))
}
fn decrypt_extend(&self, passphrase: &str, verify: &str) -> Result<Self, Bip38Err> {
let (result_bytes, checksum) = if verify.is_empty() {
(self.count() / 3 * 4, None)
} else if let Some(i) = self.language().index_of(verify)
&& i >> 8 < 5
{
((8 - (i >> 8)) * 4, Some((i & 0xff) as u8))
} else if let Ok(n) = u16::from_str_radix(verify, 10)
&& matches!(n, 12 | 15 | 18 | 21 | 24)
&& (n as usize) <= self.count()
{
(n as usize / 3 * 4, None)
} else {
return Err(Bip38Err::InvalidKey);
};
debug_assert!(matches!(result_bytes, 16 | 20 | 24 | 28 | 32));
let entropy = &mut self.entropy();
{
let salt: Vec<_> = entropy.drain(result_bytes..).collect();
let secret_key = Self::derive_secret_key(passphrase, &salt)?;
let (mask, aes_key) = secret_key.split_at(32);
entropy.resize(32, 0);
let (part1, part2) = entropy.split_at_mut(16);
let cipher = aes::Aes256::new(GenericArray::from_slice(aes_key));
cipher.decrypt_block(GenericArray::from_mut_slice(part1));
if result_bytes == 32 {
cipher.decrypt_block(GenericArray::from_mut_slice(part2));
}
entropy[..32].xor(&mask[..32]);
entropy.resize(result_bytes, 0);
}
let original = Mnemonic::new(entropy, self.language())?;
if checksum.is_some() {
let address = Self::derive_path_address(&original, DERIVE_PATH)?;
if checksum != Some(address.as_bytes().sha256_n(2)[0]) {
return Err(Bip38Err::InvalidPassphrase);
}
}
Ok(original)
}
}
pub trait MnemonicEncryption {
fn mnemonic_encrypt(&self, passphrase: &str, n: usize) -> Result<String, Bip38Err>;
fn mnemonic_decrypt(&self, passphrase: &str) -> Result<String, Bip38Err>;
}
impl MnemonicEncryption for str {
fn mnemonic_encrypt(&self, passphrase: &str, n: usize) -> Result<String, Bip38Err> {
let original: Mnemonic = self.parse()?;
let count = if n == 0 { original.count() } else { n };
if !matches!(count, 12 | 15 | 18 | 21 | 24) || count < original.count() {
return Err(Bip38Err::InvalidWordCount(count));
}
let salt = &mut vec![0u8; (count - original.count()) / 3 * 4];
if !salt.is_empty() {
rand::thread_rng().fill_bytes(salt);
}
let (mnemonic, verify) = original.encrypt_extend(passphrase, salt)?;
Ok(format!("{mnemonic}; {verify}"))
}
fn mnemonic_decrypt(&self, passphrase: &str) -> Result<String, Bip38Err> {
let word_count = self.split_whitespace().count();
if matches!(word_count, 12 | 15 | 18 | 21 | 24) {
let mnemonic: Mnemonic = self.parse()?;
let original = mnemonic.decrypt_extend(passphrase, "")?;
return Ok(original.to_string());
}
let Some((mnemonic_str, verify)) = self.rsplit_once(' ') else {
return Err(Bip38Err::InvalidKey);
};
let mnemonic: Mnemonic = mnemonic_str.trim_end_matches(';').parse()?;
let original = mnemonic.decrypt_extend(passphrase, verify)?;
Ok(original.to_string())
}
}
trait ByteOperation {
fn sha256_n(&self, n: usize) -> [u8; 32];
fn xor(&mut self, other: &Self);
}
impl ByteOperation for [u8] {
#[inline(always)]
fn sha256_n(&self, n: usize) -> [u8; 32] {
use bitcoin::{hashes::Hash, hashes::sha256};
assert!(n > 0, "Cannot hash zero times");
let mut hash = sha256::Hash::hash(self).to_byte_array();
for _ in 1..n {
hash = sha256::Hash::hash(&hash).to_byte_array();
}
hash
}
#[inline(always)]
fn xor(&mut self, other: &Self) {
debug_assert!(self.len() == other.len());
(0..self.len()).for_each(|i| self[i] ^= other[i]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mnemonic_encrypt() {
const TEST_DATA: &[&str] = &[
"派 贤 博 如 恐 臂 诺 职 畜 给 压 钱 牲 案 隔",
"坏 火 发 恐 晒 为 陕 伪 镜 锻 略 越 力 秦 音; 胞",
];
for data in TEST_DATA.chunks(2) {
assert_eq!(data[0].mnemonic_encrypt("123456", 0).unwrap(), data[1]);
assert_eq!(data[1].mnemonic_decrypt("123456").unwrap(), data[0]);
let mnemonic = data[1].rsplit_once(';').unwrap().0;
assert_eq!(mnemonic.mnemonic_decrypt("123456").unwrap(), data[0]);
let mnemonic = data[1].replace(';', "");
assert_eq!(mnemonic.mnemonic_decrypt("123456").unwrap(), data[0]);
}
}
#[test]
fn test_mnemonic_extend() {
let data = "派 贤 博 如 恐 臂 诺 职 畜 给 压 钱 牲 案 隔";
let encrypted = data.mnemonic_encrypt("123456", 24).unwrap();
assert_eq!(encrypted.mnemonic_decrypt("123456").unwrap(), data);
let mnemonic = format!("{}; 15", encrypted.rsplit_once(';').unwrap().0);
assert_eq!(mnemonic.mnemonic_decrypt("123456").unwrap(), data);
println!("Encrypted: {encrypted}");
}
#[test]
fn test_mnemonic_full() {
let original = "生 别 斑 票 纤 费 普 描 比 销 柯 委 敲 普 伍 慰 思 人 曲 燥 恢 校 由 因";
let encrypted = original.mnemonic_encrypt("123456", 0).unwrap();
assert_eq!(encrypted.mnemonic_decrypt("123456").unwrap(), original);
println!("Encrypted: {encrypted}");
}
}