use super::{Bip38Error, Result};
use crate::bip39::Mnemonic;
use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit, generic_array::GenericArray};
use rand::RngCore;
use unicode_normalization::UnicodeNormalization;
const DEFAULT_SALT: &str = "Thanks Satoshi!";
const DERIVE_PATH: &str = "m/0'/0'";
#[derive(Debug)]
struct MnemonicEx {
pub mnemonic: Mnemonic,
pub verify: Verify,
}
#[derive(Debug)]
enum Verify {
Word(usize),
Count(u8),
}
impl MnemonicEx {
pub fn desired_size(&self) -> usize {
match self.verify {
Verify::Word(i) => (8 - (i >> 8)) * 3,
Verify::Count(n) => n as usize,
}
}
pub fn verify_sum(&self) -> Option<u8> {
match self.verify {
Verify::Word(i) => Some((i & 0xff) as u8),
Verify::Count(_) => None,
}
}
}
impl std::ops::Deref for MnemonicEx {
type Target = Mnemonic;
fn deref(&self) -> &Self::Target {
&self.mnemonic
}
}
impl From<Mnemonic> for MnemonicEx {
fn from(mnemonic: Mnemonic) -> Self {
let verify = Verify::Count(mnemonic.size() as u8);
Self { mnemonic, verify }
}
}
impl std::fmt::Display for MnemonicEx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.mnemonic)?;
match self.verify {
Verify::Word(i) => {
if let Some(w) = self.language().word_at(i) {
write!(f, "; {w}")?;
}
}
Verify::Count(n) => {
if n as usize != self.size() {
write!(f, "; {n}")?;
}
}
}
Ok(())
}
}
impl std::str::FromStr for MnemonicEx {
type Err = Bip38Error;
fn from_str(s: &str) -> Result<Self> {
let count = s.split_whitespace().count();
if Mnemonic::VALID_SIZES.contains(&count) {
return Ok(s.parse::<Mnemonic>()?.into());
}
let Some((mnemonic_str, verify_str)) = s.rsplit_once(' ') else {
return Err(Bip38Error::InvalidKey);
};
let mnemonic: Mnemonic = mnemonic_str.trim_end_matches(';').parse()?;
let verify = if let Some(i) = mnemonic.language().index_of(verify_str)
&& i >> 8 < 5
{
Verify::Word(i)
} else if let Ok(n) = verify_str.parse::<u8>()
&& Mnemonic::VALID_SIZES.contains(&(n as usize))
{
Verify::Count(n)
} else {
return Err(Bip38Error::InvalidKey);
};
Ok(Self { mnemonic, verify })
}
}
trait Derivation {
fn derive_secret_key(passphrase: &str, salt: &[u8]) -> Result<[u8; 64]> {
let pass: String = passphrase.nfc().collect();
let argon_salt = {
let scrypt_salt = [DEFAULT_SALT.as_bytes(), salt].concat();
let params = scrypt::Params::new(20, 8, 8, 64)?;
let mut result = [0u8; 64];
scrypt::scrypt(pass.as_bytes(), &scrypt_salt, ¶ms, &mut result)?;
let (half1, half2) = result.split_at_mut(32);
half1[..32].xor(&half2[..32]);
half1[..32].to_vec()
};
let argon = argon2::Argon2::default();
let mut secret_key = [0u8; 64];
argon.hash_password_into(pass.as_bytes(), &argon_salt, &mut secret_key)?;
Ok(secret_key)
}
fn derive_path_address(mnemonic: &Mnemonic, path: &str) -> Result<String> {
use bitcoin::bip32::{DerivationPath, Xpriv};
use bitcoin::{Address, Network, secp256k1::Secp256k1};
use pbkdf2::pbkdf2_hmac;
let seed = {
let mnemonic = mnemonic.to_string();
let salt = format!("mnemonic{DEFAULT_SALT}").into_bytes();
let mut seed = [0u8; 64];
pbkdf2_hmac::<sha2::Sha512>(mnemonic.as_bytes(), &salt, u32::pow(2, 11), &mut seed);
seed
};
let root = Xpriv::new_master(Network::Bitcoin, &seed)?;
let address = {
let derive_path: DerivationPath = path.parse()?;
let xpriv = root.derive_priv(&Secp256k1::default(), &derive_path)?;
let pub_key = xpriv.to_priv().public_key(&Secp256k1::default());
Address::p2pkh(pub_key, Network::Bitcoin).to_string()
};
Ok(address)
}
}
trait Encryption: Derivation + Sized {
fn encrypt_extend(&self, passphrase: &str, salt: &[u8]) -> Result<Self>;
fn decrypt_extend(&self, passphrase: &str) -> Result<Self>;
}
impl Derivation for MnemonicEx {}
impl Encryption for MnemonicEx {
fn encrypt_extend(&self, passphrase: &str, salt: &[u8]) -> Result<Self> {
let result_bytes = self.size() / 3 * 4 + salt.len();
assert!(matches!(result_bytes, 16 | 20 | 24 | 28 | 32));
let secret_key = Self::derive_secret_key(passphrase, salt)?;
let (mask, aes_key) = secret_key.split_at(32);
let entropy = &mut self.entropy();
{
entropy.resize(32, 0);
entropy[..32].xor(&mask[..32]);
let (part1, part2) = entropy.split_at_mut(16);
let cipher = aes::Aes256::new(GenericArray::from_slice(aes_key));
cipher.encrypt_block(GenericArray::from_mut_slice(part1));
if self.size() == 24 {
cipher.encrypt_block(GenericArray::from_mut_slice(part2));
}
entropy.resize(self.size() / 3 * 4, 0);
entropy.extend_from_slice(salt);
}
let mnemonic = Mnemonic::new(entropy, self.language())?;
let verify = {
let address = Self::derive_path_address(self, DERIVE_PATH)?;
let checksum: u16 = address.as_bytes().sha256_n(2)[0] as u16;
let size_flag: u16 = 8 - (self.size() as u16 / 3); assert!(size_flag < 5);
let index = (size_flag << 8 | checksum) as usize;
Verify::Word(index)
};
Ok(MnemonicEx { mnemonic, verify })
}
fn decrypt_extend(&self, passphrase: &str) -> Result<Self> {
let result_bytes = self.desired_size() / 3 * 4;
assert!(matches!(result_bytes, 16 | 20 | 24 | 28 | 32));
let entropy = &mut self.entropy();
{
let salt: Vec<_> = entropy.drain(result_bytes..).collect();
let secret_key = Self::derive_secret_key(passphrase, &salt)?;
let (mask, aes_key) = secret_key.split_at(32);
entropy.resize(32, 0);
let (part1, part2) = entropy.split_at_mut(16);
let cipher = aes::Aes256::new(GenericArray::from_slice(aes_key));
cipher.decrypt_block(GenericArray::from_mut_slice(part1));
if result_bytes == 32 {
cipher.decrypt_block(GenericArray::from_mut_slice(part2));
}
entropy[..32].xor(&mask[..32]);
entropy.resize(result_bytes, 0);
}
let original = Mnemonic::new(entropy, self.language())?;
if let Some(checksum) = self.verify_sum() {
let address = Self::derive_path_address(&original, DERIVE_PATH)?;
if checksum != address.as_bytes().sha256_n(2)[0] {
return Err(Bip38Error::InvalidPass);
}
}
Ok(original.into())
}
}
pub trait MnemonicEncryption {
fn mnemonic_encrypt(&self, passphrase: &str) -> Result<String>;
fn mnemonic_decrypt(&self, passphrase: &str) -> Result<String>;
}
impl MnemonicEncryption for str {
fn mnemonic_encrypt(&self, passphrase: &str) -> Result<String> {
let original: MnemonicEx = self.parse()?;
if original.desired_size() < original.size() {
return Err(Bip38Error::InvalidSize);
}
let salt = &mut vec![0u8; (original.desired_size() - original.size()) / 3 * 4];
if !salt.is_empty() {
rand::thread_rng().fill_bytes(salt);
}
let mnemonic = original.encrypt_extend(passphrase, salt)?;
Ok(mnemonic.to_string())
}
fn mnemonic_decrypt(&self, passphrase: &str) -> Result<String> {
let mnemonic: MnemonicEx = self.parse()?;
if mnemonic.desired_size() > mnemonic.size() {
return Err(Bip38Error::InvalidSize);
}
let original = mnemonic.decrypt_extend(passphrase)?;
Ok(original.to_string())
}
}
trait ByteOperation {
fn sha256_n(&self, n: usize) -> [u8; 32];
fn xor(&mut self, other: &Self);
}
impl ByteOperation for [u8] {
#[inline(always)]
fn sha256_n(&self, n: usize) -> [u8; 32] {
use bitcoin::{hashes::Hash, hashes::sha256};
assert!(n > 0, "Cannot hash zero times");
let mut hash = sha256::Hash::hash(self).to_byte_array();
for _ in 1..n {
hash = sha256::Hash::hash(&hash).to_byte_array();
}
hash
}
#[inline(always)]
fn xor(&mut self, other: &Self) {
assert!(self.len() == other.len());
(0..self.len()).for_each(|i| self[i] ^= other[i]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mnemonic_encrypt() -> Result {
const TEST_DATA: &[&str] = &[
"派 贤 博 如 恐 臂 诺 职 畜 给 压 钱 牲 案 隔",
"坏 火 发 恐 晒 为 陕 伪 镜 锻 略 越 力 秦 音; 胞",
];
for data in TEST_DATA.chunks(2) {
assert_eq!(data[0].mnemonic_encrypt("123456")?, data[1]);
assert_eq!(data[1].mnemonic_decrypt("123456")?, data[0]);
let mnemonic = data[1].rsplit_once(';').unwrap().0;
assert_eq!(mnemonic.mnemonic_decrypt("123456")?, data[0]);
let mnemonic = data[1].replace(';', "");
assert_eq!(mnemonic.mnemonic_decrypt("123456")?, data[0]);
}
Ok(())
}
#[test]
fn test_mnemonic_extend() -> Result {
let data = "派 贤 博 如 恐 臂 诺 职 畜 给 压 钱 牲 案 隔";
let encrypted = format!("{data}; 24").mnemonic_encrypt("123456")?;
assert_eq!(encrypted.mnemonic_decrypt("123456")?, data);
let mnemonic = format!("{}; 15", encrypted.rsplit_once(';').unwrap().0);
assert_eq!(mnemonic.mnemonic_decrypt("123456")?, data);
println!("Encrypted: {encrypted}");
Ok(())
}
#[test]
fn test_mnemonic_full() -> Result {
let original = "生 别 斑 票 纤 费 普 描 比 销 柯 委 敲 普 伍 慰 思 人 曲 燥 恢 校 由 因";
let encrypted = original.mnemonic_encrypt("123456")?;
assert_eq!(encrypted.mnemonic_decrypt("123456")?, original);
println!("Encrypted: {encrypted}");
Ok(())
}
}