1use std::{collections::HashSet, str::FromStr};
2
3use bip39::Mnemonic;
4
5const MIN_WORDS: usize = 12;
6const MAX_WORDS: usize = 24;
7
8fn is_invalid_word_count(word_count: usize) -> bool {
9 word_count < MIN_WORDS || word_count % 3 != 0 || word_count > MAX_WORDS
10}
11
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum MnemFetchError {
16 #[error("Invalid word count: {0}")]
17 InvalidWordCount(usize),
18}
19
20pub struct MnemFetcher<'a> {
33 pub gen_mnemonics: Vec<Mnemonic>,
34 wordlist: HashSet<&'a str>,
35 word_ns: Vec<usize>,
36}
37
38impl<'a> MnemFetcher<'a> {
39 pub fn new(lang: bip39::Language) -> Self {
40 MnemFetcher {
41 gen_mnemonics: Vec::new(),
42 wordlist: lang.word_list().into_iter().map(|w| *w).collect(),
43 word_ns: vec![MIN_WORDS, MAX_WORDS],
44 }
45 }
46
47 pub fn add_one(&mut self, mnemonic: Mnemonic) {
49 self.gen_mnemonics.push(mnemonic);
50 }
51
52 pub fn set_word_ns(&mut self, word_ns: Vec<usize>) -> Result<(), MnemFetchError> {
71 for wc in word_ns.iter() {
72 if is_invalid_word_count(*wc) {
73 return Err(MnemFetchError::InvalidWordCount(*wc));
74 }
75 }
76
77 self.word_ns = word_ns;
78
79 Ok(())
80 }
81
82 pub fn add_from_words(&mut self, words: &[&str]) -> &[Mnemonic] {
106 let valid_words = words
109 .iter()
110 .filter(|w| self.wordlist.contains(**w))
111 .map(|w| *w)
112 .collect::<Vec<&str>>();
113
114 let mut valid_words_str = String::new(); let mut valid_words_ptr: Vec<usize> = vec![0; valid_words.len()]; for (i, w) in valid_words.iter().enumerate() {
120 valid_words_ptr[i] = valid_words_str.len();
121 valid_words_str.push_str(w);
122 valid_words_str.push_str(" ");
123 }
124 assert_eq!(valid_words.len(), valid_words_ptr.len());
125
126 let mut valid_mnemonics = vec![];
127
128 for wc in self.word_ns.iter() {
129 if *wc > valid_words_ptr.len() {
130 continue;
131 }
132 for start_at in 0..valid_words_ptr.len() - (wc - 1) {
133 MnemFetcher::window_check(
134 &valid_words_str,
135 &valid_words_ptr,
136 start_at,
137 *wc,
138 &mut valid_mnemonics,
139 );
140 }
141 }
142
143 valid_mnemonics.sort();
145 valid_mnemonics.dedup();
146
147 let vml = valid_mnemonics.len();
148
149 self.gen_mnemonics.extend(valid_mnemonics);
150
151 &self.gen_mnemonics[self.gen_mnemonics.len() - vml..]
152 }
153
154 fn window_check(
156 valid_words: &str,
157 valid_words_ptr: &[usize],
158 start_at: usize,
159 wc: usize,
160 valid_mnemonics: &mut Vec<Mnemonic>,
161 ) {
162 let start_index = valid_words_ptr[start_at];
163 let end_index = valid_words_ptr[start_at + wc - 1]
164 + valid_words[valid_words_ptr[start_at + wc - 1]..]
165 .find(" ")
166 .unwrap();
167
168 let mnemonic = Mnemonic::from_str(&valid_words[start_index..end_index]);
169
170 if mnemonic.is_ok() {
171 valid_mnemonics.push(mnemonic.unwrap());
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_is_invalid_word_count() {
182 assert_eq!(is_invalid_word_count(11), true);
183 assert_eq!(is_invalid_word_count(13), true);
184 assert_eq!(is_invalid_word_count(25), true);
185 assert_eq!(is_invalid_word_count(12), false);
186 assert_eq!(is_invalid_word_count(15), false);
187 assert_eq!(is_invalid_word_count(24), false);
188 }
189
190 const VALID_MNEMONIC: &str = "aware such neglect occur kick large parade crazy ceiling rain afraid mad canyon taxi group";
191
192 #[test]
193 fn test_mnem_fetch_add_one() {
194 let mut mf = MnemFetcher::new(bip39::Language::English);
195
196 let mnemonic = Mnemonic::from_str(VALID_MNEMONIC).unwrap();
197 mf.add_one(mnemonic);
198
199 assert_eq!(mf.gen_mnemonics.len(), 1);
200 }
201
202 #[test]
203 fn test_mnem_fetch_add_from_words() {
204 let mut mf = MnemFetcher::new(bip39::Language::English);
205 mf.set_word_ns(vec![12, 15, 18, 21, 24]).unwrap();
206
207 let binding = VALID_MNEMONIC.to_string();
208 let mut words = binding.split_whitespace().collect::<Vec<&str>>();
209
210 let mnemonics1 = mf.add_from_words(&words)[0].clone();
211
212 words.push("aaaa");
213 words.push("aaaa");
214 words.push("aaaa");
215 words.push("aaaa");
216 words.push("aaaa");
217 words.push("aaaa");
218
219 let mnemonics2 = mf.add_from_words(&words)[0].clone();
220
221 words.reverse();
222
223 words.push("aaaa");
224 words.push("aaaa");
225 words.push("aaaa");
226 words.push("aaaa");
227
228 words.reverse();
229
230 let mnemonics3 = mf.add_from_words(&words)[0].clone();
231
232 assert_eq!(mnemonics1, mnemonics2);
233 assert_eq!(mnemonics2, mnemonics3);
234 }
235
236 #[test]
237 fn test_mnem_fetch_add_from_words_invalid_mnemonic() {
238 let mut mf = MnemFetcher::new(bip39::Language::English);
239
240 let words = vec![
241 "abandon", "ability", "able", "about", "above", "absent", "absorb", "abstract",
242 "absurd", "abuse", "access", "accident",
243 ];
244
245 let mnemonics = mf.add_from_words(&words);
246
247 assert_eq!(mnemonics.len(), 0);
248
249 let binding = VALID_MNEMONIC.to_string();
250 let mut words = binding.split_whitespace().collect::<Vec<&str>>();
251 assert!(words.len() < 24); words.insert(words.len() / 2, "aaaaaa");
253
254 let mnemonics = mf.add_from_words(&words);
255 assert_eq!(mnemonics.len(), 0);
256 }
257
258 #[test]
259 fn test_mnem_fetch_add_from_words_bulk() {
260 let mut mf = MnemFetcher::new(bip39::Language::English);
261
262 let _mnemonics = mf.add_from_words(bip39::Language::English.word_list());
263 assert_eq!(mf.gen_mnemonics.len(), 137);
264 }
265}