use std::io::{BufRead, Write};
use failure::{ensure, err_msg, Error, ResultExt};
use itertools::Itertools;
use ndarray::Array2;
use crate::embeddings::Embeddings;
use crate::storage::{NdArray, Storage};
use crate::util::l2_normalize;
use crate::vocab::{SimpleVocab, Vocab};
pub trait ReadText<R>
where
Self: Sized,
R: BufRead,
{
fn read_text(reader: &mut R, normalize: bool) -> Result<Self, Error>;
}
impl<R> ReadText<R> for Embeddings<SimpleVocab, NdArray>
where
R: BufRead,
{
fn read_text(reader: &mut R, normalize: bool) -> Result<Self, Error> {
read_embeds(reader, None, normalize)
}
}
pub trait ReadTextDims<R>
where
Self: Sized,
R: BufRead,
{
fn read_text_dims(reader: &mut R, normalize: bool) -> Result<Self, Error>;
}
impl<R> ReadTextDims<R> for Embeddings<SimpleVocab, NdArray>
where
R: BufRead,
{
fn read_text_dims(reader: &mut R, normalize: bool) -> Result<Self, Error> {
let mut dims = String::new();
reader.read_line(&mut dims)?;
let mut dims_iter = dims.split_whitespace();
let vocab_len = dims_iter
.next()
.ok_or_else(|| failure::err_msg("Missing vocabulary size"))?
.parse::<usize>()
.context("Cannot parse vocabulary size")?;
let embed_len = dims_iter
.next()
.ok_or_else(|| failure::err_msg("Missing vocabulary size"))?
.parse::<usize>()
.context("Cannot parse vocabulary size")?;
read_embeds(reader, Some((vocab_len, embed_len)), normalize)
}
}
fn read_embeds<R>(
reader: &mut R,
shape: Option<(usize, usize)>,
normalize: bool,
) -> Result<Embeddings<SimpleVocab, NdArray>, Error>
where
R: BufRead,
{
let (mut words, mut data) = if let Some((n_words, dims)) = shape {
(
Vec::with_capacity(n_words),
Vec::with_capacity(n_words * dims),
)
} else {
(Vec::new(), Vec::new())
};
for line in reader.lines() {
let line = line?;
let mut parts = line.split_whitespace();
let word = parts.next().ok_or_else(|| err_msg("Empty line"))?.trim();
words.push(word.to_owned());
for part in parts {
data.push(part.parse()?);
}
}
let shape = if let Some((n_words, dims)) = shape {
ensure!(
words.len() == n_words,
"Expected {} words, got: {}",
n_words,
words.len()
);
ensure!(
data.len() / n_words == dims,
"Expected {} dimensions, got: {}",
dims,
data.len() / n_words
);
(n_words, dims)
} else {
let dims = data.len() / words.len();
(words.len(), dims)
};
ensure!(
data.len() % shape.1 == 0,
"Number of dimensions per vector is not constant"
);
let mut matrix = Array2::from_shape_vec(shape, data)?;
if normalize {
for mut embedding in matrix.outer_iter_mut() {
l2_normalize(embedding.view_mut());
}
}
Ok(Embeddings::new(
None,
SimpleVocab::new(words),
NdArray(matrix),
))
}
pub trait WriteText<W>
where
W: Write,
{
fn write_text(&self, writer: &mut W) -> Result<(), Error>;
}
impl<W, V, S> WriteText<W> for Embeddings<V, S>
where
W: Write,
V: Vocab,
S: Storage,
{
fn write_text(&self, write: &mut W) -> Result<(), Error> {
for (word, embed) in self.iter() {
let embed_str = embed.as_view().iter().map(ToString::to_string).join(" ");
writeln!(write, "{} {}", word, embed_str)?;
}
Ok(())
}
}
pub trait WriteTextDims<W>
where
W: Write,
{
fn write_text_dims(&self, writer: &mut W) -> Result<(), Error>;
}
impl<W, V, S> WriteTextDims<W> for Embeddings<V, S>
where
W: Write,
V: Vocab,
S: Storage,
{
fn write_text_dims(&self, write: &mut W) -> Result<(), Error> {
writeln!(write, "{} {}", self.vocab().len(), self.dims())?;
self.write_text(write)
}
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::{BufReader, Read, Seek, SeekFrom};
use crate::embeddings::Embeddings;
use crate::storage::{NdArray, StorageView};
use crate::vocab::{SimpleVocab, Vocab};
use crate::word2vec::ReadWord2Vec;
use super::{ReadText, ReadTextDims, WriteText, WriteTextDims};
fn read_word2vec() -> Embeddings<SimpleVocab, NdArray> {
let f = File::open("testdata/similarity.bin").unwrap();
let mut reader = BufReader::new(f);
Embeddings::read_word2vec_binary(&mut reader, false).unwrap()
}
#[test]
fn read_text() {
let f = File::open("testdata/similarity.nodims").unwrap();
let mut reader = BufReader::new(f);
let text_embeddings = Embeddings::read_text(&mut reader, false).unwrap();
let embeddings = read_word2vec();
assert_eq!(text_embeddings.vocab().words(), embeddings.vocab().words());
assert_eq!(
text_embeddings.storage().view(),
embeddings.storage().view()
);
}
#[test]
fn read_text_dims() {
let f = File::open("testdata/similarity.txt").unwrap();
let mut reader = BufReader::new(f);
let text_embeddings = Embeddings::read_text_dims(&mut reader, false).unwrap();
let embeddings = read_word2vec();
assert_eq!(text_embeddings.vocab().words(), embeddings.vocab().words());
assert_eq!(
text_embeddings.storage().view(),
embeddings.storage().view()
);
}
#[test]
fn test_word2vec_text_roundtrip() {
let mut reader = BufReader::new(File::open("testdata/similarity.nodims").unwrap());
let mut check = String::new();
reader.read_to_string(&mut check).unwrap();
reader.seek(SeekFrom::Start(0)).unwrap();
let embeddings = Embeddings::read_text(&mut reader, false).unwrap();
let mut output = Vec::new();
embeddings.write_text(&mut output).unwrap();
assert_eq!(check, String::from_utf8_lossy(&output));
}
#[test]
fn test_word2vec_text_dims_roundtrip() {
let mut reader = BufReader::new(File::open("testdata/similarity.txt").unwrap());
let mut check = String::new();
reader.read_to_string(&mut check).unwrap();
reader.seek(SeekFrom::Start(0)).unwrap();
let embeddings = Embeddings::read_text_dims(&mut reader, false).unwrap();
let mut output = Vec::new();
embeddings.write_text_dims(&mut output).unwrap();
assert_eq!(check, String::from_utf8_lossy(&output));
}
}