use std::io::{Seek, Write};
use std::sync::Arc;
use anyhow::{anyhow, bail, Result};
use finalfusion::compat::fasttext::WriteFastText;
use finalfusion::compat::text::{WriteText, WriteTextDims};
use finalfusion::compat::word2vec::WriteWord2Vec;
use finalfusion::io::WriteEmbeddings;
use finalfusion::metadata::Metadata;
use finalfusion::norms::NdNorms;
use finalfusion::prelude::{Embeddings, VocabWrap};
use finalfusion::storage::NdArray;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, Axis};
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use serde::Serialize;
use toml::Value;
use crate::hogwild::HogwildArray2;
use crate::idx::WordIdx;
use crate::io::{EmbeddingFormat, TrainInfo};
use crate::util::VersionInfo;
use crate::vec_simd::{l2_normalize, scale, scaled_add};
use crate::{CommonConfig, Vocab, WriteModelBinary};
#[derive(Clone)]
pub struct TrainModel<T> {
trainer: T,
input: HogwildArray2<f32>,
output: HogwildArray2<f32>,
}
impl<T> From<T> for TrainModel<T>
where
T: Trainer,
{
fn from(trainer: T) -> TrainModel<T> {
let config = *trainer.config();
let init_bound = 1.0 / config.dims as f32;
let distribution = Uniform::new_inclusive(-init_bound, init_bound);
let input = Array2::random(
(trainer.input_vocab().n_input_types(), config.dims as usize),
distribution,
)
.into();
let output = Array2::random(
(trainer.n_output_types(), config.dims as usize),
distribution,
)
.into();
TrainModel {
trainer,
input,
output,
}
}
}
impl<T> TrainModel<T>
where
T: Trainer,
{
pub fn config(&self) -> &CommonConfig {
&self.trainer.config()
}
}
impl<V, T> TrainModel<T>
where
T: Trainer<InputVocab = V>,
V: Vocab,
{
pub fn input_vocab(&self) -> &V {
self.trainer.input_vocab()
}
}
impl<T> TrainModel<T> {
pub fn trainer(&mut self) -> &mut T {
&mut self.trainer
}
pub(crate) fn mean_input_embedding<'a, I>(&self, idx: &'a I) -> Array1<f32>
where
I: WordIdx,
&'a I: IntoIterator<Item = u64>,
{
if idx.len() == 1 {
self.input
.view()
.row(idx.into_iter().next().unwrap() as usize)
.to_owned()
} else {
Self::mean_embedding(self.input.view(), idx)
}
}
fn mean_embedding<'a, I>(embeds: ArrayView2<f32>, indices: &'a I) -> Array1<f32>
where
I: WordIdx,
&'a I: IntoIterator<Item = u64>,
{
let mut embed = Array1::zeros((embeds.ncols(),));
let len = indices.len();
for idx in indices {
scaled_add(
embed.view_mut(),
embeds.index_axis(Axis(0), idx as usize),
1.0,
);
}
scale(embed.view_mut(), 1.0 / len as f32);
embed
}
#[allow(dead_code)]
#[inline]
pub(crate) fn input_embedding(&self, idx: usize) -> ArrayView1<f32> {
self.input.subview(Axis(0), idx)
}
#[inline]
pub(crate) fn input_embedding_mut(&mut self, idx: usize) -> ArrayViewMut1<f32> {
self.input.subview_mut(Axis(0), idx)
}
pub(crate) fn into_parts(self) -> Result<(T, Array2<f32>)> {
let input = match Arc::try_unwrap(self.input.into_inner()) {
Ok(input) => input.into_inner(),
Err(_) => bail!("Cannot unwrap input matrix."),
};
Ok((self.trainer, input))
}
#[inline]
pub(crate) fn output_embedding(&self, idx: usize) -> ArrayView1<f32> {
self.output.subview(Axis(0), idx)
}
#[inline]
pub(crate) fn output_embedding_mut(&mut self, idx: usize) -> ArrayViewMut1<f32> {
self.output.subview_mut(Axis(0), idx)
}
}
impl<W, T, V, M> WriteModelBinary<W> for TrainModel<T>
where
W: Seek + Write,
T: Trainer<InputVocab = V, Metadata = M>,
V: Vocab + Into<VocabWrap>,
V::VocabType: ToString,
for<'a> &'a V::IdxType: IntoIterator<Item = u64>,
M: Serialize,
{
fn write_model_binary(
self,
write: &mut W,
mut train_info: TrainInfo,
format: EmbeddingFormat,
) -> Result<()> {
let (trainer, mut input_matrix) = self.into_parts()?;
let mut metadata = Value::try_from(trainer.to_metadata())?;
let build_info = Value::try_from(VersionInfo::new())?;
let metadata_table = metadata
.as_table_mut()
.ok_or_else(|| anyhow!("Metadata has to be 'Table'."))?;
metadata_table.insert("version_info".to_string(), build_info);
train_info.set_end();
let train_info = Value::try_from(train_info)?;
metadata_table.insert("training_info".to_string(), train_info);
let mut norms = vec![0f32; trainer.input_vocab().len()];
for (i, (norm, word)) in norms
.iter_mut()
.zip(trainer.input_vocab().types())
.take(trainer.input_vocab().len())
.enumerate()
{
let input = trainer.input_vocab().idx(word.label()).unwrap();
let mut embed = Self::mean_embedding(input_matrix.view(), &input);
*norm = l2_normalize(embed.view_mut());
input_matrix.index_axis_mut(Axis(0), i).assign(&embed);
}
let vocab: VocabWrap = trainer.try_into_input_vocab()?.into();
let storage = NdArray::new(input_matrix);
let norms = NdNorms::new(Array1::from(norms));
use self::EmbeddingFormat::*;
match format {
FastText => {
let vocab = match vocab {
VocabWrap::FastTextSubwordVocab(vocab) => vocab,
_ => bail!("Only fastText vocabularies can be written to fastText files"),
};
Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms)
.write_fasttext(write)?
}
FinalFusion => Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms)
.write_embeddings(write)?,
Word2Vec => Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms)
.write_word2vec_binary(write, true)?,
Text => Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms)
.write_text(write, true)?,
TextDims => Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms)
.write_text_dims(write, true)?,
};
Ok(())
}
}
pub trait Trainer {
type InputVocab: Vocab;
type Metadata;
fn input_vocab(&self) -> &Self::InputVocab;
fn try_into_input_vocab(self) -> Result<Self::InputVocab>;
fn n_input_types(&self) -> usize;
fn n_output_types(&self) -> usize;
fn config(&self) -> &CommonConfig;
fn to_metadata(&self) -> Self::Metadata;
}
pub trait TrainIterFrom<'a, S>
where
S: ?Sized,
{
type Iter: Iterator<Item = (Self::Focus, Self::Contexts)>;
type Focus;
type Contexts: IntoIterator<Item = usize>;
fn train_iter_from(&mut self, sequence: &S) -> Self::Iter;
}
pub trait NegativeSamples {
fn negative_sample(&mut self, output: usize) -> usize;
}
#[cfg(test)]
mod tests {
use finalfusion::subword::FinalfusionHashIndexer;
use ndarray::Array2;
use rand::SeedableRng;
use rand_xorshift::XorShiftRng;
use super::TrainModel;
use crate::config::BucketIndexerType::Finalfusion;
use crate::config::SubwordVocabConfig;
use crate::idx::WordWithSubwordsIdx;
use crate::io::EmbeddingFormat;
use crate::skipgram_trainer::SkipgramTrainer;
use crate::util::all_close;
use crate::{
BucketConfig, CommonConfig, Cutoff, LossType, ModelType, SkipGramConfig, SubwordVocab,
VocabBuilder,
};
const TEST_COMMON_CONFIG: CommonConfig = CommonConfig {
dims: 3,
epochs: 5,
format: EmbeddingFormat::FinalFusion,
loss: LossType::LogisticNegativeSampling,
lr: 0.05,
negative_samples: 5,
zipf_exponent: 0.5,
};
const TEST_SKIP_CONFIG: SkipGramConfig = SkipGramConfig {
context_size: 5,
model: ModelType::SkipGram,
};
const VOCAB_CONF: SubwordVocabConfig<BucketConfig> = SubwordVocabConfig {
discard_threshold: 1e-4,
cutoff: Cutoff::MinCount(2),
max_n: 6,
min_n: 3,
indexer: BucketConfig {
buckets_exp: 21,
indexer_type: Finalfusion,
},
};
#[test]
pub fn model_embed_methods() {
let mut vocab_config = VOCAB_CONF.clone();
vocab_config.cutoff = Cutoff::MinCount(1);
let common_config = TEST_COMMON_CONFIG.clone();
let skipgram_config = TEST_SKIP_CONFIG.clone();
let mut builder: VocabBuilder<_, String> = VocabBuilder::new(vocab_config);
builder.count("bla".to_string());
let vocab: SubwordVocab<_, FinalfusionHashIndexer> = builder.into();
let input = Array2::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.])
.unwrap()
.into();
let output = Array2::from_shape_vec((2, 3), vec![-1., -2., -3., -4., -5., -6.])
.unwrap()
.into();
let mut model = TrainModel {
trainer: SkipgramTrainer::new(
vocab,
XorShiftRng::from_entropy(),
common_config,
skipgram_config,
),
input,
output,
};
assert!(all_close(
model.input_embedding(0).as_slice().unwrap(),
&[1., 2., 3.],
1e-5
));
assert!(all_close(
model.input_embedding(1).as_slice().unwrap(),
&[4., 5., 6.],
1e-5
));
assert!(all_close(
model.input_embedding_mut(0).as_slice().unwrap(),
&[1., 2., 3.],
1e-5
));
assert!(all_close(
model.input_embedding_mut(1).as_slice().unwrap(),
&[4., 5., 6.],
1e-5
));
assert!(all_close(
model.output_embedding(0).as_slice().unwrap(),
&[-1., -2., -3.],
1e-5
));
assert!(all_close(
model.output_embedding(1).as_slice().unwrap(),
&[-4., -5., -6.],
1e-5
));
assert!(all_close(
model.output_embedding_mut(0).as_slice().unwrap(),
&[-1., -2., -3.],
1e-5
));
assert!(all_close(
model.output_embedding_mut(1).as_slice().unwrap(),
&[-4., -5., -6.],
1e-5
));
assert!(all_close(
model
.mean_input_embedding(&WordWithSubwordsIdx::new(0, vec![1]))
.as_slice()
.unwrap(),
&[2.5, 3.5, 4.5],
1e-5
));
}
}