1use crate::{registry::Architecture, source::ResolvedModelSource};
4use ferrum_types::{
5 Activation, AttentionConfig, FerrumError, ModelInfo, ModelType, NormType, Result, RopeScaling,
6};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use tracing::{debug, warn};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelDefinition {
15 pub architecture: Architecture,
17 pub hidden_size: usize,
19 pub intermediate_size: usize,
21 pub vocab_size: usize,
23 pub num_hidden_layers: usize,
25 pub num_attention_heads: usize,
27 pub num_key_value_heads: Option<usize>,
29 pub max_position_embeddings: usize,
31 pub rope_theta: Option<f64>,
33 pub rope_scaling: Option<RopeScaling>,
35 pub norm_type: NormType,
37 pub norm_eps: f64,
39 pub attention_config: AttentionConfig,
41 pub activation: Activation,
43 #[serde(flatten)]
45 pub extra_params: serde_json::Value,
46}
47
48impl Default for ModelDefinition {
49 fn default() -> Self {
50 Self {
51 architecture: Architecture::Llama,
52 hidden_size: 4096,
53 intermediate_size: 11008,
54 vocab_size: 32000,
55 num_hidden_layers: 32,
56 num_attention_heads: 32,
57 num_key_value_heads: None,
58 max_position_embeddings: 2048,
59 rope_theta: Some(10000.0),
60 rope_scaling: None,
61 norm_type: NormType::RMSNorm,
62 norm_eps: 1e-6,
63 attention_config: AttentionConfig {
64 attention_bias: false,
65 sliding_window: None,
66 },
67 activation: Activation::SiLU,
68 extra_params: serde_json::Value::Object(serde_json::Map::new()),
69 }
70 }
71}
72
73impl ModelDefinition {
74 pub fn to_model_info(&self, model_id: impl Into<String>) -> ModelInfo {
76 use ferrum_types::{DataType, Device};
77
78 let model_id_str = model_id.into();
79
80 let params = self.estimate_parameters();
82
83 ModelInfo {
84 model_id: ferrum_types::ModelId::new(model_id_str.clone()),
85 model_type: ModelType::Custom(format!("{:?}", self.architecture)),
86 num_parameters: params as u64,
87 hidden_size: self.hidden_size,
88 num_layers: self.num_hidden_layers,
89 num_heads: self.num_attention_heads,
90 num_kv_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
91 vocab_size: self.vocab_size,
92 max_sequence_length: self.max_position_embeddings,
93 dtype: DataType::FP16, device: Device::CPU, version: None,
96 license: None,
97 metadata: HashMap::new(),
98 }
99 }
100
101 fn estimate_parameters(&self) -> usize {
103 let embedding_params = self.vocab_size * self.hidden_size;
105 let layer_params = self.num_hidden_layers
106 * (
107 4 * self.hidden_size * self.hidden_size +
109 3 * self.hidden_size * self.intermediate_size +
111 2 * self.hidden_size
113 );
114 let lm_head_params = self.vocab_size * self.hidden_size;
115
116 embedding_params + layer_params + lm_head_params
117 }
118}
119
120#[derive(Debug, Default)]
122pub struct ConfigManager {
123 _cache: HashMap<String, ModelDefinition>,
124}
125
126impl ConfigManager {
127 pub fn new() -> Self {
128 Self {
129 _cache: HashMap::new(),
130 }
131 }
132
133 pub async fn load_from_source(
135 &mut self,
136 source: &ResolvedModelSource,
137 ) -> Result<ModelDefinition> {
138 self.load_from_path(&source.local_path).await
139 }
140
141 pub async fn load_from_path(&mut self, path: &Path) -> Result<ModelDefinition> {
143 let config_path = path.join("config.json");
144
145 if !config_path.exists() {
146 return Err(FerrumError::model(format!(
147 "config.json not found in model directory: {:?}",
148 path
149 )));
150 }
151
152 debug!("Loading model config from: {:?}", config_path);
153
154 let content = tokio::fs::read_to_string(&config_path)
155 .await
156 .map_err(|e| FerrumError::io(format!("Failed to read config.json: {}", e)))?;
157
158 let raw_config: serde_json::Value = serde_json::from_str(&content)
159 .map_err(|e| FerrumError::model(format!("Failed to parse config.json: {}", e)))?;
160
161 self.parse_config(&raw_config)
162 }
163
164 fn parse_config(&mut self, raw: &serde_json::Value) -> Result<ModelDefinition> {
166 let obj = raw
167 .as_object()
168 .ok_or_else(|| FerrumError::model("config.json root is not an object"))?;
169
170 let architecture = self.detect_architecture(raw)?;
172
173 let hidden_size = obj
175 .get("hidden_size")
176 .and_then(|v| v.as_u64())
177 .unwrap_or(4096) as usize;
178
179 let intermediate_size = obj
180 .get("intermediate_size")
181 .and_then(|v| v.as_u64())
182 .or_else(|| obj.get("ffn_dim").and_then(|v| v.as_u64()))
183 .unwrap_or(11008) as usize;
184
185 let vocab_size = obj
186 .get("vocab_size")
187 .and_then(|v| v.as_u64())
188 .ok_or_else(|| FerrumError::model("vocab_size not found in config"))?
189 as usize;
190
191 let num_hidden_layers = obj
192 .get("num_hidden_layers")
193 .and_then(|v| v.as_u64())
194 .or_else(|| obj.get("n_layer").and_then(|v| v.as_u64()))
195 .unwrap_or(32) as usize;
196
197 let num_attention_heads = obj
198 .get("num_attention_heads")
199 .and_then(|v| v.as_u64())
200 .or_else(|| obj.get("n_head").and_then(|v| v.as_u64()))
201 .unwrap_or(32) as usize;
202
203 let num_key_value_heads = obj
204 .get("num_key_value_heads")
205 .and_then(|v| v.as_u64())
206 .map(|v| v as usize);
207
208 let max_position_embeddings = obj
209 .get("max_position_embeddings")
210 .and_then(|v| v.as_u64())
211 .or_else(|| obj.get("n_positions").and_then(|v| v.as_u64()))
212 .unwrap_or(2048) as usize;
213
214 let rope_theta = obj
215 .get("rope_theta")
216 .and_then(|v| v.as_f64())
217 .or_else(|| obj.get("rotary_emb_base").and_then(|v| v.as_f64()));
218
219 let rope_scaling = obj
221 .get("rope_scaling")
222 .and_then(|v| serde_json::from_value(v.clone()).ok());
223
224 let norm_type = if obj.get("rms_norm_eps").is_some() {
226 NormType::RMSNorm
227 } else {
228 NormType::LayerNorm
229 };
230
231 let norm_eps = obj
232 .get("rms_norm_eps")
233 .or_else(|| obj.get("layer_norm_eps"))
234 .or_else(|| obj.get("layer_norm_epsilon"))
235 .and_then(|v| v.as_f64())
236 .unwrap_or(1e-6);
237
238 let attention_bias = obj
240 .get("attention_bias")
241 .and_then(|v| v.as_bool())
242 .unwrap_or(false);
243
244 let sliding_window = obj
245 .get("sliding_window")
246 .and_then(|v| v.as_u64())
247 .map(|v| v as usize);
248
249 let activation = obj
251 .get("hidden_act")
252 .and_then(|v| v.as_str())
253 .map(|s| match s {
254 "gelu" | "gelu_new" => Activation::GELU,
255 "silu" => Activation::SiLU,
256 "relu" => Activation::ReLU,
257 "swish" => Activation::Swish,
258 _ => {
259 warn!("Unknown activation function: {}, defaulting to SiLU", s);
260 Activation::SiLU
261 }
262 })
263 .unwrap_or(Activation::SiLU);
264
265 Ok(ModelDefinition {
266 architecture,
267 hidden_size,
268 intermediate_size,
269 vocab_size,
270 num_hidden_layers,
271 num_attention_heads,
272 num_key_value_heads,
273 max_position_embeddings,
274 rope_theta,
275 rope_scaling,
276 norm_type,
277 norm_eps,
278 attention_config: AttentionConfig {
279 attention_bias,
280 sliding_window,
281 },
282 activation,
283 extra_params: raw.clone(),
284 })
285 }
286
287 fn detect_architecture(&self, config: &serde_json::Value) -> Result<Architecture> {
289 let obj = config
290 .as_object()
291 .ok_or_else(|| FerrumError::model("config.json root is not an object"))?;
292
293 if let Some(model_type) = obj.get("model_type").and_then(|v| v.as_str()) {
295 return Ok(Architecture::from_str(model_type));
296 }
297
298 if let Some(architectures) = obj.get("architectures").and_then(|v| v.as_array()) {
300 if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
301 return Ok(Architecture::from_str(arch));
302 }
303 }
304
305 warn!("Could not detect architecture, using default (Llama)");
306 Ok(Architecture::Llama)
307 }
308
309 pub fn infer_model_type(&self, definition: &ModelDefinition) -> ModelType {
311 ModelType::Custom(format!("{:?}", definition.architecture))
312 }
313}