finalfrontier/
dep_trainer.rs1use std::borrow::Borrow;
2use std::sync::Arc;
3
4use anyhow::{bail, Result};
5use conllu::graph::Sentence;
6use rand::{Rng, SeedableRng};
7use serde::Serialize;
8
9use crate::idx::WordIdx;
10use crate::sampling::ZipfRangeGenerator;
11use crate::train_model::{NegativeSamples, TrainIterFrom};
12use crate::util::ReseedOnCloneRng;
13use crate::{
14 CommonConfig, DepembedsConfig, Dependency, DependencyIterator, SimpleVocab, SimpleVocabConfig,
15 Trainer, Vocab,
16};
17
18#[derive(Clone)]
24pub struct DepembedsTrainer<R, V> {
25 dep_config: DepembedsConfig,
26 common_config: CommonConfig,
27 input_vocab: Arc<V>,
28 output_vocab: Arc<SimpleVocab<Dependency>>,
29 range_gen: ZipfRangeGenerator<R>,
30 rng: R,
31}
32
33impl<R, V> DepembedsTrainer<R, V> {
34 pub fn dep_config(&self) -> DepembedsConfig {
35 self.dep_config
36 }
37}
38
39impl<R, V> DepembedsTrainer<ReseedOnCloneRng<R>, V>
40where
41 R: Rng + Clone + SeedableRng,
42{
43 pub fn new(
45 input_vocab: V,
46 output_vocab: SimpleVocab<Dependency>,
47 common_config: CommonConfig,
48 dep_config: DepembedsConfig,
49 rng: R,
50 ) -> Self {
51 let rng = ReseedOnCloneRng(rng);
52 let range_gen = ZipfRangeGenerator::new_with_exponent(
53 rng.clone(),
54 output_vocab.len(),
55 common_config.zipf_exponent,
56 );
57 DepembedsTrainer {
58 common_config,
59 dep_config,
60 input_vocab: Arc::new(input_vocab),
61 output_vocab: Arc::new(output_vocab),
62 range_gen,
63 rng,
64 }
65 }
66}
67
68impl<R, V> NegativeSamples for DepembedsTrainer<R, V>
69where
70 R: Rng,
71{
72 fn negative_sample(&mut self, output: usize) -> usize {
73 loop {
74 let negative = self.range_gen.next().unwrap();
75 if negative != output {
76 return negative;
77 }
78 }
79 }
80}
81
82impl<'a, R, V> TrainIterFrom<'a, Sentence> for DepembedsTrainer<R, V>
83where
84 R: Rng,
85 V: Vocab,
86 V::VocabType: Borrow<str>,
87 V::IdxType: WordIdx + 'a,
88{
89 type Iter = Box<dyn Iterator<Item = (Self::Focus, Vec<usize>)> + 'a>;
90 type Focus = V::IdxType;
91 type Contexts = Vec<usize>;
92
93 fn train_iter_from(&mut self, sentence: &Sentence) -> Self::Iter {
94 let invalid_idx = self.input_vocab.len() as u64;
95 let mut tokens = vec![WordIdx::from_word_idx(invalid_idx); sentence.len() - 1];
96 for (idx, token) in sentence.iter().filter_map(|node| node.token()).enumerate() {
97 if let Some(vocab_idx) = self.input_vocab.idx(token.form()) {
98 if self.rng.gen_range(0f32, 1f32)
99 < self.input_vocab.discard(vocab_idx.word_idx() as usize)
100 {
101 tokens[idx] = vocab_idx
102 }
103 }
104 }
105
106 let mut contexts = vec![Vec::new(); sentence.len() - 1];
107 let graph = sentence.dep_graph();
108 for (focus, dep) in DependencyIterator::new_from_config(&graph, self.dep_config)
109 .filter(|(focus, _dep)| tokens[*focus].word_idx() != invalid_idx)
110 {
111 if let Some(dep_id) = self.output_vocab.idx(&dep) {
112 if self.rng.gen_range(0f32, 1f32) < self.output_vocab.discard(dep_id.idx() as usize)
113 {
114 contexts[focus].push(dep_id.idx() as usize)
115 }
116 }
117 }
118 Box::new(
119 tokens
120 .into_iter()
121 .zip(contexts.into_iter())
122 .filter(move |(focus, _)| focus.word_idx() != invalid_idx),
123 )
124 }
125}
126
127impl<R, V> Trainer for DepembedsTrainer<R, V>
128where
129 R: Rng,
130 V: Vocab,
131 V::Config: Serialize,
132{
133 type InputVocab = V;
134 type Metadata = DepembedsMetadata<V::Config, SimpleVocabConfig>;
135
136 fn input_vocab(&self) -> &Self::InputVocab {
137 &self.input_vocab
138 }
139
140 fn try_into_input_vocab(self) -> Result<Self::InputVocab> {
141 match Arc::try_unwrap(self.input_vocab) {
142 Ok(vocab) => Ok(vocab),
143 Err(_) => bail!("Cannot unwrap input vocab."),
144 }
145 }
146
147 fn n_input_types(&self) -> usize {
148 self.input_vocab.n_input_types()
149 }
150
151 fn n_output_types(&self) -> usize {
152 self.output_vocab.len()
153 }
154
155 fn config(&self) -> &CommonConfig {
156 &self.common_config
157 }
158
159 fn to_metadata(&self) -> Self::Metadata {
160 DepembedsMetadata {
161 common_config: self.common_config,
162 dep_config: self.dep_config,
163 input_vocab_config: self.input_vocab.config(),
164 output_vocab_config: self.output_vocab.config(),
165 }
166 }
167}
168
169#[derive(Clone, Copy, Debug, Serialize)]
171pub struct DepembedsMetadata<IC, OC> {
172 common_config: CommonConfig,
173 #[serde(rename = "model_config")]
174 dep_config: DepembedsConfig,
175 input_vocab_config: IC,
176 output_vocab_config: OC,
177}