1use std::borrow::Borrow;
2use std::hash::Hash;
3use std::iter::FusedIterator;
4use std::sync::Arc;
5use std::{cmp, mem};
6
7use anyhow::{bail, Result};
8use rand::{Rng, SeedableRng};
9use serde::Serialize;
10
11use crate::idx::WordIdx;
12use crate::sampling::{BandedRangeGenerator, ZipfRangeGenerator};
13use crate::train_model::{NegativeSamples, TrainIterFrom, Trainer};
14use crate::util::ReseedOnCloneRng;
15use crate::{CommonConfig, ModelType, SkipGramConfig, Vocab};
16
17#[derive(Clone)]
23pub struct SkipgramTrainer<R, V> {
24 vocab: Arc<V>,
25 rng: R,
26 range_gen: BandedRangeGenerator<R, ZipfRangeGenerator<R>>,
27 common_config: CommonConfig,
28 skipgram_config: SkipGramConfig,
29}
30
31impl<R, V> SkipgramTrainer<ReseedOnCloneRng<R>, V>
32where
33 R: Rng + Clone + SeedableRng,
34 V: Vocab,
35{
36 pub fn new(
38 vocab: V,
39 rng: R,
40 common_config: CommonConfig,
41 skipgram_config: SkipGramConfig,
42 ) -> Self {
43 let vocab = Arc::new(vocab);
44 let rng = ReseedOnCloneRng(rng);
45 let band_size = match skipgram_config.model {
46 ModelType::SkipGram => 1,
47 ModelType::StructuredSkipGram => skipgram_config.context_size * 2,
48 ModelType::DirectionalSkipgram => 2,
49 };
50
51 let range_gen = BandedRangeGenerator::new(
52 rng.clone(),
53 ZipfRangeGenerator::new_with_exponent(
54 rng.clone(),
55 vocab.len(),
56 common_config.zipf_exponent,
57 ),
58 band_size as usize,
59 );
60 SkipgramTrainer {
61 vocab,
62 rng,
63 range_gen,
64 common_config,
65 skipgram_config,
66 }
67 }
68}
69
70impl<'a, S, R, V, I> TrainIterFrom<'a, [S]> for SkipgramTrainer<R, V>
71where
72 S: Hash + Eq,
73 R: Rng + Clone,
74 V: Vocab<IdxType = I>,
75 V::VocabType: Borrow<S>,
76 I: WordIdx,
77{
78 type Iter = SkipGramIter<R, I>;
79 type Focus = I;
80 type Contexts = Vec<usize>;
81
82 fn train_iter_from(&mut self, sequence: &[S]) -> Self::Iter {
83 let mut ids = Vec::new();
84 for t in sequence {
85 if let Some(idx) = self.vocab.idx(t) {
86 if self.rng.gen_range(0f32, 1f32) < self.vocab.discard(idx.word_idx() as usize) {
87 ids.push(idx);
88 }
89 }
90 }
91 SkipGramIter::new(self.rng.clone(), ids, self.skipgram_config)
92 }
93}
94
95impl<R, V> NegativeSamples for SkipgramTrainer<R, V>
96where
97 R: Rng,
98{
99 fn negative_sample(&mut self, output: usize) -> usize {
100 loop {
101 let negative = self.range_gen.next().unwrap();
102 if negative != output {
103 return negative;
104 }
105 }
106 }
107}
108
109impl<R, V> Trainer for SkipgramTrainer<R, V>
110where
111 R: Rng + Clone,
112 V: Vocab,
113 V::Config: Serialize,
114{
115 type InputVocab = V;
116 type Metadata = SkipgramMetadata<V::Config>;
117
118 fn input_vocab(&self) -> &V {
119 &self.vocab
120 }
121
122 fn try_into_input_vocab(self) -> Result<V> {
123 match Arc::try_unwrap(self.vocab) {
124 Ok(vocab) => Ok(vocab),
125 Err(_) => bail!("Cannot unwrap input vocab."),
126 }
127 }
128
129 fn n_input_types(&self) -> usize {
130 self.input_vocab().n_input_types()
131 }
132
133 fn n_output_types(&self) -> usize {
134 match self.skipgram_config.model {
135 ModelType::StructuredSkipGram => {
136 self.vocab.len() * 2 * self.skipgram_config.context_size as usize
137 }
138 ModelType::SkipGram => self.vocab.len(),
139 ModelType::DirectionalSkipgram => self.vocab.len() * 2,
140 }
141 }
142
143 fn config(&self) -> &CommonConfig {
144 &self.common_config
145 }
146
147 fn to_metadata(&self) -> SkipgramMetadata<V::Config> {
148 SkipgramMetadata {
149 common_config: self.common_config,
150 skipgram_config: self.skipgram_config,
151 vocab_config: self.vocab.config(),
152 }
153 }
154}
155
156pub struct SkipGramIter<R, I> {
158 ids: Vec<I>,
159 rng: R,
160 i: usize,
161 model_type: ModelType,
162 ctx_size: usize,
163}
164
165impl<R, I> SkipGramIter<R, I>
166where
167 R: Rng + Clone,
168 I: WordIdx,
169{
170 pub fn new(rng: R, ids: Vec<I>, skip_config: SkipGramConfig) -> Self {
174 SkipGramIter {
175 ids,
176 rng,
177 i: 0,
178 model_type: skip_config.model,
179 ctx_size: skip_config.context_size as usize,
180 }
181 }
182
183 fn output_(&self, token: usize, focus_idx: usize, offset_idx: usize) -> usize {
184 match self.model_type {
185 ModelType::StructuredSkipGram => {
186 let offset = if offset_idx < focus_idx {
187 (offset_idx + self.ctx_size) - focus_idx
188 } else {
189 (offset_idx - focus_idx - 1) + self.ctx_size
190 };
191
192 (token * self.ctx_size * 2) + offset
193 }
194 ModelType::SkipGram => token,
195 ModelType::DirectionalSkipgram => {
196 let offset = if offset_idx < focus_idx { 0 } else { 1 };
197
198 (token * 2) + offset
199 }
200 }
201 }
202}
203
204impl<R, I> Iterator for SkipGramIter<R, I>
205where
206 R: Rng + Clone,
207 I: WordIdx,
208{
209 type Item = (I, Vec<usize>);
210
211 fn next(&mut self) -> Option<Self::Item> {
212 if self.i < self.ids.len() {
213 let context_size = self.rng.gen_range(1, self.ctx_size + 1) as usize;
215 let left = self.i - cmp::min(self.i, context_size);
216 let right = cmp::min(self.i + context_size + 1, self.ids.len());
217 let contexts = (left..right)
218 .filter(|&idx| idx != self.i)
219 .map(|idx| self.output_(self.ids[idx].word_idx() as usize, self.i, idx))
220 .fold(Vec::with_capacity(right - left), |mut contexts, idx| {
221 contexts.push(idx);
222 contexts
223 });
224
225 let mut word_idx = WordIdx::from_word_idx(self.ids[self.i].word_idx());
228 mem::swap(&mut self.ids[self.i], &mut word_idx);
229 self.i += 1;
230 return Some((word_idx, contexts));
231 }
232 None
233 }
234}
235
236impl<R, I> FusedIterator for SkipGramIter<R, I>
237where
238 R: Rng + Clone,
239 I: WordIdx,
240{
241}
242
243#[derive(Clone, Copy, Debug, Serialize)]
245pub struct SkipgramMetadata<V> {
246 common_config: CommonConfig,
247 #[serde(rename = "model_config")]
248 skipgram_config: SkipGramConfig,
249 vocab_config: V,
250}