1use std::convert::TryFrom;
2
3use anyhow::{bail, Error, Result};
4use serde::Serialize;
5
6use crate::io::EmbeddingFormat;
7use crate::vocab::Cutoff;
8
9#[derive(Copy, Clone, Debug, Serialize)]
11pub enum ModelType {
12 SkipGram,
14
15 StructuredSkipGram,
17
18 DirectionalSkipgram,
20}
21
22impl TryFrom<u8> for ModelType {
23 type Error = Error;
24
25 fn try_from(model: u8) -> Result<ModelType> {
26 match model {
27 0 => Ok(ModelType::SkipGram),
28 1 => Ok(ModelType::StructuredSkipGram),
29 2 => Ok(ModelType::DirectionalSkipgram),
30 _ => bail!("Unknown model type: {}", model),
31 }
32 }
33}
34
35impl TryFrom<&str> for ModelType {
36 type Error = Error;
37
38 fn try_from(model: &str) -> Result<ModelType> {
39 match model {
40 "skipgram" => Ok(ModelType::SkipGram),
41 "structgram" => Ok(ModelType::StructuredSkipGram),
42 "dirgram" => Ok(ModelType::DirectionalSkipgram),
43 _ => bail!("Unknown model type: {}", model),
44 }
45 }
46}
47
48#[derive(Copy, Clone, Debug, Serialize)]
50pub enum LossType {
51 LogisticNegativeSampling,
53}
54
55impl TryFrom<u8> for LossType {
56 type Error = Error;
57
58 fn try_from(model: u8) -> Result<LossType> {
59 match model {
60 0 => Ok(LossType::LogisticNegativeSampling),
61 _ => bail!("Unknown model type: {}", model),
62 }
63 }
64}
65
66#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize)]
68pub enum BucketIndexerType {
69 Finalfusion,
71 FastText,
73}
74
75impl TryFrom<&str> for BucketIndexerType {
76 type Error = Error;
77
78 fn try_from(value: &str) -> Result<Self> {
79 match value {
80 "finalfusion" => Ok(BucketIndexerType::Finalfusion),
81 "fasttext" => Ok(BucketIndexerType::FastText),
82 v => bail!("Unknown indexer type: {}", v),
83 }
84 }
85}
86
87#[derive(Clone, Copy, Debug, Serialize)]
89pub struct CommonConfig {
90 pub loss: LossType,
92
93 pub dims: u32,
95
96 pub epochs: u32,
98
99 #[serde(skip)]
101 pub format: EmbeddingFormat,
102
103 pub negative_samples: u32,
105
106 pub lr: f32,
108
109 pub zipf_exponent: f64,
113}
114
115#[derive(Clone, Copy, Debug, Serialize)]
117#[serde(tag = "type")]
118#[serde(rename = "Depembeds")]
119pub struct DepembedsConfig {
120 pub depth: u32,
122
123 pub use_root: bool,
125
126 pub normalize: bool,
128
129 pub projectivize: bool,
131
132 pub untyped: bool,
136}
137
138#[derive(Clone, Copy, Debug, Serialize)]
140#[serde(rename = "SubwordVocab")]
141#[serde(tag = "type")]
142pub struct SubwordVocabConfig<V> {
143 pub cutoff: Cutoff,
148
149 pub discard_threshold: f32,
155
156 pub min_n: u32,
158
159 pub max_n: u32,
161
162 pub indexer: V,
164}
165
166#[derive(Clone, Copy, Debug, Serialize)]
168#[serde(rename = "Buckets")]
169#[serde(tag = "type")]
170pub struct BucketConfig {
171 pub buckets_exp: u32,
176
177 pub indexer_type: BucketIndexerType,
178}
179
180#[derive(Clone, Copy, Debug, Serialize)]
182#[serde(rename = "NGrams")]
183#[serde(tag = "type")]
184pub struct NGramConfig {
185 pub cutoff: Cutoff,
189}
190
191#[derive(Clone, Copy, Debug, Serialize)]
193#[serde(rename = "SimpleVocab")]
194#[serde(tag = "type")]
195pub struct SimpleVocabConfig {
196 pub cutoff: Cutoff,
201
202 pub discard_threshold: f32,
208}
209
210#[derive(Clone, Copy, Debug, Serialize)]
212#[serde(tag = "type")]
213#[serde(rename = "SkipGramLike")]
214pub struct SkipGramConfig {
215 pub model: ModelType,
217
218 pub context_size: u32,
224}