use std::convert::TryFrom;
use anyhow::{bail, Error, Result};
use serde::Serialize;
use crate::io::EmbeddingFormat;
use crate::vocab::Cutoff;
#[derive(Copy, Clone, Debug, Serialize)]
pub enum ModelType {
SkipGram,
StructuredSkipGram,
DirectionalSkipgram,
}
impl TryFrom<u8> for ModelType {
type Error = Error;
fn try_from(model: u8) -> Result<ModelType> {
match model {
0 => Ok(ModelType::SkipGram),
1 => Ok(ModelType::StructuredSkipGram),
2 => Ok(ModelType::DirectionalSkipgram),
_ => bail!("Unknown model type: {}", model),
}
}
}
impl TryFrom<&str> for ModelType {
type Error = Error;
fn try_from(model: &str) -> Result<ModelType> {
match model {
"skipgram" => Ok(ModelType::SkipGram),
"structgram" => Ok(ModelType::StructuredSkipGram),
"dirgram" => Ok(ModelType::DirectionalSkipgram),
_ => bail!("Unknown model type: {}", model),
}
}
}
#[derive(Copy, Clone, Debug, Serialize)]
pub enum LossType {
LogisticNegativeSampling,
}
impl TryFrom<u8> for LossType {
type Error = Error;
fn try_from(model: u8) -> Result<LossType> {
match model {
0 => Ok(LossType::LogisticNegativeSampling),
_ => bail!("Unknown model type: {}", model),
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize)]
pub enum BucketIndexerType {
Finalfusion,
FastText,
}
impl TryFrom<&str> for BucketIndexerType {
type Error = Error;
fn try_from(value: &str) -> Result<Self> {
match value {
"finalfusion" => Ok(BucketIndexerType::Finalfusion),
"fasttext" => Ok(BucketIndexerType::FastText),
v => bail!("Unknown indexer type: {}", v),
}
}
}
#[derive(Clone, Copy, Debug, Serialize)]
pub struct CommonConfig {
pub loss: LossType,
pub dims: u32,
pub epochs: u32,
#[serde(skip)]
pub format: EmbeddingFormat,
pub negative_samples: u32,
pub lr: f32,
pub zipf_exponent: f64,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(tag = "type")]
#[serde(rename = "Depembeds")]
pub struct DepembedsConfig {
pub depth: u32,
pub use_root: bool,
pub normalize: bool,
pub projectivize: bool,
pub untyped: bool,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(rename = "SubwordVocab")]
#[serde(tag = "type")]
pub struct SubwordVocabConfig<V> {
pub cutoff: Cutoff,
pub discard_threshold: f32,
pub min_n: u32,
pub max_n: u32,
pub indexer: V,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(rename = "Buckets")]
#[serde(tag = "type")]
pub struct BucketConfig {
pub buckets_exp: u32,
pub indexer_type: BucketIndexerType,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(rename = "NGrams")]
#[serde(tag = "type")]
pub struct NGramConfig {
pub cutoff: Cutoff,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(rename = "SimpleVocab")]
#[serde(tag = "type")]
pub struct SimpleVocabConfig {
pub cutoff: Cutoff,
pub discard_threshold: f32,
}
#[derive(Clone, Copy, Debug, Serialize)]
#[serde(tag = "type")]
#[serde(rename = "SkipGramLike")]
pub struct SkipGramConfig {
pub model: ModelType,
pub context_size: u32,
}