use std::borrow::Borrow;
use std::hash::Hash;
use std::iter::FusedIterator;
use std::sync::Arc;
use std::{cmp, mem};
use anyhow::{bail, Result};
use rand::{Rng, SeedableRng};
use serde::Serialize;
use crate::idx::WordIdx;
use crate::sampling::{BandedRangeGenerator, ZipfRangeGenerator};
use crate::train_model::{NegativeSamples, TrainIterFrom, Trainer};
use crate::util::ReseedOnCloneRng;
use crate::{CommonConfig, ModelType, SkipGramConfig, Vocab};
#[derive(Clone)]
pub struct SkipgramTrainer<R, V> {
vocab: Arc<V>,
rng: R,
range_gen: BandedRangeGenerator<R, ZipfRangeGenerator<R>>,
common_config: CommonConfig,
skipgram_config: SkipGramConfig,
}
impl<R, V> SkipgramTrainer<ReseedOnCloneRng<R>, V>
where
R: Rng + Clone + SeedableRng,
V: Vocab,
{
pub fn new(
vocab: V,
rng: R,
common_config: CommonConfig,
skipgram_config: SkipGramConfig,
) -> Self {
let vocab = Arc::new(vocab);
let rng = ReseedOnCloneRng(rng);
let band_size = match skipgram_config.model {
ModelType::SkipGram => 1,
ModelType::StructuredSkipGram => skipgram_config.context_size * 2,
ModelType::DirectionalSkipgram => 2,
};
let range_gen = BandedRangeGenerator::new(
rng.clone(),
ZipfRangeGenerator::new_with_exponent(
rng.clone(),
vocab.len(),
common_config.zipf_exponent,
),
band_size as usize,
);
SkipgramTrainer {
vocab,
rng,
range_gen,
common_config,
skipgram_config,
}
}
}
impl<'a, S, R, V, I> TrainIterFrom<'a, [S]> for SkipgramTrainer<R, V>
where
S: Hash + Eq,
R: Rng + Clone,
V: Vocab<IdxType = I>,
V::VocabType: Borrow<S>,
I: WordIdx,
{
type Iter = SkipGramIter<R, I>;
type Focus = I;
type Contexts = Vec<usize>;
fn train_iter_from(&mut self, sequence: &[S]) -> Self::Iter {
let mut ids = Vec::new();
for t in sequence {
if let Some(idx) = self.vocab.idx(t) {
if self.rng.gen_range(0f32, 1f32) < self.vocab.discard(idx.word_idx() as usize) {
ids.push(idx);
}
}
}
SkipGramIter::new(self.rng.clone(), ids, self.skipgram_config)
}
}
impl<R, V> NegativeSamples for SkipgramTrainer<R, V>
where
R: Rng,
{
fn negative_sample(&mut self, output: usize) -> usize {
loop {
let negative = self.range_gen.next().unwrap();
if negative != output {
return negative;
}
}
}
}
impl<R, V> Trainer for SkipgramTrainer<R, V>
where
R: Rng + Clone,
V: Vocab,
V::Config: Serialize,
{
type InputVocab = V;
type Metadata = SkipgramMetadata<V::Config>;
fn input_vocab(&self) -> &V {
&self.vocab
}
fn try_into_input_vocab(self) -> Result<V> {
match Arc::try_unwrap(self.vocab) {
Ok(vocab) => Ok(vocab),
Err(_) => bail!("Cannot unwrap input vocab."),
}
}
fn n_input_types(&self) -> usize {
self.input_vocab().n_input_types()
}
fn n_output_types(&self) -> usize {
match self.skipgram_config.model {
ModelType::StructuredSkipGram => {
self.vocab.len() * 2 * self.skipgram_config.context_size as usize
}
ModelType::SkipGram => self.vocab.len(),
ModelType::DirectionalSkipgram => self.vocab.len() * 2,
}
}
fn config(&self) -> &CommonConfig {
&self.common_config
}
fn to_metadata(&self) -> SkipgramMetadata<V::Config> {
SkipgramMetadata {
common_config: self.common_config,
skipgram_config: self.skipgram_config,
vocab_config: self.vocab.config(),
}
}
}
pub struct SkipGramIter<R, I> {
ids: Vec<I>,
rng: R,
i: usize,
model_type: ModelType,
ctx_size: usize,
}
impl<R, I> SkipGramIter<R, I>
where
R: Rng + Clone,
I: WordIdx,
{
pub fn new(rng: R, ids: Vec<I>, skip_config: SkipGramConfig) -> Self {
SkipGramIter {
ids,
rng,
i: 0,
model_type: skip_config.model,
ctx_size: skip_config.context_size as usize,
}
}
fn output_(&self, token: usize, focus_idx: usize, offset_idx: usize) -> usize {
match self.model_type {
ModelType::StructuredSkipGram => {
let offset = if offset_idx < focus_idx {
(offset_idx + self.ctx_size) - focus_idx
} else {
(offset_idx - focus_idx - 1) + self.ctx_size
};
(token * self.ctx_size * 2) + offset
}
ModelType::SkipGram => token,
ModelType::DirectionalSkipgram => {
let offset = if offset_idx < focus_idx { 0 } else { 1 };
(token * 2) + offset
}
}
}
}
impl<R, I> Iterator for SkipGramIter<R, I>
where
R: Rng + Clone,
I: WordIdx,
{
type Item = (I, Vec<usize>);
fn next(&mut self) -> Option<Self::Item> {
if self.i < self.ids.len() {
let context_size = self.rng.gen_range(1, self.ctx_size + 1) as usize;
let left = self.i - cmp::min(self.i, context_size);
let right = cmp::min(self.i + context_size + 1, self.ids.len());
let contexts = (left..right)
.filter(|&idx| idx != self.i)
.map(|idx| self.output_(self.ids[idx].word_idx() as usize, self.i, idx))
.fold(Vec::with_capacity(right - left), |mut contexts, idx| {
contexts.push(idx);
contexts
});
let mut word_idx = WordIdx::from_word_idx(self.ids[self.i].word_idx());
mem::swap(&mut self.ids[self.i], &mut word_idx);
self.i += 1;
return Some((word_idx, contexts));
}
None
}
}
impl<R, I> FusedIterator for SkipGramIter<R, I>
where
R: Rng + Clone,
I: WordIdx,
{
}
#[derive(Clone, Copy, Debug, Serialize)]
pub struct SkipgramMetadata<V> {
common_config: CommonConfig,
#[serde(rename = "model_config")]
skipgram_config: SkipGramConfig,
vocab_config: V,
}