use std::io::{Seek, Write};
use std::sync::Arc;
use failure::{err_msg, Error};
use finalfusion::vocab::VocabWrap;
use finalfusion::{
embeddings::Embeddings, io::WriteEmbeddings, metadata::Metadata, storage::NdArray,
};
use hogwild::HogwildArray2;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, Axis};
use ndarray_rand::RandomExt;
use rand::distributions::Uniform;
use serde::Serialize;
use toml::Value;
use crate::vec_simd::{l2_normalize, scale, scaled_add};
use crate::{CommonConfig, Vocab, WriteModelBinary};
#[derive(Clone)]
pub struct TrainModel<T> {
trainer: T,
input: HogwildArray2<f32>,
output: HogwildArray2<f32>,
}
impl<T> From<T> for TrainModel<T>
where
T: Trainer,
{
fn from(trainer: T) -> TrainModel<T> {
let config = *trainer.config();
let init_bound = 1.0 / config.dims as f32;
let distribution = Uniform::new_inclusive(-init_bound, init_bound);
let input = Array2::random(
(trainer.n_input_types(), config.dims as usize),
distribution,
)
.into();
let output = Array2::random(
(trainer.n_output_types(), config.dims as usize),
distribution,
)
.into();
TrainModel {
trainer,
input,
output,
}
}
}
impl<T> TrainModel<T>
where
T: Trainer,
{
pub fn config(&self) -> &CommonConfig {
&self.trainer.config()
}
}
impl<V, T> TrainModel<T>
where
T: Trainer<InputVocab = V>,
V: Vocab,
{
pub fn input_vocab(&self) -> &V {
self.trainer.input_vocab()
}
}
impl<T> TrainModel<T> {
pub fn trainer(&mut self) -> &mut T {
&mut self.trainer
}
pub(crate) fn mean_input_embedding(&self, indices: &[u64]) -> Array1<f32> {
Self::mean_embedding(self.input.view(), indices)
}
fn mean_embedding(embeds: ArrayView2<f32>, indices: &[u64]) -> Array1<f32> {
let mut embed = Array1::zeros((embeds.cols(),));
for &idx in indices.iter() {
scaled_add(
embed.view_mut(),
embeds.index_axis(Axis(0), idx as usize),
1.0,
);
}
scale(embed.view_mut(), 1.0 / indices.len() as f32);
embed
}
#[allow(dead_code)]
#[inline]
pub(crate) fn input_embedding(&self, idx: usize) -> ArrayView1<f32> {
self.input.subview(Axis(0), idx)
}
#[inline]
pub(crate) fn input_embedding_mut(&mut self, idx: usize) -> ArrayViewMut1<f32> {
self.input.subview_mut(Axis(0), idx)
}
pub(crate) fn into_parts(self) -> Result<(T, Array2<f32>), Error> {
let input = match Arc::try_unwrap(self.input.into_inner()) {
Ok(input) => input.into_inner(),
Err(_) => return Err(err_msg("Cannot unwrap input matrix.")),
};
Ok((self.trainer, input))
}
#[inline]
pub(crate) fn output_embedding(&self, idx: usize) -> ArrayView1<f32> {
self.output.subview(Axis(0), idx)
}
#[inline]
pub(crate) fn output_embedding_mut(&mut self, idx: usize) -> ArrayViewMut1<f32> {
self.output.subview_mut(Axis(0), idx)
}
}
impl<W, T, V, M> WriteModelBinary<W> for TrainModel<T>
where
W: Seek + Write,
T: Trainer<InputVocab = V, Metadata = M>,
V: Vocab + Into<VocabWrap>,
M: Serialize,
{
fn write_model_binary(self, write: &mut W) -> Result<(), Error> {
let (trainer, mut input_matrix) = self.into_parts()?;
let metadata = Metadata(Value::try_from(trainer.to_metadata())?);
let mut norms = vec![0f32; trainer.input_vocab().len()];
for (i, norm) in norms
.iter_mut()
.enumerate()
.take(trainer.input_vocab().len())
{
let input = trainer.input_indices(i);
let mut embed = Self::mean_embedding(input_matrix.view(), &input);
*norm = l2_normalize(embed.view_mut());
input_matrix.index_axis_mut(Axis(0), i).assign(&embed);
}
let vocab: VocabWrap = trainer.try_into_input_vocab()?.into();
let storage = NdArray(input_matrix);
Embeddings::new(Some(metadata), vocab, storage).write_embeddings(write)
}
}
pub trait Trainer {
type InputVocab: Vocab;
type Metadata;
fn input_indices(&self, idx: usize) -> Vec<u64>;
fn input_vocab(&self) -> &Self::InputVocab;
fn try_into_input_vocab(self) -> Result<Self::InputVocab, Error>;
fn n_input_types(&self) -> usize;
fn n_output_types(&self) -> usize;
fn config(&self) -> &CommonConfig;
fn to_metadata(&self) -> Self::Metadata;
}
pub trait TrainIterFrom<S>
where
S: ?Sized,
{
type Iter: Iterator<Item = (usize, Self::Contexts)>;
type Contexts: Sized + IntoIterator<Item = usize>;
fn train_iter_from(&mut self, sequence: &S) -> Self::Iter;
}
pub trait NegativeSamples {
fn negative_sample(&mut self, output: usize) -> usize;
}
#[cfg(test)]
mod tests {
use ndarray::Array2;
use rand::FromEntropy;
use rand_xorshift::XorShiftRng;
use super::TrainModel;
use crate::skipgram_trainer::SkipgramTrainer;
use crate::util::all_close;
use crate::{
CommonConfig, LossType, ModelType, SkipGramConfig, SubwordVocab, SubwordVocabConfig,
VocabBuilder,
};
const TEST_COMMON_CONFIG: CommonConfig = CommonConfig {
dims: 3,
epochs: 5,
loss: LossType::LogisticNegativeSampling,
lr: 0.05,
negative_samples: 5,
zipf_exponent: 0.5,
};
const TEST_SKIP_CONFIG: SkipGramConfig = SkipGramConfig {
context_size: 5,
model: ModelType::SkipGram,
};
const VOCAB_CONF: SubwordVocabConfig = SubwordVocabConfig {
buckets_exp: 21,
discard_threshold: 1e-4,
min_count: 2,
max_n: 6,
min_n: 3,
};
#[test]
pub fn model_embed_methods() {
let mut vocab_config = VOCAB_CONF.clone();
vocab_config.min_count = 1;
let common_config = TEST_COMMON_CONFIG.clone();
let skipgram_config = TEST_SKIP_CONFIG.clone();
let mut builder: VocabBuilder<SubwordVocabConfig, String> = VocabBuilder::new(vocab_config);
builder.count("bla".to_string());
let vocab: SubwordVocab = builder.into();
let input = Array2::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.])
.unwrap()
.into();
let output = Array2::from_shape_vec((2, 3), vec![-1., -2., -3., -4., -5., -6.])
.unwrap()
.into();
let mut model = TrainModel {
trainer: SkipgramTrainer::new(
vocab,
XorShiftRng::from_entropy(),
common_config,
skipgram_config,
),
input,
output,
};
assert!(all_close(
model.input_embedding(0).as_slice().unwrap(),
&[1., 2., 3.],
1e-5
));
assert!(all_close(
model.input_embedding(1).as_slice().unwrap(),
&[4., 5., 6.],
1e-5
));
assert!(all_close(
model.input_embedding_mut(0).as_slice().unwrap(),
&[1., 2., 3.],
1e-5
));
assert!(all_close(
model.input_embedding_mut(1).as_slice().unwrap(),
&[4., 5., 6.],
1e-5
));
assert!(all_close(
model.output_embedding(0).as_slice().unwrap(),
&[-1., -2., -3.],
1e-5
));
assert!(all_close(
model.output_embedding(1).as_slice().unwrap(),
&[-4., -5., -6.],
1e-5
));
assert!(all_close(
model.output_embedding_mut(0).as_slice().unwrap(),
&[-1., -2., -3.],
1e-5
));
assert!(all_close(
model.output_embedding_mut(1).as_slice().unwrap(),
&[-4., -5., -6.],
1e-5
));
assert!(all_close(
model.mean_input_embedding(&[0, 1]).as_slice().unwrap(),
&[2.5, 3.5, 4.5],
1e-5
));
}
}