1use std::io::{BufRead, Write};
24use std::mem;
25use std::slice::from_raw_parts_mut;
26
27use byteorder::{LittleEndian, WriteBytesExt};
28use failure::{err_msg, Error};
29use ndarray::{Array2, Axis};
30
31use crate::embeddings::Embeddings;
32use crate::storage::{NdArray, Storage};
33use crate::util::l2_normalize;
34use crate::vocab::{SimpleVocab, Vocab};
35
36pub trait ReadWord2Vec<R>
41where
42 Self: Sized,
43 R: BufRead,
44{
45 fn read_word2vec_binary(reader: &mut R, normalize: bool) -> Result<Self, Error>;
47}
48
49impl<R> ReadWord2Vec<R> for Embeddings<SimpleVocab, NdArray>
50where
51 R: BufRead,
52{
53 fn read_word2vec_binary(reader: &mut R, normalize: bool) -> Result<Self, Error> {
54 let n_words = read_number(reader, b' ')?;
55 let embed_len = read_number(reader, b'\n')?;
56
57 let mut matrix = Array2::zeros((n_words, embed_len));
58 let mut words = Vec::with_capacity(n_words);
59
60 for idx in 0..n_words {
61 let word = read_string(reader, b' ')?;
62 let word = word.trim();
63 words.push(word.to_owned());
64
65 let mut embedding = matrix.index_axis_mut(Axis(0), idx);
66
67 {
68 let mut embedding_raw = match embedding.as_slice_mut() {
69 Some(s) => unsafe { typed_to_bytes(s) },
70 None => return Err(err_msg("Matrix not contiguous")),
71 };
72 reader.read_exact(&mut embedding_raw)?;
73 }
74 }
75
76 if normalize {
77 for mut embedding in matrix.outer_iter_mut() {
78 l2_normalize(embedding.view_mut());
79 }
80 }
81
82 Ok(Embeddings::new(
83 None,
84 SimpleVocab::new(words),
85 NdArray(matrix),
86 ))
87 }
88}
89
90fn read_number(reader: &mut BufRead, delim: u8) -> Result<usize, Error> {
91 let field_str = read_string(reader, delim)?;
92 Ok(field_str.parse()?)
93}
94
95fn read_string(reader: &mut BufRead, delim: u8) -> Result<String, Error> {
96 let mut buf = Vec::new();
97 reader.read_until(delim, &mut buf)?;
98 buf.pop();
99 Ok(String::from_utf8(buf)?)
100}
101
102unsafe fn typed_to_bytes<T>(slice: &mut [T]) -> &mut [u8] {
103 from_raw_parts_mut(
104 slice.as_mut_ptr() as *mut u8,
105 slice.len() * mem::size_of::<T>(),
106 )
107}
108
109pub trait WriteWord2Vec<W>
114where
115 W: Write,
116{
117 fn write_word2vec_binary(&self, w: &mut W) -> Result<(), Error>;
119}
120
121impl<W, V, S> WriteWord2Vec<W> for Embeddings<V, S>
122where
123 W: Write,
124 V: Vocab,
125 S: Storage,
126{
127 fn write_word2vec_binary(&self, w: &mut W) -> Result<(), Error>
128 where
129 W: Write,
130 {
131 writeln!(w, "{} {}", self.vocab().len(), self.dims())?;
132
133 for (word, embed) in self.iter() {
134 write!(w, "{} ", word)?;
135
136 for v in embed.as_view() {
138 w.write_f32::<LittleEndian>(*v)?;
139 }
140
141 w.write_all(&[0x0a])?;
142 }
143
144 Ok(())
145 }
146}