rust2vec/
word2vec.rs

1//! Reader and writer for the word2vec binary format.
2//!
3//! Embeddings in the word2vec binary format are these formats are
4//! read as follows:
5//!
6//! ```
7//! use std::fs::File;
8//! use std::io::BufReader;
9//!
10//! use rust2vec::prelude::*;
11//!
12//! let mut reader = BufReader::new(File::open("testdata/similarity.bin").unwrap());
13//!
14//! // Read the embeddings. The second arguments specifies whether
15//! // the embeddings should be normalized to unit vectors.
16//! let embeddings = Embeddings::read_word2vec_binary(&mut reader, true)
17//!     .unwrap();
18//!
19//! // Look up an embedding.
20//! let embedding = embeddings.embedding("Berlin");
21//! ```
22
23use std::io::{BufRead, Write};
24use std::mem;
25use std::slice::from_raw_parts_mut;
26
27use byteorder::{LittleEndian, WriteBytesExt};
28use failure::{err_msg, Error};
29use ndarray::{Array2, Axis};
30
31use crate::embeddings::Embeddings;
32use crate::storage::{NdArray, Storage};
33use crate::util::l2_normalize;
34use crate::vocab::{SimpleVocab, Vocab};
35
36/// Method to construct `Embeddings` from a word2vec binary file.
37///
38/// This trait defines an extension to `Embeddings` to read the word embeddings
39/// from a file in word2vec binary format.
40pub trait ReadWord2Vec<R>
41where
42    Self: Sized,
43    R: BufRead,
44{
45    /// Read the embeddings from the given buffered reader.
46    fn read_word2vec_binary(reader: &mut R, normalize: bool) -> Result<Self, Error>;
47}
48
49impl<R> ReadWord2Vec<R> for Embeddings<SimpleVocab, NdArray>
50where
51    R: BufRead,
52{
53    fn read_word2vec_binary(reader: &mut R, normalize: bool) -> Result<Self, Error> {
54        let n_words = read_number(reader, b' ')?;
55        let embed_len = read_number(reader, b'\n')?;
56
57        let mut matrix = Array2::zeros((n_words, embed_len));
58        let mut words = Vec::with_capacity(n_words);
59
60        for idx in 0..n_words {
61            let word = read_string(reader, b' ')?;
62            let word = word.trim();
63            words.push(word.to_owned());
64
65            let mut embedding = matrix.index_axis_mut(Axis(0), idx);
66
67            {
68                let mut embedding_raw = match embedding.as_slice_mut() {
69                    Some(s) => unsafe { typed_to_bytes(s) },
70                    None => return Err(err_msg("Matrix not contiguous")),
71                };
72                reader.read_exact(&mut embedding_raw)?;
73            }
74        }
75
76        if normalize {
77            for mut embedding in matrix.outer_iter_mut() {
78                l2_normalize(embedding.view_mut());
79            }
80        }
81
82        Ok(Embeddings::new(
83            None,
84            SimpleVocab::new(words),
85            NdArray(matrix),
86        ))
87    }
88}
89
90fn read_number(reader: &mut BufRead, delim: u8) -> Result<usize, Error> {
91    let field_str = read_string(reader, delim)?;
92    Ok(field_str.parse()?)
93}
94
95fn read_string(reader: &mut BufRead, delim: u8) -> Result<String, Error> {
96    let mut buf = Vec::new();
97    reader.read_until(delim, &mut buf)?;
98    buf.pop();
99    Ok(String::from_utf8(buf)?)
100}
101
102unsafe fn typed_to_bytes<T>(slice: &mut [T]) -> &mut [u8] {
103    from_raw_parts_mut(
104        slice.as_mut_ptr() as *mut u8,
105        slice.len() * mem::size_of::<T>(),
106    )
107}
108
109/// Method to write `Embeddings` to a word2vec binary file.
110///
111/// This trait defines an extension to `Embeddings` to write the word embeddings
112/// to a file in word2vec binary format.
113pub trait WriteWord2Vec<W>
114where
115    W: Write,
116{
117    /// Write the embeddings from the given writer.
118    fn write_word2vec_binary(&self, w: &mut W) -> Result<(), Error>;
119}
120
121impl<W, V, S> WriteWord2Vec<W> for Embeddings<V, S>
122where
123    W: Write,
124    V: Vocab,
125    S: Storage,
126{
127    fn write_word2vec_binary(&self, w: &mut W) -> Result<(), Error>
128    where
129        W: Write,
130    {
131        writeln!(w, "{} {}", self.vocab().len(), self.dims())?;
132
133        for (word, embed) in self.iter() {
134            write!(w, "{} ", word)?;
135
136            // Write embedding to a vector with little-endian encoding.
137            for v in embed.as_view() {
138                w.write_f32::<LittleEndian>(*v)?;
139            }
140
141            w.write_all(&[0x0a])?;
142        }
143
144        Ok(())
145    }
146}