1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use failure::{err_msg, Error};
use serde::Serialize;
#[derive(Copy, Clone, Debug, Serialize)]
pub enum ModelType {
SkipGram,
StructuredSkipGram,
}
impl ModelType {
pub fn try_from(model: u8) -> Result<ModelType, Error> {
match model {
0 => Ok(ModelType::SkipGram),
1 => Ok(ModelType::StructuredSkipGram),
_ => Err(err_msg(format!("Unknown model type: {}", model))),
}
}
pub fn try_from_str(model: &str) -> Result<ModelType, Error> {
match model {
"skipgram" => Ok(ModelType::SkipGram),
"structgram" => Ok(ModelType::StructuredSkipGram),
_ => Err(err_msg(format!("Unknown model type: {}", model))),
}
}
}
#[derive(Copy, Clone, Debug, Serialize)]
pub enum LossType {
LogisticNegativeSampling,
}
impl LossType {
pub fn try_from(model: u8) -> Result<LossType, Error> {
match model {
0 => Ok(LossType::LogisticNegativeSampling),
_ => Err(err_msg(format!("Unknown model type: {}", model))),
}
}
}
#[derive(Clone, Copy, Debug, Serialize)]
pub struct CommonConfig {
pub loss: LossType,
pub dims: u32,
pub epochs: u32,
pub negative_samples: u32,
pub lr: f32,
pub zipf_exponent: f64,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(tag = "type")]
#[serde(rename = "Depembeds")]
pub struct DepembedsConfig {
pub depth: u32,
pub use_root: bool,
pub normalize: bool,
pub projectivize: bool,
pub untyped: bool,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(rename = "SubwordVocab")]
#[serde(tag = "type")]
pub struct SubwordVocabConfig {
pub min_n: u32,
pub max_n: u32,
pub buckets_exp: u32,
pub min_count: u32,
pub discard_threshold: f32,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(rename = "SimpleVocab")]
#[serde(tag = "type")]
pub struct SimpleVocabConfig {
pub min_count: u32,
pub discard_threshold: f32,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(tag = "type")]
#[serde(rename = "SkipGramLike")]
pub struct SkipGramConfig {
pub model: ModelType,
pub context_size: u32,
}