finalfrontier/
dep_trainer.rs

1use std::borrow::Borrow;
2use std::sync::Arc;
3
4use anyhow::{bail, Result};
5use conllu::graph::Sentence;
6use rand::{Rng, SeedableRng};
7use serde::Serialize;
8
9use crate::idx::WordIdx;
10use crate::sampling::ZipfRangeGenerator;
11use crate::train_model::{NegativeSamples, TrainIterFrom};
12use crate::util::ReseedOnCloneRng;
13use crate::{
14    CommonConfig, DepembedsConfig, Dependency, DependencyIterator, SimpleVocab, SimpleVocabConfig,
15    Trainer, Vocab,
16};
17
18/// Dependency embeddings Trainer.
19///
20/// The `DepembedsTrainer` holds the information and logic necessary to transform a
21/// `conllu::Sentence` into an iterator of focus and context tuples. The struct is cheap to clone
22/// because the vocabulary is shared between clones.
23#[derive(Clone)]
24pub struct DepembedsTrainer<R, V> {
25    dep_config: DepembedsConfig,
26    common_config: CommonConfig,
27    input_vocab: Arc<V>,
28    output_vocab: Arc<SimpleVocab<Dependency>>,
29    range_gen: ZipfRangeGenerator<R>,
30    rng: R,
31}
32
33impl<R, V> DepembedsTrainer<R, V> {
34    pub fn dep_config(&self) -> DepembedsConfig {
35        self.dep_config
36    }
37}
38
39impl<R, V> DepembedsTrainer<ReseedOnCloneRng<R>, V>
40where
41    R: Rng + Clone + SeedableRng,
42{
43    /// Constructs a new `DepTrainer`.
44    pub fn new(
45        input_vocab: V,
46        output_vocab: SimpleVocab<Dependency>,
47        common_config: CommonConfig,
48        dep_config: DepembedsConfig,
49        rng: R,
50    ) -> Self {
51        let rng = ReseedOnCloneRng(rng);
52        let range_gen = ZipfRangeGenerator::new_with_exponent(
53            rng.clone(),
54            output_vocab.len(),
55            common_config.zipf_exponent,
56        );
57        DepembedsTrainer {
58            common_config,
59            dep_config,
60            input_vocab: Arc::new(input_vocab),
61            output_vocab: Arc::new(output_vocab),
62            range_gen,
63            rng,
64        }
65    }
66}
67
68impl<R, V> NegativeSamples for DepembedsTrainer<R, V>
69where
70    R: Rng,
71{
72    fn negative_sample(&mut self, output: usize) -> usize {
73        loop {
74            let negative = self.range_gen.next().unwrap();
75            if negative != output {
76                return negative;
77            }
78        }
79    }
80}
81
82impl<'a, R, V> TrainIterFrom<'a, Sentence> for DepembedsTrainer<R, V>
83where
84    R: Rng,
85    V: Vocab,
86    V::VocabType: Borrow<str>,
87    V::IdxType: WordIdx + 'a,
88{
89    type Iter = Box<dyn Iterator<Item = (Self::Focus, Vec<usize>)> + 'a>;
90    type Focus = V::IdxType;
91    type Contexts = Vec<usize>;
92
93    fn train_iter_from(&mut self, sentence: &Sentence) -> Self::Iter {
94        let invalid_idx = self.input_vocab.len() as u64;
95        let mut tokens = vec![WordIdx::from_word_idx(invalid_idx); sentence.len() - 1];
96        for (idx, token) in sentence.iter().filter_map(|node| node.token()).enumerate() {
97            if let Some(vocab_idx) = self.input_vocab.idx(token.form()) {
98                if self.rng.gen_range(0f32, 1f32)
99                    < self.input_vocab.discard(vocab_idx.word_idx() as usize)
100                {
101                    tokens[idx] = vocab_idx
102                }
103            }
104        }
105
106        let mut contexts = vec![Vec::new(); sentence.len() - 1];
107        let graph = sentence.dep_graph();
108        for (focus, dep) in DependencyIterator::new_from_config(&graph, self.dep_config)
109            .filter(|(focus, _dep)| tokens[*focus].word_idx() != invalid_idx)
110        {
111            if let Some(dep_id) = self.output_vocab.idx(&dep) {
112                if self.rng.gen_range(0f32, 1f32) < self.output_vocab.discard(dep_id.idx() as usize)
113                {
114                    contexts[focus].push(dep_id.idx() as usize)
115                }
116            }
117        }
118        Box::new(
119            tokens
120                .into_iter()
121                .zip(contexts.into_iter())
122                .filter(move |(focus, _)| focus.word_idx() != invalid_idx),
123        )
124    }
125}
126
127impl<R, V> Trainer for DepembedsTrainer<R, V>
128where
129    R: Rng,
130    V: Vocab,
131    V::Config: Serialize,
132{
133    type InputVocab = V;
134    type Metadata = DepembedsMetadata<V::Config, SimpleVocabConfig>;
135
136    fn input_vocab(&self) -> &Self::InputVocab {
137        &self.input_vocab
138    }
139
140    fn try_into_input_vocab(self) -> Result<Self::InputVocab> {
141        match Arc::try_unwrap(self.input_vocab) {
142            Ok(vocab) => Ok(vocab),
143            Err(_) => bail!("Cannot unwrap input vocab."),
144        }
145    }
146
147    fn n_input_types(&self) -> usize {
148        self.input_vocab.n_input_types()
149    }
150
151    fn n_output_types(&self) -> usize {
152        self.output_vocab.len()
153    }
154
155    fn config(&self) -> &CommonConfig {
156        &self.common_config
157    }
158
159    fn to_metadata(&self) -> Self::Metadata {
160        DepembedsMetadata {
161            common_config: self.common_config,
162            dep_config: self.dep_config,
163            input_vocab_config: self.input_vocab.config(),
164            output_vocab_config: self.output_vocab.config(),
165        }
166    }
167}
168
169/// Metadata for dependency embeddings.
170#[derive(Clone, Copy, Debug, Serialize)]
171pub struct DepembedsMetadata<IC, OC> {
172    common_config: CommonConfig,
173    #[serde(rename = "model_config")]
174    dep_config: DepembedsConfig,
175    input_vocab_config: IC,
176    output_vocab_config: OC,
177}