Skip to main content

finalfrontier/
skipgram_trainer.rs

1use std::borrow::Borrow;
2use std::hash::Hash;
3use std::iter::FusedIterator;
4use std::sync::Arc;
5use std::{cmp, mem};
6
7use anyhow::{bail, Result};
8use rand::{Rng, SeedableRng};
9use serde::Serialize;
10
11use crate::idx::WordIdx;
12use crate::sampling::{BandedRangeGenerator, ZipfRangeGenerator};
13use crate::train_model::{NegativeSamples, TrainIterFrom, Trainer};
14use crate::util::ReseedOnCloneRng;
15use crate::{CommonConfig, ModelType, SkipGramConfig, Vocab};
16
17/// Skipgram Trainer
18///
19/// The `SkipgramTrainer` holds the information and logic necessary to transform a tokenized
20/// sentence into an iterator of focus and context tuples. The struct is cheap to clone because
21/// the vocabulary is shared between clones.
22#[derive(Clone)]
23pub struct SkipgramTrainer<R, V> {
24    vocab: Arc<V>,
25    rng: R,
26    range_gen: BandedRangeGenerator<R, ZipfRangeGenerator<R>>,
27    common_config: CommonConfig,
28    skipgram_config: SkipGramConfig,
29}
30
31impl<R, V> SkipgramTrainer<ReseedOnCloneRng<R>, V>
32where
33    R: Rng + Clone + SeedableRng,
34    V: Vocab,
35{
36    /// Constructs a new `SkipgramTrainer`.
37    pub fn new(
38        vocab: V,
39        rng: R,
40        common_config: CommonConfig,
41        skipgram_config: SkipGramConfig,
42    ) -> Self {
43        let vocab = Arc::new(vocab);
44        let rng = ReseedOnCloneRng(rng);
45        let band_size = match skipgram_config.model {
46            ModelType::SkipGram => 1,
47            ModelType::StructuredSkipGram => skipgram_config.context_size * 2,
48            ModelType::DirectionalSkipgram => 2,
49        };
50
51        let range_gen = BandedRangeGenerator::new(
52            rng.clone(),
53            ZipfRangeGenerator::new_with_exponent(
54                rng.clone(),
55                vocab.len(),
56                common_config.zipf_exponent,
57            ),
58            band_size as usize,
59        );
60        SkipgramTrainer {
61            vocab,
62            rng,
63            range_gen,
64            common_config,
65            skipgram_config,
66        }
67    }
68}
69
70impl<'a, S, R, V, I> TrainIterFrom<'a, [S]> for SkipgramTrainer<R, V>
71where
72    S: Hash + Eq,
73    R: Rng + Clone,
74    V: Vocab<IdxType = I>,
75    V::VocabType: Borrow<S>,
76    I: WordIdx,
77{
78    type Iter = SkipGramIter<R, I>;
79    type Focus = I;
80    type Contexts = Vec<usize>;
81
82    fn train_iter_from(&mut self, sequence: &[S]) -> Self::Iter {
83        let mut ids = Vec::new();
84        for t in sequence {
85            if let Some(idx) = self.vocab.idx(t) {
86                if self.rng.gen_range(0f32, 1f32) < self.vocab.discard(idx.word_idx() as usize) {
87                    ids.push(idx);
88                }
89            }
90        }
91        SkipGramIter::new(self.rng.clone(), ids, self.skipgram_config)
92    }
93}
94
95impl<R, V> NegativeSamples for SkipgramTrainer<R, V>
96where
97    R: Rng,
98{
99    fn negative_sample(&mut self, output: usize) -> usize {
100        loop {
101            let negative = self.range_gen.next().unwrap();
102            if negative != output {
103                return negative;
104            }
105        }
106    }
107}
108
109impl<R, V> Trainer for SkipgramTrainer<R, V>
110where
111    R: Rng + Clone,
112    V: Vocab,
113    V::Config: Serialize,
114{
115    type InputVocab = V;
116    type Metadata = SkipgramMetadata<V::Config>;
117
118    fn input_vocab(&self) -> &V {
119        &self.vocab
120    }
121
122    fn try_into_input_vocab(self) -> Result<V> {
123        match Arc::try_unwrap(self.vocab) {
124            Ok(vocab) => Ok(vocab),
125            Err(_) => bail!("Cannot unwrap input vocab."),
126        }
127    }
128
129    fn n_input_types(&self) -> usize {
130        self.input_vocab().n_input_types()
131    }
132
133    fn n_output_types(&self) -> usize {
134        match self.skipgram_config.model {
135            ModelType::StructuredSkipGram => {
136                self.vocab.len() * 2 * self.skipgram_config.context_size as usize
137            }
138            ModelType::SkipGram => self.vocab.len(),
139            ModelType::DirectionalSkipgram => self.vocab.len() * 2,
140        }
141    }
142
143    fn config(&self) -> &CommonConfig {
144        &self.common_config
145    }
146
147    fn to_metadata(&self) -> SkipgramMetadata<V::Config> {
148        SkipgramMetadata {
149            common_config: self.common_config,
150            skipgram_config: self.skipgram_config,
151            vocab_config: self.vocab.config(),
152        }
153    }
154}
155
156/// Iterator over focus identifier and associated context identifiers in a sentence.
157pub struct SkipGramIter<R, I> {
158    ids: Vec<I>,
159    rng: R,
160    i: usize,
161    model_type: ModelType,
162    ctx_size: usize,
163}
164
165impl<R, I> SkipGramIter<R, I>
166where
167    R: Rng + Clone,
168    I: WordIdx,
169{
170    /// Constructs a new `SkipGramIter`.
171    ///
172    /// The `rng` is used to determine the window size for each focus token.
173    pub fn new(rng: R, ids: Vec<I>, skip_config: SkipGramConfig) -> Self {
174        SkipGramIter {
175            ids,
176            rng,
177            i: 0,
178            model_type: skip_config.model,
179            ctx_size: skip_config.context_size as usize,
180        }
181    }
182
183    fn output_(&self, token: usize, focus_idx: usize, offset_idx: usize) -> usize {
184        match self.model_type {
185            ModelType::StructuredSkipGram => {
186                let offset = if offset_idx < focus_idx {
187                    (offset_idx + self.ctx_size) - focus_idx
188                } else {
189                    (offset_idx - focus_idx - 1) + self.ctx_size
190                };
191
192                (token * self.ctx_size * 2) + offset
193            }
194            ModelType::SkipGram => token,
195            ModelType::DirectionalSkipgram => {
196                let offset = if offset_idx < focus_idx { 0 } else { 1 };
197
198                (token * 2) + offset
199            }
200        }
201    }
202}
203
204impl<R, I> Iterator for SkipGramIter<R, I>
205where
206    R: Rng + Clone,
207    I: WordIdx,
208{
209    type Item = (I, Vec<usize>);
210
211    fn next(&mut self) -> Option<Self::Item> {
212        if self.i < self.ids.len() {
213            // Bojanowski, et al., 2017 uniformly sample the context size between 1 and c.
214            let context_size = self.rng.gen_range(1, self.ctx_size + 1) as usize;
215            let left = self.i - cmp::min(self.i, context_size);
216            let right = cmp::min(self.i + context_size + 1, self.ids.len());
217            let contexts = (left..right)
218                .filter(|&idx| idx != self.i)
219                .map(|idx| self.output_(self.ids[idx].word_idx() as usize, self.i, idx))
220                .fold(Vec::with_capacity(right - left), |mut contexts, idx| {
221                    contexts.push(idx);
222                    contexts
223                });
224
225            // swap the representation possibly containing multiple indices with one that only
226            // contains the distinct word index since we need the word index for context lookups.
227            let mut word_idx = WordIdx::from_word_idx(self.ids[self.i].word_idx());
228            mem::swap(&mut self.ids[self.i], &mut word_idx);
229            self.i += 1;
230            return Some((word_idx, contexts));
231        }
232        None
233    }
234}
235
236impl<R, I> FusedIterator for SkipGramIter<R, I>
237where
238    R: Rng + Clone,
239    I: WordIdx,
240{
241}
242
243/// Metadata for Skipgramlike training algorithms.
244#[derive(Clone, Copy, Debug, Serialize)]
245pub struct SkipgramMetadata<V> {
246    common_config: CommonConfig,
247    #[serde(rename = "model_config")]
248    skipgram_config: SkipGramConfig,
249    vocab_config: V,
250}