Skip to main content

finalfrontier/
config.rs

1use std::convert::TryFrom;
2
3use anyhow::{bail, Error, Result};
4use serde::Serialize;
5
6use crate::io::EmbeddingFormat;
7use crate::vocab::Cutoff;
8
9/// Model types.
10#[derive(Copy, Clone, Debug, Serialize)]
11pub enum ModelType {
12    // The skip-gram model (Mikolov, 2013).
13    SkipGram,
14
15    // The structured skip-gram model (Ling et al., 2015).
16    StructuredSkipGram,
17
18    // The directional skip-gram model (Song et al., 2018).
19    DirectionalSkipgram,
20}
21
22impl TryFrom<u8> for ModelType {
23    type Error = Error;
24
25    fn try_from(model: u8) -> Result<ModelType> {
26        match model {
27            0 => Ok(ModelType::SkipGram),
28            1 => Ok(ModelType::StructuredSkipGram),
29            2 => Ok(ModelType::DirectionalSkipgram),
30            _ => bail!("Unknown model type: {}", model),
31        }
32    }
33}
34
35impl TryFrom<&str> for ModelType {
36    type Error = Error;
37
38    fn try_from(model: &str) -> Result<ModelType> {
39        match model {
40            "skipgram" => Ok(ModelType::SkipGram),
41            "structgram" => Ok(ModelType::StructuredSkipGram),
42            "dirgram" => Ok(ModelType::DirectionalSkipgram),
43            _ => bail!("Unknown model type: {}", model),
44        }
45    }
46}
47
48/// Losses.
49#[derive(Copy, Clone, Debug, Serialize)]
50pub enum LossType {
51    /// Logistic regression with negative sampling.
52    LogisticNegativeSampling,
53}
54
55impl TryFrom<u8> for LossType {
56    type Error = Error;
57
58    fn try_from(model: u8) -> Result<LossType> {
59        match model {
60            0 => Ok(LossType::LogisticNegativeSampling),
61            _ => bail!("Unknown model type: {}", model),
62        }
63    }
64}
65
66/// Bucket Indexer Types
67#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize)]
68pub enum BucketIndexerType {
69    /// FinalfusionBucketIndexer
70    Finalfusion,
71    /// FastTextIndexer
72    FastText,
73}
74
75impl TryFrom<&str> for BucketIndexerType {
76    type Error = Error;
77
78    fn try_from(value: &str) -> Result<Self> {
79        match value {
80            "finalfusion" => Ok(BucketIndexerType::Finalfusion),
81            "fasttext" => Ok(BucketIndexerType::FastText),
82            v => bail!("Unknown indexer type: {}", v),
83        }
84    }
85}
86
87/// Common embedding model hyperparameters.
88#[derive(Clone, Copy, Debug, Serialize)]
89pub struct CommonConfig {
90    /// The loss function used for the model.
91    pub loss: LossType,
92
93    /// Word embedding dimensionality.
94    pub dims: u32,
95
96    /// The number of training epochs.
97    pub epochs: u32,
98
99    /// The output format.
100    #[serde(skip)]
101    pub format: EmbeddingFormat,
102
103    /// Number of negative samples to use for each context word.
104    pub negative_samples: u32,
105
106    /// The initial learning rate.
107    pub lr: f32,
108
109    /// Exponent in zipfian distribution.
110    ///
111    /// This is s in *f(k) = 1 / (k^s H_{N, s})*.
112    pub zipf_exponent: f64,
113}
114
115/// Hyperparameters for Dependency Embeddings.
116#[derive(Clone, Copy, Debug, Serialize)]
117#[serde(tag = "type")]
118#[serde(rename = "Depembeds")]
119pub struct DepembedsConfig {
120    /// Maximum depth to extract dependency contexts from.
121    pub depth: u32,
122
123    /// Include the ROOT as dependency context.
124    pub use_root: bool,
125
126    /// Lowercase all tokens when used as context.
127    pub normalize: bool,
128
129    /// Projectivize dependency graphs before training.
130    pub projectivize: bool,
131
132    /// Extract untyped dependency contexts.
133    ///
134    /// Only takes the attached word-form into account.
135    pub untyped: bool,
136}
137
138/// Hyperparameters for Subword vocabs.
139#[derive(Clone, Copy, Debug, Serialize)]
140#[serde(rename = "SubwordVocab")]
141#[serde(tag = "type")]
142pub struct SubwordVocabConfig<V> {
143    /// Token cutoff.
144    ///
145    /// No word-specific embeddings will be trained for tokens excluded by the
146    /// cutoff.
147    pub cutoff: Cutoff,
148
149    /// Discard threshold.
150    ///
151    /// The discard threshold is used to compute the discard probability of
152    /// a token. E.g. with a threshold of 0.00001 tokens with approximately
153    /// that probability will never be discarded.
154    pub discard_threshold: f32,
155
156    /// Minimum n-gram length for subword units (inclusive).
157    pub min_n: u32,
158
159    /// Maximum n-gram length for subword units (inclusive).
160    pub max_n: u32,
161
162    /// Indexer specific parameters.
163    pub indexer: V,
164}
165
166/// Hyperparameters for bucket-vocabs.
167#[derive(Clone, Copy, Debug, Serialize)]
168#[serde(rename = "Buckets")]
169#[serde(tag = "type")]
170pub struct BucketConfig {
171    /// Bucket exponent. The model will use 2^bucket_exp buckets.
172    ///
173    /// A typical value for this parameter is 21, which gives roughly 2M
174    /// buckets.
175    pub buckets_exp: u32,
176
177    pub indexer_type: BucketIndexerType,
178}
179
180/// Hyperparameters for ngram-vocabs.
181#[derive(Clone, Copy, Debug, Serialize)]
182#[serde(rename = "NGrams")]
183#[serde(tag = "type")]
184pub struct NGramConfig {
185    /// NGram cutoff.
186    ///
187    /// NGrams excluded by the cutoff will be ignored during training.
188    pub cutoff: Cutoff,
189}
190
191/// Hyperparameters for simple vocabs.
192#[derive(Clone, Copy, Debug, Serialize)]
193#[serde(rename = "SimpleVocab")]
194#[serde(tag = "type")]
195pub struct SimpleVocabConfig {
196    /// Token cutoff.
197    ///
198    /// No word-specific embeddings will be trained for tokens excluded by the
199    /// cutoff.
200    pub cutoff: Cutoff,
201
202    /// Discard threshold.
203    ///
204    /// The discard threshold is used to compute the discard probability of
205    /// a token. E.g. with a threshold of 0.00001 tokens with approximately
206    /// that probability will never be discarded.
207    pub discard_threshold: f32,
208}
209
210/// Hyperparameters for SkipGram-like models.
211#[derive(Clone, Copy, Debug, Serialize)]
212#[serde(tag = "type")]
213#[serde(rename = "SkipGramLike")]
214pub struct SkipGramConfig {
215    /// The model type.
216    pub model: ModelType,
217
218    /// The number of preceding and succeeding tokens that will be consider
219    /// as context during training.
220    ///
221    /// For example, a context size of 5 will consider the 5 tokens preceding
222    /// and the 5 tokens succeeding the focus token.
223    pub context_size: u32,
224}