1use super::bip39_wordlists;
12use super::error::CompatError;
13use crate::primitives::hash::{pbkdf2_hmac_sha512, sha256};
14use crate::primitives::random::random_bytes;
15
16#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum Language {
19 English,
20 Japanese,
21 Korean,
22 Spanish,
23 French,
24 Italian,
25 Czech,
26 ChineseSimplified,
27 ChineseTraditional,
28}
29
30fn get_wordlist(lang: Language) -> &'static [&'static str; 2048] {
32 match lang {
33 Language::English => bip39_wordlists::english::ENGLISH,
34 Language::Japanese => bip39_wordlists::japanese::JAPANESE,
35 Language::Korean => bip39_wordlists::korean::KOREAN,
36 Language::Spanish => bip39_wordlists::spanish::SPANISH,
37 Language::French => bip39_wordlists::french::FRENCH,
38 Language::Italian => bip39_wordlists::italian::ITALIAN,
39 Language::Czech => bip39_wordlists::czech::CZECH,
40 Language::ChineseSimplified => bip39_wordlists::chinese_simplified::CHINESE_SIMPLIFIED,
41 Language::ChineseTraditional => bip39_wordlists::chinese_traditional::CHINESE_TRADITIONAL,
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct Mnemonic {
48 words: Vec<String>,
49 entropy: Vec<u8>,
50 language: Language,
51}
52
53impl Mnemonic {
54 pub fn from_entropy(entropy: &[u8], language: Language) -> Result<Self, CompatError> {
59 let ent_bits = entropy.len() * 8;
60 if !(128..=256).contains(&ent_bits) || !ent_bits.is_multiple_of(32) {
61 return Err(CompatError::InvalidEntropy(format!(
62 "entropy must be 128-256 bits in 32-bit increments, got {} bits",
63 ent_bits
64 )));
65 }
66
67 let checksum_bits = ent_bits / 32;
68 let checksum = sha256(entropy);
69
70 let total_bits = ent_bits + checksum_bits;
72 let wordlist = get_wordlist(language);
73 let mut words = Vec::with_capacity(total_bits / 11);
74
75 for i in 0..(total_bits / 11) {
76 let mut index: u32 = 0;
77 for j in 0..11 {
78 let bit_pos = i * 11 + j;
79 let bit = if bit_pos < ent_bits {
80 (entropy[bit_pos / 8] >> (7 - (bit_pos % 8))) & 1
82 } else {
83 let cs_pos = bit_pos - ent_bits;
85 (checksum[cs_pos / 8] >> (7 - (cs_pos % 8))) & 1
86 };
87 index = (index << 1) | bit as u32;
88 }
89 words.push(wordlist[index as usize].to_string());
90 }
91
92 Ok(Mnemonic {
93 words,
94 entropy: entropy.to_vec(),
95 language,
96 })
97 }
98
99 pub fn from_random(bits: usize, language: Language) -> Result<Self, CompatError> {
103 if !(128..=256).contains(&bits) || !bits.is_multiple_of(32) {
104 return Err(CompatError::InvalidEntropy(format!(
105 "bits must be 128-256 in 32-bit increments, got {}",
106 bits
107 )));
108 }
109 let entropy = random_bytes(bits / 8);
110 Self::from_entropy(&entropy, language)
111 }
112
113 pub fn from_string(mnemonic: &str, language: Language) -> Result<Self, CompatError> {
118 let separator = if language == Language::Japanese {
119 "\u{3000}"
120 } else {
121 " "
122 };
123
124 let word_strs: Vec<&str> = mnemonic.split(separator).collect();
125 let word_count = word_strs.len();
126
127 if !(12..=24).contains(&word_count) || !word_count.is_multiple_of(3) {
129 return Err(CompatError::InvalidMnemonic(format!(
130 "invalid word count: {} (must be 12, 15, 18, 21, or 24)",
131 word_count
132 )));
133 }
134
135 let wordlist = get_wordlist(language);
136
137 let mut indices = Vec::with_capacity(word_count);
139 for word in &word_strs {
140 match wordlist.iter().position(|w| w == word) {
141 Some(idx) => indices.push(idx as u32),
142 None => {
143 return Err(CompatError::InvalidMnemonic(format!(
144 "word not in wordlist: {}",
145 word
146 )));
147 }
148 }
149 }
150
151 let total_bits = word_count * 11;
153 let ent_bits = (total_bits * 32) / 33; let checksum_bits = ent_bits / 32;
155 let ent_bytes = ent_bits / 8;
156
157 let mut bits_vec: Vec<u8> = Vec::with_capacity(total_bits);
159 for idx in &indices {
160 for j in (0..11).rev() {
161 bits_vec.push(((idx >> j) & 1) as u8);
162 }
163 }
164
165 let mut entropy = vec![0u8; ent_bytes];
167 for i in 0..ent_bits {
168 if bits_vec[i] == 1 {
169 entropy[i / 8] |= 1 << (7 - (i % 8));
170 }
171 }
172
173 let checksum = sha256(&entropy);
175 for i in 0..checksum_bits {
176 let expected_bit = (checksum[i / 8] >> (7 - (i % 8))) & 1;
177 let actual_bit = bits_vec[ent_bits + i];
178 if expected_bit != actual_bit {
179 return Err(CompatError::InvalidMnemonic(
180 "checksum mismatch".to_string(),
181 ));
182 }
183 }
184
185 Ok(Mnemonic {
186 words: word_strs.iter().map(|s| s.to_string()).collect(),
187 entropy,
188 language,
189 })
190 }
191
192 pub fn check(&self) -> bool {
194 let ent_bits = self.entropy.len() * 8;
195 let checksum_bits = ent_bits / 32;
196 let checksum = sha256(&self.entropy);
197
198 let wordlist = get_wordlist(self.language);
200 let total_bits = ent_bits + checksum_bits;
201
202 for i in 0..(total_bits / 11) {
203 let mut index: u32 = 0;
204 for j in 0..11 {
205 let bit_pos = i * 11 + j;
206 let bit = if bit_pos < ent_bits {
207 (self.entropy[bit_pos / 8] >> (7 - (bit_pos % 8))) & 1
208 } else {
209 let cs_pos = bit_pos - ent_bits;
210 (checksum[cs_pos / 8] >> (7 - (cs_pos % 8))) & 1
211 };
212 index = (index << 1) | bit as u32;
213 }
214 if self.words[i] != wordlist[index as usize] {
215 return false;
216 }
217 }
218
219 true
220 }
221
222 pub fn to_seed(&self, passphrase: &str) -> Vec<u8> {
227 let mnemonic_str = self.to_phrase();
228 let salt = format!("mnemonic{}", passphrase);
229 pbkdf2_hmac_sha512(mnemonic_str.as_bytes(), salt.as_bytes(), 2048, 64)
230 }
231
232 pub fn to_phrase(&self) -> String {
236 let separator = if self.language == Language::Japanese {
237 "\u{3000}"
238 } else {
239 " "
240 };
241 self.words.join(separator)
242 }
243
244 pub fn words(&self) -> &[String] {
246 &self.words
247 }
248
249 pub fn entropy(&self) -> &[u8] {
251 &self.entropy
252 }
253}
254
255impl std::fmt::Display for Mnemonic {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 write!(f, "{}", self.to_phrase())
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 fn hex_to_bytes(hex: &str) -> Vec<u8> {
266 (0..hex.len())
267 .step_by(2)
268 .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
269 .collect()
270 }
271
272 fn bytes_to_hex(bytes: &[u8]) -> String {
273 bytes.iter().map(|b| format!("{:02x}", b)).collect()
274 }
275
276 #[derive(serde::Deserialize)]
277 struct TestVector {
278 entropy: String,
279 mnemonic: String,
280 passphrase: String,
281 seed: String,
282 }
283
284 #[derive(serde::Deserialize)]
285 struct TestVectors {
286 vectors: Vec<TestVector>,
287 }
288
289 fn load_vectors() -> TestVectors {
290 let json = include_str!("../../test-vectors/bip39_vectors.json");
291 serde_json::from_str(json).expect("failed to parse BIP39 test vectors")
292 }
293
294 #[test]
296 fn test_from_entropy_128bit() {
297 let vectors = load_vectors();
298 let v = &vectors.vectors[0]; let entropy = hex_to_bytes(&v.entropy);
300 let m = Mnemonic::from_entropy(&entropy, Language::English).unwrap();
301 assert_eq!(m.to_string(), v.mnemonic);
302 assert_eq!(m.words().len(), 12);
303 }
304
305 #[test]
307 fn test_from_entropy_256bit() {
308 let vectors = load_vectors();
309 let v = &vectors.vectors[8]; let entropy = hex_to_bytes(&v.entropy);
311 let m = Mnemonic::from_entropy(&entropy, Language::English).unwrap();
312 assert_eq!(m.to_string(), v.mnemonic);
313 assert_eq!(m.words().len(), 24);
314 }
315
316 #[test]
318 fn test_to_seed_with_trezor_passphrase() {
319 let vectors = load_vectors();
320 let v = &vectors.vectors[0];
321 let m = Mnemonic::from_string(&v.mnemonic, Language::English).unwrap();
322 let seed = m.to_seed(&v.passphrase);
323 assert_eq!(bytes_to_hex(&seed), v.seed);
324 }
325
326 #[test]
328 fn test_to_seed_empty_passphrase() {
329 let m = Mnemonic::from_string(
331 "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
332 Language::English,
333 ).unwrap();
334 let seed = m.to_seed("");
335 assert_eq!(seed.len(), 64);
337 let trezor_seed = m.to_seed("TREZOR");
338 assert_ne!(seed, trezor_seed);
339 }
340
341 #[test]
343 fn test_check_valid() {
344 let m = Mnemonic::from_string(
345 "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
346 Language::English,
347 ).unwrap();
348 assert!(m.check());
349 }
350
351 #[test]
353 fn test_check_invalid_checksum() {
354 let result = Mnemonic::from_string(
356 "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon",
357 Language::English,
358 );
359 assert!(result.is_err());
360 }
361
362 #[test]
364 fn test_from_random_128() {
365 let m = Mnemonic::from_random(128, Language::English).unwrap();
366 assert_eq!(m.words().len(), 12);
367 assert!(m.check());
368 }
369
370 #[test]
372 fn test_from_random_256() {
373 let m = Mnemonic::from_random(256, Language::English).unwrap();
374 assert_eq!(m.words().len(), 24);
375 assert!(m.check());
376 }
377
378 #[test]
380 fn test_from_string_roundtrip() {
381 let mnemonic_str =
382 "legal winner thank year wave sausage worth useful legal winner thank yellow";
383 let m = Mnemonic::from_string(mnemonic_str, Language::English).unwrap();
384 assert_eq!(m.to_string(), mnemonic_str);
385 }
386
387 #[test]
389 fn test_all_vectors_entropy_to_mnemonic() {
390 let vectors = load_vectors();
391 for (i, v) in vectors.vectors.iter().enumerate() {
392 let entropy = hex_to_bytes(&v.entropy);
393 let m = Mnemonic::from_entropy(&entropy, Language::English).unwrap();
394 assert_eq!(m.to_string(), v.mnemonic, "Vector {} mnemonic mismatch", i);
395 }
396 }
397
398 #[test]
400 fn test_all_vectors_seed_derivation() {
401 let vectors = load_vectors();
402 for (i, v) in vectors.vectors.iter().enumerate() {
403 let m = Mnemonic::from_string(&v.mnemonic, Language::English).unwrap();
404 let seed = m.to_seed(&v.passphrase);
405 assert_eq!(bytes_to_hex(&seed), v.seed, "Vector {} seed mismatch", i);
406 }
407 }
408}