aiha/hub/
config.rs

1//! Config metadata struct
2use serde::Deserialize;
3
4use crate::models::{
5    BertModelConfig, BloomModelConfig, GPT2ModelConfig, GPTJModelConfig, GPTNeoModelConfig,
6    LlamaModelConfig, ModelConfigTrait, ModelError, OPTModelConfig, T5ModelConfig,
7};
8
9/// Enum all the possible model types
10#[derive(Clone, Debug, Deserialize)]
11pub enum ModelConfig {
12    /// Bert model config
13    Bert(BertModelConfig),
14    /// Bloom model config
15    Bloom(BloomModelConfig),
16    /// GPT2 model config
17    Gpt2(GPT2ModelConfig),
18    /// GPTJ model config
19    GptJ(GPTJModelConfig),
20    /// GPTNeo model config
21    GPTNeo(GPTNeoModelConfig),
22    /// Llama model config
23    Llama(LlamaModelConfig),
24    /// OPT model config
25    Opt(OPTModelConfig),
26    /// T5 model config
27    T5(T5ModelConfig),
28}
29
30/// Model config implementation
31impl ModelConfigTrait for ModelConfig {
32    fn hidden_size(&self) -> i32 {
33        match self {
34            ModelConfig::Bert(config) => config.hidden_size(),
35            ModelConfig::Bloom(config) => config.hidden_size(),
36            ModelConfig::Gpt2(config) => config.hidden_size(),
37            ModelConfig::GptJ(config) => config.hidden_size(),
38            ModelConfig::GPTNeo(config) => config.hidden_size(),
39            ModelConfig::Llama(config) => config.hidden_size(),
40            ModelConfig::Opt(config) => config.hidden_size(),
41            ModelConfig::T5(config) => config.hidden_size(),
42        }
43    }
44    fn intermediate_size(&self) -> i32 {
45        match self {
46            ModelConfig::Bert(config) => config.intermediate_size(),
47            ModelConfig::Bloom(config) => config.intermediate_size(),
48            ModelConfig::Gpt2(config) => config.intermediate_size(),
49            ModelConfig::GptJ(config) => config.intermediate_size(),
50            ModelConfig::GPTNeo(config) => config.intermediate_size(),
51            ModelConfig::Llama(config) => config.intermediate_size(),
52            ModelConfig::Opt(config) => config.intermediate_size(),
53            ModelConfig::T5(config) => config.intermediate_size(),
54        }
55    }
56    fn max_position_embeddings(&self) -> i32 {
57        match self {
58            ModelConfig::Bert(config) => config.max_position_embeddings(),
59            ModelConfig::Bloom(config) => config.max_position_embeddings(),
60            ModelConfig::Gpt2(config) => config.max_position_embeddings(),
61            ModelConfig::GptJ(config) => config.max_position_embeddings(),
62            ModelConfig::GPTNeo(config) => config.max_position_embeddings(),
63            ModelConfig::Llama(config) => config.max_position_embeddings(),
64            ModelConfig::Opt(config) => config.max_position_embeddings(),
65            ModelConfig::T5(config) => config.max_position_embeddings(),
66        }
67    }
68    fn num_attention_heads(&self) -> i32 {
69        match self {
70            ModelConfig::Bert(config) => config.num_attention_heads(),
71            ModelConfig::Bloom(config) => config.num_attention_heads(),
72            ModelConfig::Gpt2(config) => config.num_attention_heads(),
73            ModelConfig::GptJ(config) => config.num_attention_heads(),
74            ModelConfig::GPTNeo(config) => config.num_attention_heads(),
75            ModelConfig::Llama(config) => config.num_attention_heads(),
76            ModelConfig::Opt(config) => config.num_attention_heads(),
77            ModelConfig::T5(config) => config.num_attention_heads(),
78        }
79    }
80    fn num_hidden_layers(&self) -> i32 {
81        match self {
82            ModelConfig::Bert(config) => config.num_hidden_layers(),
83            ModelConfig::Bloom(config) => config.num_hidden_layers(),
84            ModelConfig::Gpt2(config) => config.num_hidden_layers(),
85            ModelConfig::GptJ(config) => config.num_hidden_layers(),
86            ModelConfig::GPTNeo(config) => config.num_hidden_layers(),
87            ModelConfig::Llama(config) => config.num_hidden_layers(),
88            ModelConfig::Opt(config) => config.num_hidden_layers(),
89            ModelConfig::T5(config) => config.num_hidden_layers(),
90        }
91    }
92    fn model_type(&self) -> &str {
93        match self {
94            ModelConfig::Bert(config) => config.model_type(),
95            ModelConfig::Bloom(config) => config.model_type(),
96            ModelConfig::Gpt2(config) => config.model_type(),
97            ModelConfig::GptJ(config) => config.model_type(),
98            ModelConfig::GPTNeo(config) => config.model_type(),
99            ModelConfig::Llama(config) => config.model_type(),
100            ModelConfig::Opt(config) => config.model_type(),
101            ModelConfig::T5(config) => config.model_type(),
102        }
103    }
104    fn available_libraries(&self) -> &[crate::ModelLibraries] {
105        match self {
106            ModelConfig::Bert(config) => config.available_libraries(),
107            ModelConfig::Bloom(config) => config.available_libraries(),
108            ModelConfig::Gpt2(config) => config.available_libraries(),
109            ModelConfig::GptJ(config) => config.available_libraries(),
110            ModelConfig::GPTNeo(config) => config.available_libraries(),
111            ModelConfig::Llama(config) => config.available_libraries(),
112            ModelConfig::Opt(config) => config.available_libraries(),
113            ModelConfig::T5(config) => config.available_libraries(),
114        }
115    }
116    fn from_json(value: serde_json::Value) -> Result<Self, ModelError>
117    where
118        Self: Sized,
119    {
120        let model_type = value["model_type"]
121            .as_str()
122            .ok_or(ModelError::MissingField("model_type".to_string()))?;
123        match model_type {
124            "bert" => Ok(ModelConfig::Bert(BertModelConfig::from_json(value)?)),
125            "bloom" => Ok(ModelConfig::Bloom(BloomModelConfig::from_json(value)?)),
126            "gpt2" => Ok(ModelConfig::Gpt2(GPT2ModelConfig::from_json(value)?)),
127            "gptj" => Ok(ModelConfig::GptJ(GPTJModelConfig::from_json(value)?)),
128            "gpt_neo" => Ok(ModelConfig::GPTNeo(GPTNeoModelConfig::from_json(value)?)),
129            "gpt_neox" => Ok(ModelConfig::GPTNeo(GPTNeoModelConfig::from_json(value)?)),
130            "llama" => Ok(ModelConfig::Llama(LlamaModelConfig::from_json(value)?)),
131            "opt" => Ok(ModelConfig::Opt(OPTModelConfig::from_json(value)?)),
132            "t5" => Ok(ModelConfig::T5(T5ModelConfig::from_json(value)?)),
133            _ => Err(ModelError::ModelNotImplemented(model_type.to_string())),
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {}