use std::io::{BufRead, Write};
use std::mem;
use std::slice::from_raw_parts_mut;
use byteorder::{LittleEndian, WriteBytesExt};
use failure::{err_msg, Error};
use ndarray::{Array2, Axis};
use crate::embeddings::Embeddings;
use crate::storage::{NdArray, Storage};
use crate::util::l2_normalize;
use crate::vocab::{SimpleVocab, Vocab};
pub trait ReadWord2Vec<R>
where
Self: Sized,
R: BufRead,
{
fn read_word2vec_binary(reader: &mut R, normalize: bool) -> Result<Self, Error>;
}
impl<R> ReadWord2Vec<R> for Embeddings<SimpleVocab, NdArray>
where
R: BufRead,
{
fn read_word2vec_binary(reader: &mut R, normalize: bool) -> Result<Self, Error> {
let n_words = read_number(reader, b' ')?;
let embed_len = read_number(reader, b'\n')?;
let mut matrix = Array2::zeros((n_words, embed_len));
let mut words = Vec::with_capacity(n_words);
for idx in 0..n_words {
let word = read_string(reader, b' ')?;
let word = word.trim();
words.push(word.to_owned());
let mut embedding = matrix.index_axis_mut(Axis(0), idx);
{
let mut embedding_raw = match embedding.as_slice_mut() {
Some(s) => unsafe { typed_to_bytes(s) },
None => return Err(err_msg("Matrix not contiguous")),
};
reader.read_exact(&mut embedding_raw)?;
}
}
if normalize {
for mut embedding in matrix.outer_iter_mut() {
l2_normalize(embedding.view_mut());
}
}
Ok(Embeddings::new(
None,
SimpleVocab::new(words),
NdArray(matrix),
))
}
}
fn read_number(reader: &mut BufRead, delim: u8) -> Result<usize, Error> {
let field_str = read_string(reader, delim)?;
Ok(field_str.parse()?)
}
fn read_string(reader: &mut BufRead, delim: u8) -> Result<String, Error> {
let mut buf = Vec::new();
reader.read_until(delim, &mut buf)?;
buf.pop();
Ok(String::from_utf8(buf)?)
}
unsafe fn typed_to_bytes<T>(slice: &mut [T]) -> &mut [u8] {
from_raw_parts_mut(
slice.as_mut_ptr() as *mut u8,
slice.len() * mem::size_of::<T>(),
)
}
pub trait WriteWord2Vec<W>
where
W: Write,
{
fn write_word2vec_binary(&self, w: &mut W) -> Result<(), Error>;
}
impl<W, V, S> WriteWord2Vec<W> for Embeddings<V, S>
where
W: Write,
V: Vocab,
S: Storage,
{
fn write_word2vec_binary(&self, w: &mut W) -> Result<(), Error>
where
W: Write,
{
writeln!(w, "{} {}", self.vocab().len(), self.dims())?;
for (word, embed) in self.iter() {
write!(w, "{} ", word)?;
for v in embed.as_view() {
w.write_f32::<LittleEndian>(*v)?;
}
w.write_all(&[0x0a])?;
}
Ok(())
}
}