artimonist 1.6.0

A tool for generating mnemonics and wallets.
Documentation
use crate::bip39::Mnemonic;
use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit, generic_array::GenericArray};
use rand::RngCore;
use unicode_normalization::UnicodeNormalization;

const DEFAULT_SALT: &str = "Thanks Satoshi!";
const DERIVE_PATH: &str = "m/0'/0'";

type Bip38Err = super::Bip38Error;

trait Derivation {
    /// Derive a secret key from the passphrase and salt.
    fn derive_secret_key(passphrase: &str, salt: &[u8]) -> Result<[u8; 64], Bip38Err> {
        let pass: String = passphrase.nfc().collect();
        let argon_salt = {
            let scrypt_salt = [DEFAULT_SALT.as_bytes(), salt].concat();
            let params = scrypt::Params::new(20, 8, 8, 64)?;
            let mut result = [0u8; 64];
            scrypt::scrypt(pass.as_bytes(), &scrypt_salt, &params, &mut result)?;

            let (half1, half2) = result.split_at_mut(32);
            half1[..32].xor(&half2[..32]);
            half1[..32].to_vec()
        };
        let argon = argon2::Argon2::default();
        let mut secret_key = [0u8; 64];
        argon.hash_password_into(pass.as_bytes(), &argon_salt, &mut secret_key)?;
        Ok(secret_key)
    }

    /// Derive a Bitcoin address from the mnemonic and derivation path.
    fn derive_path_address(mnemonic: &Mnemonic, path: &str) -> Result<String, Bip38Err> {
        use bitcoin::bip32::{DerivationPath, Xpriv};
        use bitcoin::{Address, Network, secp256k1::Secp256k1};
        use pbkdf2::pbkdf2_hmac;

        let seed = {
            let mnemonic = mnemonic.to_string();
            let salt = format!("mnemonic{DEFAULT_SALT}").into_bytes();
            let mut seed = [0u8; 64];
            pbkdf2_hmac::<sha2::Sha512>(mnemonic.as_bytes(), &salt, u32::pow(2, 11), &mut seed);
            seed
        };
        let root = Xpriv::new_master(Network::Bitcoin, &seed)?;

        let address = {
            let derive_path: DerivationPath = path.parse()?;
            let xpriv = root.derive_priv(&Secp256k1::default(), &derive_path)?;
            let pub_key = xpriv.to_priv().public_key(&Secp256k1::default());
            Address::p2pkh(&pub_key, Network::Bitcoin).to_string()
        };
        Ok(address)
    }
}

trait Encryption: Derivation + Sized {
    /// Encrypt the mnemonic with a passphrase and salt,
    ///   returning the new mnemonic and a verify word.
    /// The salt is used to extend the mnemonic length,
    ///   and the verify word is used to verify the decryption.
    fn encrypt_extend(&self, passphrase: &str, salt: &[u8]) -> Result<(Self, String), Bip38Err>;

    /// Decrypt the mnemonic with a passphrase and verify word, returning the original mnemonic.
    /// If the verify word is empty, it will ignore the checksum.
    /// The verify word can be a word from the mnemonic language
    ///   or a count in the format "12", "15", "18", "21", or "24".
    fn decrypt_extend(&self, passphrase: &str, verify: &str) -> Result<Self, Bip38Err>;
}

