sentence2vec/
word2vec.rs

1#[cfg(feature = "loading")]
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4#[cfg(feature = "loading")]
5use std::path::Path;
6
7/// Word vector of dimension D.
8#[derive(Debug, Clone)]
9#[cfg_attr(feature = "loading", derive(Serialize, Deserialize))]
10pub struct WordVec<const D: usize> {
11    vec: Vec<f32>,
12}
13
14impl<const D: usize> WordVec<D> {
15    /// Create a new WordVec from a vector of dimension D.
16    pub fn new(vec: [f32; D]) -> Self {
17        Self { vec: vec.to_vec() }
18    }
19
20    /// Get the vector as a slice.
21    pub fn get_vec(&self) -> &[f32; D] {
22        // This is safe because we know the length of the vector.
23        self.vec.as_slice().try_into().unwrap()
24    }
25
26    /// Calculate the cosine similarity of two vectors.
27    pub fn cosine(&self, vec: &Self) -> f32 {
28        let mut dot = 0.0;
29        let mut norm1 = 0.0;
30        let mut norm2 = 0.0;
31
32        for (v1, v2) in self.vec.iter().zip(vec.vec.iter()) {
33            dot += v1 * v2;
34            norm1 += v1 * v1;
35            norm2 += v2 * v2;
36        }
37
38        if norm1 == 0.0 || norm2 == 0.0 {
39            eprintln!("Warning: the norm of a vector is 0.");
40        }
41
42        dot.powi(2) / (norm1 * norm2)
43    }
44}
45
46pub trait Word2VecTrait<const D: usize> {
47    fn get_vec(&self, word: &str) -> Option<&WordVec<D>>;
48}
49
50/// Word2Vec model.
51/// Contains a map from words to vectors.
52/// D is the dimension of the vectors.
53#[cfg_attr(feature = "loading", derive(Serialize, Deserialize))]
54pub struct Word2Vec<const D: usize> {
55    word_vecs: HashMap<String, WordVec<D>>,
56}
57
58impl<const D: usize> Word2Vec<D> {
59    /// Load the word2vec model from a text file with the format:
60    /// word1 0.1 0.2 0.3 ... 0.1
61    /// word1 is the word, and the rest are the vector of dimension D.
62    #[cfg(feature = "loading")]
63    pub async fn load_from_txt<P>(path: P) -> Option<Self>
64    where
65        P: AsRef<Path>,
66    {
67        let mut word_vecs = HashMap::new();
68        let lines = crate::file::read_lines(path);
69        if let Ok(lines) = lines {
70            for line in lines.skip(1).flatten() {
71                let mut iter = line.split_whitespace();
72                if let Some(word) = iter.next() {
73                    let vec = iter.flat_map(|s| s.parse::<f32>()).collect::<Vec<_>>();
74                    if vec.len() == D {
75                        // This is safe because we know the length of the vector.s
76                        word_vecs.insert(word.to_string(), WordVec::new(vec.try_into().unwrap()));
77                    } else {
78                        eprintln!("The vector of {} is not of dimension {}, so it wasn't insert.", word, D)
79                    }    
80                }
81                
82            }
83            Some(Self { word_vecs })
84        } else {
85            None
86        }
87        
88    }
89
90    /// Create a new Word2Vec from a map of words to vectors.
91    pub async fn from_word_vecs(word_vecs: HashMap<String, WordVec<D>>) -> Self {
92        Self { word_vecs }
93    }
94
95    /// Save the word2vec model to a binary file with custom serialization.
96    #[cfg(feature = "loading")]
97    pub async fn save_to_bytes<P>(&self, path: P) -> Result<(), Box<dyn std::error::Error>>
98    where
99        P: AsRef<Path>,
100    {
101        let mut bytes = Vec::new();
102        bincode::serialize_into(&mut bytes, &self)?;
103        std::fs::write(path, bytes)?;
104        Ok(())
105    }
106
107    /// Load the word2vec model from a binary file with custom serialization.
108    #[cfg(feature = "loading")]
109    pub async fn load_from_bytes<P>(path: P) -> Option<Self>
110    where
111        P: AsRef<Path>,
112    {
113        let bytes = std::fs::read(path).ok()?;
114        bincode::deserialize(&bytes).ok()
115    }
116
117    /// Calculate the cosine similarity of two words.
118    pub fn cosine(&self, word1: &str, word2: &str) -> Option<f32> {
119        let vec1 = self.get_vec(word1)?;
120        let vec2 = self.get_vec(word2)?;
121
122        Some(vec1.cosine(vec2))
123    }
124
125    #[cfg(feature = "partition")]
126    /// Partition the word2vec model into f folders for a total of n files.
127    /// The words are sorted alphabetically and distributed evenly.
128    /// Files and folders are named as the first word they contain.
129    pub async fn partition<P>(&self, dist: P, n: usize, f: usize) -> Result<(), Box<dyn std::error::Error>>
130    where
131        P: AsRef<Path>,
132    {
133        use log::{info, trace};
134
135        info!("Partitioning into {} folders and {} files", f, n);
136        let mut dist = dist.as_ref().to_path_buf();
137        dist.push("word2vec");
138        std::fs::create_dir_all(&dist)?;
139
140        // Sort the words alphabetically.
141        let mut words = self.word_vecs.keys().collect::<Vec<_>>();
142
143        info!("Sorting {} words", words.len());
144        words.sort();
145        info!("Done sorting");
146        // Calculate the number of words per file.
147        let words_per_file = words.len() / n;
148        let words_per_folder = words.len() / f;
149
150        // Create the folders.
151        let mut current_map: HashMap<String, WordVec<D>> = HashMap::new();
152        let mut current_folder = dist.clone();
153        for (i, word) in words.iter().enumerate() {
154            if i % words_per_folder == 0 {
155                current_folder = dist.clone();
156                current_folder.push(words[i]);
157                std::fs::create_dir_all(&current_folder)?;
158                trace!("Created folder {}", current_folder.display());
159            }
160
161            if let Some(vec) = self.get_vec(word) {
162                current_map.insert(word.to_string(), vec.clone());
163            }
164
165            if i % words_per_file == 0 || i == words.len() - 1 {
166                let mut file = current_folder.clone();
167                file.push(words[i]);
168                file.set_extension("bin");
169                let mut bytes = Vec::new();
170                bincode::serialize_into(&mut bytes, &current_map)?;
171                std::fs::write(file.clone(), bytes)?;
172                current_map.clear();
173                trace!("Created file {}", file.display());
174            }
175        }
176        Ok(())
177    }
178
179    /// Get the subset of words that are in the model from the list.
180    /// If a word is not in the model, it is ignored.
181    pub async fn get_subset(&self, words: &[String]) -> Word2Vec<D> {
182        let mut word_vecs = HashMap::new();
183        for word in words {
184            if let Some(vec) = self.get_vec(word) {
185                word_vecs.insert(word.to_string(), vec.clone());
186            }
187        }
188        Self { word_vecs }
189    }
190
191    /// Get the subset of words that are in the model from wordlist.txt.
192    /// If a word is not in the model, it is ignored.
193    #[cfg(feature = "loading")]
194    pub async fn get_subset_from_wordlist<P>(&self, path: P) -> Result<Word2Vec<D>, Box<dyn std::error::Error>>
195    where
196        P: AsRef<Path>,
197    {
198        let mut word_vecs = HashMap::new();
199        for word in crate::file::read_lines(path)?.flatten() {
200            if let Some(vec) = self.get_vec(&word) {
201                word_vecs.insert(word.to_string(), vec.clone());
202            }
203        }
204        Ok(Self { word_vecs })
205    }
206}
207
208impl<const D: usize> Word2VecTrait<D> for Word2Vec<D> {
209    /// Get the vector of a word.
210    fn get_vec(&self, word: &str) -> Option<&WordVec<D>> {
211        self.word_vecs.get(word)
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_wordvec() {
221        let vec1 = WordVec::new([1.0, 0.0]);
222        let vec2 = WordVec::new([1.0, 0.0]);
223        let vec3 = WordVec::new([0.0, 1.0]);
224
225        // Get the vector as a slice.
226        assert_eq!(vec1.get_vec(), &[1.0, 0.0]);
227
228        // Calculate the cosine similarity of two vectors.
229        assert_eq!(vec1.cosine(&vec2), 1.0);
230        assert_eq!(vec1.cosine(&vec3), 0.0);
231
232        // Test with norm 0.
233        let vec1 = WordVec::new([0.0, 0.0]);
234        assert!(vec1.cosine(&vec2).is_nan());
235
236    }
237
238    #[tokio::test]
239    async fn test_word2vec() {
240        let mut word_vecs = HashMap::new();
241        word_vecs.insert("word1".to_string(), WordVec::new([1.0, 0.0]));
242        word_vecs.insert("word2".to_string(), WordVec::new([0.0, 1.0]));
243        let word2vec = Word2Vec::from_word_vecs(word_vecs).await;
244
245        assert_eq!(word2vec.cosine("word1", "word2").unwrap(), 0.0);
246        assert_eq!(word2vec.cosine("word1", "word1").unwrap(), 1.0);
247    }
248
249    #[tokio::test]
250    async fn test_word2vec_subset() {
251        let mut word_vecs = HashMap::new();
252        word_vecs.insert("word1".to_string(), WordVec::new([1.0, 2.0, 3.0]));
253        word_vecs.insert("word2".to_string(), WordVec::new([1.0, 2.0, 4.0]));
254        let word2vec = Word2Vec::from_word_vecs(word_vecs).await;
255
256        let subset = word2vec.get_subset(&["word1".to_string(), "word3".to_string()]).await;
257
258        assert_eq!(subset.word_vecs.len(), 1);
259        assert_eq!(subset.word_vecs.get("word1").unwrap().get_vec(), &[1.0, 2.0, 3.0]);
260    }
261
262    #[cfg(feature = "loading")]
263    #[tokio::test]
264    async fn test_word2vec_load() {
265        let word2vec: Word2Vec<3> = Word2Vec::load_from_txt("tests/word2vec.txt").await.unwrap();
266
267        assert!(Word2Vec::<30>::load_from_txt("tests/word2v9+65ds6d5ec.txt").await.is_none());
268
269        assert_eq!(word2vec.word_vecs.len(), 5);
270
271        assert_eq!(word2vec.cosine("chien", "chat").unwrap(), 0.0);
272    }
273
274
275    #[cfg(feature = "loading")]
276    #[tokio::test]
277    async fn test_save_and_load_from_byte() {
278        let word2vec: Word2Vec<3> = Word2Vec::load_from_txt("tests/word2vec.txt").await.unwrap();
279        assert!(word2vec.save_to_bytes("tests/word2vec.bin").await.is_ok());
280        test_load_from_byte().await;
281    }
282
283    #[cfg(feature = "loading")]
284    async fn test_load_from_byte() {
285        let word2vec: Word2Vec<3> = Word2Vec::load_from_txt("tests/word2vec.txt").await.unwrap();
286        assert!(Word2Vec::<3>::load_from_bytes("tests/word2vec.bin").await.is_some());
287
288        // Check that the two models are the same.
289        let word2vec2: Word2Vec<3> = Word2Vec::load_from_bytes("tests/word2vec.bin").await.unwrap();
290        assert_eq!(word2vec.word_vecs.len(), word2vec2.word_vecs.len());
291
292        for (word, vec) in word2vec.word_vecs.iter() {
293            assert_eq!(word2vec2.word_vecs.get(word).unwrap().get_vec(), vec.get_vec());
294        }
295
296        // Wrong file.
297        assert!(Word2Vec::<3>::load_from_bytes("tests/word2vec.txt").await.is_none());
298    }
299
300}