use super::bip39_wordlists;
use super::error::CompatError;
use crate::primitives::hash::{pbkdf2_hmac_sha512, sha256};
use crate::primitives::random::random_bytes;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Language {
English,
Japanese,
Korean,
Spanish,
French,
Italian,
Czech,
ChineseSimplified,
ChineseTraditional,
}
fn get_wordlist(lang: Language) -> &'static [&'static str; 2048] {
match lang {
Language::English => bip39_wordlists::english::ENGLISH,
Language::Japanese => bip39_wordlists::japanese::JAPANESE,
Language::Korean => bip39_wordlists::korean::KOREAN,
Language::Spanish => bip39_wordlists::spanish::SPANISH,
Language::French => bip39_wordlists::french::FRENCH,
Language::Italian => bip39_wordlists::italian::ITALIAN,
Language::Czech => bip39_wordlists::czech::CZECH,
Language::ChineseSimplified => bip39_wordlists::chinese_simplified::CHINESE_SIMPLIFIED,
Language::ChineseTraditional => bip39_wordlists::chinese_traditional::CHINESE_TRADITIONAL,
}
}
#[derive(Debug, Clone)]
pub struct Mnemonic {
words: Vec<String>,
entropy: Vec<u8>,
language: Language,
}
impl Mnemonic {
pub fn from_entropy(entropy: &[u8], language: Language) -> Result<Self, CompatError> {
let ent_bits = entropy.len() * 8;
if !(128..=256).contains(&ent_bits) || !ent_bits.is_multiple_of(32) {
return Err(CompatError::InvalidEntropy(format!(
"entropy must be 128-256 bits in 32-bit increments, got {} bits",
ent_bits
)));
}
let checksum_bits = ent_bits / 32;
let checksum = sha256(entropy);
let total_bits = ent_bits + checksum_bits;
let wordlist = get_wordlist(language);
let mut words = Vec::with_capacity(total_bits / 11);
for i in 0..(total_bits / 11) {
let mut index: u32 = 0;
for j in 0..11 {
let bit_pos = i * 11 + j;
let bit = if bit_pos < ent_bits {
(entropy[bit_pos / 8] >> (7 - (bit_pos % 8))) & 1
} else {
let cs_pos = bit_pos - ent_bits;
(checksum[cs_pos / 8] >> (7 - (cs_pos % 8))) & 1
};
index = (index << 1) | bit as u32;
}
words.push(wordlist[index as usize].to_string());
}
Ok(Mnemonic {
words,
entropy: entropy.to_vec(),
language,
})
}
pub fn from_random(bits: usize, language: Language) -> Result<Self, CompatError> {
if !(128..=256).contains(&bits) || !bits.is_multiple_of(32) {
return Err(CompatError::InvalidEntropy(format!(
"bits must be 128-256 in 32-bit increments, got {}",
bits
)));
}
let entropy = random_bytes(bits / 8);
Self::from_entropy(&entropy, language)
}
pub fn from_string(mnemonic: &str, language: Language) -> Result<Self, CompatError> {
let separator = if language == Language::Japanese {
"\u{3000}"
} else {
" "
};
let word_strs: Vec<&str> = mnemonic.split(separator).collect();
let word_count = word_strs.len();
if !(12..=24).contains(&word_count) || !word_count.is_multiple_of(3) {
return Err(CompatError::InvalidMnemonic(format!(
"invalid word count: {} (must be 12, 15, 18, 21, or 24)",
word_count
)));
}
let wordlist = get_wordlist(language);
let mut indices = Vec::with_capacity(word_count);
for word in &word_strs {
match wordlist.iter().position(|w| w == word) {
Some(idx) => indices.push(idx as u32),
None => {
return Err(CompatError::InvalidMnemonic(format!(
"word not in wordlist: {}",
word
)));
}
}
}
let total_bits = word_count * 11;
let ent_bits = (total_bits * 32) / 33; let checksum_bits = ent_bits / 32;
let ent_bytes = ent_bits / 8;
let mut bits_vec: Vec<u8> = Vec::with_capacity(total_bits);
for idx in &indices {
for j in (0..11).rev() {
bits_vec.push(((idx >> j) & 1) as u8);
}
}
let mut entropy = vec![0u8; ent_bytes];
for i in 0..ent_bits {
if bits_vec[i] == 1 {
entropy[i / 8] |= 1 << (7 - (i % 8));
}
}
let checksum = sha256(&entropy);
for i in 0..checksum_bits {
let expected_bit = (checksum[i / 8] >> (7 - (i % 8))) & 1;
let actual_bit = bits_vec[ent_bits + i];
if expected_bit != actual_bit {
return Err(CompatError::InvalidMnemonic(
"checksum mismatch".to_string(),
));
}
}
Ok(Mnemonic {
words: word_strs.iter().map(|s| s.to_string()).collect(),
entropy,
language,
})
}
pub fn check(&self) -> bool {
let ent_bits = self.entropy.len() * 8;
let checksum_bits = ent_bits / 32;
let checksum = sha256(&self.entropy);
let wordlist = get_wordlist(self.language);
let total_bits = ent_bits + checksum_bits;
for i in 0..(total_bits / 11) {
let mut index: u32 = 0;
for j in 0..11 {
let bit_pos = i * 11 + j;
let bit = if bit_pos < ent_bits {
(self.entropy[bit_pos / 8] >> (7 - (bit_pos % 8))) & 1
} else {
let cs_pos = bit_pos - ent_bits;
(checksum[cs_pos / 8] >> (7 - (cs_pos % 8))) & 1
};
index = (index << 1) | bit as u32;
}
if self.words[i] != wordlist[index as usize] {
return false;
}
}
true
}
pub fn to_seed(&self, passphrase: &str) -> Vec<u8> {
let mnemonic_str = self.to_phrase();
let salt = format!("mnemonic{}", passphrase);
pbkdf2_hmac_sha512(mnemonic_str.as_bytes(), salt.as_bytes(), 2048, 64)
}
pub fn to_phrase(&self) -> String {
let separator = if self.language == Language::Japanese {
"\u{3000}"
} else {
" "
};
self.words.join(separator)
}
pub fn words(&self) -> &[String] {
&self.words
}
pub fn entropy(&self) -> &[u8] {
&self.entropy
}
}
impl std::fmt::Display for Mnemonic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_phrase())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn hex_to_bytes(hex: &str) -> Vec<u8> {
(0..hex.len())
.step_by(2)
.map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
.collect()
}
fn bytes_to_hex(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
#[derive(serde::Deserialize)]
struct TestVector {
entropy: String,
mnemonic: String,
passphrase: String,
seed: String,
}
#[derive(serde::Deserialize)]
struct TestVectors {
vectors: Vec<TestVector>,
}
fn load_vectors() -> TestVectors {
let json = include_str!("../../test-vectors/bip39_vectors.json");
serde_json::from_str(json).expect("failed to parse BIP39 test vectors")
}
#[test]
fn test_from_entropy_128bit() {
let vectors = load_vectors();
let v = &vectors.vectors[0]; let entropy = hex_to_bytes(&v.entropy);
let m = Mnemonic::from_entropy(&entropy, Language::English).unwrap();
assert_eq!(m.to_string(), v.mnemonic);
assert_eq!(m.words().len(), 12);
}
#[test]
fn test_from_entropy_256bit() {
let vectors = load_vectors();
let v = &vectors.vectors[8]; let entropy = hex_to_bytes(&v.entropy);
let m = Mnemonic::from_entropy(&entropy, Language::English).unwrap();
assert_eq!(m.to_string(), v.mnemonic);
assert_eq!(m.words().len(), 24);
}
#[test]
fn test_to_seed_with_trezor_passphrase() {
let vectors = load_vectors();
let v = &vectors.vectors[0];
let m = Mnemonic::from_string(&v.mnemonic, Language::English).unwrap();
let seed = m.to_seed(&v.passphrase);
assert_eq!(bytes_to_hex(&seed), v.seed);
}
#[test]
fn test_to_seed_empty_passphrase() {
let m = Mnemonic::from_string(
"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
Language::English,
).unwrap();
let seed = m.to_seed("");
assert_eq!(seed.len(), 64);
let trezor_seed = m.to_seed("TREZOR");
assert_ne!(seed, trezor_seed);
}
#[test]
fn test_check_valid() {
let m = Mnemonic::from_string(
"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
Language::English,
).unwrap();
assert!(m.check());
}
#[test]
fn test_check_invalid_checksum() {
let result = Mnemonic::from_string(
"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon",
Language::English,
);
assert!(result.is_err());
}
#[test]
fn test_from_random_128() {
let m = Mnemonic::from_random(128, Language::English).unwrap();
assert_eq!(m.words().len(), 12);
assert!(m.check());
}
#[test]
fn test_from_random_256() {
let m = Mnemonic::from_random(256, Language::English).unwrap();
assert_eq!(m.words().len(), 24);
assert!(m.check());
}
#[test]
fn test_from_string_roundtrip() {
let mnemonic_str =
"legal winner thank year wave sausage worth useful legal winner thank yellow";
let m = Mnemonic::from_string(mnemonic_str, Language::English).unwrap();
assert_eq!(m.to_string(), mnemonic_str);
}
#[test]
fn test_all_vectors_entropy_to_mnemonic() {
let vectors = load_vectors();
for (i, v) in vectors.vectors.iter().enumerate() {
let entropy = hex_to_bytes(&v.entropy);
let m = Mnemonic::from_entropy(&entropy, Language::English).unwrap();
assert_eq!(m.to_string(), v.mnemonic, "Vector {} mnemonic mismatch", i);
}
}
#[test]
fn test_all_vectors_seed_derivation() {
let vectors = load_vectors();
for (i, v) in vectors.vectors.iter().enumerate() {
let m = Mnemonic::from_string(&v.mnemonic, Language::English).unwrap();
let seed = m.to_seed(&v.passphrase);
assert_eq!(bytes_to_hex(&seed), v.seed, "Vector {} seed mismatch", i);
}
}
}