1#[cfg(feature = "loading")]
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4#[cfg(feature = "loading")]
5use std::path::Path;
6
7#[derive(Debug, Clone)]
9#[cfg_attr(feature = "loading", derive(Serialize, Deserialize))]
10pub struct WordVec<const D: usize> {
11 vec: Vec<f32>,
12}
13
14impl<const D: usize> WordVec<D> {
15 pub fn new(vec: [f32; D]) -> Self {
17 Self { vec: vec.to_vec() }
18 }
19
20 pub fn get_vec(&self) -> &[f32; D] {
22 self.vec.as_slice().try_into().unwrap()
24 }
25
26 pub fn cosine(&self, vec: &Self) -> f32 {
28 let mut dot = 0.0;
29 let mut norm1 = 0.0;
30 let mut norm2 = 0.0;
31
32 for (v1, v2) in self.vec.iter().zip(vec.vec.iter()) {
33 dot += v1 * v2;
34 norm1 += v1 * v1;
35 norm2 += v2 * v2;
36 }
37
38 if norm1 == 0.0 || norm2 == 0.0 {
39 eprintln!("Warning: the norm of a vector is 0.");
40 }
41
42 dot.powi(2) / (norm1 * norm2)
43 }
44}
45
46pub trait Word2VecTrait<const D: usize> {
47 fn get_vec(&self, word: &str) -> Option<&WordVec<D>>;
48}
49
50#[cfg_attr(feature = "loading", derive(Serialize, Deserialize))]
54pub struct Word2Vec<const D: usize> {
55 word_vecs: HashMap<String, WordVec<D>>,
56}
57
58impl<const D: usize> Word2Vec<D> {
59 #[cfg(feature = "loading")]
63 pub async fn load_from_txt<P>(path: P) -> Option<Self>
64 where
65 P: AsRef<Path>,
66 {
67 let mut word_vecs = HashMap::new();
68 let lines = crate::file::read_lines(path);
69 if let Ok(lines) = lines {
70 for line in lines.skip(1).flatten() {
71 let mut iter = line.split_whitespace();
72 if let Some(word) = iter.next() {
73 let vec = iter.flat_map(|s| s.parse::<f32>()).collect::<Vec<_>>();
74 if vec.len() == D {
75 word_vecs.insert(word.to_string(), WordVec::new(vec.try_into().unwrap()));
77 } else {
78 eprintln!("The vector of {} is not of dimension {}, so it wasn't insert.", word, D)
79 }
80 }
81
82 }
83 Some(Self { word_vecs })
84 } else {
85 None
86 }
87
88 }
89
90 pub async fn from_word_vecs(word_vecs: HashMap<String, WordVec<D>>) -> Self {
92 Self { word_vecs }
93 }
94
95 #[cfg(feature = "loading")]
97 pub async fn save_to_bytes<P>(&self, path: P) -> Result<(), Box<dyn std::error::Error>>
98 where
99 P: AsRef<Path>,
100 {
101 let mut bytes = Vec::new();
102 bincode::serialize_into(&mut bytes, &self)?;
103 std::fs::write(path, bytes)?;
104 Ok(())
105 }
106
107 #[cfg(feature = "loading")]
109 pub async fn load_from_bytes<P>(path: P) -> Option<Self>
110 where
111 P: AsRef<Path>,
112 {
113 let bytes = std::fs::read(path).ok()?;
114 bincode::deserialize(&bytes).ok()
115 }
116
117 pub fn cosine(&self, word1: &str, word2: &str) -> Option<f32> {
119 let vec1 = self.get_vec(word1)?;
120 let vec2 = self.get_vec(word2)?;
121
122 Some(vec1.cosine(vec2))
123 }
124
125 #[cfg(feature = "partition")]
126 pub async fn partition<P>(&self, dist: P, n: usize, f: usize) -> Result<(), Box<dyn std::error::Error>>
130 where
131 P: AsRef<Path>,
132 {
133 use log::{info, trace};
134
135 info!("Partitioning into {} folders and {} files", f, n);
136 let mut dist = dist.as_ref().to_path_buf();
137 dist.push("word2vec");
138 std::fs::create_dir_all(&dist)?;
139
140 let mut words = self.word_vecs.keys().collect::<Vec<_>>();
142
143 info!("Sorting {} words", words.len());
144 words.sort();
145 info!("Done sorting");
146 let words_per_file = words.len() / n;
148 let words_per_folder = words.len() / f;
149
150 let mut current_map: HashMap<String, WordVec<D>> = HashMap::new();
152 let mut current_folder = dist.clone();
153 for (i, word) in words.iter().enumerate() {
154 if i % words_per_folder == 0 {
155 current_folder = dist.clone();
156 current_folder.push(words[i]);
157 std::fs::create_dir_all(¤t_folder)?;
158 trace!("Created folder {}", current_folder.display());
159 }
160
161 if let Some(vec) = self.get_vec(word) {
162 current_map.insert(word.to_string(), vec.clone());
163 }
164
165 if i % words_per_file == 0 || i == words.len() - 1 {
166 let mut file = current_folder.clone();
167 file.push(words[i]);
168 file.set_extension("bin");
169 let mut bytes = Vec::new();
170 bincode::serialize_into(&mut bytes, ¤t_map)?;
171 std::fs::write(file.clone(), bytes)?;
172 current_map.clear();
173 trace!("Created file {}", file.display());
174 }
175 }
176 Ok(())
177 }
178
179 pub async fn get_subset(&self, words: &[String]) -> Word2Vec<D> {
182 let mut word_vecs = HashMap::new();
183 for word in words {
184 if let Some(vec) = self.get_vec(word) {
185 word_vecs.insert(word.to_string(), vec.clone());
186 }
187 }
188 Self { word_vecs }
189 }
190
191 #[cfg(feature = "loading")]
194 pub async fn get_subset_from_wordlist<P>(&self, path: P) -> Result<Word2Vec<D>, Box<dyn std::error::Error>>
195 where
196 P: AsRef<Path>,
197 {
198 let mut word_vecs = HashMap::new();
199 for word in crate::file::read_lines(path)?.flatten() {
200 if let Some(vec) = self.get_vec(&word) {
201 word_vecs.insert(word.to_string(), vec.clone());
202 }
203 }
204 Ok(Self { word_vecs })
205 }
206}
207
208impl<const D: usize> Word2VecTrait<D> for Word2Vec<D> {
209 fn get_vec(&self, word: &str) -> Option<&WordVec<D>> {
211 self.word_vecs.get(word)
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_wordvec() {
221 let vec1 = WordVec::new([1.0, 0.0]);
222 let vec2 = WordVec::new([1.0, 0.0]);
223 let vec3 = WordVec::new([0.0, 1.0]);
224
225 assert_eq!(vec1.get_vec(), &[1.0, 0.0]);
227
228 assert_eq!(vec1.cosine(&vec2), 1.0);
230 assert_eq!(vec1.cosine(&vec3), 0.0);
231
232 let vec1 = WordVec::new([0.0, 0.0]);
234 assert!(vec1.cosine(&vec2).is_nan());
235
236 }
237
238 #[tokio::test]
239 async fn test_word2vec() {
240 let mut word_vecs = HashMap::new();
241 word_vecs.insert("word1".to_string(), WordVec::new([1.0, 0.0]));
242 word_vecs.insert("word2".to_string(), WordVec::new([0.0, 1.0]));
243 let word2vec = Word2Vec::from_word_vecs(word_vecs).await;
244
245 assert_eq!(word2vec.cosine("word1", "word2").unwrap(), 0.0);
246 assert_eq!(word2vec.cosine("word1", "word1").unwrap(), 1.0);
247 }
248
249 #[tokio::test]
250 async fn test_word2vec_subset() {
251 let mut word_vecs = HashMap::new();
252 word_vecs.insert("word1".to_string(), WordVec::new([1.0, 2.0, 3.0]));
253 word_vecs.insert("word2".to_string(), WordVec::new([1.0, 2.0, 4.0]));
254 let word2vec = Word2Vec::from_word_vecs(word_vecs).await;
255
256 let subset = word2vec.get_subset(&["word1".to_string(), "word3".to_string()]).await;
257
258 assert_eq!(subset.word_vecs.len(), 1);
259 assert_eq!(subset.word_vecs.get("word1").unwrap().get_vec(), &[1.0, 2.0, 3.0]);
260 }
261
262 #[cfg(feature = "loading")]
263 #[tokio::test]
264 async fn test_word2vec_load() {
265 let word2vec: Word2Vec<3> = Word2Vec::load_from_txt("tests/word2vec.txt").await.unwrap();
266
267 assert!(Word2Vec::<30>::load_from_txt("tests/word2v9+65ds6d5ec.txt").await.is_none());
268
269 assert_eq!(word2vec.word_vecs.len(), 5);
270
271 assert_eq!(word2vec.cosine("chien", "chat").unwrap(), 0.0);
272 }
273
274
275 #[cfg(feature = "loading")]
276 #[tokio::test]
277 async fn test_save_and_load_from_byte() {
278 let word2vec: Word2Vec<3> = Word2Vec::load_from_txt("tests/word2vec.txt").await.unwrap();
279 assert!(word2vec.save_to_bytes("tests/word2vec.bin").await.is_ok());
280 test_load_from_byte().await;
281 }
282
283 #[cfg(feature = "loading")]
284 async fn test_load_from_byte() {
285 let word2vec: Word2Vec<3> = Word2Vec::load_from_txt("tests/word2vec.txt").await.unwrap();
286 assert!(Word2Vec::<3>::load_from_bytes("tests/word2vec.bin").await.is_some());
287
288 let word2vec2: Word2Vec<3> = Word2Vec::load_from_bytes("tests/word2vec.bin").await.unwrap();
290 assert_eq!(word2vec.word_vecs.len(), word2vec2.word_vecs.len());
291
292 for (word, vec) in word2vec.word_vecs.iter() {
293 assert_eq!(word2vec2.word_vecs.get(word).unwrap().get_vec(), vec.get_vec());
294 }
295
296 assert!(Word2Vec::<3>::load_from_bytes("tests/word2vec.txt").await.is_none());
298 }
299
300}