1use std::{collections::HashSet, hash::BuildHasherDefault};
2
3use ahash::AHasher;
4use bip39::Mnemonic;
5
6pub type AHashBuilder = BuildHasherDefault<AHasher>;
7
8const MIN_WORDS: usize = 12;
9const MAX_WORDS: usize = 24;
10
11fn is_invalid_word_count(word_count: usize) -> bool {
12 word_count < MIN_WORDS || word_count % 3 != 0 || word_count > MAX_WORDS
13}
14
15use thiserror::Error;
16
17#[derive(Error, Debug)]
18pub enum MnemFetchError {
19 #[error("Invalid word count: {0}")]
20 InvalidWordCount(usize),
21}
22
23pub struct MnemFetcher<'a> {
36 pub gen_mnemonics: HashSet<bip39::Mnemonic, AHashBuilder>,
37 wordlist: HashSet<&'a str, AHashBuilder>,
38 word_ns: Vec<usize>,
39 lang: bip39::Language,
40}
41
42impl<'a> MnemFetcher<'a> {
43 pub fn new(lang: bip39::Language) -> Self {
44 let wordlist: HashSet<&'a str, AHashBuilder> =
45 lang.word_list().into_iter().map(|w| *w).collect();
46
47 MnemFetcher {
48 gen_mnemonics: HashSet::with_hasher(AHashBuilder::default()),
49 wordlist,
50 word_ns: vec![MIN_WORDS, MAX_WORDS],
51 lang,
52 }
53 }
54
55 pub fn add_one(&mut self, mnemonic: Mnemonic) {
57 self.gen_mnemonics.insert(mnemonic);
58 }
59
60 pub fn set_word_ns(&mut self, word_ns: Vec<usize>) -> Result<(), MnemFetchError> {
79 for wc in word_ns.iter() {
80 if is_invalid_word_count(*wc) {
81 return Err(MnemFetchError::InvalidWordCount(*wc));
82 }
83 }
84
85 self.word_ns = word_ns;
86
87 Ok(())
88 }
89
90 pub fn add_from_words(&mut self, words: &[&str]) -> usize {
114 let mut valid_words_str = String::with_capacity(words.len() * 10); let mut valid_words_start_ptr: Vec<usize> = Vec::with_capacity(words.len()); let mut valid_words_end_ptr: Vec<usize> = Vec::with_capacity(words.len()); self.gen_mnemonics.reserve(words.len() / 1000);
121
122 for w in words {
124 if !self.wordlist.contains(w) {
125 continue;
126 }
127
128 valid_words_start_ptr.push(valid_words_str.len());
129 valid_words_str.push_str(w);
130 valid_words_end_ptr.push(valid_words_str.len());
131 valid_words_str.push_str(" ");
132 }
133
134 let mut valid_mnemonics = 0;
135
136 for wc in self.word_ns.clone() {
137 if wc > valid_words_start_ptr.len() {
138 continue;
139 }
140 for start_at in 0..valid_words_start_ptr.len() - (wc - 1) {
141 if self.window_check(
142 &valid_words_str,
143 &valid_words_start_ptr,
144 &valid_words_end_ptr,
145 start_at,
146 wc,
147 ) {
148 valid_mnemonics += 1;
149 }
150 }
151 }
152
153 valid_mnemonics
154 }
155
156 fn window_check(
158 &mut self,
159 valid_words: &str,
160 valid_words_start_ptr: &[usize],
161 valid_words_end_ptr: &[usize],
162 start_at: usize,
163 wc: usize,
164 ) -> bool {
165 let start_index = valid_words_start_ptr[start_at];
166 let end_index = valid_words_end_ptr[start_at + wc - 1];
167
168 let mnemonic =
169 Mnemonic::parse_in_normalized(self.lang, &valid_words[start_index..end_index]);
170
171 if mnemonic.is_ok() {
172 self.gen_mnemonics.insert(mnemonic.unwrap());
173 return true;
174 }
175
176 return false;
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use std::str::FromStr;
183
184 use super::*;
185
186 #[test]
187 fn test_is_invalid_word_count() {
188 assert_eq!(is_invalid_word_count(11), true);
189 assert_eq!(is_invalid_word_count(13), true);
190 assert_eq!(is_invalid_word_count(25), true);
191 assert_eq!(is_invalid_word_count(12), false);
192 assert_eq!(is_invalid_word_count(15), false);
193 assert_eq!(is_invalid_word_count(24), false);
194 }
195
196 const VALID_MNEMONIC: &str = "aware such neglect occur kick large parade crazy ceiling rain afraid mad canyon taxi group";
197
198 #[test]
199 fn test_mnem_fetch_add_one() {
200 let mut mf = MnemFetcher::new(bip39::Language::English);
201
202 let mnemonic = Mnemonic::from_str(VALID_MNEMONIC).unwrap();
203 mf.add_one(mnemonic);
204
205 assert_eq!(mf.gen_mnemonics.len(), 1);
206 }
207
208 #[test]
209 fn test_mnem_fetch_add_from_words() {
210 let mut mf = MnemFetcher::new(bip39::Language::English);
211 mf.set_word_ns(vec![12, 15, 18, 21, 24]).unwrap();
212
213 let binding = VALID_MNEMONIC.to_string();
214 let mut words = binding.split_whitespace().collect::<Vec<&str>>();
215
216 let mnemonics1 = mf.add_from_words(&words);
217
218 words.push("aaaa");
219 words.push("aaaa");
220 words.push("aaaa");
221 words.push("aaaa");
222 words.push("aaaa");
223 words.push("aaaa");
224
225 let mnemonics2 = mf.add_from_words(&words);
226
227 words.reverse();
228
229 words.push("aaaa");
230 words.push("aaaa");
231 words.push("aaaa");
232 words.push("aaaa");
233
234 words.reverse();
235
236 let mnemonics3 = mf.add_from_words(&words);
237
238 assert_eq!(mnemonics1, mnemonics2);
239 assert_eq!(mnemonics2, mnemonics3);
240 }
241
242 #[test]
243 fn test_mnem_fetch_add_from_words_invalid_mnemonic() {
244 let mut mf = MnemFetcher::new(bip39::Language::English);
245
246 let words = vec![
247 "abandon", "ability", "able", "about", "above", "absent", "absorb", "abstract",
248 "absurd", "abuse", "access", "accident",
249 ];
250
251 let mnemonics = mf.add_from_words(&words);
252
253 assert_eq!(mnemonics, 0);
254
255 let binding = VALID_MNEMONIC.to_string();
256 let mut words = binding.split_whitespace().collect::<Vec<&str>>();
257 assert!(words.len() < 24); words.insert(words.len() / 2, "aaaaaa");
259
260 let mnemonics = mf.add_from_words(&words);
261 assert_eq!(mnemonics, 0);
262 }
263
264 #[test]
265 fn test_mnem_fetch_add_from_words_bulk() {
266 let mut mf = MnemFetcher::new(bip39::Language::English);
267
268 let _mnemonics = mf.add_from_words(bip39::Language::English.word_list());
269 assert_eq!(mf.gen_mnemonics.len(), 137);
270 }
271}