use crate::primitives::hash::{pbkdf2_sha512, sha256};
use crate::{Error, Result};
use super::wordlists;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WordCount {
Words12,
Words15,
Words18,
Words21,
Words24,
}
impl WordCount {
pub fn entropy_bytes(self) -> usize {
match self {
WordCount::Words12 => 16, WordCount::Words15 => 20, WordCount::Words18 => 24, WordCount::Words21 => 28, WordCount::Words24 => 32, }
}
pub fn word_count(self) -> usize {
match self {
WordCount::Words12 => 12,
WordCount::Words15 => 15,
WordCount::Words18 => 18,
WordCount::Words21 => 21,
WordCount::Words24 => 24,
}
}
pub fn checksum_bits(self) -> usize {
self.entropy_bytes() * 8 / 32
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Language {
ChineseSimplified,
ChineseTraditional,
Czech,
#[default]
English,
French,
Italian,
Japanese,
Korean,
Spanish,
}
impl Language {
fn wordlist(&self) -> &'static [&'static str; 2048] {
match self {
Language::ChineseSimplified => &wordlists::CHINESE_SIMPLIFIED,
Language::ChineseTraditional => &wordlists::CHINESE_TRADITIONAL,
Language::Czech => &wordlists::CZECH,
Language::English => &wordlists::ENGLISH,
Language::French => &wordlists::FRENCH,
Language::Italian => &wordlists::ITALIAN,
Language::Japanese => &wordlists::JAPANESE,
Language::Korean => &wordlists::KOREAN,
Language::Spanish => &wordlists::SPANISH,
}
}
fn separator(&self) -> &'static str {
match self {
Language::Japanese => "\u{3000}",
_ => " ",
}
}
fn word_index(&self, word: &str) -> Option<usize> {
let wordlist = self.wordlist();
wordlist.iter().position(|&w| w == word)
}
}
#[derive(Debug, Clone)]
pub struct Mnemonic {
words: Vec<String>,
language: Language,
}
impl Mnemonic {
pub fn new(word_count: WordCount) -> Result<Self> {
let entropy_len = word_count.entropy_bytes();
let mut entropy = vec![0u8; entropy_len];
getrandom::getrandom(&mut entropy)
.map_err(|e| Error::CryptoError(format!("failed to generate entropy: {}", e)))?;
Self::from_entropy(&entropy)
}
pub fn from_entropy(entropy: &[u8]) -> Result<Self> {
Self::from_entropy_with_language(entropy, Language::English)
}
pub fn from_entropy_with_language(entropy: &[u8], language: Language) -> Result<Self> {
let entropy_bits = entropy.len() * 8;
#[allow(unknown_lints, clippy::manual_is_multiple_of)]
if !(128..=256).contains(&entropy_bits) || entropy_bits % 32 != 0 {
return Err(Error::InvalidEntropyLength {
expected: "128, 160, 192, 224, or 256 bits".to_string(),
actual: entropy_bits,
});
}
let hash = sha256(entropy);
let checksum_bits = entropy_bits / 32;
let total_bits = entropy_bits + checksum_bits;
let word_count = total_bits / 11;
let wordlist = language.wordlist();
let mut words = Vec::with_capacity(word_count);
for i in 0..word_count {
let bit_offset = i * 11;
let mut index: u16 = 0;
for bit in 0..11 {
let pos = bit_offset + bit;
let byte_idx = pos / 8;
let bit_idx = 7 - (pos % 8);
let byte = if byte_idx < entropy.len() {
entropy[byte_idx]
} else {
hash[byte_idx - entropy.len()]
};
if (byte >> bit_idx) & 1 == 1 {
index |= 1 << (10 - bit);
}
}
words.push(wordlist[index as usize].to_string());
}
Ok(Self { words, language })
}
pub fn from_phrase(phrase: &str) -> Result<Self> {
Self::from_phrase_with_language(phrase, Language::English)
}
pub fn from_phrase_with_language(phrase: &str, language: Language) -> Result<Self> {
let words: Vec<String> = phrase
.split_whitespace()
.map(|s| s.to_lowercase())
.collect();
let word_count = words.len();
if word_count != 12
&& word_count != 15
&& word_count != 18
&& word_count != 21
&& word_count != 24
{
return Err(Error::InvalidMnemonic(format!(
"invalid word count: {}, expected 12, 15, 18, 21, or 24",
word_count
)));
}
let mut indices = Vec::with_capacity(word_count);
for word in &words {
match language.word_index(word) {
Some(idx) => indices.push(idx),
None => return Err(Error::InvalidMnemonicWord(word.clone())),
}
}
let mnemonic = Self { words, language };
if !mnemonic.validate_checksum(&indices)? {
return Err(Error::InvalidMnemonic("invalid checksum".to_string()));
}
Ok(mnemonic)
}
fn validate_checksum(&self, indices: &[usize]) -> Result<bool> {
let word_count = self.words.len();
let total_bits = word_count * 11;
let checksum_bits = word_count / 3; let entropy_bits = total_bits - checksum_bits;
let entropy_bytes = entropy_bits / 8;
let mut bits = vec![false; total_bits];
for (i, &index) in indices.iter().enumerate() {
for bit in 0..11 {
bits[i * 11 + bit] = (index >> (10 - bit)) & 1 == 1;
}
}
let mut entropy = vec![0u8; entropy_bytes];
for (i, byte) in entropy.iter_mut().enumerate() {
for bit in 0..8 {
if bits[i * 8 + bit] {
*byte |= 1 << (7 - bit);
}
}
}
let hash = sha256(&entropy);
let mut expected_checksum_bits = vec![false; checksum_bits];
for (i, bit) in expected_checksum_bits.iter_mut().enumerate() {
let byte_idx = i / 8;
let bit_idx = 7 - (i % 8);
*bit = (hash[byte_idx] >> bit_idx) & 1 == 1;
}
let actual_checksum_bits = &bits[entropy_bits..];
Ok(actual_checksum_bits == expected_checksum_bits)
}
pub fn phrase(&self) -> String {
self.words.join(self.language.separator())
}
pub fn words(&self) -> &[String] {
&self.words
}
pub fn language(&self) -> Language {
self.language
}
pub fn to_seed(&self, passphrase: &str) -> [u8; 64] {
let mnemonic_normalized = self.phrase().nfkd_normalize();
let passphrase_normalized = passphrase.nfkd_normalize();
let salt = format!("mnemonic{}", passphrase_normalized);
let seed_vec = pbkdf2_sha512(mnemonic_normalized.as_bytes(), salt.as_bytes(), 2048, 64);
let mut seed = [0u8; 64];
seed.copy_from_slice(&seed_vec);
seed
}
pub fn to_seed_normalized(&self) -> [u8; 64] {
self.to_seed("")
}
pub fn entropy(&self) -> Vec<u8> {
self.extract_entropy_internal().0
}
pub fn entropy_with_checksum(&self) -> Vec<u8> {
let indices: Vec<usize> = self
.words
.iter()
.map(|w| self.language.word_index(w).unwrap_or(0))
.collect();
let word_count = self.words.len();
let total_bits = word_count * 11;
let checksum_bit_count = word_count / 3;
let entropy_bits = total_bits - checksum_bit_count;
let full_byte_size = (entropy_bits + checksum_bit_count).div_ceil(8);
let mut bits = vec![false; total_bits];
for (i, &index) in indices.iter().enumerate() {
for bit in 0..11 {
bits[i * 11 + bit] = (index >> (10 - bit)) & 1 == 1;
}
}
let mut result = vec![0u8; full_byte_size];
for (i, bit) in bits.iter().enumerate() {
if *bit {
result[i / 8] |= 1 << (7 - (i % 8));
}
}
result
}
fn extract_entropy_internal(&self) -> (Vec<u8>, usize) {
let word_count = self.words.len();
let total_bits = word_count * 11;
let checksum_bits = word_count / 3;
let entropy_bits = total_bits - checksum_bits;
let entropy_bytes = entropy_bits / 8;
let indices: Vec<usize> = self
.words
.iter()
.map(|w| self.language.word_index(w).unwrap_or(0))
.collect();
let mut bits = vec![false; total_bits];
for (i, &index) in indices.iter().enumerate() {
for bit in 0..11 {
bits[i * 11 + bit] = (index >> (10 - bit)) & 1 == 1;
}
}
let mut entropy = vec![0u8; entropy_bytes];
for (i, byte) in entropy.iter_mut().enumerate() {
for bit in 0..8 {
if bits[i * 8 + bit] {
*byte |= 1 << (7 - bit);
}
}
}
(entropy, checksum_bits)
}
pub fn is_valid(&self) -> bool {
let indices: Vec<usize> = self
.words
.iter()
.filter_map(|w| self.language.word_index(w))
.collect();
if indices.len() != self.words.len() {
return false;
}
self.validate_checksum(&indices).unwrap_or(false)
}
pub fn to_binary(&self) -> Vec<u8> {
self.phrase().into_bytes()
}
pub fn from_binary(data: &[u8]) -> Result<Self> {
let phrase = std::str::from_utf8(data)
.map_err(|e| Error::InvalidMnemonic(format!("invalid UTF-8: {}", e)))?;
Self::from_phrase(phrase)
}
}
impl std::fmt::Display for Mnemonic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.phrase())
}
}
trait NfkdNormalize {
fn nfkd_normalize(&self) -> String;
}
impl NfkdNormalize for str {
fn nfkd_normalize(&self) -> String {
self.to_string()
}
}
impl NfkdNormalize for String {
fn nfkd_normalize(&self) -> String {
self.as_str().nfkd_normalize()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_word_count_entropy_bytes() {
assert_eq!(WordCount::Words12.entropy_bytes(), 16);
assert_eq!(WordCount::Words15.entropy_bytes(), 20);
assert_eq!(WordCount::Words18.entropy_bytes(), 24);
assert_eq!(WordCount::Words21.entropy_bytes(), 28);
assert_eq!(WordCount::Words24.entropy_bytes(), 32);
}
#[test]
fn test_word_count_word_count() {
assert_eq!(WordCount::Words12.word_count(), 12);
assert_eq!(WordCount::Words15.word_count(), 15);
assert_eq!(WordCount::Words18.word_count(), 18);
assert_eq!(WordCount::Words21.word_count(), 21);
assert_eq!(WordCount::Words24.word_count(), 24);
}
#[test]
fn test_from_entropy_all_zeros_12_words() {
let entropy = [0u8; 16];
let mnemonic = Mnemonic::from_entropy(&entropy).unwrap();
assert_eq!(
mnemonic.phrase(),
"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
);
}
#[test]
fn test_from_entropy_all_ones_12_words() {
let entropy = [0xffu8; 16];
let mnemonic = Mnemonic::from_entropy(&entropy).unwrap();
assert_eq!(
mnemonic.phrase(),
"zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo wrong"
);
}
#[test]
fn test_from_phrase_valid() {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let mnemonic = Mnemonic::from_phrase(phrase).unwrap();
assert!(mnemonic.is_valid());
assert_eq!(mnemonic.words().len(), 12);
}
#[test]
fn test_from_phrase_invalid_word() {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon notaword";
let result = Mnemonic::from_phrase(phrase);
assert!(result.is_err());
}
#[test]
fn test_from_phrase_invalid_checksum() {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon";
let result = Mnemonic::from_phrase(phrase);
assert!(result.is_err());
}
#[test]
fn test_entropy_roundtrip() {
let original = hex::decode("7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f").unwrap();
let mnemonic = Mnemonic::from_entropy(&original).unwrap();
let extracted = mnemonic.entropy();
assert_eq!(original, extracted);
}
#[test]
fn test_to_seed_trezor() {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let mnemonic = Mnemonic::from_phrase(phrase).unwrap();
let seed = mnemonic.to_seed("TREZOR");
assert_eq!(
hex::encode(seed),
"c55257c360c07c72029aebc1b53c05ed0362ada38ead3e3e9efa3708e53495531f09a6987599d18264c1e1c92f2cf141630c7a3c4ab7c81b2f001698e7463b04"
);
}
#[test]
fn test_new_generates_valid_mnemonic() {
let mnemonic = Mnemonic::new(WordCount::Words12).unwrap();
assert!(mnemonic.is_valid());
assert_eq!(mnemonic.words().len(), 12);
let mnemonic = Mnemonic::new(WordCount::Words24).unwrap();
assert!(mnemonic.is_valid());
assert_eq!(mnemonic.words().len(), 24);
}
#[test]
fn test_invalid_entropy_length() {
let result = Mnemonic::from_entropy(&[0u8; 15]); assert!(result.is_err());
let result = Mnemonic::from_entropy(&[0u8; 33]); assert!(result.is_err());
}
#[test]
fn test_to_binary_from_binary_roundtrip() {
let entropy = [0u8; 16];
let mnemonic = Mnemonic::from_entropy(&entropy).unwrap();
let binary = mnemonic.to_binary();
let restored = Mnemonic::from_binary(&binary).unwrap();
assert_eq!(mnemonic.phrase(), restored.phrase());
}
#[test]
fn test_to_binary_format() {
let entropy = [0u8; 16];
let mnemonic = Mnemonic::from_entropy(&entropy).unwrap();
let binary = mnemonic.to_binary();
let phrase = mnemonic.phrase();
assert_eq!(binary, phrase.as_bytes());
}
#[test]
fn test_from_binary_valid() {
let phrase = b"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let mnemonic = Mnemonic::from_binary(phrase).unwrap();
assert!(mnemonic.is_valid());
assert_eq!(mnemonic.words().len(), 12);
}
#[test]
fn test_from_binary_invalid_utf8() {
let invalid_utf8 = [0xff, 0xfe, 0xfd];
let result = Mnemonic::from_binary(&invalid_utf8);
assert!(result.is_err());
}
#[test]
fn test_from_binary_invalid_phrase() {
let invalid_phrase = b"not a valid mnemonic phrase";
let result = Mnemonic::from_binary(invalid_phrase);
assert!(result.is_err());
}
#[test]
fn test_binary_roundtrip_all_word_counts() {
for word_count in [
WordCount::Words12,
WordCount::Words15,
WordCount::Words18,
WordCount::Words21,
WordCount::Words24,
] {
let mnemonic = Mnemonic::new(word_count).unwrap();
let binary = mnemonic.to_binary();
let restored = Mnemonic::from_binary(&binary).unwrap();
assert_eq!(mnemonic.phrase(), restored.phrase());
assert_eq!(
mnemonic.words().len(),
restored.words().len(),
"Word count mismatch for {:?}",
word_count
);
}
}
#[test]
fn test_generate_mnemonic_each_language() {
let languages = [
Language::ChineseSimplified,
Language::ChineseTraditional,
Language::Czech,
Language::English,
Language::French,
Language::Italian,
Language::Japanese,
Language::Korean,
Language::Spanish,
];
for lang in &languages {
let entropy = [0u8; 16]; let mnemonic = Mnemonic::from_entropy_with_language(&entropy, *lang).unwrap();
assert_eq!(
mnemonic.words().len(),
12,
"Expected 12 words for {:?}",
lang
);
assert_eq!(mnemonic.language(), *lang);
assert!(!mnemonic.phrase().is_empty(), "Empty phrase for {:?}", lang);
}
}
#[test]
fn test_roundtrip_seed_non_english() {
let languages = [
Language::ChineseSimplified,
Language::ChineseTraditional,
Language::Czech,
Language::French,
Language::Italian,
Language::Japanese,
Language::Korean,
Language::Spanish,
];
for lang in &languages {
let entropy = [0xABu8; 16]; let mnemonic = Mnemonic::from_entropy_with_language(&entropy, *lang).unwrap();
let seed1 = mnemonic.to_seed("");
assert_eq!(seed1.len(), 64, "Seed length wrong for {:?}", lang);
let seed2 = mnemonic.to_seed("test passphrase");
assert_eq!(seed2.len(), 64);
assert_ne!(
seed1, seed2,
"Seeds should differ with different passphrases for {:?}",
lang
);
let seed3 = mnemonic.to_seed("");
assert_eq!(seed1, seed3, "Seeds should be deterministic for {:?}", lang);
}
}
#[test]
fn test_japanese_mnemonic_with_ideographic_space() {
let entropy = [0u8; 16]; let mnemonic = Mnemonic::from_entropy_with_language(&entropy, Language::Japanese).unwrap();
let phrase = mnemonic.phrase();
assert!(
phrase.contains('\u{3000}'),
"Japanese mnemonic should use ideographic space separator"
);
assert!(
!phrase.contains(' '),
"Japanese mnemonic should not contain regular spaces"
);
let words: Vec<&str> = phrase.split('\u{3000}').collect();
assert_eq!(words.len(), 12);
let wordlist = &wordlists::JAPANESE;
for word in &words {
assert!(
wordlist.contains(word),
"Word '{}' not found in Japanese wordlist",
word
);
}
}
#[test]
fn test_language_word_count() {
assert_eq!(wordlists::CHINESE_SIMPLIFIED.len(), 2048);
assert_eq!(wordlists::CHINESE_TRADITIONAL.len(), 2048);
assert_eq!(wordlists::CZECH.len(), 2048);
assert_eq!(wordlists::ENGLISH.len(), 2048);
assert_eq!(wordlists::FRENCH.len(), 2048);
assert_eq!(wordlists::ITALIAN.len(), 2048);
assert_eq!(wordlists::JAPANESE.len(), 2048);
assert_eq!(wordlists::KOREAN.len(), 2048);
assert_eq!(wordlists::SPANISH.len(), 2048);
}
#[test]
fn test_language_default_is_english() {
let default_lang = Language::default();
assert_eq!(default_lang, Language::English);
}
#[test]
fn test_non_english_from_phrase_roundtrip() {
let languages = [
Language::ChineseSimplified,
Language::ChineseTraditional,
Language::Czech,
Language::French,
Language::Italian,
Language::Japanese,
Language::Korean,
Language::Spanish,
];
for lang in &languages {
let entropy = [0x42u8; 16];
let mnemonic = Mnemonic::from_entropy_with_language(&entropy, *lang).unwrap();
let phrase = mnemonic.phrase();
let restored = Mnemonic::from_phrase_with_language(&phrase, *lang).unwrap();
assert_eq!(
mnemonic.entropy(),
restored.entropy(),
"Entropy mismatch for {:?}",
lang
);
}
}
}