1use serde::Deserialize;
3
4use crate::models::{
5 BertModelConfig, BloomModelConfig, GPT2ModelConfig, GPTJModelConfig, GPTNeoModelConfig,
6 LlamaModelConfig, ModelConfigTrait, ModelError, OPTModelConfig, T5ModelConfig,
7};
8
9#[derive(Clone, Debug, Deserialize)]
11pub enum ModelConfig {
12 Bert(BertModelConfig),
14 Bloom(BloomModelConfig),
16 Gpt2(GPT2ModelConfig),
18 GptJ(GPTJModelConfig),
20 GPTNeo(GPTNeoModelConfig),
22 Llama(LlamaModelConfig),
24 Opt(OPTModelConfig),
26 T5(T5ModelConfig),
28}
29
30impl ModelConfigTrait for ModelConfig {
32 fn hidden_size(&self) -> i32 {
33 match self {
34 ModelConfig::Bert(config) => config.hidden_size(),
35 ModelConfig::Bloom(config) => config.hidden_size(),
36 ModelConfig::Gpt2(config) => config.hidden_size(),
37 ModelConfig::GptJ(config) => config.hidden_size(),
38 ModelConfig::GPTNeo(config) => config.hidden_size(),
39 ModelConfig::Llama(config) => config.hidden_size(),
40 ModelConfig::Opt(config) => config.hidden_size(),
41 ModelConfig::T5(config) => config.hidden_size(),
42 }
43 }
44 fn intermediate_size(&self) -> i32 {
45 match self {
46 ModelConfig::Bert(config) => config.intermediate_size(),
47 ModelConfig::Bloom(config) => config.intermediate_size(),
48 ModelConfig::Gpt2(config) => config.intermediate_size(),
49 ModelConfig::GptJ(config) => config.intermediate_size(),
50 ModelConfig::GPTNeo(config) => config.intermediate_size(),
51 ModelConfig::Llama(config) => config.intermediate_size(),
52 ModelConfig::Opt(config) => config.intermediate_size(),
53 ModelConfig::T5(config) => config.intermediate_size(),
54 }
55 }
56 fn max_position_embeddings(&self) -> i32 {
57 match self {
58 ModelConfig::Bert(config) => config.max_position_embeddings(),
59 ModelConfig::Bloom(config) => config.max_position_embeddings(),
60 ModelConfig::Gpt2(config) => config.max_position_embeddings(),
61 ModelConfig::GptJ(config) => config.max_position_embeddings(),
62 ModelConfig::GPTNeo(config) => config.max_position_embeddings(),
63 ModelConfig::Llama(config) => config.max_position_embeddings(),
64 ModelConfig::Opt(config) => config.max_position_embeddings(),
65 ModelConfig::T5(config) => config.max_position_embeddings(),
66 }
67 }
68 fn num_attention_heads(&self) -> i32 {
69 match self {
70 ModelConfig::Bert(config) => config.num_attention_heads(),
71 ModelConfig::Bloom(config) => config.num_attention_heads(),
72 ModelConfig::Gpt2(config) => config.num_attention_heads(),
73 ModelConfig::GptJ(config) => config.num_attention_heads(),
74 ModelConfig::GPTNeo(config) => config.num_attention_heads(),
75 ModelConfig::Llama(config) => config.num_attention_heads(),
76 ModelConfig::Opt(config) => config.num_attention_heads(),
77 ModelConfig::T5(config) => config.num_attention_heads(),
78 }
79 }
80 fn num_hidden_layers(&self) -> i32 {
81 match self {
82 ModelConfig::Bert(config) => config.num_hidden_layers(),
83 ModelConfig::Bloom(config) => config.num_hidden_layers(),
84 ModelConfig::Gpt2(config) => config.num_hidden_layers(),
85 ModelConfig::GptJ(config) => config.num_hidden_layers(),
86 ModelConfig::GPTNeo(config) => config.num_hidden_layers(),
87 ModelConfig::Llama(config) => config.num_hidden_layers(),
88 ModelConfig::Opt(config) => config.num_hidden_layers(),
89 ModelConfig::T5(config) => config.num_hidden_layers(),
90 }
91 }
92 fn model_type(&self) -> &str {
93 match self {
94 ModelConfig::Bert(config) => config.model_type(),
95 ModelConfig::Bloom(config) => config.model_type(),
96 ModelConfig::Gpt2(config) => config.model_type(),
97 ModelConfig::GptJ(config) => config.model_type(),
98 ModelConfig::GPTNeo(config) => config.model_type(),
99 ModelConfig::Llama(config) => config.model_type(),
100 ModelConfig::Opt(config) => config.model_type(),
101 ModelConfig::T5(config) => config.model_type(),
102 }
103 }
104 fn available_libraries(&self) -> &[crate::ModelLibraries] {
105 match self {
106 ModelConfig::Bert(config) => config.available_libraries(),
107 ModelConfig::Bloom(config) => config.available_libraries(),
108 ModelConfig::Gpt2(config) => config.available_libraries(),
109 ModelConfig::GptJ(config) => config.available_libraries(),
110 ModelConfig::GPTNeo(config) => config.available_libraries(),
111 ModelConfig::Llama(config) => config.available_libraries(),
112 ModelConfig::Opt(config) => config.available_libraries(),
113 ModelConfig::T5(config) => config.available_libraries(),
114 }
115 }
116 fn from_json(value: serde_json::Value) -> Result<Self, ModelError>
117 where
118 Self: Sized,
119 {
120 let model_type = value["model_type"]
121 .as_str()
122 .ok_or(ModelError::MissingField("model_type".to_string()))?;
123 match model_type {
124 "bert" => Ok(ModelConfig::Bert(BertModelConfig::from_json(value)?)),
125 "bloom" => Ok(ModelConfig::Bloom(BloomModelConfig::from_json(value)?)),
126 "gpt2" => Ok(ModelConfig::Gpt2(GPT2ModelConfig::from_json(value)?)),
127 "gptj" => Ok(ModelConfig::GptJ(GPTJModelConfig::from_json(value)?)),
128 "gpt_neo" => Ok(ModelConfig::GPTNeo(GPTNeoModelConfig::from_json(value)?)),
129 "gpt_neox" => Ok(ModelConfig::GPTNeo(GPTNeoModelConfig::from_json(value)?)),
130 "llama" => Ok(ModelConfig::Llama(LlamaModelConfig::from_json(value)?)),
131 "opt" => Ok(ModelConfig::Opt(OPTModelConfig::from_json(value)?)),
132 "t5" => Ok(ModelConfig::T5(T5ModelConfig::from_json(value)?)),
133 _ => Err(ModelError::ModelNotImplemented(model_type.to_string())),
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {}