use crate::error::MnemonicError;
use crate::seed::Seed;
use crate::wordlist::Language;
use rand::RngCore;
use sha2::{Digest, Sha256};
use std::fmt;
use zeroize::{Zeroize, Zeroizing};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WordCount {
Words12 = 12,
Words15 = 15,
Words18 = 18,
Words21 = 21,
Words24 = 24,
}
impl WordCount {
pub fn entropy_bits(&self) -> usize {
match self {
WordCount::Words12 => 128,
WordCount::Words15 => 160,
WordCount::Words18 => 192,
WordCount::Words21 => 224,
WordCount::Words24 => 256,
}
}
pub fn entropy_bytes(&self) -> usize {
self.entropy_bits() / 8
}
pub fn checksum_bits(&self) -> usize {
self.entropy_bits() / 32
}
pub fn from_word_count(count: usize) -> Result<Self, MnemonicError> {
match count {
12 => Ok(WordCount::Words12),
15 => Ok(WordCount::Words15),
18 => Ok(WordCount::Words18),
21 => Ok(WordCount::Words21),
24 => Ok(WordCount::Words24),
_ => Err(MnemonicError::InvalidWordCount(count)),
}
}
}
pub struct Mnemonic {
words: Vec<String>,
entropy: Zeroizing<Vec<u8>>,
language: Language,
}
impl Mnemonic {
pub fn generate(word_count: WordCount) -> Self {
Self::generate_in(word_count, Language::default())
}
pub fn generate_in(word_count: WordCount, language: Language) -> Self {
Self::generate_with_language(word_count, language)
}
pub fn generate_with_language(word_count: WordCount, language: Language) -> Self {
let entropy_bytes = word_count.entropy_bytes();
let mut entropy = vec![0u8; entropy_bytes];
rand::rngs::OsRng.fill_bytes(&mut entropy);
Self::from_entropy_internal(&entropy, language)
.expect("Generated entropy should always be valid")
}
fn from_entropy_internal(entropy: &[u8], language: Language) -> Result<Self, MnemonicError> {
let entropy_bits = entropy.len() * 8;
let word_count = match entropy_bits {
128 => WordCount::Words12,
160 => WordCount::Words15,
192 => WordCount::Words18,
224 => WordCount::Words21,
256 => WordCount::Words24,
_ => {
return Err(MnemonicError::InvalidEntropyLength {
expected: 16, actual: entropy.len(),
})
}
};
let hash = Sha256::digest(entropy);
let checksum_bits = word_count.checksum_bits();
let mut bits = Vec::with_capacity(entropy_bits + checksum_bits);
for byte in entropy {
for i in (0..8).rev() {
bits.push((byte >> i) & 1 == 1);
}
}
for i in (0..checksum_bits).rev() {
let byte_idx = (checksum_bits - 1 - i) / 8;
let bit_idx = 7 - ((checksum_bits - 1 - i) % 8);
bits.push((hash[byte_idx] >> bit_idx) & 1 == 1);
}
let wordlist = language.wordlist();
let mut words = Vec::with_capacity(word_count as usize);
for chunk in bits.chunks(11) {
let mut index = 0usize;
for (i, &bit) in chunk.iter().enumerate() {
if bit {
index |= 1 << (10 - i);
}
}
words.push(wordlist[index].to_string());
}
Ok(Self {
words,
entropy: Zeroizing::new(entropy.to_vec()),
language,
})
}
pub fn from_phrase(phrase: &str) -> Result<Self, MnemonicError> {
Self::from_phrase_in(phrase, Language::default())
}
pub fn from_phrase_in(phrase: &str, language: Language) -> Result<Self, MnemonicError> {
let phrase = phrase.trim();
if phrase.is_empty() {
return Err(MnemonicError::EmptyPhrase);
}
let words: Vec<String> = phrase
.split_whitespace()
.map(|w| w.to_lowercase())
.collect();
let word_count = WordCount::from_word_count(words.len())?;
let mut indices = Vec::with_capacity(words.len());
for word in &words {
match language.get_index(word) {
Some(idx) => indices.push(idx),
None => return Err(MnemonicError::InvalidWord(word.clone())),
}
}
let total_bits = words.len() * 11;
let mut bits = Vec::with_capacity(total_bits);
for idx in &indices {
for i in (0..11).rev() {
bits.push((idx >> i) & 1 == 1);
}
}
let entropy_bits = word_count.entropy_bits();
let checksum_bits = word_count.checksum_bits();
let mut entropy = vec![0u8; entropy_bits / 8];
for (i, bit) in bits[..entropy_bits].iter().enumerate() {
if *bit {
entropy[i / 8] |= 1 << (7 - (i % 8));
}
}
let hash = Sha256::digest(&entropy);
for i in 0..checksum_bits {
let expected_bit = (hash[i / 8] >> (7 - (i % 8))) & 1 == 1;
let actual_bit = bits[entropy_bits + i];
if expected_bit != actual_bit {
return Err(MnemonicError::InvalidChecksum);
}
}
Ok(Self {
words,
entropy: Zeroizing::new(entropy),
language,
})
}
pub fn parse_auto_detect(phrase: &str) -> Result<Self, MnemonicError> {
let phrase = phrase.trim();
if phrase.is_empty() {
return Err(MnemonicError::EmptyPhrase);
}
match Language::detect_from_phrase(phrase) {
Some(language) => Self::from_phrase_in(phrase, language),
None => {
let words: Vec<&str> = phrase.split_whitespace().collect();
for word in &words {
let word_lower = word.to_lowercase();
if Language::detect_from_word(&word_lower).is_none() {
return Err(MnemonicError::InvalidWord(word_lower));
}
}
Err(MnemonicError::LanguageDetectionFailed)
}
}
}
pub fn validate(&self) -> Result<(), MnemonicError> {
Self::from_phrase_in(&self.to_phrase(), self.language)?;
Ok(())
}
pub fn to_seed(&self, passphrase: &str) -> Seed {
Seed::new(self, passphrase)
}
pub fn to_seed_normalized(&self) -> Seed {
self.to_seed("")
}
pub fn to_private_key(
&self,
passphrase: &str,
) -> Result<rustywallet_keys::private_key::PrivateKey, MnemonicError> {
self.to_seed(passphrase).to_private_key()
}
pub fn words(&self) -> &[String] {
&self.words
}
pub fn word_count(&self) -> WordCount {
WordCount::from_word_count(self.words.len()).expect("Mnemonic always has valid word count")
}
pub fn language(&self) -> Language {
self.language
}
pub fn to_phrase(&self) -> String {
self.words.join(" ")
}
pub fn entropy(&self) -> &[u8] {
&self.entropy
}
}
impl fmt::Display for Mnemonic {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_phrase())
}
}
impl fmt::Debug for Mnemonic {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Mnemonic(****)")
}
}
impl Drop for Mnemonic {
fn drop(&mut self) {
for word in &mut self.words {
word.zeroize();
}
}
}
impl Clone for Mnemonic {
fn clone(&self) -> Self {
Self {
words: self.words.clone(),
entropy: Zeroizing::new(self.entropy.to_vec()),
language: self.language,
}
}
}
impl std::str::FromStr for Mnemonic {
type Err = MnemonicError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_phrase(s)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_12_words() {
let mnemonic = Mnemonic::generate(WordCount::Words12);
assert_eq!(mnemonic.words().len(), 12);
assert!(mnemonic.validate().is_ok());
}
#[test]
fn test_generate_24_words() {
let mnemonic = Mnemonic::generate(WordCount::Words24);
assert_eq!(mnemonic.words().len(), 24);
assert!(mnemonic.validate().is_ok());
}
#[test]
fn test_parse_valid_mnemonic() {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let mnemonic = Mnemonic::from_phrase(phrase).unwrap();
assert_eq!(mnemonic.words().len(), 12);
assert_eq!(mnemonic.to_phrase(), phrase);
}
#[test]
fn test_parse_invalid_word() {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon invalid";
let result = Mnemonic::from_phrase(phrase);
assert!(matches!(result, Err(MnemonicError::InvalidWord(_))));
}
#[test]
fn test_parse_invalid_checksum() {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon";
let result = Mnemonic::from_phrase(phrase);
assert!(matches!(result, Err(MnemonicError::InvalidChecksum)));
}
#[test]
fn test_parse_invalid_word_count() {
let phrase = "abandon abandon abandon";
let result = Mnemonic::from_phrase(phrase);
assert!(matches!(result, Err(MnemonicError::InvalidWordCount(_))));
}
#[test]
fn test_case_insensitive() {
let phrase = "ABANDON Abandon ABANDON abandon ABANDON abandon ABANDON abandon ABANDON abandon ABANDON About";
let mnemonic = Mnemonic::from_phrase(phrase).unwrap();
assert_eq!(mnemonic.words().len(), 12);
}
#[test]
fn test_word_count_entropy() {
assert_eq!(WordCount::Words12.entropy_bits(), 128);
assert_eq!(WordCount::Words15.entropy_bits(), 160);
assert_eq!(WordCount::Words18.entropy_bits(), 192);
assert_eq!(WordCount::Words21.entropy_bits(), 224);
assert_eq!(WordCount::Words24.entropy_bits(), 256);
}
#[test]
fn test_debug_masked() {
let mnemonic = Mnemonic::generate(WordCount::Words12);
let debug = format!("{:?}", mnemonic);
assert_eq!(debug, "Mnemonic(****)");
assert!(!debug.contains("abandon"));
}
#[test]
fn test_generate_with_language_japanese() {
let mnemonic = Mnemonic::generate_with_language(WordCount::Words12, Language::Japanese);
assert_eq!(mnemonic.words().len(), 12);
assert_eq!(mnemonic.language(), Language::Japanese);
assert!(mnemonic.validate().is_ok());
for word in mnemonic.words() {
assert!(Language::Japanese.contains(word));
}
}
#[test]
fn test_generate_with_language_spanish() {
let mnemonic = Mnemonic::generate_with_language(WordCount::Words12, Language::Spanish);
assert_eq!(mnemonic.words().len(), 12);
assert_eq!(mnemonic.language(), Language::Spanish);
assert!(mnemonic.validate().is_ok());
}
#[test]
fn test_generate_with_language_chinese() {
let mnemonic = Mnemonic::generate_with_language(WordCount::Words12, Language::ChineseSimplified);
assert_eq!(mnemonic.words().len(), 12);
assert_eq!(mnemonic.language(), Language::ChineseSimplified);
assert!(mnemonic.validate().is_ok());
}
#[test]
fn test_generate_with_language_korean() {
let mnemonic = Mnemonic::generate_with_language(WordCount::Words12, Language::Korean);
assert_eq!(mnemonic.words().len(), 12);
assert_eq!(mnemonic.language(), Language::Korean);
assert!(mnemonic.validate().is_ok());
}
#[test]
fn test_parse_auto_detect_english() {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let mnemonic = Mnemonic::parse_auto_detect(phrase).unwrap();
assert_eq!(mnemonic.language(), Language::English);
assert_eq!(mnemonic.words().len(), 12);
}
#[test]
fn test_parse_auto_detect_japanese() {
let original = Mnemonic::generate_with_language(WordCount::Words12, Language::Japanese);
let phrase = original.to_phrase();
let parsed = Mnemonic::parse_auto_detect(&phrase).unwrap();
assert_eq!(parsed.language(), Language::Japanese);
assert_eq!(parsed.to_phrase(), phrase);
}
#[test]
fn test_parse_auto_detect_spanish() {
let original = Mnemonic::generate_with_language(WordCount::Words12, Language::Spanish);
let phrase = original.to_phrase();
let parsed = Mnemonic::parse_auto_detect(&phrase).unwrap();
assert_eq!(parsed.language(), Language::Spanish);
}
#[test]
fn test_parse_auto_detect_chinese() {
let original = Mnemonic::generate_with_language(WordCount::Words12, Language::ChineseSimplified);
let phrase = original.to_phrase();
let parsed = Mnemonic::parse_auto_detect(&phrase).unwrap();
assert_eq!(parsed.language(), Language::ChineseSimplified);
}
#[test]
fn test_parse_auto_detect_korean() {
let original = Mnemonic::generate_with_language(WordCount::Words12, Language::Korean);
let phrase = original.to_phrase();
let parsed = Mnemonic::parse_auto_detect(&phrase).unwrap();
assert_eq!(parsed.language(), Language::Korean);
}
#[test]
fn test_parse_auto_detect_invalid_word() {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon notaword";
let result = Mnemonic::parse_auto_detect(phrase);
assert!(matches!(result, Err(MnemonicError::InvalidWord(_))));
}
#[test]
fn test_parse_auto_detect_empty() {
let result = Mnemonic::parse_auto_detect("");
assert!(matches!(result, Err(MnemonicError::EmptyPhrase)));
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_generated_mnemonic_is_valid(word_count in prop_oneof![
Just(WordCount::Words12),
Just(WordCount::Words15),
Just(WordCount::Words18),
Just(WordCount::Words21),
Just(WordCount::Words24),
]) {
let mnemonic = Mnemonic::generate(word_count);
prop_assert!(mnemonic.validate().is_ok());
prop_assert_eq!(mnemonic.words().len(), word_count as usize);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_mnemonic_roundtrip(word_count in prop_oneof![
Just(WordCount::Words12),
Just(WordCount::Words15),
Just(WordCount::Words18),
Just(WordCount::Words21),
Just(WordCount::Words24),
]) {
let original = Mnemonic::generate(word_count);
let phrase = original.to_phrase();
let parsed = Mnemonic::from_phrase(&phrase).unwrap();
prop_assert_eq!(original.to_phrase(), parsed.to_phrase());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_seed_derivation_deterministic(passphrase in "[a-zA-Z0-9]{0,20}") {
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let mnemonic = Mnemonic::from_phrase(phrase).unwrap();
let seed1 = mnemonic.to_seed(&passphrase);
let seed2 = mnemonic.to_seed(&passphrase);
prop_assert_eq!(seed1.as_bytes(), seed2.as_bytes());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_different_passphrase_different_seed(
pass1 in "[a-zA-Z]{1,10}",
pass2 in "[a-zA-Z]{1,10}"
) {
prop_assume!(pass1 != pass2);
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
let mnemonic = Mnemonic::from_phrase(phrase).unwrap();
let seed1 = mnemonic.to_seed(&pass1);
let seed2 = mnemonic.to_seed(&pass2);
prop_assert_ne!(seed1.as_bytes(), seed2.as_bytes());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_word_count_consistency(word_count in prop_oneof![
Just(WordCount::Words12),
Just(WordCount::Words15),
Just(WordCount::Words18),
Just(WordCount::Words21),
Just(WordCount::Words24),
]) {
let mnemonic = Mnemonic::generate(word_count);
let phrase = mnemonic.to_phrase();
let words: Vec<&str> = phrase.split_whitespace().collect();
prop_assert_eq!(words.len(), word_count as usize);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_private_key_derivation_valid(word_count in prop_oneof![
Just(WordCount::Words12),
Just(WordCount::Words24),
]) {
let mnemonic = Mnemonic::generate(word_count);
let result = mnemonic.to_private_key("");
prop_assert!(result.is_ok());
}
}
}