bit_digger/
mnem_fetch.rs

1use std::{collections::HashSet, str::FromStr};
2
3use bip39::Mnemonic;
4
5const MIN_WORDS: usize = 12;
6const MAX_WORDS: usize = 24;
7
8fn is_invalid_word_count(word_count: usize) -> bool {
9    word_count < MIN_WORDS || word_count % 3 != 0 || word_count > MAX_WORDS
10}
11
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum MnemFetchError {
16    #[error("Invalid word count: {0}")]
17    InvalidWordCount(usize),
18}
19
20/// Struct used to fetch mnemonics from different sources
21///
22/// ## Fields
23/// `gen_mnemonics`: A vector of mnemonics that have been generated
24/// `wordlist`: A HashSet of all the words in the wordlist used to discover the mnemonics
25/// `word_ns`: A vector of the valid mnemonic lengths
26///
27/// ## Methods
28/// `new(lang: bip39::Language) -> Self`: Creates a new MnemFetcher with the given language
29/// `add_one(mnemonic: Mnemonic)`: Adds a single mnemonic to the internal collection
30/// `set_word_ns(word_ns: Vec<usize>)`: Sets the valid mnemonic lengths
31/// `add_from_words(words: &[&str]) -> &[Mnemonic]`: Creates mnemonics from the given words and adds them to the internal collection
32pub struct MnemFetcher<'a> {
33    pub gen_mnemonics: Vec<Mnemonic>,
34    wordlist: HashSet<&'a str>,
35    word_ns: Vec<usize>,
36}
37
38impl<'a> MnemFetcher<'a> {
39    pub fn new(lang: bip39::Language) -> Self {
40        MnemFetcher {
41            gen_mnemonics: Vec::new(),
42            wordlist: lang.word_list().into_iter().map(|w| *w).collect(),
43            word_ns: vec![MIN_WORDS, MAX_WORDS],
44        }
45    }
46
47    /// Just add one already created mnemonic
48    pub fn add_one(&mut self, mnemonic: Mnemonic) {
49        self.gen_mnemonics.push(mnemonic);
50    }
51
52    /// Set word_ns
53    ///
54    /// # Description
55    /// Sets the valid mnemonic lengths
56    ///
57    /// # Arguments
58    /// - `word_ns`: A vector of the valid mnemonic lengths
59    ///
60    /// # Returns
61    /// - Error with the first invalid word count
62    /// - Ok if all word counts are valid
63    ///
64    /// # Example
65    /// ```rust
66    /// use bit_digger::mnem_fetch::MnemFetcher;
67    /// let mut mf = MnemFetcher::new(bip39::Language::English);
68    /// mf.set_word_ns(vec![12, 15, 18, 21, 24]).unwrap();
69    /// ```
70    pub fn set_word_ns(&mut self, word_ns: Vec<usize>) -> Result<(), MnemFetchError> {
71        for wc in word_ns.iter() {
72            if is_invalid_word_count(*wc) {
73                return Err(MnemFetchError::InvalidWordCount(*wc));
74            }
75        }
76
77        self.word_ns = word_ns;
78
79        Ok(())
80    }
81
82    /// Create mnemonics from `words` and add them to internal collection
83    ///
84    /// # Description
85    /// Tries to create mnemonics from the given words.
86    /// - It will take a sequence of n `words` that can generate a valid mnemonic, where n is any of the valid mnemonic lengths, in the given order.
87    ///
88    /// # Arguments
89    /// - `words`: A list of words that should be used to generate the mnemonics.
90    ///
91    /// # Returns
92    /// - It will return a reference to the added mnemonics slice, if no mnemonic was generated the slice will be empty.
93    ///
94    /// # Example
95    /// ```rust
96    /// use bit_digger::mnem_fetch::MnemFetcher;
97    /// let mut mf = MnemFetcher::new(bip39::Language::English);
98    /// let invalid_words = vec![
99    ///    "abandon", "ability", "able", "about", "above", "absent", "absorb", "abstract",
100    ///   "absurd", "abuse", "access", "accident",
101    /// ];
102    /// let mnemonics = mf.add_from_words(&invalid_words);
103    /// assert_eq!(mnemonics.len(), 0);
104    /// ```
105    pub fn add_from_words(&mut self, words: &[&str]) -> &[Mnemonic] {
106        // let words = words.into_iter().filter(|w| self.wordlist.contains(w));
107
108        let valid_words = words
109            .iter()
110            .filter(|w| self.wordlist.contains(**w))
111            .map(|w| *w)
112            .collect::<Vec<&str>>();
113
114        let mut valid_words_str = String::new(); // String that contains all the valid words
115        let mut valid_words_ptr: Vec<usize> = vec![0; valid_words.len()]; // Pointer to the start of each (word )
116        //                                                                                                (^    )
117
118        // Construct the words String along with the pointers
119        for (i, w) in valid_words.iter().enumerate() {
120            valid_words_ptr[i] = valid_words_str.len();
121            valid_words_str.push_str(w);
122            valid_words_str.push_str(" ");
123        }
124        assert_eq!(valid_words.len(), valid_words_ptr.len());
125
126        let mut valid_mnemonics = vec![];
127
128        for wc in self.word_ns.iter() {
129            if *wc > valid_words_ptr.len() {
130                continue;
131            }
132            for start_at in 0..valid_words_ptr.len() - (wc - 1) {
133                MnemFetcher::window_check(
134                    &valid_words_str,
135                    &valid_words_ptr,
136                    start_at,
137                    *wc,
138                    &mut valid_mnemonics,
139                );
140            }
141        }
142
143        // Only keep unique mnemonics
144        valid_mnemonics.sort();
145        valid_mnemonics.dedup();
146
147        let vml = valid_mnemonics.len();
148
149        self.gen_mnemonics.extend(valid_mnemonics);
150
151        &self.gen_mnemonics[self.gen_mnemonics.len() - vml..]
152    }
153
154    /// Internal function to check wether a &str slice contains a valid mnemonic of `wc` words
155    fn window_check(
156        valid_words: &str,
157        valid_words_ptr: &[usize],
158        start_at: usize,
159        wc: usize,
160        valid_mnemonics: &mut Vec<Mnemonic>,
161    ) {
162        let start_index = valid_words_ptr[start_at];
163        let end_index = valid_words_ptr[start_at + wc - 1]
164            + valid_words[valid_words_ptr[start_at + wc - 1]..]
165                .find(" ")
166                .unwrap();
167
168        let mnemonic = Mnemonic::from_str(&valid_words[start_index..end_index]);
169
170        if mnemonic.is_ok() {
171            valid_mnemonics.push(mnemonic.unwrap());
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_is_invalid_word_count() {
182        assert_eq!(is_invalid_word_count(11), true);
183        assert_eq!(is_invalid_word_count(13), true);
184        assert_eq!(is_invalid_word_count(25), true);
185        assert_eq!(is_invalid_word_count(12), false);
186        assert_eq!(is_invalid_word_count(15), false);
187        assert_eq!(is_invalid_word_count(24), false);
188    }
189
190    const VALID_MNEMONIC: &str = "aware such neglect occur kick large parade crazy ceiling rain afraid mad canyon taxi group";
191
192    #[test]
193    fn test_mnem_fetch_add_one() {
194        let mut mf = MnemFetcher::new(bip39::Language::English);
195
196        let mnemonic = Mnemonic::from_str(VALID_MNEMONIC).unwrap();
197        mf.add_one(mnemonic);
198
199        assert_eq!(mf.gen_mnemonics.len(), 1);
200    }
201
202    #[test]
203    fn test_mnem_fetch_add_from_words() {
204        let mut mf = MnemFetcher::new(bip39::Language::English);
205        mf.set_word_ns(vec![12, 15, 18, 21, 24]).unwrap();
206
207        let binding = VALID_MNEMONIC.to_string();
208        let mut words = binding.split_whitespace().collect::<Vec<&str>>();
209
210        let mnemonics1 = mf.add_from_words(&words)[0].clone();
211
212        words.push("aaaa");
213        words.push("aaaa");
214        words.push("aaaa");
215        words.push("aaaa");
216        words.push("aaaa");
217        words.push("aaaa");
218
219        let mnemonics2 = mf.add_from_words(&words)[0].clone();
220
221        words.reverse();
222
223        words.push("aaaa");
224        words.push("aaaa");
225        words.push("aaaa");
226        words.push("aaaa");
227
228        words.reverse();
229
230        let mnemonics3 = mf.add_from_words(&words)[0].clone();
231
232        assert_eq!(mnemonics1, mnemonics2);
233        assert_eq!(mnemonics2, mnemonics3);
234    }
235
236    #[test]
237    fn test_mnem_fetch_add_from_words_invalid_mnemonic() {
238        let mut mf = MnemFetcher::new(bip39::Language::English);
239
240        let words = vec![
241            "abandon", "ability", "able", "about", "above", "absent", "absorb", "abstract",
242            "absurd", "abuse", "access", "accident",
243        ];
244
245        let mnemonics = mf.add_from_words(&words);
246
247        assert_eq!(mnemonics.len(), 0);
248
249        let binding = VALID_MNEMONIC.to_string();
250        let mut words = binding.split_whitespace().collect::<Vec<&str>>();
251        assert!(words.len() < 24); // Test makes no sense if we have 24 words
252        words.insert(words.len() / 2, "aaaaaa");
253
254        let mnemonics = mf.add_from_words(&words);
255        assert_eq!(mnemonics.len(), 0);
256    }
257
258    #[test]
259    fn test_mnem_fetch_add_from_words_bulk() {
260        let mut mf = MnemFetcher::new(bip39::Language::English);
261
262        let _mnemonics = mf.add_from_words(bip39::Language::English.word_list());
263        assert_eq!(mf.gen_mnemonics.len(), 137);
264    }
265}