impl Derivation for Mnemonic {}
impl Encryption for Mnemonic {
    fn encrypt_extend(&self, passphrase: &str, salt: &[u8]) -> Result<(Self, String), Bip38Err> {
        let result_bytes = self.count() / 3 * 4 + salt.len();
        debug_assert!(matches!(result_bytes, 16 | 20 | 24 | 28 | 32));

        let secret_key = Self::derive_secret_key(passphrase, salt)?;
        let (mask, aes_key) = secret_key.split_at(32);

        let entropy = &mut self.entropy();
        {
            entropy.resize(32, 0);
            entropy[..32].xor(&mask[..32]);
            let (part1, part2) = entropy.split_at_mut(16);

            let cipher = aes::Aes256::new(GenericArray::from_slice(aes_key));
            cipher.encrypt_block(GenericArray::from_mut_slice(part1));
            if self.count() == 24 {
                cipher.encrypt_block(GenericArray::from_mut_slice(part2));
            }

            entropy.resize(self.count() / 3 * 4, 0);
            entropy.extend_from_slice(salt);
        }

        let new_mnemonic = Mnemonic::new(entropy, self.language())?;
        let verify_word = {
            let address = Self::derive_path_address(&self, DERIVE_PATH)?;
            let checksum: u16 = address.as_bytes().sha256_n(2)[0] as u16;
            let count_flag: u16 = 8 - self.count() as u16 / 3; // 4 | 3 | 2 | 1 | 0
            let verify_idx = (count_flag << 8 | checksum) as usize;
            debug_assert!(verify_idx < 2048);
            self.language().word_at(verify_idx).unwrap().to_string()
        };
        Ok((new_mnemonic, verify_word))
    }

    fn decrypt_extend(&self, passphrase: &str, verify: &str) -> Result<Self, Bip38Err> {
        let (result_bytes, checksum) = if verify.is_empty() {
            (self.count() / 3 * 4, None)
        } else if let Some(i) = self.language().index_of(verify)
            && i >> 8 < 5
        {
            ((8 - (i >> 8)) * 4, Some((i & 0xff) as u8))
        } else if let Ok(n) = u16::from_str_radix(verify, 10)
            && matches!(n, 12 | 15 | 18 | 21 | 24)
            && (n as usize) <= self.count()
        {
            (n as usize / 3 * 4, None)
        } else {
            return Err(Bip38Err::InvalidKey);
        };
        debug_assert!(matches!(result_bytes, 16 | 20 | 24 | 28 | 32));

        let entropy = &mut self.entropy();
        {
            let salt: Vec<_> = entropy.drain(result_bytes..).collect();
            let secret_key = Self::derive_secret_key(passphrase, &salt)?;
            let (mask, aes_key) = secret_key.split_at(32);

            entropy.resize(32, 0);
            let (part1, part2) = entropy.split_at_mut(16);

            let cipher = aes::Aes256::new(GenericArray::from_slice(aes_key));
            cipher.decrypt_block(GenericArray::from_mut_slice(part1));
            if result_bytes == 32 {
                cipher.decrypt_block(GenericArray::from_mut_slice(part2));
            }
            entropy[..32].xor(&mask[..32]);
            entropy.resize(result_bytes, 0);
        }

        let original = Mnemonic::new(entropy, self.language())?;
        if checksum.is_some() {
            let address = Self::derive_path_address(&original, DERIVE_PATH)?;
            if checksum != Some(address.as_bytes().sha256_n(2)[0]) {
                return Err(Bip38Err::InvalidPassphrase);
            }
        }
        Ok(original)
    }
}

/// Mnemonic encryption and decryption with a passphrase.
/// # Reference:
///   <https://github.com/artimonist/disguise/blob/main/docs/mnemonic_encrypt.mmd>
pub trait MnemonicEncryption {
    /// Encrypt the mnemonic with a passphrase and desired word count.
    /// The word count must be one of 12, 15, 18, 21, or 24.
    /// The mnemonic will be extended with random words to match the desired count.
    /// Returns the new mnemonic and a verify word for decryption.
    fn mnemonic_encrypt(&self, passphrase: &str, n: usize) -> Result<String, Bip38Err>;

    /// Decrypt the mnemonic with a passphrase.
    /// If the mnemonic is encrypted with a verify word, it will be used to verify the decryption.
    fn mnemonic_decrypt(&self, passphrase: &str) -> Result<String, Bip38Err>;
}

impl MnemonicEncryption for str {
    /// Encrypt the mnemonic with a passphrase and desired word count.
    fn mnemonic_encrypt(&self, passphrase: &str, n: usize) -> Result<String, Bip38Err> {
        let original: Mnemonic = self.parse()?;

        // Validate the desired word count.
        let count = if n == 0 { original.count() } else { n };
        if !matches!(count, 12 | 15 | 18 | 21 | 24) || count < original.count() {
            return Err(Bip38Err::InvalidWordCount(count));
        }

        // Generate a random salt if the desired count is greater than the original.
        // The salt will be used to extend the mnemonic length.
        let salt = &mut vec![0u8; (count - original.count()) / 3 * 4];
        if !salt.is_empty() {
            rand::thread_rng().fill_bytes(salt);
        }

        let (mnemonic, verify) = original.encrypt_extend(passphrase, salt)?;
        Ok(format!("{mnemonic}; {verify}"))
    }

    /// Decrypt the mnemonic with a passphrase.
    fn mnemonic_decrypt(&self, passphrase: &str) -> Result<String, Bip38Err> {
        let word_count = self.split_whitespace().count();
        if matches!(word_count, 12 | 15 | 18 | 21 | 24) {
            // none verify word
            let mnemonic: Mnemonic = self.parse()?;
            let original = mnemonic.decrypt_extend(passphrase, "")?;
            return Ok(original.to_string());
        }

        // has verify word or desired count
        let Some((mnemonic_str, verify)) = self.rsplit_once(' ') else {
            return Err(Bip38Err::InvalidKey);
        };
        let mnemonic: Mnemonic = mnemonic_str.trim_end_matches(';').parse()?;
        let original = mnemonic.decrypt_extend(passphrase, verify)?;
        Ok(original.to_string())
    }
}

trait ByteOperation {
    fn sha256_n(&self, n: usize) -> [u8; 32];
    fn xor(&mut self, other: &Self);
}

impl ByteOperation for [u8] {
    #[inline(always)]
    fn sha256_n(&self, n: usize) -> [u8; 32] {
        use bitcoin::{hashes::Hash, hashes::sha256};
        assert!(n > 0, "Cannot hash zero times");

        let mut hash = sha256::Hash::hash(self).to_byte_array();
        for _ in 1..n {
            hash = sha256::Hash::hash(&hash).to_byte_array();
        }
        hash
    }

    #[inline(always)]
    fn xor(&mut self, other: &Self) {
        debug_assert!(self.len() == other.len());
        (0..self.len()).for_each(|i| self[i] ^= other[i]);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_mnemonic_encrypt() {
        const TEST_DATA: &[&str] = &[
            "派 贤 博 如 恐 臂 诺 职 畜 给 压 钱 牲 案 隔",
            "坏 火 发 恐 晒 为 陕 伪 镜 锻 略 越 力 秦 音; 胞",
        ];
        for data in TEST_DATA.chunks(2) {
            assert_eq!(data[0].mnemonic_encrypt("123456", 0).unwrap(), data[1]);
            assert_eq!(data[1].mnemonic_decrypt("123456").unwrap(), data[0]);

            let mnemonic = data[1].rsplit_once(';').unwrap().0;
            assert_eq!(mnemonic.mnemonic_decrypt("123456").unwrap(), data[0]);
            let mnemonic = data[1].replace(';', "");
            assert_eq!(mnemonic.mnemonic_decrypt("123456").unwrap(), data[0]);
        }
    }

    #[test]
    fn test_mnemonic_extend() {
        let data = "派 贤 博 如 恐 臂 诺 职 畜 给 压 钱 牲 案 隔";
        let encrypted = data.mnemonic_encrypt("123456", 24).unwrap();
        assert_eq!(encrypted.mnemonic_decrypt("123456").unwrap(), data);

        let mnemonic = format!("{}; 15", encrypted.rsplit_once(';').unwrap().0);
        assert_eq!(mnemonic.mnemonic_decrypt("123456").unwrap(), data);
        println!("Encrypted: {encrypted}");
    }

    #[test]
    fn test_mnemonic_full() {
        let original = "生 别 斑 票 纤 费 普 描 比 销 柯 委 敲 普 伍 慰 思 人 曲 燥 恢 校 由 因";
        let encrypted = original.mnemonic_encrypt("123456", 0).unwrap();
        assert_eq!(encrypted.mnemonic_decrypt("123456").unwrap(), original);
        println!("Encrypted: {encrypted}");
    }
}